Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/10989#discussion_r221972056
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
 ---
    @@ -117,6 +120,87 @@ case class BroadcastHashJoin(
           hashJoin(streamedIter, numStreamedRows, hashedRelation, 
numOutputRows)
         }
       }
    +
    +  // the term for hash relation
    +  private var relationTerm: String = _
    +
    +  override def upstream(): RDD[InternalRow] = {
    +    streamedPlan.asInstanceOf[CodegenSupport].upstream()
    +  }
    +
    +  override def doProduce(ctx: CodegenContext): String = {
    +    // create a name for HashRelation
    +    val broadcastRelation = Await.result(broadcastFuture, timeout)
    +    val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
    +    relationTerm = ctx.freshName("relation")
    +    // TODO: create specialized HashRelation for single join key
    +    val clsName = classOf[UnsafeHashedRelation].getName
    +    ctx.addMutableState(clsName, relationTerm,
    +      s"""
    +         | $relationTerm = ($clsName) $broadcast.value();
    +         | incPeakExecutionMemory($relationTerm.getUnsafeSize());
    +       """.stripMargin)
    +
    +    s"""
    +       | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
    +     """.stripMargin
    +  }
    +
    +  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
    +    // generate the key as UnsafeRow
    +    ctx.currentVars = input
    +    val keyExpr = streamedKeys.map(BindReferences.bindReference(_, 
streamedPlan.output))
    +    val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
    +    val keyTerm = keyVal.value
    +    val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" 
else "false"
    +
    +    // find the matches from HashedRelation
    +    val matches = ctx.freshName("matches")
    +    val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
    +    val i = ctx.freshName("i")
    +    val size = ctx.freshName("size")
    +    val row = ctx.freshName("row")
    +
    +    // create variables for output
    +    ctx.currentVars = null
    +    ctx.INPUT_ROW = row
    +    val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
    +      BoundReference(i, a.dataType, a.nullable).gen(ctx)
    +    }
    +    val resultVars = buildSide match {
    +      case BuildLeft => buildColumns ++ input
    +      case BuildRight => input ++ buildColumns
    +    }
    +
    +    val ouputCode = if (condition.isDefined) {
    +      // filter the output via condition
    +      ctx.currentVars = resultVars
    +      val ev = BindReferences.bindReference(condition.get, 
this.output).gen(ctx)
    +      s"""
    +         | ${ev.code}
    +         | if (!${ev.isNull} && ${ev.value}) {
    +         |   ${consume(ctx, resultVars)}
    +         | }
    +       """.stripMargin
    +    } else {
    +      consume(ctx, resultVars)
    +    }
    +
    +    s"""
    +       | // generate join key
    +       | ${keyVal.code}
    +       | // find matches from HashRelation
    +       | $bufferType $matches = $anyNull ? null : ($bufferType) 
$relationTerm.get($keyTerm);
    +       | if ($matches != null) {
    +       |   int $size = $matches.size();
    +       |   for (int $i = 0; $i < $size; $i++) {
    --- End diff --
    
    hmm, yeah, this code is changed a lot since this PR, looks like at that 
moment this `BroadcastHashJoin` only supports inner join. I also don't really 
get the idea to interrupt this loop early, as looks like we need to go through 
all matched rows here?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to