Github user eyalfa commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19054#discussion_r165861581
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 ---
    @@ -220,45 +220,99 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
         operator.withNewChildren(children)
       }
     
    +  private def isSubset(biggerSet: Seq[Expression], smallerSet: 
Seq[Expression]): Boolean =
    +    smallerSet.length <= biggerSet.length &&
    +      smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x)))
    +
    +  /**
    +   * Reorders `leftKeys` and `rightKeys` by aligning `currentOrderOfKeys` 
to be a prefix of
    +   * `expectedOrderOfKeys`
    +   */
       private def reorder(
           leftKeys: Seq[Expression],
           rightKeys: Seq[Expression],
    -      expectedOrderOfKeys: Seq[Expression],
    -      currentOrderOfKeys: Seq[Expression]): (Seq[Expression], 
Seq[Expression]) = {
    -    val leftKeysBuffer = ArrayBuffer[Expression]()
    -    val rightKeysBuffer = ArrayBuffer[Expression]()
    +      expectedOrderOfKeys: Seq[Expression], // comes from child's output 
partitioning
    +      currentOrderOfKeys: Seq[Expression]): // comes from join predicate
    +  (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
    +
    +    assert(leftKeys.length == rightKeys.length)
    +
    +    val allLeftKeys = ArrayBuffer[Expression]()
    +    val allRightKeys = ArrayBuffer[Expression]()
    +    val reorderedLeftKeys = ArrayBuffer[Expression]()
    +    val reorderedRightKeys = ArrayBuffer[Expression]()
    +
    +    // Tracking indicies here to track to which keys are accounted. Using 
a set based approach
    +    // won't work because its possible that some keys are repeated in the 
join clause
    +    // eg. a.key1 = b.key1 AND a.key1 = b.key2
    +    val processedIndicies = mutable.Set[Int]()
     
         expectedOrderOfKeys.foreach(expression => {
    -      val index = currentOrderOfKeys.indexWhere(e => 
e.semanticEquals(expression))
    -      leftKeysBuffer.append(leftKeys(index))
    -      rightKeysBuffer.append(rightKeys(index))
    +      val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i) 
=>
    +        !processedIndicies.contains(i) && 
currKey.semanticEquals(expression)
    +      }.get._2
    --- End diff --
    
    is the find guaranteed to always succeed?
    if so, worth a comment on method's pre/post conditions.
    
    a getOrElse(sys error "...") might also be a good way of documenting this.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to