Github user eyalfa commented on a diff in the pull request: https://github.com/apache/spark/pull/19054#discussion_r165861581 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala --- @@ -220,45 +220,99 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { operator.withNewChildren(children) } + private def isSubset(biggerSet: Seq[Expression], smallerSet: Seq[Expression]): Boolean = + smallerSet.length <= biggerSet.length && + smallerSet.forall(x => biggerSet.exists(_.semanticEquals(x))) + + /** + * Reorders `leftKeys` and `rightKeys` by aligning `currentOrderOfKeys` to be a prefix of + * `expectedOrderOfKeys` + */ private def reorder( leftKeys: Seq[Expression], rightKeys: Seq[Expression], - expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() + expectedOrderOfKeys: Seq[Expression], // comes from child's output partitioning + currentOrderOfKeys: Seq[Expression]): // comes from join predicate + (Seq[Expression], Seq[Expression], Seq[Expression], Seq[Expression]) = { + + assert(leftKeys.length == rightKeys.length) + + val allLeftKeys = ArrayBuffer[Expression]() + val allRightKeys = ArrayBuffer[Expression]() + val reorderedLeftKeys = ArrayBuffer[Expression]() + val reorderedRightKeys = ArrayBuffer[Expression]() + + // Tracking indicies here to track to which keys are accounted. Using a set based approach + // won't work because its possible that some keys are repeated in the join clause + // eg. a.key1 = b.key1 AND a.key1 = b.key2 + val processedIndicies = mutable.Set[Int]() expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) + val index = currentOrderOfKeys.zipWithIndex.find { case (currKey, i) => + !processedIndicies.contains(i) && currKey.semanticEquals(expression) + }.get._2 --- End diff -- is the find guaranteed to always succeed? if so, worth a comment on method's pre/post conditions. a getOrElse(sys error "...") might also be a good way of documenting this.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org