Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/21853#discussion_r204609532 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala --- @@ -970,4 +973,300 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row("3", "b") :: Row("4", "b") :: Nil) } } + + private def getNumSortsInQuery(query: String): Int = { + val plan = sql(query).queryExecution.optimizedPlan + getNumSorts(plan) + getSubqueryExpressions(plan).map{s => getNumSorts(s.plan)}.sum + } + + private def getSubqueryExpressions(plan: LogicalPlan): Seq[SubqueryExpression] = { + val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression] + plan transformAllExpressions { + case s: SubqueryExpression => + subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s) + s + } + subqueryExpressions + } + + private def getNumSorts(plan: LogicalPlan): Int = { + plan.collect { case s: Sort => s }.size + } + + test("SPARK-23957 Remove redundant sort from subquery plan(in subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order bys + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT * + | FROM t2 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + + // nested IN + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM t2 + | WHERE c1 IN (SELECT c1 + | FROM t3 + | WHERE c1 = 1 + | ORDER BY c3) + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Complex subplan and multiple sorts + val query4 = + """ + |SELECT c1 + |FROM t1 + |WHERE c1 IN (SELECT c1 + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | ORDER BY c1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Join in subplan + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT t2.c1 FROM t2, t3 + | WHERE t2.c1 = t3.c1 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 0) + + val query6 = + """ + |SELECT c1 + |FROM t1 + |WHERE (c1, c2) IN (SELECT c1, max(c2) + | FROM (SELECT c1, c2, count(*) + | FROM t2 + | GROUP BY c1, c2 + | HAVING count(*) > 0 + | ORDER BY c2) + | GROUP BY c1 + | HAVING max(c2) > 0 + | ORDER BY c1) + """.stripMargin + // The rule to remove redundant sorts is not able to remove the inner sort under + // an Aggregate operator. We only remove the top level sort. + assert(getNumSortsInQuery(query6) == 1) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query7 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query7) == 1) + + // Sort below a set operations (intersect, union) + val query8 = + """ + |SELECT c1 FROM t1 + |WHERE + |c1 IN (( + | SELECT c1 FROM t2 + | ORDER BY c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | ORDER BY c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query8) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(exists subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Simple order by exists correlated + val query1 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // Nested order by and correlated. + val query2 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM (SELECT * + | FROM t2 + | WHERE t2.c1 = t1.c1 + | ORDER BY t2.c2) t2 + | ORDER BY t2.c1) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested EXISTS + val query3 = + """ + |SELECT c1 + |FROM t1 + |WHERE EXISTS (SELECT c1 + | FROM t2 + | WHERE EXISTS (SELECT c1 + | FROM t3 + | WHERE t3.c1 = t2.c1 + | ORDER BY c3) + | AND t2.c1 = t1.c1 + | ORDER BY c2) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Cases when sort is not removed from the plan + // Limit on top of sort + val query4 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 limit 1) + """.stripMargin + assert(getNumSortsInQuery(query4) == 1) + + // Sort below a set operations (intersect, union) + val query5 = + """ + |SELECT c1 FROM t1 + |WHERE + |EXISTS (( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 1 + | ORDER BY t2.c1 + | ) + | UNION + | ( + | SELECT c1 FROM t2 + | WHERE t2.c1 = 2 + | ORDER BY t2.c1 + | )) + """.stripMargin + assert(getNumSortsInQuery(query5) == 2) + } + } + + test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") { + withTempView("t1", "t2", "t3") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") + + // Two scalar subqueries in OR + val query1 = + """ + |SELECT * FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | ORDER BY max(t2.c1)) + |OR c2 = (SELECT min(t3.c2) + | FROM t3 + | WHERE t3.c1 = 1 + | ORDER BY min(t3.c2)) + """.stripMargin + assert(getNumSortsInQuery(query1) == 0) + + // scalar subquery - groupby and having + val query2 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query2) == 0) + + // nested scalar subquery + val query3 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1)) + """.stripMargin + assert(getNumSortsInQuery(query3) == 0) + + // Scalar subquery in projection + val query4 = + """ + |SELECT (SELECT min(c1) from t1 group by c1 order by c1) + |FROM t1 + |WHERE t1.c1 = 1 + """.stripMargin + assert(getNumSortsInQuery(query4) == 0) + + // Limit on top of sort prevents it from being pruned. + val query5 = + """ + |SELECT * + |FROM t1 + |WHERE c1 = (SELECT max(t2.c1) + | FROM t2 + | WHERE c1 = (SELECT max(t3.c1) + | FROM t3 + | WHERE t3.c1 = 1 + | GROUP BY t3.c1 + | ORDER BY max(t3.c1) + | ) + | GROUP BY t2.c1 + | HAVING count(*) >= 1 + | ORDER BY max(t2.c1) + | LIMIT 1) + """.stripMargin + assert(getNumSortsInQuery(query5) == 1) + } + } } + --- End diff -- super nit: remove this blank line
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org