cloud-fan commented on a change in pull request #34444:
URL: https://github.com/apache/spark/pull/34444#discussion_r740046205



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -332,6 +327,266 @@ case class ShuffledHashJoinExec(
     HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false)
   }
 
+  override def doProduce(ctx: CodegenContext): String = {
+    // Specialize `doProduce` code for full outer join, because full outer 
join needs to
+    // iterate streamed and build side separately.
+    if (joinType != FullOuter) {
+      return super.doProduce(ctx)
+    }
+
+    val HashedRelationInfo(relationTerm, _, _) = prepareRelation(ctx)
+
+    // Inline mutable state since not many join operations in a task
+    val keyIsUnique = ctx.addMutableState("boolean", "keyIsUnique",
+      v => s"$v = $relationTerm.keyIsUnique();", forceInline = true)
+    val streamedInput = ctx.addMutableState("scala.collection.Iterator", 
"streamedInput",
+      v => s"$v = inputs[0];", forceInline = true)
+    val buildInput = ctx.addMutableState("scala.collection.Iterator", 
"buildInput",
+      v => s"$v = $relationTerm.valuesWithKeyIndex();", forceInline = true)
+    val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", 
forceInline = true)
+    val buildRow = ctx.addMutableState("InternalRow", "buildRow", forceInline 
= true)
+
+    // Generate variables and related code from streamed side
+    val streamedVars = genOneSideJoinVars(ctx, streamedRow, streamedPlan, 
setDefaultValue = false)
+    val streamedKeyVariables = evaluateRequiredVariables(streamedOutput, 
streamedVars,
+      
AttributeSet.fromAttributeSets(HashJoin.rewriteKeyExpr(streamedKeys).map(_.references)))
+    ctx.currentVars = streamedVars
+    val streamedKeyExprCode = GenerateUnsafeProjection.createCode(ctx, 
streamedBoundKeys)
+    val streamedKeyEv =
+      s"""
+         |$streamedKeyVariables
+         |${streamedKeyExprCode.code}
+       """.stripMargin
+    val streamedKeyAnyNull = s"${streamedKeyExprCode.value}.anyNull()"
+
+    // Generate code for join condition
+    val (_, conditionCheck, _) =
+      getJoinCondition(ctx, streamedVars, streamedPlan, buildPlan, 
Some(buildRow))
+
+    // Generate code for result output in separate function, as we need to 
output result from
+    // multiple places in join code.
+    val streamedResultVars = genOneSideJoinVars(
+      ctx, streamedRow, streamedPlan, setDefaultValue = true)
+    val buildResultVars = genOneSideJoinVars(
+      ctx, buildRow, buildPlan, setDefaultValue = true)
+    val resultVars = buildSide match {
+      case BuildLeft => buildResultVars ++ streamedResultVars
+      case BuildRight => streamedResultVars ++ buildResultVars
+    }
+    val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
+    ctx.addNewFunction(consumeFullOuterJoinRow,
+      s"""
+         |private void $consumeFullOuterJoinRow() {
+         |  ${metricTerm(ctx, "numOutputRows")}.add(1);
+         |  ${consume(ctx, resultVars)}
+         |}
+       """.stripMargin)
+    val stopCheck = "if (shouldStop()) return;"
+
+    val joinWithUniqueKey = codegenFullOuterJoinWithUniqueKey(
+      ctx, (streamedRow, buildRow), (streamedInput, buildInput), 
streamedKeyEv, streamedKeyAnyNull,
+      streamedKeyExprCode.value, relationTerm, conditionCheck, stopCheck, 
consumeFullOuterJoinRow)
+    val joinWithNonUniqueKey = codegenFullOuterJoinWithNonUniqueKey(
+      ctx, (streamedRow, buildRow), (streamedInput, buildInput), 
streamedKeyEv, streamedKeyAnyNull,
+      streamedKeyExprCode.value, relationTerm, conditionCheck, stopCheck, 
consumeFullOuterJoinRow)
+
+    s"""
+       |if ($keyIsUnique) {
+       |  $joinWithUniqueKey
+       |} else {
+       |  $joinWithNonUniqueKey
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generates the code for full outer join with unique join keys.
+   * This is code-gen version of `fullOuterJoinWithUniqueKey()`.
+   */
+  private def codegenFullOuterJoinWithUniqueKey(
+      ctx: CodegenContext,
+      rows: (String, String),
+      inputs: (String, String),
+      streamedKeyEv: String,
+      streamedKeyAnyNull: String,
+      streamedKeyValue: ExprValue,
+      relationTerm: String,
+      conditionCheck: String,
+      stopCheck: String,
+      consumeFullOuterJoinRow: String): String = {
+    // Inline mutable state since not many join operations in a task
+    val matchedKeySetClsName = classOf[BitSet].getName
+    val matchedKeySet = ctx.addMutableState(matchedKeySetClsName, 
"matchedKeySet",
+      v => s"$v = new 
$matchedKeySetClsName($relationTerm.maxNumKeysIndex());", forceInline = true)
+    val rowWithIndexClsName = classOf[ValueRowWithKeyIndex].getName
+    val rowWithIndex = ctx.freshName("rowWithIndex")
+    val foundMatch = ctx.freshName("foundMatch")
+    val (streamedRow, buildRow) = rows
+    val (streamedInput, buildInput) = inputs
+
+    val joinStreamSide =
+      s"""
+         |while ($streamedInput.hasNext()) {
+         |  $streamedRow = (InternalRow) $streamedInput.next();
+         |
+         |  // generate join key for stream side
+         |  $streamedKeyEv
+         |
+         |  // find matches from HashedRelation
+         |  boolean $foundMatch = false;
+         |  $buildRow = null;
+         |  $rowWithIndexClsName $rowWithIndex = $streamedKeyAnyNull ? null:
+         |    $relationTerm.getValueWithKeyIndex($streamedKeyValue);
+         |
+         |  if ($rowWithIndex != null) {
+         |    $buildRow = $rowWithIndex.getValue();
+         |    // check join condition
+         |    $conditionCheck {
+         |      // set key index in matched keys set
+         |      $matchedKeySet.set($rowWithIndex.getKeyIndex());
+         |      $foundMatch = true;
+         |    }
+         |
+         |    if (!$foundMatch) {
+         |      $buildRow = null;
+         |    }
+         |  }
+         |
+         |  $consumeFullOuterJoinRow();
+         |  $stopCheck
+         |}
+       """.stripMargin
+
+    val filterBuildSide =
+      s"""
+         |$streamedRow = null;
+         |
+         |// find non-matched rows from HashedRelation
+         |while ($buildInput.hasNext()) {
+         |  $rowWithIndexClsName $rowWithIndex = ($rowWithIndexClsName) 
$buildInput.next();
+         |
+         |  // check if key index is not in matched keys set
+         |  if (!$matchedKeySet.get($rowWithIndex.getKeyIndex())) {
+         |    $buildRow = $rowWithIndex.getValue();
+         |    $consumeFullOuterJoinRow();
+         |  }
+         |
+         |  $stopCheck
+         |}
+       """.stripMargin
+
+    s"""
+       |$joinStreamSide
+       |$filterBuildSide
+     """.stripMargin
+  }
+
+  /**
+   * Generates the code for full outer join with non-unique join keys.
+   * This is code-gen version of `fullOuterJoinWithNonUniqueKey()`.
+   */
+  private def codegenFullOuterJoinWithNonUniqueKey(
+      ctx: CodegenContext,
+      rows: (String, String),
+      inputs: (String, String),
+      streamedKeyEv: String,
+      streamedKeyAnyNull: String,
+      streamedKeyValue: ExprValue,
+      relationTerm: String,
+      conditionCheck: String,
+      stopCheck: String,
+      consumeFullOuterJoinRow: String): String = {
+    // Inline mutable state since not many join operations in a task
+    val matchedRowSetClsName = classOf[OpenHashSet[_]].getName
+    val matchedRowSet = ctx.addMutableState(matchedRowSetClsName, 
"matchedRowSet",
+      v => s"$v = new 
$matchedRowSetClsName(scala.reflect.ClassTag$$.MODULE$$.Long());",
+      forceInline = true)
+    val prevKeyIndex = ctx.addMutableState("int", "prevKeyIndex",
+      v => s"$v = -1;", forceInline = true)
+    val valueIndex = ctx.addMutableState("int", "valueIndex",
+      v => s"$v = -1;", forceInline = true)
+    val rowWithIndexClsName = classOf[ValueRowWithKeyIndex].getName
+    val rowWithIndex = ctx.freshName("rowWithIndex")
+    val buildIterator = ctx.freshName("buildIterator")
+    val foundMatch = ctx.freshName("foundMatch")
+    val keyIndex = ctx.freshName("keyIndex")
+    val (streamedRow, buildRow) = rows
+    val (streamedInput, buildInput) = inputs
+
+    val rowIndex = s"(((long)$keyIndex) << 32) | $valueIndex"
+    val markRowMatched = s"$matchedRowSet.add($rowIndex);"

Review comment:
       ditto, let's inline the short code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to