Repository: spark
Updated Branches:
  refs/heads/branch-2.0 c94288b57 -> 3566e40a4


[SPARK-18969][SQL] Support grouping by nondeterministic expressions

## What changes were proposed in this pull request?

Currently nondeterministic expressions are allowed in `Aggregate`(see the 
[comment](https://github.com/apache/spark/blob/v2.0.2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L249-L251)),
 but the `PullOutNondeterministic` analyzer rule failed to handle `Aggregate`, 
this PR fixes it.

close https://github.com/apache/spark/pull/16379

There is still one remaining issue: `SELECT a + rand() FROM t GROUP BY a + 
rand()` is not allowed, because the 2 `rand()` are different(we generate random 
seed as the default seed for `rand()`). 
https://issues.apache.org/jira/browse/SPARK-19035 is tracking this issue.

## How was this patch tested?

a new test suite

Author: Wenchen Fan <wenc...@databricks.com>

Closes #16404 from cloud-fan/groupby.

(cherry picked from commit 871d266649ddfed38c64dfda7158d8bb58d4b979)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3566e40a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3566e40a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3566e40a

Branch: refs/heads/branch-2.0
Commit: 3566e40a4ce319e095780062abf94154b4aba334
Parents: c94288b
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Thu Jan 12 20:21:04 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Thu Jan 12 20:25:44 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 37 ++++++++-----
 .../analysis/PullOutNondeterministicSuite.scala | 56 ++++++++++++++++++++
 .../sql-tests/results/group-by-ordinal.sql.out  | 10 ++--
 3 files changed, 86 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3566e40a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 32dc70a..9040ced 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1789,28 +1789,37 @@ class Analyzer(
       case p: Project => p
       case f: Filter => f
 
+      case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) =>
+        val nondeterToAttr = getNondeterToAttr(a.groupingExpressions)
+        val newChild = Project(a.child.output ++ nondeterToAttr.values, 
a.child)
+        a.transformExpressions { case e =>
+          nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
+        }.copy(child = newChild)
+
       // todo: It's hard to write a general rule to pull out nondeterministic 
expressions
       // from LogicalPlan, currently we only do it for UnaryNode which has 
same output
       // schema with its child.
       case p: UnaryNode if p.output == p.child.output && 
p.expressions.exists(!_.deterministic) =>
-        val nondeterministicExprs = 
p.expressions.filterNot(_.deterministic).flatMap { expr =>
-          val leafNondeterministic = expr.collect {
-            case n: Nondeterministic => n
-          }
-          leafNondeterministic.map { e =>
-            val ne = e match {
-              case n: NamedExpression => n
-              case _ => Alias(e, "_nondeterministic")(isGenerated = true)
-            }
-            new TreeNodeRef(e) -> ne
-          }
-        }.toMap
+        val nondeterToAttr = getNondeterToAttr(p.expressions)
         val newPlan = p.transformExpressions { case e =>
-          nondeterministicExprs.get(new 
TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
+          nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e)
         }
-        val newChild = Project(p.child.output ++ nondeterministicExprs.values, 
p.child)
+        val newChild = Project(p.child.output ++ nondeterToAttr.values, 
p.child)
         Project(p.output, newPlan.withNewChildren(newChild :: Nil))
     }
+
+    private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, 
NamedExpression] = {
+      exprs.filterNot(_.deterministic).flatMap { expr =>
+        val leafNondeterministic = expr.collect { case n: Nondeterministic => 
n }
+        leafNondeterministic.distinct.map { e =>
+          val ne = e match {
+            case n: NamedExpression => n
+            case _ => Alias(e, "_nondeterministic")(isGenerated = true)
+          }
+          e -> ne
+        }
+      }.toMap
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/3566e40a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
new file mode 100644
index 0000000..72e10ea
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+
+/**
+ * Test suite for moving non-deterministic expressions into Project.
+ */
+class PullOutNondeterministicSuite extends AnalysisTest {
+
+  private lazy val a = 'a.int
+  private lazy val b = 'b.int
+  private lazy val r = LocalRelation(a, b)
+  private lazy val rnd = Rand(10).as('_nondeterministic)
+  private lazy val rndref = rnd.toAttribute
+
+  test("no-op on filter") {
+    checkAnalysis(
+      r.where(Rand(10) > Literal(1.0)),
+      r.where(Rand(10) > Literal(1.0))
+    )
+  }
+
+  test("sort") {
+    checkAnalysis(
+      r.sortBy(SortOrder(Rand(10), Ascending)),
+      r.select(a, b, rnd).sortBy(SortOrder(rndref, Ascending)).select(a, b)
+    )
+  }
+
+  test("aggregate") {
+    checkAnalysis(
+      r.groupBy(Rand(10))(Rand(10).as("rnd")),
+      r.select(a, b, rnd).groupBy(rndref)(rndref.as("rnd"))
+    )
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3566e40a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out 
b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
index 2f10b7e..e3a5a93 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
@@ -137,10 +137,14 @@ GROUP BY position 3 is an aggregate function, and 
aggregate functions are not al
 -- !query 13
 select a, rand(0), sum(b) from data group by a, 2
 -- !query 13 schema
-struct<>
+struct<a:int,rand(0):double,sum(b):bigint>
 -- !query 13 output
-org.apache.spark.sql.AnalysisException
-nondeterministic expression rand(0) should not appear in grouping expression.;
+1      0.4048454303385226      2
+1      0.8446490682263027      1
+2      0.5871875724155838      1
+2      0.8865128837019473      2
+3      0.742083829230211       1
+3      0.9179913208300406      2
 
 
 -- !query 14


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to