c21 commented on a change in pull request #29277: URL: https://github.com/apache/spark/pull/29277#discussion_r462401733
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala ########## @@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec { resultProj(r) } } + + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + protected def genStreamSideJoinKey( + ctx: CodegenContext, + input: Seq[ExprCode]): (ExprCode, String) = { + ctx.currentVars = input + if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) { + // generate the join key as Long + val ev = streamedBoundKeys.head.genCode(ctx) + (ev, ev.isNull) + } else { + // generate the join key as UnsafeRow + val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys) + (ev, s"${ev.value}.anyNull()") + } + } + + /** + * Generates the code for variable of build side. + */ + private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { + ctx.currentVars = null + ctx.INPUT_ROW = matched + buildPlan.output.zipWithIndex.map { case (a, i) => + val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) + if (joinType.isInstanceOf[InnerLike]) { + ev + } else { + // the variables are needed even there is no matched rows + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) + val code = code""" + |boolean $isNull = true; + |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) + } + } + } + + /** + * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi + * and Left Anti joins. + */ + protected def getJoinCondition( + ctx: CodegenContext, + input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + val skipRow = s"${ev.isNull} || !${ev.value}" + s""" + |$eval + |${ev.code} + |if (!($skipRow)) + """.stripMargin + } else { + "" + } + (matched, checkCondition, buildVars) + } + + /** + * Generates the code for Inner join. + */ + protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + + if (keyIsKnownUnique) { Review comment: @cloud-fan - sorry I mean just replace [`broadcastRelation.value.keyIsUnique`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala#L325) with variable `keyIsKnownUnique`, as `ShuffledHashJoinExec` does not have `HashedRelation` during codegen time. The logic for unique-key code path for `BroadcastHashJoinExec` is the existing logic and I don't have anything new here. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala ########## @@ -316,6 +318,387 @@ trait HashJoin extends BaseJoinExec { resultProj(r) } } + + /** + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ + protected def genStreamSideJoinKey( + ctx: CodegenContext, + input: Seq[ExprCode]): (ExprCode, String) = { + ctx.currentVars = input + if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) { + // generate the join key as Long + val ev = streamedBoundKeys.head.genCode(ctx) + (ev, ev.isNull) + } else { + // generate the join key as UnsafeRow + val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys) + (ev, s"${ev.value}.anyNull()") + } + } + + /** + * Generates the code for variable of build side. + */ + private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { + ctx.currentVars = null + ctx.INPUT_ROW = matched + buildPlan.output.zipWithIndex.map { case (a, i) => + val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) + if (joinType.isInstanceOf[InnerLike]) { + ev + } else { + // the variables are needed even there is no matched rows + val isNull = ctx.freshName("isNull") + val value = ctx.freshName("value") + val javaType = CodeGenerator.javaType(a.dataType) + val code = code""" + |boolean $isNull = true; + |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; + |if ($matched != null) { + | ${ev.code} + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + |} + """.stripMargin + ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)) + } + } + } + + /** + * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi + * and Left Anti joins. + */ + protected def getJoinCondition( + ctx: CodegenContext, + input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + val skipRow = s"${ev.isNull} || !${ev.value}" + s""" + |$eval + |${ev.code} + |if (!($skipRow)) + """.stripMargin + } else { + "" + } + (matched, checkCondition, buildVars) + } + + /** + * Generates the code for Inner join. + */ + protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + + if (keyIsKnownUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |if ($matched != null) { + | $checkCondition { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? + | null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |if ($matches != null) { + | while ($matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | } + |} + """.stripMargin + } + } + + /** + * Generates the code for left or right outer join. + */ + protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (relationTerm, keyIsKnownUnique) = prepareRelation(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val numOutput = metricTerm(ctx, "numOutputRows") + + // filter the output via condition + val conditionPassed = ctx.freshName("conditionPassed") + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + s""" + |boolean $conditionPassed = true; + |${eval.trim} + |if ($matched != null) { + | ${ev.code} + | $conditionPassed = !${ev.isNull} && ${ev.value}; + |} + """.stripMargin + } else { + s"final boolean $conditionPassed = true;" + } + + val resultVars = buildSide match { + case BuildLeft => buildVars ++ input + case BuildRight => input ++ buildVars + } + + if (keyIsKnownUnique) { Review comment: @cloud-fan - ditto, copy here - sorry I mean just replace [`broadcastRelation.value.keyIsUnique`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala#L325) with variable `keyIsKnownUnique`, as `ShuffledHashJoinExec` does not have `HashedRelation` during codegen time. The logic for unique-key code path for `BroadcastHashJoinExec` is the existing logic and I don't have anything new here. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala ########## @@ -70,4 +74,69 @@ case class ShuffledHashJoinExec( join(streamIter, hashed, numOutputRows) } } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + streamedPlan.execute() :: buildPlan.execute() :: Nil + } + + override def needCopyResult: Boolean = true + + override protected def doProduce(ctx: CodegenContext): String = { + // inline mutable state since not many join operations in a task + val streamedInput = ctx.addMutableState( + "scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];", forceInline = true) + val buildInput = ctx.addMutableState( + "scala.collection.Iterator", "buildInput", v => s"$v = inputs[1];", forceInline = true) + val initRelation = ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, "initRelation", v => s"$v = false;", forceInline = true) + val streamedRow = ctx.addMutableState( + "InternalRow", "streamedRow", forceInline = true) + + val thisPlan = ctx.addReferenceObj("plan", this) + val (relationTerm, _) = prepareRelation(ctx) + val buildRelation = s"$relationTerm = $thisPlan.buildHashedRelation($buildInput);" Review comment: @cloud-fan - sorry if we include `buildRelation` inside `prepareRelation`, how do we use `buildRelation` in final [code-gen code](https://github.com/apache/spark/pull/29277/files/7fbd3a8d94f7b6fbc09c5401922b400fd0432ac3#diff-db4ffe4f0196a9d7cf1f04c350ee3381R114)? Do you mean creating a private var to keep `buildRelation` after `prepareRelation` is called? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala ########## @@ -70,4 +74,69 @@ case class ShuffledHashJoinExec( join(streamIter, hashed, numOutputRows) } } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + streamedPlan.execute() :: buildPlan.execute() :: Nil + } + + override def needCopyResult: Boolean = true + + override protected def doProduce(ctx: CodegenContext): String = { + // inline mutable state since not many join operations in a task + val streamedInput = ctx.addMutableState( + "scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];", forceInline = true) + val buildInput = ctx.addMutableState( + "scala.collection.Iterator", "buildInput", v => s"$v = inputs[1];", forceInline = true) + val initRelation = ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, "initRelation", v => s"$v = false;", forceInline = true) + val streamedRow = ctx.addMutableState( + "InternalRow", "streamedRow", forceInline = true) + + val thisPlan = ctx.addReferenceObj("plan", this) + val (relationTerm, _) = prepareRelation(ctx) + val buildRelation = s"$relationTerm = $thisPlan.buildHashedRelation($buildInput);" + val (streamInputVar, streamInputVarDecl) = createVars(ctx, streamedRow, streamedPlan.output) + + val join = joinType match { + case _: InnerLike => codegenInner(ctx, streamInputVar) + case LeftOuter | RightOuter => codegenOuter(ctx, streamInputVar) + case LeftSemi => codegenSemi(ctx, streamInputVar) + case LeftAnti => codegenAnti(ctx, streamInputVar) + case _: ExistenceJoin => codegenExistence(ctx, streamInputVar) + case x => + throw new IllegalArgumentException( + s"ShuffledHashJoin should not take $x as the JoinType") + } + + s""" + |// construct hash map for shuffled hash join build side + |if (!$initRelation) { + | $buildRelation + | $initRelation = true; + |} + | + |while ($streamedInput.hasNext()) { + | $streamedRow = (InternalRow) $streamedInput.next(); + | ${streamInputVarDecl.mkString("\n")} + | $join + | + | if (shouldStop()) return; + |} + """.stripMargin + } + + /** + * Returns a tuple of variable name for HashedRelation, + * and boolean false to indicate key not to be known unique in code-gen time. + */ + protected override def prepareRelation(ctx: CodegenContext): (String, Boolean) = { + if (relationTerm == null) { Review comment: @cloud-fan - as you may already find out - I need the same `relationTerm` to generate code for building relation in [`doProduce()`](https://github.com/apache/spark/pull/29277/files/7fbd3a8d94f7b6fbc09c5401922b400fd0432ac3#diff-db4ffe4f0196a9d7cf1f04c350ee3381R97). ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org