Repository: spark
Updated Branches:
  refs/heads/master a04f2bea6 -> c235b5f97


[SPARK-22746][SQL] Avoid the generation of useless mutable states by 
SortMergeJoin

## What changes were proposed in this pull request?

This PR reduce the number of global mutable variables in generated code of 
`SortMergeJoin`.

Before this PR, global mutable variables are used to extend lifetime of 
variables in the nested loop. This can be achieved by declaring variable at the 
outer most loop level where the variables are used.
In the following example, `smj_value8`, `smj_value8`, and `smj_value9` are 
declared as local variable at lines 145-147 in `With this PR`.

This PR fixes potential assertion error by #19865. Without this PR, a global 
mutable variable is potentially passed to arguments in generated code of split 
function.

Without this PR
```
/* 010 */   int smj_value8;
/* 011 */   boolean smj_value8;
/* 012 */   int smj_value9;
..
/* 143 */   protected void processNext() throws java.io.IOException {
/* 144 */     while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) {
/* 145 */       boolean smj_loaded = false;
/* 146 */       smj_isNull6 = smj_leftRow.isNullAt(1);
/* 147 */       smj_value9 = smj_isNull6 ? -1 : (smj_leftRow.getInt(1));
/* 148 */       scala.collection.Iterator<UnsafeRow> smj_iterator = 
smj_matches.generateIterator();
/* 149 */       while (smj_iterator.hasNext()) {
/* 150 */         InternalRow smj_rightRow1 = (InternalRow) smj_iterator.next();
/* 151 */         boolean smj_isNull8 = smj_rightRow1.isNullAt(1);
/* 152 */         int smj_value11 = smj_isNull8 ? -1 : 
(smj_rightRow1.getInt(1));
/* 153 */
/* 154 */         boolean smj_value12 = (smj_isNull6 && smj_isNull8) ||
/* 155 */         (!smj_isNull6 && !smj_isNull8 && smj_value9 == smj_value11);
/* 156 */         if (false || !smj_value12) continue;
/* 157 */         if (!smj_loaded) {
/* 158 */           smj_loaded = true;
/* 159 */           smj_value8 = smj_leftRow.getInt(0);
/* 160 */         }
/* 161 */         int smj_value10 = smj_rightRow1.getInt(0);
/* 162 */         smj_numOutputRows.add(1);
/* 163 */
/* 164 */         smj_rowWriter.zeroOutNullBytes();
/* 165 */
/* 166 */         smj_rowWriter.write(0, smj_value8);
/* 167 */
/* 168 */         if (smj_isNull6) {
/* 169 */           smj_rowWriter.setNullAt(1);
/* 170 */         } else {
/* 171 */           smj_rowWriter.write(1, smj_value9);
/* 172 */         }
/* 173 */
/* 174 */         smj_rowWriter.write(2, smj_value10);
/* 175 */
/* 176 */         if (smj_isNull8) {
/* 177 */           smj_rowWriter.setNullAt(3);
/* 178 */         } else {
/* 179 */           smj_rowWriter.write(3, smj_value11);
/* 180 */         }
/* 181 */         append(smj_result.copy());
/* 182 */
/* 183 */       }
/* 184 */       if (shouldStop()) return;
/* 185 */     }
/* 186 */   }
```

With this PR
```
/* 143 */   protected void processNext() throws java.io.IOException {
/* 144 */     while (findNextInnerJoinRows(smj_leftInput, smj_rightInput)) {
/* 145 */       int smj_value8 = -1;
/* 146 */       boolean smj_isNull6 = false;
/* 147 */       int smj_value9 = -1;
/* 148 */       boolean smj_loaded = false;
/* 149 */       smj_isNull6 = smj_leftRow.isNullAt(1);
/* 150 */       smj_value9 = smj_isNull6 ? -1 : (smj_leftRow.getInt(1));
/* 151 */       scala.collection.Iterator<UnsafeRow> smj_iterator = 
smj_matches.generateIterator();
/* 152 */       while (smj_iterator.hasNext()) {
/* 153 */         InternalRow smj_rightRow1 = (InternalRow) smj_iterator.next();
/* 154 */         boolean smj_isNull8 = smj_rightRow1.isNullAt(1);
/* 155 */         int smj_value11 = smj_isNull8 ? -1 : 
(smj_rightRow1.getInt(1));
/* 156 */
/* 157 */         boolean smj_value12 = (smj_isNull6 && smj_isNull8) ||
/* 158 */         (!smj_isNull6 && !smj_isNull8 && smj_value9 == smj_value11);
/* 159 */         if (false || !smj_value12) continue;
/* 160 */         if (!smj_loaded) {
/* 161 */           smj_loaded = true;
/* 162 */           smj_value8 = smj_leftRow.getInt(0);
/* 163 */         }
/* 164 */         int smj_value10 = smj_rightRow1.getInt(0);
/* 165 */         smj_numOutputRows.add(1);
/* 166 */
/* 167 */         smj_rowWriter.zeroOutNullBytes();
/* 168 */
/* 169 */         smj_rowWriter.write(0, smj_value8);
/* 170 */
/* 171 */         if (smj_isNull6) {
/* 172 */           smj_rowWriter.setNullAt(1);
/* 173 */         } else {
/* 174 */           smj_rowWriter.write(1, smj_value9);
/* 175 */         }
/* 176 */
/* 177 */         smj_rowWriter.write(2, smj_value10);
/* 178 */
/* 179 */         if (smj_isNull8) {
/* 180 */           smj_rowWriter.setNullAt(3);
/* 181 */         } else {
/* 182 */           smj_rowWriter.write(3, smj_value11);
/* 183 */         }
/* 184 */         append(smj_result.copy());
/* 185 */
/* 186 */       }
/* 187 */       if (shouldStop()) return;
/* 188 */     }
/* 189 */   }
```

## How was this patch tested?

Existing test cases

Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>

Closes #19937 from kiszk/SPARK-22746.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c235b5f9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c235b5f9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c235b5f9

Branch: refs/heads/master
Commit: c235b5f9772be57122ca5109b9295cf8c85926df
Parents: a04f2be
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Mon Dec 11 13:15:45 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Mon Dec 11 13:15:45 2017 -0800

----------------------------------------------------------------------
 .../sql/execution/joins/SortMergeJoinExec.scala | 27 ++++++++++++--------
 1 file changed, 17 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c235b5f9/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 9c08ec7..554b731 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -507,32 +507,38 @@ case class SortMergeJoinExec(
   }
 
   /**
-   * Creates variables for left part of result row.
+   * Creates variables and declarations for left part of result row.
    *
    * In order to defer the access after condition and also only access once in 
the loop,
    * the variables should be declared separately from accessing the columns, 
we can't use the
    * codegen of BoundReference here.
    */
-  private def createLeftVars(ctx: CodegenContext, leftRow: String): 
Seq[ExprCode] = {
+  private def createLeftVars(ctx: CodegenContext, leftRow: String): 
(Seq[ExprCode], Seq[String]) = {
     ctx.INPUT_ROW = leftRow
     left.output.zipWithIndex.map { case (a, i) =>
       val value = ctx.freshName("value")
       val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
-      // declare it as class member, so we can access the column before or in 
the loop.
-      ctx.addMutableState(ctx.javaType(a.dataType), value)
+      val javaType = ctx.javaType(a.dataType)
+      val defaultValue = ctx.defaultValue(a.dataType)
       if (a.nullable) {
         val isNull = ctx.freshName("isNull")
-        ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull)
         val code =
           s"""
              |$isNull = $leftRow.isNullAt($i);
-             |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : 
($valueCode);
+             |$value = $isNull ? $defaultValue : ($valueCode);
            """.stripMargin
-        ExprCode(code, isNull, value)
+        val leftVarsDecl =
+          s"""
+             |boolean $isNull = false;
+             |$javaType $value = $defaultValue;
+           """.stripMargin
+        (ExprCode(code, isNull, value), leftVarsDecl)
       } else {
-        ExprCode(s"$value = $valueCode;", "false", value)
+        val code = s"$value = $valueCode;"
+        val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
+        (ExprCode(code, "false", value), leftVarsDecl)
       }
-    }
+    }.unzip
   }
 
   /**
@@ -580,7 +586,7 @@ case class SortMergeJoinExec(
     val (leftRow, matches) = genScanner(ctx)
 
     // Create variables for row from both sides.
-    val leftVars = createLeftVars(ctx, leftRow)
+    val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow)
     val rightRow = ctx.freshName("rightRow")
     val rightVars = createRightVar(ctx, rightRow)
 
@@ -617,6 +623,7 @@ case class SortMergeJoinExec(
 
     s"""
        |while (findNextInnerJoinRows($leftInput, $rightInput)) {
+       |  ${leftVarDecl.mkString("\n")}
        |  ${beforeLoop.trim}
        |  scala.collection.Iterator<UnsafeRow> $iterator = 
$matches.generateIterator();
        |  while ($iterator.hasNext()) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to