Github user attilapiros commented on a diff in the pull request: https://github.com/apache/spark/pull/20046#discussion_r158456767 --- Diff: sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala --- @@ -154,6 +154,217 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(2.0d), Row(2.0d))) } + test("corr, covar_pop, stddev_pop functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + corr("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + covar_pop("value1", "value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value1") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_pop("value2") + .over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + + // As stddev_pop(expr) = sqrt(var_pop(expr)) + // the "stddev_pop" column can be calculated from the "var_pop" column. + // + // As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2)) + // the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns. + Seq( + Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0), + Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0), + Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0), + Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0))) + } + + test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") { + val df = Seq( + ("a", "p1", 10.0, 20.0), + ("b", "p1", 20.0, 10.0), + ("c", "p2", 20.0, 20.0), + ("d", "p2", 20.0, 20.0), + ("e", "p3", 0.0, 0.0), + ("f", "p3", 6.0, 12.0), + ("g", "p3", 6.0, 12.0), + ("h", "p3", 8.0, 16.0), + ("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2") + checkAnswer( + df.select( + $"key", + covar_samp("value1", "value2").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + var_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + variance("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev_samp("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + stddev("value1").over(Window.partitionBy("partitionId") + .orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) + ), + Seq( + Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755), + Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ), + Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ), + Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN))) + } + + test("collect_list in ascending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Seq( + Row("a", Array("1", "2", "2", "3")), + Row("b", Array("1", "2", "2", "3")), + Row("c", Array("1", "2", "2", "3")), + Row("d", Array("1", "2", "2", "3")), + Row("e", Array("1", "2", "2", "3")), + Row("f", Array("10", "11")), + Row("g", Array("10", "11")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_list in descending ordered window") { + val df = Seq( + ("a", "p1", "1"), + ("b", "p1", "2"), + ("c", "p1", "2"), + ("d", "p1", null), + ("e", "p1", "3"), + ("f", "p2", "10"), + ("g", "p2", "11"), + ("h", "p3", "20"), + ("i", "p4", null)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc) + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Seq( + Row("a", Array("3", "2", "2", "1")), + Row("b", Array("3", "2", "2", "1")), + Row("c", Array("3", "2", "2", "1")), + Row("d", Array("3", "2", "2", "1")), + Row("e", Array("3", "2", "2", "1")), + Row("f", Array("11", "10")), + Row("g", Array("11", "10")), + Row("h", Array("20")), + Row("i", Array()))) + } + + test("collect_set in window") { + val df = Seq( + ("a", "p1", 1), + ("b", "p1", 2), + ("c", "p1", 2), + ("d", "p1", 3), + ("e", "p1", 3), + ("f", "p2", 10), + ("g", "p2", 11), + ("h", "p3", 20)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + collect_set("value").over(Window.partitionBy($"partition").orderBy($"value") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + Seq( + Row("a", Array(1, 2, 3)), + Row("b", Array(1, 2, 3)), + Row("c", Array(1, 2, 3)), + Row("d", Array(1, 2, 3)), + Row("e", Array(1, 2, 3)), + Row("f", Array(10, 11)), + Row("g", Array(10, 11)), + Row("h", Array(20)))) + } + + test("skewness and kurtosis functions in window") { + val df = Seq( + ("a", "p1", 1.0), + ("b", "p1", 1.0), + ("c", "p1", 2.0), + ("d", "p1", 2.0), + ("e", "p1", 3.0), + ("f", "p1", 3.0), + ("g", "p1", 3.0), + ("h", "p2", 1.0), + ("i", "p2", 2.0), + ("j", "p2", 5.0)).toDF("key", "partition", "value") + checkAnswer( + df.select( + $"key", + skewness("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)), + kurtosis("value").over(Window.partitionBy("partition").orderBy($"key") + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))), + // results are checked by scipy.stats.skew() and scipy.stats.kurtosis() + Seq( + Row("a", -0.27238010581457267, -1.506920415224914), + Row("b", -0.27238010581457267, -1.506920415224914), + Row("c", -0.27238010581457267, -1.506920415224914), + Row("d", -0.27238010581457267, -1.506920415224914), + Row("e", -0.27238010581457267, -1.506920415224914), + Row("f", -0.27238010581457267, -1.506920415224914), + Row("g", -0.27238010581457267, -1.506920415224914), + Row("h", 0.5280049792181881, -1.5000000000000013), + Row("i", 0.5280049792181881, -1.5000000000000013), + Row("j", 0.5280049792181881, -1.5000000000000013))) + } + + test("aggregation function on invalid column") { + val df = Seq((1, "1")).toDF("key", "value") + val e = intercept[AnalysisException]( + df.select( + $"key", + count("invalid").over( --- End diff -- Thanks, I will remove the unnecessary parts.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org