Github user tejasapatil commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19054#discussion_r162768446
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 ---
    @@ -220,45 +220,76 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
         operator.withNewChildren(children)
       }
     
    +  private def isSubset(biggerSet: Seq[Expression], smallerSet: 
Seq[Expression]): Boolean =
    +    smallerSet.length <= biggerSet.length &&
    +      smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x)))
    +
       private def reorder(
           leftKeys: Seq[Expression],
           rightKeys: Seq[Expression],
    -      expectedOrderOfKeys: Seq[Expression],
    -      currentOrderOfKeys: Seq[Expression]): (Seq[Expression], 
Seq[Expression]) = {
    -    val leftKeysBuffer = ArrayBuffer[Expression]()
    -    val rightKeysBuffer = ArrayBuffer[Expression]()
    +      expectedOrderOfKeys: Seq[Expression], // comes from child's output 
partitioning
    +      currentOrderOfKeys: Seq[Expression]): // comes from join predicate
    +  (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
    +
    +    assert(leftKeys.length == rightKeys.length)
    +
    +    val allLeftKeys = ArrayBuffer[Expression]()
    +    val allRightKeys = ArrayBuffer[Expression]()
    +    val reorderedLeftKeys = ArrayBuffer[Expression]()
    +    val reorderedRightKeys = ArrayBuffer[Expression]()
    +    val processedIndicies = mutable.Set[Int]()
     
         expectedOrderOfKeys.foreach(expression => {
    -      val index = currentOrderOfKeys.indexWhere(e => 
e.semanticEquals(expression))
    -      leftKeysBuffer.append(leftKeys(index))
    -      rightKeysBuffer.append(rightKeys(index))
    +      val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i) 
=>
    +        !processedIndicies.contains(i) && 
currKey.semanticEquals(expression)
    +      }.get._2
    +      processedIndicies.add(index)
    +
    +      reorderedLeftKeys.append(leftKeys(index))
    +      allLeftKeys.append(leftKeys(index))
    +
    +      reorderedRightKeys.append(rightKeys(index))
    +      allRightKeys.append(rightKeys(index))
         })
    -    (leftKeysBuffer, rightKeysBuffer)
    +
    +    // If len(currentOrderOfKeys) > len(expectedOrderOfKeys), then the 
re-ordering won't have
    +    // all the keys. Append the remaining keys to the end so that we are 
covering all the keys
    +    for (i <- leftKeys.indices) {
    +      if (!processedIndicies.contains(i)) {
    +        allLeftKeys.append(leftKeys(i))
    +        allRightKeys.append(rightKeys(i))
    +      }
    +    }
    +
    +    assert(allLeftKeys.length == leftKeys.length)
    +    assert(allRightKeys.length == rightKeys.length)
    +    assert(reorderedLeftKeys.length == reorderedRightKeys.length)
    +
    +    (allLeftKeys, reorderedLeftKeys, allRightKeys, reorderedRightKeys)
       }
     
       private def reorderJoinKeys(
           leftKeys: Seq[Expression],
           rightKeys: Seq[Expression],
           leftPartitioning: Partitioning,
    -      rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) 
= {
    +      rightPartitioning: Partitioning):
    +  (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = {
    --- End diff --
    
    added more doc. I wasn't sure how to make it easier to understand. Hope 
that the example helps with that


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to