Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/10989#discussion_r222235960
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
 ---
    @@ -117,6 +120,87 @@ case class BroadcastHashJoin(
           hashJoin(streamedIter, numStreamedRows, hashedRelation, 
numOutputRows)
         }
       }
    +
    +  // the term for hash relation
    +  private var relationTerm: String = _
    +
    +  override def upstream(): RDD[InternalRow] = {
    +    streamedPlan.asInstanceOf[CodegenSupport].upstream()
    +  }
    +
    +  override def doProduce(ctx: CodegenContext): String = {
    +    // create a name for HashRelation
    +    val broadcastRelation = Await.result(broadcastFuture, timeout)
    +    val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
    +    relationTerm = ctx.freshName("relation")
    +    // TODO: create specialized HashRelation for single join key
    +    val clsName = classOf[UnsafeHashedRelation].getName
    +    ctx.addMutableState(clsName, relationTerm,
    +      s"""
    +         | $relationTerm = ($clsName) $broadcast.value();
    +         | incPeakExecutionMemory($relationTerm.getUnsafeSize());
    +       """.stripMargin)
    +
    +    s"""
    +       | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
    +     """.stripMargin
    +  }
    +
    +  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
    +    // generate the key as UnsafeRow
    +    ctx.currentVars = input
    +    val keyExpr = streamedKeys.map(BindReferences.bindReference(_, 
streamedPlan.output))
    +    val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
    +    val keyTerm = keyVal.value
    +    val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" 
else "false"
    +
    +    // find the matches from HashedRelation
    +    val matches = ctx.freshName("matches")
    +    val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
    +    val i = ctx.freshName("i")
    +    val size = ctx.freshName("size")
    +    val row = ctx.freshName("row")
    +
    +    // create variables for output
    +    ctx.currentVars = null
    +    ctx.INPUT_ROW = row
    +    val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
    +      BoundReference(i, a.dataType, a.nullable).gen(ctx)
    +    }
    +    val resultVars = buildSide match {
    +      case BuildLeft => buildColumns ++ input
    +      case BuildRight => input ++ buildColumns
    +    }
    +
    +    val ouputCode = if (condition.isDefined) {
    +      // filter the output via condition
    +      ctx.currentVars = resultVars
    +      val ev = BindReferences.bindReference(condition.get, 
this.output).gen(ctx)
    +      s"""
    +         | ${ev.code}
    +         | if (!${ev.isNull} && ${ev.value}) {
    +         |   ${consume(ctx, resultVars)}
    +         | }
    +       """.stripMargin
    +    } else {
    +      consume(ctx, resultVars)
    +    }
    +
    +    s"""
    +       | // generate join key
    +       | ${keyVal.code}
    +       | // find matches from HashRelation
    +       | $bufferType $matches = $anyNull ? null : ($bufferType) 
$relationTerm.get($keyTerm);
    +       | if ($matches != null) {
    +       |   int $size = $matches.size();
    +       |   for (int $i = 0; $i < $size; $i++) {
    --- End diff --
    
    I see. Then we gotta to keep `matches` as global status instead of local 
one, so we can go over remaining matched rows in next iterations. And we 
shouldn't get next row from streaming side but use previous row from that side. 
It can make a single result row without buffering all matched rows into 
`currentRows`, though it might need to add some complexity into the generated 
code.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to