Repository: spark
Updated Branches:
  refs/heads/master 813c0f945 -> e98f9647f


[SPARK-22695][SQL] ScalaUDF should not use global variables

## What changes were proposed in this pull request?

ScalaUDF is using global variables which are not needed. This can generate some 
unneeded entries in the constant pool.

The PR replaces the unneeded global variables with local variables.

## How was this patch tested?

added UT

Author: Marco Gaido <mga...@hortonworks.com>
Author: Marco Gaido <marcogaid...@gmail.com>

Closes #19900 from mgaido91/SPARK-22695.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e98f9647
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e98f9647
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e98f9647

Branch: refs/heads/master
Commit: e98f9647f44d1071a6b070db070841b8cda6bd7a
Parents: 813c0f9
Author: Marco Gaido <mga...@hortonworks.com>
Authored: Thu Dec 7 00:50:49 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Thu Dec 7 00:50:49 2017 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/ScalaUDF.scala     | 88 ++++++++++----------
 .../catalyst/expressions/ScalaUDFSuite.scala    |  6 ++
 2 files changed, 51 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e98f9647/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 1798530..4d26d98 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -982,35 +982,28 @@ case class ScalaUDF(
 
   // scalastyle:on line.size.limit
 
-  // Generate codes used to convert the arguments to Scala type for 
user-defined functions
-  private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): 
String = {
-    val converterClassName = classOf[Any => Any].getName
-    val typeConvertersClassName = CatalystTypeConverters.getClass.getName + 
".MODULE$"
-    val expressionClassName = classOf[Expression].getName
-    val scalaUDFClassName = classOf[ScalaUDF].getName
+  private val converterClassName = classOf[Any => Any].getName
+  private val scalaUDFClassName = classOf[ScalaUDF].getName
+  private val typeConvertersClassName = 
CatalystTypeConverters.getClass.getName + ".MODULE$"
 
+  // Generate codes used to convert the arguments to Scala type for 
user-defined functions
+  private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): 
(String, String) = {
     val converterTerm = ctx.freshName("converter")
     val expressionIdx = ctx.references.size - 1
-    ctx.addMutableState(converterClassName, converterTerm,
-      s"$converterTerm = ($converterClassName)$typeConvertersClassName" +
-        
s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" +
-          
s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
-    converterTerm
+    (converterTerm,
+      s"$converterClassName $converterTerm = 
($converterClassName)$typeConvertersClassName" +
+        s".createToScalaConverter(((Expression)((($scalaUDFClassName)" +
+        
s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
   }
 
   override def doGenCode(
       ctx: CodegenContext,
       ev: ExprCode): ExprCode = {
+    val scalaUDF = ctx.freshName("scalaUDF")
+    val scalaUDFRef = ctx.addReferenceMinorObj(this, scalaUDFClassName)
 
-    val scalaUDF = ctx.addReferenceObj("scalaUDF", this)
-    val converterClassName = classOf[Any => Any].getName
-    val typeConvertersClassName = CatalystTypeConverters.getClass.getName + 
".MODULE$"
-
-    // Generate codes used to convert the returned value of user-defined 
functions to Catalyst type
+    // Object to convert the returned value of user-defined functions to 
Catalyst type
     val catalystConverterTerm = ctx.freshName("catalystConverter")
-    ctx.addMutableState(converterClassName, catalystConverterTerm,
-      s"$catalystConverterTerm = 
($converterClassName)$typeConvertersClassName" +
-        s".createToCatalystConverter($scalaUDF.dataType());")
 
     val resultTerm = ctx.freshName("result")
 
@@ -1022,8 +1015,6 @@ case class ScalaUDF(
     val funcClassName = s"scala.Function${children.size}"
 
     val funcTerm = ctx.freshName("udf")
-    ctx.addMutableState(funcClassName, funcTerm,
-      s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
 
     // codegen for children expressions
     val evals = children.map(_.genCode(ctx))
@@ -1033,34 +1024,45 @@ case class ScalaUDF(
     // such as IntegerType, its javaType is `int` and the returned type of 
user-defined
     // function is Object. Trying to convert an Object to `int` will cause 
casting exception.
     val evalCode = evals.map(_.code).mkString
-    val (converters, funcArguments) = converterTerms.zipWithIndex.map { case 
(converter, i) =>
-      val eval = evals(i)
-      val argTerm = ctx.freshName("arg")
-      val convert = s"Object $argTerm = ${eval.isNull} ? null : 
$converter.apply(${eval.value});"
-      (convert, argTerm)
+    val (converters, funcArguments) = converterTerms.zipWithIndex.map {
+      case ((convName, convInit), i) =>
+        val eval = evals(i)
+        val argTerm = ctx.freshName("arg")
+        val convert =
+          s"""
+             |$convInit
+             |Object $argTerm = ${eval.isNull} ? null : 
$convName.apply(${eval.value});
+           """.stripMargin
+        (convert, argTerm)
     }.unzip
 
     val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})"
     val callFunc =
       s"""
-         ${ctx.boxedType(dataType)} $resultTerm = null;
-         try {
-           $resultTerm = 
(${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
-         } catch (Exception e) {
-           throw new 
org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
-         }
-       """
+         |${ctx.boxedType(dataType)} $resultTerm = null;
+         |$scalaUDFClassName $scalaUDF = $scalaUDFRef;
+         |try {
+         |  $funcClassName $funcTerm = 
($funcClassName)$scalaUDF.userDefinedFunc();
+         |  $converterClassName $catalystConverterTerm = ($converterClassName)
+         |    
$typeConvertersClassName.createToCatalystConverter($scalaUDF.dataType());
+         |  $resultTerm = 
(${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
+         |} catch (Exception e) {
+         |  throw new 
org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
+         |}
+       """.stripMargin
 
-    ev.copy(code = s"""
-      $evalCode
-      ${converters.mkString("\n")}
-      $callFunc
-
-      boolean ${ev.isNull} = $resultTerm == null;
-      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${ev.value} = $resultTerm;
-      }""")
+    ev.copy(code =
+      s"""
+         |$evalCode
+         |${converters.mkString("\n")}
+         |$callFunc
+         |
+         |boolean ${ev.isNull} = $resultTerm == null;
+         |${ctx.javaType(dataType)} ${ev.value} = 
${ctx.defaultValue(dataType)};
+         |if (!${ev.isNull}) {
+         |  ${ev.value} = $resultTerm;
+         |}
+       """.stripMargin)
   }
 
   private[this] val converter = 
CatalystTypeConverters.createToCatalystConverter(dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/e98f9647/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index 13bd363..70dea4b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import java.util.Locale
 
 import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.types.{IntegerType, StringType}
 
 class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -47,4 +48,9 @@ class ScalaUDFSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(e2.getMessage.contains("Failed to execute user defined function"))
   }
 
+  test("SPARK-22695: ScalaUDF should not use global variables") {
+    val ctx = new CodegenContext
+    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: 
Nil).genCode(ctx)
+    assert(ctx.mutableStates.isEmpty)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to