Github user tejasapatil commented on a diff in the pull request: https://github.com/apache/spark/pull/19054#discussion_r162768516 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala --- @@ -220,45 +220,76 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { operator.withNewChildren(children) } + private def isSubset(biggerSet: Seq[Expression], smallerSet: Seq[Expression]): Boolean = + smallerSet.length <= biggerSet.length && + smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x))) + private def reorder( leftKeys: Seq[Expression], rightKeys: Seq[Expression], - expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() + expectedOrderOfKeys: Seq[Expression], // comes from child's output partitioning + currentOrderOfKeys: Seq[Expression]): // comes from join predicate + (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = { + + assert(leftKeys.length == rightKeys.length) + + val allLeftKeys = ArrayBuffer[Expression]() + val allRightKeys = ArrayBuffer[Expression]() + val reorderedLeftKeys = ArrayBuffer[Expression]() + val reorderedRightKeys = ArrayBuffer[Expression]() + val processedIndicies = mutable.Set[Int]() expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) + val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i) => + !processedIndicies.contains(i) && currKey.semanticEquals(expression) + }.get._2 + processedIndicies.add(index) + + reorderedLeftKeys.append(leftKeys(index)) + allLeftKeys.append(leftKeys(index)) + + reorderedRightKeys.append(rightKeys(index)) + allRightKeys.append(rightKeys(index)) }) - (leftKeysBuffer, rightKeysBuffer) + + // If len(currentOrderOfKeys) > len(expectedOrderOfKeys), then the re-ordering won't have + // all the keys. Append the remaining keys to the end so that we are covering all the keys + for (i <- leftKeys.indices) { + if (!processedIndicies.contains(i)) { + allLeftKeys.append(leftKeys(i)) + allRightKeys.append(rightKeys(i)) + } + } + + assert(allLeftKeys.length == leftKeys.length) + assert(allRightKeys.length == rightKeys.length) + assert(reorderedLeftKeys.length == reorderedRightKeys.length) + + (allLeftKeys, reorderedLeftKeys, allRightKeys, reorderedRightKeys) } private def reorderJoinKeys( leftKeys: Seq[Expression], rightKeys: Seq[Expression], leftPartitioning: Partitioning, - rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + rightPartitioning: Partitioning): + (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = { + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { leftPartitioning match { - case HashPartitioning(leftExpressions, _) - if leftExpressions.length == leftKeys.length && - leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + case HashPartitioning(leftExpressions, _) if isSubset(leftKeys, leftExpressions) => reorder(leftKeys, rightKeys, leftExpressions, leftKeys) --- End diff -- given that this was only done over `SortMergeJoinExec` and `ShuffledHashJoinExec` where both the partitionings are `HashPartitioning`, things worked fine. I have changed this to have a stricter check.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org