This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dfd4fe95744 [SPARK-40192][SQL][ML] Remove redundant groupby dfd4fe95744 is described below commit dfd4fe957442d41f39e8b3f223ee5cc9adfa6b79 Author: Deshan Xiao <deshanx...@microsoft.com> AuthorDate: Thu Aug 25 08:46:24 2022 -0500 [SPARK-40192][SQL][ML] Remove redundant groupby ### What changes were proposed in this pull request? Remove redundant groupby invoking in code. ### Why are the changes needed? For Code optimization. `Dataset.agg()` has invoked the function `groupBy()`. We don't need to call `groupBy` again before executing `agg()`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing UT Closes #37628 from deshanxiao/remove-group-by. Authored-by: Deshan Xiao <deshanx...@microsoft.com> Signed-off-by: Sean Owen <sro...@gmail.com> --- .../apache/spark/ml/feature/StringIndexer.scala | 2 +- .../scala/org/apache/spark/ml/stat/ANOVATest.scala | 3 +-- .../org/apache/spark/ml/stat/ChiSquareTest.scala | 3 +-- .../org/apache/spark/ml/stat/FValueTest.scala | 3 +-- .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 +++++----- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- .../sql/execution/WholeStageCodegenSuite.scala | 2 +- .../execution/benchmark/AggregateBenchmark.scala | 4 ++-- .../sql/execution/metric/SQLMetricsSuite.scala | 2 +- .../sql/hive/execution/AggregationQuerySuite.scala | 26 +++++++++++----------- 10 files changed, 28 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 98a42371d29..4f11c58a7dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -200,7 +200,7 @@ class StringIndexer @Since("1.4.0") ( val selectedCols = getSelectedCols(dataset, inputCols) dataset.select(selectedCols: _*) .toDF - .groupBy().agg(aggregator.toColumn) + .agg(aggregator.toColumn) .as[Array[OpenHashMap[String, Long]]] .collect()(0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala index 7a7e76c457d..d7b13f1bf25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ANOVATest.scala @@ -75,8 +75,7 @@ private[ml] object ANOVATest { if (flatten) { resultDF } else { - resultDF.groupBy() - .agg(collect_list(struct("*"))) + resultDF.agg(collect_list(struct("*"))) .as[Seq[(Int, Double, Long, Double)]] .map { seq => val results = seq.toArray.sortBy(_._1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala index a38a7c446ac..e97d007a0f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala @@ -89,8 +89,7 @@ object ChiSquareTest { if (flatten) { resultDF } else { - resultDF.groupBy() - .agg(collect_list(struct("*"))) + resultDF.agg(collect_list(struct("*"))) .as[Seq[(Int, Double, Int, Double)]] .map { seq => val results = seq.toArray.sortBy(_._1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala index f315e92e86d..800c68d3b0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/FValueTest.scala @@ -76,8 +76,7 @@ private[ml] object FValueTest { if (flatten) { resultDF } else { - resultDF.groupBy() - .agg(collect_list(struct("*"))) + resultDF.agg(collect_list(struct("*"))) .as[Seq[(Int, Double, Long, Double)]] .map { seq => val results = seq.toArray.sortBy(_._1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 958b3e3f53c..4ab509b5e01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -191,11 +191,11 @@ class DataFrameAggregateSuite extends QueryTest ) intercept[AnalysisException] { - courseSales.groupBy().agg(grouping("course")).explain() + courseSales.agg(grouping("course")).explain() } intercept[AnalysisException] { - courseSales.groupBy().agg(grouping_id("course")).explain() + courseSales.agg(grouping_id("course")).explain() } } @@ -755,11 +755,11 @@ class DataFrameAggregateSuite extends QueryTest // explicit global aggregations val emptyAgg = Map.empty[String, String] checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), Seq(Row(0))) + checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row())) + checkAnswer(spark.emptyDataFrame.agg(count("*")), Seq(Row(0))) + checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), Seq(Row())) - checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), Seq(Row(0))) + checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(count("*")), Seq(Row(0))) // global aggregation is converted to grouping aggregation: assert(spark.emptyDataFrame.dropDuplicates().count() == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cc7e51abc4e..cbd65ede054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1923,7 +1923,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark x }) verifyCallCount( - df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + df.agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) verifyCallCount( df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index eca22b14763..ac710c32296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -42,7 +42,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession } test("HashAggregate should be included in WholeStageCodegen") { - val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id"))) + val df = spark.range(10).agg(max(col("id")), avg(col("id"))) val plan = df.queryExecution.executedPlan assert(plan.exists(p => p.isInstanceOf[WholeStageCodegenExec] && diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index ae4281cd639..b2f1ee31f9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -57,11 +57,11 @@ object AggregateBenchmark extends SqlBasedBenchmark { val N = 100L << 20 codegenBenchmark("stddev", N) { - spark.range(N).groupBy().agg("id" -> "stddev").noop() + spark.range(N).agg("id" -> "stddev").noop() } codegenBenchmark("kurtosis", N) { - spark.range(N).groupBy().agg("id" -> "kurtosis").noop() + spark.range(N).agg("id" -> "kurtosis").noop() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 9a63572069d..f5cfbbf5a65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -206,7 +206,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // Assume the execution plan is // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1) // -> ObjectHashAggregate(nodeId = 0) - val df = testData2.groupBy().agg(collect_set($"a")) // 2 partitions + val df = testData2.agg(collect_set($"a")) // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> (("ObjectHashAggregate", Map("number of output rows" -> 2L))), 1L -> (("Exchange", Map( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index e63cdddd81c..1966e1e64fd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -751,9 +751,9 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("pearson correlation") { val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") - val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + val corr1 = df.repartition(2).agg(corr("a", "b")).collect()(0).getDouble(0) assert(math.abs(corr1 - 1.0) < 1e-12) - val corr2 = df.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + val corr2 = df.agg(corr("a", "c")).collect()(0).getDouble(0) assert(math.abs(corr2 + 1.0) < 1e-12) // non-trivial example. To reproduce in python, use: // >>> from scipy.stats import pearsonr @@ -768,17 +768,17 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // > cor(a, b) // [1] 0.957233913947585835 val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") - val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + val corr3 = df2.agg(corr("a", "b")).collect()(0).getDouble(0) assert(math.abs(corr3 - 0.95723391394758572) < 1e-12) val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b") - val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0) + val corr4 = df3.agg(corr("a", "b")).collect()(0) assert(corr4 == Row(null)) val df4 = Seq.tabulate(10)(i => (1 * i, 2 * i, i * -1)).toDF("a", "b", "c") - val corr5 = df4.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) + val corr5 = df4.repartition(2).agg(corr("a", "b")).collect()(0).getDouble(0) assert(math.abs(corr5 - 1.0) < 1e-12) - val corr6 = df4.groupBy().agg(corr("a", "c")).collect()(0).getDouble(0) + val corr6 = df4.agg(corr("a", "c")).collect()(0).getDouble(0) assert(math.abs(corr6 + 1.0) < 1e-12) // Test for udaf_corr in HiveCompatibilitySuite @@ -855,23 +855,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // >>> np.cov(a, b, bias = 1)[0][1] // 565.25 val df = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b") - val cov_samp = df.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + val cov_samp = df.agg(covar_samp("a", "b")).collect()(0).getDouble(0) assert(math.abs(cov_samp - 595.0) < 1e-12) - val cov_pop = df.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + val cov_pop = df.agg(covar_pop("a", "b")).collect()(0).getDouble(0) assert(math.abs(cov_pop - 565.25) < 1e-12) val df2 = Seq.tabulate(20)(x => (1 * x, x * x * x - 2)).toDF("a", "b") - val cov_samp2 = df2.groupBy().agg(covar_samp("a", "b")).collect()(0).getDouble(0) + val cov_samp2 = df2.agg(covar_samp("a", "b")).collect()(0).getDouble(0) assert(math.abs(cov_samp2 - 11564.0) < 1e-12) - val cov_pop2 = df2.groupBy().agg(covar_pop("a", "b")).collect()(0).getDouble(0) + val cov_pop2 = df2.agg(covar_pop("a", "b")).collect()(0).getDouble(0) assert(math.abs(cov_pop2 - 10985.799999999999) < 1e-12) // one row test val df3 = Seq.tabulate(1)(x => (1 * x, x * x * x - 2)).toDF("a", "b") - checkAnswer(df3.groupBy().agg(covar_samp("a", "b")), Row(null)) - checkAnswer(df3.groupBy().agg(covar_pop("a", "b")), Row(0.0)) + checkAnswer(df3.agg(covar_samp("a", "b")), Row(null)) + checkAnswer(df3.agg(covar_pop("a", "b")), Row(0.0)) } test("no aggregation function (SPARK-11486)") { @@ -938,7 +938,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te .find(r => r.getInt(0) == 50) .getOrElse(fail("A row with id 50 should be the expected answer.")) checkAnswer( - df.groupBy().agg(udaf(allColumns: _*)), + df.agg(udaf(allColumns: _*)), // udaf returns a Row as the output value. Row(expectedAnswer) ) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org