Repository: spark
Updated Branches:
  refs/heads/master 57626a557 -> a814eeac6


[SPARK-18125][SQL] Fix a compilation error in codegen due to splitExpression

## What changes were proposed in this pull request?

As reported in the jira, sometimes the generated java code in codegen will 
cause compilation error.

Code snippet to test it:

    case class Route(src: String, dest: String, cost: Int)
    case class GroupedRoutes(src: String, dest: String, routes: Seq[Route])

    val ds = sc.parallelize(Array(
      Route("a", "b", 1),
      Route("a", "b", 2),
      Route("a", "c", 2),
      Route("a", "d", 10),
      Route("b", "a", 1),
      Route("b", "a", 5),
      Route("b", "c", 6))
    ).toDF.as[Route]

    val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r)))
      .groupByKey(r => (r.src, r.dest))
      .reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) =>
        GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes)
      }.map(_._2)

The problem here is, in `ReferenceToExpressions` we evaluate the children vars 
to local variables. Then the result expression is evaluated to use those 
children variables. In the above case, the result expression code is too long 
and will be split by `CodegenContext.splitExpression`. So those local variables 
cannot be accessed and cause compilation error.

## How was this patch tested?

Jenkins tests.

Please review 
https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before 
opening a pull request.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #15693 from viirya/fix-codege-compilation-error.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a814eeac
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a814eeac
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a814eeac

Branch: refs/heads/master
Commit: a814eeac6b3c38d1294b88c60cd083fc4d01bd25
Parents: 57626a5
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Mon Nov 7 12:18:19 2016 +0100
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Mon Nov 7 12:18:19 2016 +0100

----------------------------------------------------------------------
 .../expressions/ReferenceToExpressions.scala    | 27 ++++++++++----
 .../org/apache/spark/sql/DatasetSuite.scala     | 37 ++++++++++++++++++++
 2 files changed, 58 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a814eeac/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
index 127797c..6c75a7a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
@@ -63,15 +63,30 @@ case class ReferenceToExpressions(result: Expression, 
children: Seq[Expression])
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     val childrenGen = children.map(_.genCode(ctx))
-    val childrenVars = childrenGen.zip(children).map {
-      case (childGen, child) => LambdaVariable(childGen.value, 
childGen.isNull, child.dataType)
-    }
+    val (classChildrenVars, initClassChildrenVars) = 
childrenGen.zip(children).map {
+      case (childGen, child) =>
+        // SPARK-18125: The children vars are local variables. If the result 
expression uses
+        // splitExpression, those variables cannot be accessed so compilation 
fails.
+        // To fix it, we use class variables to hold those local variables.
+        val classChildVarName = ctx.freshName("classChildVar")
+        val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
+        ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, 
"")
+        ctx.addMutableState("boolean", classChildVarIsNull, "")
+
+        val classChildVar =
+          LambdaVariable(classChildVarName, classChildVarIsNull, 
child.dataType)
+
+        val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
+          s"${classChildVar.isNull} = ${childGen.isNull};"
+
+        (classChildVar, initCode)
+    }.unzip
 
     val resultGen = result.transform {
-      case b: BoundReference => childrenVars(b.ordinal)
+      case b: BoundReference => classChildrenVars(b.ordinal)
     }.genCode(ctx)
 
-    ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + 
resultGen.code,
-      isNull = resultGen.isNull, value = resultGen.value)
+    ExprCode(code = childrenGen.map(_.code).mkString("\n") + 
initClassChildrenVars.mkString("\n") +
+      resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a814eeac/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6fa7b04..a8dd422 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -923,6 +923,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
         .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
   }
 
+  test("SPARK-18125: Spark generated code causes CompileException") {
+    val data = Array(
+      Route("a", "b", 1),
+      Route("a", "b", 2),
+      Route("a", "c", 2),
+      Route("a", "d", 10),
+      Route("b", "a", 1),
+      Route("b", "a", 5),
+      Route("b", "c", 6))
+    val ds = sparkContext.parallelize(data).toDF.as[Route]
+
+    val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r)))
+      .groupByKey(r => (r.src, r.dest))
+      .reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) =>
+        GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes)
+      }.map(_._2)
+
+    val expected = Seq(
+      GroupedRoutes("a", "d", Seq(Route("a", "d", 10))),
+      GroupedRoutes("b", "c", Seq(Route("b", "c", 6))),
+      GroupedRoutes("a", "b", Seq(Route("a", "b", 1), Route("a", "b", 2))),
+      GroupedRoutes("b", "a", Seq(Route("b", "a", 1), Route("b", "a", 5))),
+      GroupedRoutes("a", "c", Seq(Route("a", "c", 2)))
+    )
+
+    implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] = new 
Ordering[GroupedRoutes] {
+      override def compare(x: GroupedRoutes, y: GroupedRoutes): Int = {
+        x.toString.compareTo(y.toString)
+      }
+    }
+
+    checkDatasetUnorderly(grped, expected: _*)
+  }
+
   test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") {
     val resultValue = 12345
     val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1)
@@ -1071,3 +1105,6 @@ object DatasetTransform {
     ds.map(_ + 1)
   }
 }
+
+case class Route(src: String, dest: String, cost: Int)
+case class GroupedRoutes(src: String, dest: String, routes: Seq[Route])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to