Repository: spark
Updated Branches:
  refs/heads/branch-2.2 23eb4d70a -> 94f9227d8


[SPARK-22550][SQL] Fix 64KB JVM bytecode limit problem with elt

This PR changes `elt` code generation to place generated code for expression 
for arguments into separated methods if these size could be large.
This PR resolved the case of `elt` with a lot of argument

Added new test cases into `StringExpressionsSuite`

Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>

Closes #19778 from kiszk/SPARK-22550.

(cherry picked from commit 9bdff0bcd83e730aba8dc1253da24a905ba07ae3)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94f9227d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94f9227d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94f9227d

Branch: refs/heads/branch-2.2
Commit: 94f9227d810698e7c16b2dbbfd3d84ef8aeef75b
Parents: 23eb4d7
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Tue Nov 21 12:19:11 2017 +0100
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Nov 21 12:53:55 2017 +0100

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     | 38 ++++++++++------
 .../expressions/stringExpressions.scala         | 48 ++++++++++++++++----
 .../expressions/StringExpressionsSuite.scala    |  7 +++
 3 files changed, 70 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/94f9227d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 50112d5..b61ad42 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -660,20 +660,7 @@ class CodegenContext {
       returnType: String = "void",
       makeSplitFunction: String => String = identity,
       foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): 
String = {
-    val blocks = new ArrayBuffer[String]()
-    val blockBuilder = new StringBuilder()
-    for (code <- expressions) {
-      // We can't know how many bytecode will be generated, so use the length 
of source code
-      // as metric. A method should not go beyond 8K, otherwise it will not be 
JITted, should
-      // also not be too small, or it will have many function calls (for wide 
table), see the
-      // results in BenchmarkWideTable.
-      if (blockBuilder.length > 1024) {
-        blocks += blockBuilder.toString()
-        blockBuilder.clear()
-      }
-      blockBuilder.append(code)
-    }
-    blocks += blockBuilder.toString()
+    val blocks = buildCodeBlocks(expressions)
 
     if (blocks.length == 1) {
       // inline execution if only one block
@@ -697,6 +684,29 @@ class CodegenContext {
   }
 
   /**
+   * Splits the generated code of expressions into multiple sequences of String
+   * based on a threshold of length of a String
+   *
+   * @param expressions the codes to evaluate expressions.
+   */
+  def buildCodeBlocks(expressions: Seq[String]): Seq[String] = {
+    val blocks = new ArrayBuffer[String]()
+    val blockBuilder = new StringBuilder()
+    for (code <- expressions) {
+      // We can't know how many bytecode will be generated, so use the length 
of source code
+      // as metric. A method should not go beyond 8K, otherwise it will not be 
JITted, should
+      // also not be too small, or it will have many function calls (for wide 
table), see the
+      // results in BenchmarkWideTable.
+      if (blockBuilder.length > 1024) {
+        blocks += blockBuilder.toString()
+        blockBuilder.clear()
+      }
+      blockBuilder.append(code)
+    }
+    blocks += blockBuilder.toString()
+  }
+
+  /**
    * Wrap the generated code of expression, which was created from a row 
object in INPUT_ROW,
    * by a function. ev.isNull and ev.value are passed by global variables
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/94f9227d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index cb8b36b..014ac77 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -287,22 +287,52 @@ case class Elt(children: Seq[Expression])
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     val index = indexExpr.genCode(ctx)
     val strings = stringExprs.map(_.genCode(ctx))
+    val indexVal = ctx.freshName("index")
+    val stringVal = ctx.freshName("stringVal")
     val assignStringValue = strings.zipWithIndex.map { case (eval, index) =>
       s"""
         case ${index + 1}:
-          ${ev.value} = ${eval.isNull} ? null : ${eval.value};
+          ${eval.code}
+          $stringVal = ${eval.isNull} ? null : ${eval.value};
           break;
       """
-    }.mkString("\n")
-    val indexVal = ctx.freshName("index")
-    val stringArray = ctx.freshName("strings");
+    }
 
-    ev.copy(index.code + "\n" + strings.map(_.code).mkString("\n") + s"""
-      final int $indexVal = ${index.value};
-      UTF8String ${ev.value} = null;
-      switch ($indexVal) {
-        $assignStringValue
+    val cases = ctx.buildCodeBlocks(assignStringValue)
+    val codes = if (cases.length == 1) {
+      s"""
+        UTF8String $stringVal = null;
+        switch ($indexVal) {
+          ${cases.head}
+        }
+       """
+    } else {
+      var prevFunc = "null"
+      for (c <- cases.reverse) {
+        val funcName = ctx.freshName("eltFunc")
+        val funcBody = s"""
+         private UTF8String $funcName(InternalRow ${ctx.INPUT_ROW}, int 
$indexVal) {
+           UTF8String $stringVal = null;
+           switch ($indexVal) {
+             $c
+             default:
+               return $prevFunc;
+           }
+           return $stringVal;
+         }
+        """
+        ctx.addNewFunction(funcName, funcBody)
+        prevFunc = s"$funcName(${ctx.INPUT_ROW}, $indexVal)"
       }
+      s"UTF8String $stringVal = $prevFunc;"
+    }
+
+    ev.copy(
+      s"""
+      ${index.code}
+      final int $indexVal = ${index.value};
+      $codes
+      UTF8String ${ev.value} = $stringVal;
       final boolean ${ev.isNull} = ${ev.value} == null;
     """)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/94f9227d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 312a88d..7adf967 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -117,6 +117,13 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure)
   }
 
+  test("SPARK-22550: Elt should not generate codes beyond 64KB") {
+    val N = 10000
+    val strings = (1 to N).map(x => s"s$x")
+    val args = Literal.create(N, IntegerType) +: strings.map(Literal.create(_, 
StringType))
+    checkEvaluation(Elt(args), s"s$N")
+  }
+
   test("StringComparison") {
     val row = create_row("abc", null)
     val c1 = 'a.string.at(0)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to