Repository: spark
Updated Branches:
  refs/heads/master 9e86e6efd -> 7d05d02bf


[SPARK-13637][SQL] use more information to simplify the code in Expand builder

## What changes were proposed in this pull request?

The code in `Expand.apply` can be simplified by existing information:

* the `groupByExprs` parameter are all `Attribute`s
* the `child` parameter is a `Project` that append aliased group by expressions 
to its child's output

## How was this patch tested?

by existing tests.

Author: Wenchen Fan <wenc...@databricks.com>

Closes #11485 from cloud-fan/expand.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7d05d02b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7d05d02b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7d05d02b

Branch: refs/heads/master
Commit: 7d05d02bffe5f1c4fbf955664bcc87e38ce01f5f
Parents: 9e86e6e
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Mar 8 23:34:42 2016 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Mar 8 23:34:42 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  4 +-
 .../catalyst/plans/logical/basicOperators.scala | 48 +++++++++-----------
 2 files changed, 23 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7d05d02b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b5fa372..268d7f2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -298,12 +298,10 @@ class Analyzer(
           }.asInstanceOf[NamedExpression]
         }
 
-        val child = Project(x.child.output ++ groupByAliases, x.child)
-
         Aggregate(
           groupByAttributes :+ VirtualColumn.groupingIdAttribute,
           aggregations,
-          Expand(x.bitmasks, groupByAttributes, gid, child))
+          Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7d05d02b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 411594c..3bc246a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -449,21 +449,21 @@ private[sql] object Expand {
    * Extract attribute set according to the grouping id.
    *
    * @param bitmask bitmask to represent the selected of the attribute sequence
-   * @param exprs the attributes in sequence
+   * @param attrs the attributes in sequence
    * @return the attributes of non selected specified via bitmask (with the 
bit set to 1)
    */
-  private def buildNonSelectExprSet(
+  private def buildNonSelectAttrSet(
       bitmask: Int,
-      exprs: Seq[Expression]): ArrayBuffer[Expression] = {
-    val set = new ArrayBuffer[Expression](2)
+      attrs: Seq[Attribute]): AttributeSet = {
+    val nonSelect = new ArrayBuffer[Attribute]()
 
-    var bit = exprs.length - 1
+    var bit = attrs.length - 1
     while (bit >= 0) {
-      if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1)
+      if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 
1)
       bit -= 1
     }
 
-    set
+    AttributeSet(nonSelect)
   }
 
   /**
@@ -471,13 +471,15 @@ private[sql] object Expand {
    * multiple output rows for a input row.
    *
    * @param bitmasks The bitmask set represents the grouping sets
-   * @param groupByExprs The grouping by expressions
+   * @param groupByAliases The aliased original group by expressions
+   * @param groupByAttrs The attributes of aliased group by expressions
    * @param gid Attribute of the grouping id
    * @param child Child operator
    */
   def apply(
     bitmasks: Seq[Int],
-    groupByExprs: Seq[Expression],
+    groupByAliases: Seq[Alias],
+    groupByAttrs: Seq[Attribute],
     gid: Attribute,
     child: LogicalPlan): Expand = {
     // Create an array of Projections for the child projection, and replace 
the projections'
@@ -485,27 +487,21 @@ private[sql] object Expand {
     // are not set for this grouping set (according to the bit mask).
     val projections = bitmasks.map { bitmask =>
       // get the non selected grouping attributes according to the bit mask
-      val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, 
groupByExprs)
+      val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, 
groupByAttrs)
 
-      (child.output :+ gid).map(expr => expr transformDown {
-        // TODO this causes a problem when a column is used both for grouping 
and aggregation.
-        case x: Expression if 
nonSelectedGroupExprSet.exists(_.semanticEquals(x)) =>
+      child.output ++ groupByAttrs.map { attr =>
+        if (nonSelectedGroupAttrSet.contains(attr)) {
           // if the input attribute in the Invalid Grouping Expression set of 
for this group
           // replace it with constant null
-          Literal.create(null, expr.dataType)
-        case x if x == gid =>
-          // replace the groupingId with concrete value (the bit mask)
-          Literal.create(bitmask, IntegerType)
-      })
-    }
-    val output = child.output.map { attr =>
-      if (groupByExprs.exists(_.semanticEquals(attr))) {
-        attr.withNullability(true)
-      } else {
-        attr
-      }
+          Literal.create(null, attr.dataType)
+        } else {
+          attr
+        }
+      // groupingId is the last output, here we use the bit mask as the 
concrete value for it.
+      } :+ Literal.create(bitmask, IntegerType)
     }
-    Expand(projections, output :+ gid, child)
+    val output = child.output ++ groupByAttrs :+ gid
+    Expand(projections, output, Project(child.output ++ groupByAliases, child))
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to