Github user maropu commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21853#discussion_r204612114
  
    --- Diff: sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
---
    @@ -970,4 +973,300 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
             Row("3", "b") :: Row("4", "b") :: Nil)
         }
       }
    +
    +  private def getNumSortsInQuery(query: String): Int = {
    +    val plan = sql(query).queryExecution.optimizedPlan
    +    getNumSorts(plan) + getSubqueryExpressions(plan).map{s => 
getNumSorts(s.plan)}.sum
    +  }
    +
    +  private def getSubqueryExpressions(plan: LogicalPlan): 
Seq[SubqueryExpression] = {
    +    val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression]
    +    plan transformAllExpressions {
    +      case s: SubqueryExpression =>
    +        subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s)
    +        s
    +    }
    +    subqueryExpressions
    +  }
    +
    +  private def getNumSorts(plan: LogicalPlan): Int = {
    +    plan.collect { case s: Sort => s }.size
    +  }
    +
    +  test("SPARK-23957 Remove redundant sort from subquery plan(in 
subquery)") {
    +    withTempView("t1", "t2", "t3") {
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
    +      Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", 
"c3").createOrReplaceTempView("t3")
    +
    +      // Simple order by
    +      val query1 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |c1 IN (SELECT c1 FROM t2 ORDER BY c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query1) == 0)
    +
    +      // Nested order bys
    +      val query2 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  c1 IN (SELECT c1
    +           |              FROM   (SELECT *
    +           |                      FROM   t2
    +           |                      ORDER  BY c2)
    +           |              ORDER  BY c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query2) == 0)
    +
    +
    +      // nested IN
    +      val query3 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  c1 IN (SELECT c1
    +           |              FROM   t2
    +           |              WHERE  c1 IN (SELECT c1
    +           |                            FROM   t3
    +           |                            WHERE  c1 = 1
    +           |                            ORDER  BY c3)
    +           |              ORDER  BY c2)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query3) == 0)
    +
    +      // Complex subplan and multiple sorts
    +      val query4 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  c1 IN (SELECT c1
    +           |              FROM   (SELECT c1, c2, count(*)
    +           |                      FROM   t2
    +           |                      GROUP BY c1, c2
    +           |                      HAVING count(*) > 0
    +           |                      ORDER BY c2)
    +           |              ORDER  BY c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query4) == 0)
    +
    +      // Join in subplan
    +      val query5 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |c1 IN (SELECT t2.c1 FROM t2, t3
    +           |       WHERE t2.c1 = t3.c1
    +           |       ORDER BY t2.c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query5) == 0)
    +
    +      val query6 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  (c1, c2) IN (SELECT c1, max(c2)
    +           |                    FROM   (SELECT c1, c2, count(*)
    +           |                            FROM   t2
    +           |                            GROUP BY c1, c2
    +           |                            HAVING count(*) > 0
    +           |                            ORDER BY c2)
    +           |                    GROUP BY c1
    +           |                    HAVING max(c2) > 0
    +           |                    ORDER  BY c1)
    +        """.stripMargin
    +      // The rule to remove redundant sorts is not able to remove the 
inner sort under
    +      // an Aggregate operator. We only remove the top level sort.
    +      assert(getNumSortsInQuery(query6) == 1)
    +
    +      // Cases when sort is not removed from the plan
    +      // Limit on top of sort
    +      val query7 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query7) == 1)
    +
    +      // Sort below a set operations (intersect, union)
    +      val query8 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |c1 IN ((
    +           |        SELECT c1 FROM t2
    +           |        ORDER BY c1
    +           |       )
    +           |       UNION
    +           |       (
    +           |         SELECT c1 FROM t2
    +           |         ORDER BY c1
    +           |       ))
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query8) == 2)
    +    }
    +  }
    +
    +  test("SPARK-23957 Remove redundant sort from subquery plan(exists 
subquery)") {
    +    withTempView("t1", "t2", "t3") {
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
    +      Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", 
"c3").createOrReplaceTempView("t3")
    +
    +      // Simple order by exists correlated
    +      val query1 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY 
t2.c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query1) == 0)
    +
    +      // Nested order by and correlated.
    +      val query2 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  EXISTS (SELECT c1
    +           |               FROM   (SELECT *
    +           |                      FROM   t2
    +           |                      WHERE t2.c1 = t1.c1
    +           |                      ORDER  BY t2.c2) t2
    +           |              ORDER  BY t2.c1)
    --- End diff --
    
    super nit: add one space before `ORDER`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to