Github user maropu commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21853#discussion_r204609532
  
    --- Diff: sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
---
    @@ -970,4 +973,300 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
             Row("3", "b") :: Row("4", "b") :: Nil)
         }
       }
    +
    +  private def getNumSortsInQuery(query: String): Int = {
    +    val plan = sql(query).queryExecution.optimizedPlan
    +    getNumSorts(plan) + getSubqueryExpressions(plan).map{s => 
getNumSorts(s.plan)}.sum
    +  }
    +
    +  private def getSubqueryExpressions(plan: LogicalPlan): 
Seq[SubqueryExpression] = {
    +    val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression]
    +    plan transformAllExpressions {
    +      case s: SubqueryExpression =>
    +        subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s)
    +        s
    +    }
    +    subqueryExpressions
    +  }
    +
    +  private def getNumSorts(plan: LogicalPlan): Int = {
    +    plan.collect { case s: Sort => s }.size
    +  }
    +
    +  test("SPARK-23957 Remove redundant sort from subquery plan(in 
subquery)") {
    +    withTempView("t1", "t2", "t3") {
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
    +      Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", 
"c3").createOrReplaceTempView("t3")
    +
    +      // Simple order by
    +      val query1 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |c1 IN (SELECT c1 FROM t2 ORDER BY c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query1) == 0)
    +
    +      // Nested order bys
    +      val query2 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  c1 IN (SELECT c1
    +           |              FROM   (SELECT *
    +           |                      FROM   t2
    +           |                      ORDER  BY c2)
    +           |              ORDER  BY c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query2) == 0)
    +
    +
    +      // nested IN
    +      val query3 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  c1 IN (SELECT c1
    +           |              FROM   t2
    +           |              WHERE  c1 IN (SELECT c1
    +           |                            FROM   t3
    +           |                            WHERE  c1 = 1
    +           |                            ORDER  BY c3)
    +           |              ORDER  BY c2)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query3) == 0)
    +
    +      // Complex subplan and multiple sorts
    +      val query4 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  c1 IN (SELECT c1
    +           |              FROM   (SELECT c1, c2, count(*)
    +           |                      FROM   t2
    +           |                      GROUP BY c1, c2
    +           |                      HAVING count(*) > 0
    +           |                      ORDER BY c2)
    +           |              ORDER  BY c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query4) == 0)
    +
    +      // Join in subplan
    +      val query5 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |c1 IN (SELECT t2.c1 FROM t2, t3
    +           |       WHERE t2.c1 = t3.c1
    +           |       ORDER BY t2.c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query5) == 0)
    +
    +      val query6 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  (c1, c2) IN (SELECT c1, max(c2)
    +           |                    FROM   (SELECT c1, c2, count(*)
    +           |                            FROM   t2
    +           |                            GROUP BY c1, c2
    +           |                            HAVING count(*) > 0
    +           |                            ORDER BY c2)
    +           |                    GROUP BY c1
    +           |                    HAVING max(c2) > 0
    +           |                    ORDER  BY c1)
    +        """.stripMargin
    +      // The rule to remove redundant sorts is not able to remove the 
inner sort under
    +      // an Aggregate operator. We only remove the top level sort.
    +      assert(getNumSortsInQuery(query6) == 1)
    +
    +      // Cases when sort is not removed from the plan
    +      // Limit on top of sort
    +      val query7 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query7) == 1)
    +
    +      // Sort below a set operations (intersect, union)
    +      val query8 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |c1 IN ((
    +           |        SELECT c1 FROM t2
    +           |        ORDER BY c1
    +           |       )
    +           |       UNION
    +           |       (
    +           |         SELECT c1 FROM t2
    +           |         ORDER BY c1
    +           |       ))
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query8) == 2)
    +    }
    +  }
    +
    +  test("SPARK-23957 Remove redundant sort from subquery plan(exists 
subquery)") {
    +    withTempView("t1", "t2", "t3") {
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
    +      Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", 
"c3").createOrReplaceTempView("t3")
    +
    +      // Simple order by exists correlated
    +      val query1 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY 
t2.c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query1) == 0)
    +
    +      // Nested order by and correlated.
    +      val query2 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  EXISTS (SELECT c1
    +           |               FROM   (SELECT *
    +           |                      FROM   t2
    +           |                      WHERE t2.c1 = t1.c1
    +           |                      ORDER  BY t2.c2) t2
    +           |              ORDER  BY t2.c1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query2) == 0)
    +
    +      // nested EXISTS
    +      val query3 =
    +        """
    +           |SELECT c1
    +           |FROM   t1
    +           |WHERE  EXISTS (SELECT c1
    +           |              FROM   t2
    +           |              WHERE  EXISTS (SELECT c1
    +           |                            FROM   t3
    +           |                            WHERE  t3.c1 = t2.c1
    +           |                            ORDER  BY c3)
    +           |              AND t2.c1 = t1.c1
    +           |              ORDER  BY c2)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query3) == 0)
    +
    +      // Cases when sort is not removed from the plan
    +      // Limit on top of sort
    +      val query4 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 
limit 1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query4) == 1)
    +
    +      // Sort below a set operations (intersect, union)
    +      val query5 =
    +        """
    +           |SELECT c1 FROM t1
    +           |WHERE
    +           |EXISTS ((
    +           |        SELECT c1 FROM t2
    +           |        WHERE t2.c1 = 1
    +           |        ORDER BY t2.c1
    +           |       )
    +           |       UNION
    +           |       (
    +           |         SELECT c1 FROM t2
    +           |         WHERE t2.c1 = 2
    +           |         ORDER BY t2.c1
    +           |       ))
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query5) == 2)
    +    }
    +  }
    +
    +  test("SPARK-23957 Remove redundant sort from subquery plan(scalar 
subquery)") {
    +    withTempView("t1", "t2", "t3") {
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
    +      Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
    +      Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", 
"c3").createOrReplaceTempView("t3")
    +
    +      // Two scalar subqueries in OR
    +      val query1 =
    +        """
    +          |SELECT * FROM t1
    +          |WHERE  c1 = (SELECT max(t2.c1)
    +          |             FROM   t2
    +          |             ORDER BY max(t2.c1))
    +          |OR     c2 = (SELECT min(t3.c2)
    +          |             FROM   t3
    +          |             WHERE  t3.c1 = 1
    +          |             ORDER BY min(t3.c2))
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query1) == 0)
    +
    +      // scalar subquery - groupby and having
    +      val query2 =
    +        """
    +          |SELECT *
    +          |FROM   t1
    +          |WHERE  c1 = (SELECT   max(t2.c1)
    +          |              FROM     t2
    +          |              GROUP BY t2.c1
    +          |              HAVING   count(*) >= 1
    +          |              ORDER BY max(t2.c1))
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query2) == 0)
    +
    +      // nested scalar subquery
    +      val query3 =
    +        """
    +          |SELECT *
    +          |FROM   t1
    +          |WHERE  c1 = (SELECT   max(t2.c1)
    +          |              FROM     t2
    +          |              WHERE c1 = (SELECT max(t3.c1)
    +          |                          FROM t3
    +          |                          WHERE t3.c1 = 1
    +          |                          GROUP BY t3.c1
    +          |                          ORDER BY max(t3.c1)
    +          |                          )
    +          |              GROUP BY t2.c1
    +          |              HAVING   count(*) >= 1
    +          |              ORDER BY max(t2.c1))
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query3) == 0)
    +
    +      // Scalar subquery in projection
    +      val query4 =
    +        """
    +          |SELECT (SELECT min(c1) from t1 group by c1 order by c1)
    +          |FROM t1
    +          |WHERE t1.c1 = 1
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query4) == 0)
    +
    +      // Limit on top of sort prevents it from being pruned.
    +      val query5 =
    +        """
    +          |SELECT *
    +          |FROM   t1
    +          |WHERE  c1 = (SELECT   max(t2.c1)
    +          |             FROM     t2
    +          |             WHERE c1 = (SELECT max(t3.c1)
    +          |                         FROM t3
    +          |                         WHERE t3.c1 = 1
    +          |                         GROUP BY t3.c1
    +          |                         ORDER BY max(t3.c1)
    +          |                         )
    +          |              GROUP BY t2.c1
    +          |              HAVING   count(*) >= 1
    +          |              ORDER BY max(t2.c1)
    +          |              LIMIT 1)
    +        """.stripMargin
    +      assert(getNumSortsInQuery(query5) == 1)
    +    }
    +  }
     }
    +
    --- End diff --
    
    super nit: remove this blank line


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to