Repository: spark
Updated Branches:
  refs/heads/master 10691d36d -> e9c91badc


[SPARK-20010][SQL] Sort information is lost after sort merge join

## What changes were proposed in this pull request?

After sort merge join for inner join, now we only keep left key ordering. 
However, after inner join, right key has the same value and order as left key. 
So if we need another smj on right key, we will unnecessarily add a sort which 
causes additional cost.

As a more complicated example, A join B on A.key = B.key join C on B.key = 
C.key join D on A.key = D.key. We will unnecessarily add a sort on B.key when 
join {A, B} and C, and add a sort on A.key when join {A, B, C} and D.

To fix this, we need to propagate all sorted information (equivalent 
expressions) from bottom up through `outputOrdering` and `SortOrder`.

## How was this patch tested?

Test cases are added.

Author: wangzhenhua <wangzhen...@huawei.com>

Closes #17339 from wzhfy/sortEnhance.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e9c91bad
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e9c91bad
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e9c91bad

Branch: refs/heads/master
Commit: e9c91badce64731ffd3e53cbcd9f044a7593e6b8
Parents: 10691d3
Author: wangzhenhua <wangzhen...@huawei.com>
Authored: Tue Mar 21 10:43:17 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Mar 21 10:43:17 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  4 +--
 .../analysis/SubstituteUnresolvedOrdinals.scala |  2 +-
 .../apache/spark/sql/catalyst/dsl/package.scala |  4 +--
 .../sql/catalyst/expressions/SortOrder.scala    | 21 ++++++++++++--
 .../spark/sql/catalyst/parser/AstBuilder.scala  |  2 +-
 .../scala/org/apache/spark/sql/Column.scala     |  8 +++---
 .../execution/exchange/EnsureRequirements.scala |  2 +-
 .../sql/execution/joins/SortMergeJoinExec.scala | 26 +++++++++++++++--
 .../spark/sql/execution/PlannerSuite.scala      | 30 +++++++++++++++++++-
 9 files changed, 81 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e9c91bad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8cf4073..574f91b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -966,9 +966,9 @@ class Analyzer(
       case s @ Sort(orders, global, child)
         if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
         val newOrders = orders map {
-          case s @ SortOrder(UnresolvedOrdinal(index), direction, 
nullOrdering) =>
+          case s @ SortOrder(UnresolvedOrdinal(index), direction, 
nullOrdering, _) =>
             if (index > 0 && index <= child.output.size) {
-              SortOrder(child.output(index - 1), direction, nullOrdering)
+              SortOrder(child.output(index - 1), direction, nullOrdering, 
Set.empty)
             } else {
               s.failAnalysis(
                 s"ORDER BY position $index is not in select list " +

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c91bad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
index af0a565..38a3d3d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala
@@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) 
extends Rule[LogicalPlan]
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case s: Sort if conf.orderByOrdinal && s.order.exists(o => 
isIntLiteral(o.child)) =>
       val newOrders = s.order.map {
-        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, 
_) =>
+        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, 
_, _) =>
           val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
           withOrigin(order.origin)(order.copy(child = newOrdinal))
         case other => other

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c91bad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 35ca2a0..75bf780 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -109,9 +109,9 @@ package object dsl {
     def cast(to: DataType): Expression = Cast(expr, to)
 
     def asc: SortOrder = SortOrder(expr, Ascending)
-    def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast)
+    def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, 
Set.empty)
     def desc: SortOrder = SortOrder(expr, Descending)
-    def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst)
+    def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, 
Set.empty)
     def as(alias: String): NamedExpression = Alias(expr, alias)()
     def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c91bad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 3bebd55..abcb9a2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -53,8 +53,15 @@ case object NullsLast extends NullOrdering{
 /**
  * An expression that can be used to sort a tuple.  This class extends 
expression primarily so that
  * transformations over expression will descend into its child.
+ * `sameOrderExpressions` is a set of expressions with the same sort order as 
the child. It is
+ * derived from equivalence relation in an operator, e.g. left/right keys of 
an inner sort merge
+ * join.
  */
-case class SortOrder(child: Expression, direction: SortDirection, 
nullOrdering: NullOrdering)
+case class SortOrder(
+    child: Expression,
+    direction: SortDirection,
+    nullOrdering: NullOrdering,
+    sameOrderExpressions: Set[Expression])
   extends UnaryExpression with Unevaluable {
 
   /** Sort order is not foldable because we don't have an eval for it. */
@@ -75,11 +82,19 @@ case class SortOrder(child: Expression, direction: 
SortDirection, nullOrdering:
   override def sql: String = child.sql + " " + direction.sql + " " + 
nullOrdering.sql
 
   def isAscending: Boolean = direction == Ascending
+
+  def satisfies(required: SortOrder): Boolean = {
+    (sameOrderExpressions + child).exists(required.child.semanticEquals) &&
+      direction == required.direction && nullOrdering == required.nullOrdering
+  }
 }
 
 object SortOrder {
-  def apply(child: Expression, direction: SortDirection): SortOrder = {
-    new SortOrder(child, direction, direction.defaultNullOrdering)
+  def apply(
+     child: Expression,
+     direction: SortDirection,
+     sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = {
+    new SortOrder(child, direction, direction.defaultNullOrdering, 
sameOrderExpressions)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c91bad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 4c9fb2e..cd238e0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1229,7 +1229,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with 
Logging {
     } else {
       direction.defaultNullOrdering
     }
-    SortOrder(expression(ctx.expression), direction, nullOrdering)
+    SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c91bad/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 3802955..ae07035 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -1037,7 +1037,7 @@ class Column(val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 2.1.0
    */
-  def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, 
NullsFirst) }
+  def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, 
NullsFirst, Set.empty) }
 
   /**
    * Returns a descending ordering used in sorting, where null values appear 
after non-null values.
@@ -1052,7 +1052,7 @@ class Column(val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 2.1.0
    */
-  def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, 
NullsLast) }
+  def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, 
NullsLast, Set.empty) }
 
   /**
    * Returns an ascending ordering used in sorting.
@@ -1082,7 +1082,7 @@ class Column(val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 2.1.0
    */
-  def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, 
NullsFirst) }
+  def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, 
NullsFirst, Set.empty) }
 
   /**
    * Returns an ordering used in sorting, where null values appear after 
non-null values.
@@ -1097,7 +1097,7 @@ class Column(val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 2.1.0
    */
-  def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, 
NullsLast) }
+  def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, 
NullsLast, Set.empty) }
 
   /**
    * Prints the expression to the console for debugging purpose.

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c91bad/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index f170499..b91d077 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -241,7 +241,7 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
         } else {
           requiredOrdering.zip(child.outputOrdering).forall {
             case (requiredOrder, childOutputOrder) =>
-              requiredOrder.semanticEquals(childOutputOrder)
+              childOutputOrder.satisfies(requiredOrder)
           }
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c91bad/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 02f4f55..c6aae1a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -81,17 +81,37 @@ case class SortMergeJoinExec(
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
   override def outputOrdering: Seq[SortOrder] = joinType match {
+    // For inner join, orders of both sides keys should be kept.
+    case Inner =>
+      val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering)
+      val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering)
+      leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) =>
+        // Also add the right key and its `sameOrderExpressions`
+        SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + 
rKey.child ++ rKey
+          .sameOrderExpressions)
+      }
     // For left and right outer joins, the output is ordered by the streamed 
input's join keys.
-    case LeftOuter => requiredOrders(leftKeys)
-    case RightOuter => requiredOrders(rightKeys)
+    case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering)
+    case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering)
     // There are null rows in both streams, so there is no order.
     case FullOuter => Nil
-    case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys)
+    case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering)
     case x =>
       throw new IllegalArgumentException(
         s"${getClass.getSimpleName} should not take $x as the JoinType")
   }
 
+  /**
+   * For SMJ, child's output must have been sorted on key or expressions with 
the same order as
+   * key, so we can get ordering for key from child's output ordering.
+   */
+  private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: 
Seq[SortOrder])
+    : Seq[SortOrder] = {
+    keys.zip(childOutputOrdering).map { case (key, childOrder) =>
+      SortOrder(key, Ascending, childOrder.sameOrderExpressions + 
childOrder.child - key)
+    }
+  }
+
   override def requiredChildOrdering: Seq[Seq[SortOrder]] =
     requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e9c91bad/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index f2232fc..4d155d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -477,14 +477,18 @@ class PlannerSuite extends SharedSQLContext {
 
   private val exprA = Literal(1)
   private val exprB = Literal(2)
+  private val exprC = Literal(3)
   private val orderingA = SortOrder(exprA, Ascending)
   private val orderingB = SortOrder(exprB, Ascending)
+  private val orderingC = SortOrder(exprC, Ascending)
   private val planA = DummySparkPlan(outputOrdering = Seq(orderingA),
     outputPartitioning = HashPartitioning(exprA :: Nil, 5))
   private val planB = DummySparkPlan(outputOrdering = Seq(orderingB),
     outputPartitioning = HashPartitioning(exprB :: Nil, 5))
+  private val planC = DummySparkPlan(outputOrdering = Seq(orderingC),
+    outputPartitioning = HashPartitioning(exprC :: Nil, 5))
 
-  assert(orderingA != orderingB)
+  assert(orderingA != orderingB && orderingA != orderingC && orderingB != 
orderingC)
 
   private def assertSortRequirementsAreSatisfied(
       childPlan: SparkPlan,
@@ -508,6 +512,30 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
+  test("EnsureRequirements skips sort when either side of join keys is 
required after inner SMJ") {
+    val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
planA, planB)
+    // Both left and right keys should be sorted after the SMJ.
+    Seq(orderingA, orderingB).foreach { ordering =>
+      assertSortRequirementsAreSatisfied(
+        childPlan = innerSmj,
+        requiredOrdering = Seq(ordering),
+        shouldHaveSort = false)
+    }
+  }
+
+  test("EnsureRequirements skips sort when key order of a parent SMJ is 
propagated from its " +
+    "child SMJ") {
+    val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, 
planA, planB)
+    val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, 
childSmj, planC)
+    // After the second SMJ, exprA, exprB and exprC should all be sorted.
+    Seq(orderingA, orderingB, orderingC).foreach { ordering =>
+      assertSortRequirementsAreSatisfied(
+        childPlan = parentSmj,
+        requiredOrdering = Seq(ordering),
+        shouldHaveSort = false)
+    }
+  }
+
   test("EnsureRequirements for sort operator after left outer sort merge 
join") {
     // Only left key is sorted after left outer SMJ (thus doesn't need a sort).
     val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, 
None, planA, planB)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to