Repository: spark
Updated Branches:
  refs/heads/master 53ba6d6e5 -> 9634e17d0


[SPARK-13242] [SQL] codegen fallback in case-when if there many branches

## What changes were proposed in this pull request?

If there are many branches in a CaseWhen expression, the generated code could 
go above the 64K limit for single java method, will fail to compile. This PR 
change it to fallback to interpret mode if there are more than 20 branches.

This PR is based on #11243 and #11221, thanks to joehalliwell

Closes #11243
Closes #11221

## How was this patch tested?

Add a test with 50 branches.

Author: Davies Liu <dav...@databricks.com>

Closes #11592 from davies/fix_when.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9634e17d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9634e17d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9634e17d

Branch: refs/heads/master
Commit: 9634e17d0183d43606a96fbba630e4c6ad720f7c
Parents: 53ba6d6
Author: Davies Liu <dav...@databricks.com>
Authored: Wed Mar 9 09:27:28 2016 -0800
Committer: Davies Liu <davies....@gmail.com>
Committed: Wed Mar 9 09:27:28 2016 -0800

----------------------------------------------------------------------
 .../expressions/conditionalExpressions.scala    | 14 ++++++++++++-
 .../sql/catalyst/expressions/literals.scala     |  6 +++---
 .../expressions/CodeGenerationSuite.scala       | 21 ++++++++++++++++++++
 .../spark/sql/execution/WholeStageCodegen.scala |  1 +
 4 files changed, 38 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9634e17d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index c3e9fa3..5ceb365 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -86,7 +86,7 @@ case class If(predicate: Expression, trueValue: Expression, 
falseValue: Expressi
  * @param elseValue optional value for the else branch
  */
 case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: 
Option[Expression] = None)
-  extends Expression {
+  extends Expression with CodegenFallback {
 
   override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 
:: Nil) ++ elseValue
 
@@ -136,7 +136,16 @@ case class CaseWhen(branches: Seq[(Expression, 
Expression)], elseValue: Option[E
     }
   }
 
+  def shouldCodegen: Boolean = {
+    branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN
+  }
+
   override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+    if (!shouldCodegen) {
+      // Fallback to interpreted mode if there are too many branches, as it 
may reach the
+      // 64K limit (limit on bytecode size for a single function).
+      return super[CodegenFallback].genCode(ctx, ev)
+    }
     // Generate code that looks like:
     //
     // condA = ...
@@ -205,6 +214,9 @@ case class CaseWhen(branches: Seq[(Expression, 
Expression)], elseValue: Option[E
 /** Factory methods for CaseWhen. */
 object CaseWhen {
 
+  // The maxium number of switches supported with codegen.
+  val MAX_NUM_CASES_FOR_CODEGEN = 20
+
   def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): 
CaseWhen = {
     CaseWhen(branches, Option(elseValue))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/9634e17d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 37bfe98..a76517a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -203,7 +203,7 @@ case class Literal protected (value: Any, dataType: 
DataType)
         case FloatType =>
           val v = value.asInstanceOf[Float]
           if (v.isNaN || v.isInfinite) {
-            super.genCode(ctx, ev)
+            super[CodegenFallback].genCode(ctx, ev)
           } else {
             ev.isNull = "false"
             ev.value = s"${value}f"
@@ -212,7 +212,7 @@ case class Literal protected (value: Any, dataType: 
DataType)
         case DoubleType =>
           val v = value.asInstanceOf[Double]
           if (v.isNaN || v.isInfinite) {
-            super.genCode(ctx, ev)
+            super[CodegenFallback].genCode(ctx, ev)
           } else {
             ev.isNull = "false"
             ev.value = s"${value}D"
@@ -232,7 +232,7 @@ case class Literal protected (value: Any, dataType: 
DataType)
           ""
         // eval() version may be faster for non-primitive types
         case other =>
-          super.genCode(ctx, ev)
+          super[CodegenFallback].genCode(ctx, ev)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/9634e17d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index b5413fb..260dfb3 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -58,6 +58,27 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     }
   }
 
+  test("SPARK-13242: case-when expression with large number of branches (or 
cases)") {
+    val cases = 50
+    val clauses = 20
+
+    // Generate an individual case
+    def generateCase(n: Int): (Expression, Expression) = {
+      val condition = (1 to clauses)
+        .map(c => EqualTo(BoundReference(0, StringType, false), 
Literal(s"$c:$n")))
+        .reduceLeft[Expression]((l, r) => Or(l, r))
+      (condition, Literal(n))
+    }
+
+    val expression = CaseWhen((1 to cases).map(generateCase(_)))
+
+    val plan = GenerateMutableProjection.generate(Seq(expression))()
+    val input = new 
GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}")))
+    val actual = plan(input).toSeq(Seq(expression.dataType))
+
+    assert(actual(0) == cases)
+  }
+
   test("test generated safe and unsafe projection") {
     val schema = new StructType(Array(
       StructField("a", StringType, true),

http://git-wip-us.apache.org/repos/asf/spark/blob/9634e17d/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 45578d5..dd831e6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -416,6 +416,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: 
SQLContext) extends Ru
 
   private def supportCodegen(e: Expression): Boolean = e match {
     case e: LeafExpression => true
+    case e: CaseWhen => e.shouldCodegen
     // CodegenFallback requires the input to be an InternalRow
     case e: CodegenFallback => false
     case _ => true


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to