Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/10989#discussion_r221972056 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala --- @@ -117,6 +120,87 @@ case class BroadcastHashJoin( hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) } } + + // the term for hash relation + private var relationTerm: String = _ + + override def upstream(): RDD[InternalRow] = { + streamedPlan.asInstanceOf[CodegenSupport].upstream() + } + + override def doProduce(ctx: CodegenContext): String = { + // create a name for HashRelation + val broadcastRelation = Await.result(broadcastFuture, timeout) + val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) + relationTerm = ctx.freshName("relation") + // TODO: create specialized HashRelation for single join key + val clsName = classOf[UnsafeHashedRelation].getName + ctx.addMutableState(clsName, relationTerm, + s""" + | $relationTerm = ($clsName) $broadcast.value(); + | incPeakExecutionMemory($relationTerm.getUnsafeSize()); + """.stripMargin) + + s""" + | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} + """.stripMargin + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // generate the key as UnsafeRow + ctx.currentVars = input + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) + val keyTerm = keyVal.value + val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" + + // find the matches from HashedRelation + val matches = ctx.freshName("matches") + val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val i = ctx.freshName("i") + val size = ctx.freshName("size") + val row = ctx.freshName("row") + + // create variables for output + ctx.currentVars = null + ctx.INPUT_ROW = row + val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => + BoundReference(i, a.dataType, a.nullable).gen(ctx) + } + val resultVars = buildSide match { + case BuildLeft => buildColumns ++ input + case BuildRight => input ++ buildColumns + } + + val ouputCode = if (condition.isDefined) { + // filter the output via condition + ctx.currentVars = resultVars + val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) + s""" + | ${ev.code} + | if (!${ev.isNull} && ${ev.value}) { + | ${consume(ctx, resultVars)} + | } + """.stripMargin + } else { + consume(ctx, resultVars) + } + + s""" + | // generate join key + | ${keyVal.code} + | // find matches from HashRelation + | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm); + | if ($matches != null) { + | int $size = $matches.size(); + | for (int $i = 0; $i < $size; $i++) { --- End diff -- hmm, yeah, this code is changed a lot since this PR, looks like at that moment this `BroadcastHashJoin` only supports inner join. I also don't really get the idea to interrupt this loop early, as looks like we need to go through all matched rows here?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org