Github user hvanhovell commented on a diff in the pull request:

    https://github.com/apache/spark/pull/11485#discussion_r54940996
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 ---
    @@ -473,41 +473,44 @@ private[sql] object Expand {
        * multiple output rows for a input row.
        *
        * @param bitmasks The bitmask set represents the grouping sets
    -   * @param groupByExprs The grouping by expressions
    +   * @param groupByAttrs The attributes of aliased group by expressions
        * @param gid Attribute of the grouping id
    -   * @param child Child operator
    +   * @param project The child project operator
        */
       def apply(
         bitmasks: Seq[Int],
    -    groupByExprs: Seq[Expression],
    +    groupByAttrs: Seq[Attribute],
         gid: Attribute,
    -    child: LogicalPlan): Expand = {
    +    project: Project): Expand = {
    +
    +    val originalOutput = project.child.output
    +    assert(project.output.length == (originalOutput ++ 
groupByAttrs).length)
    +    assert(project.output.zip(originalOutput ++ groupByAttrs).forall {
    +      case (attr1, attr2) => attr1 semanticEquals attr2
    +    })
    +
         // Create an array of Projections for the child projection, and 
replace the projections'
         // expressions which equal GroupBy expressions with Literal(null), if 
those expressions
         // are not set for this grouping set (according to the bit mask).
         val projections = bitmasks.map { bitmask =>
           // get the non selected grouping attributes according to the bit mask
    -      val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, 
groupByExprs)
    +      val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, 
groupByAttrs)
     
    -      (child.output :+ gid).map(expr => expr transformDown {
    -        // TODO this causes a problem when a column is used both for 
grouping and aggregation.
    -        case x: Expression if 
nonSelectedGroupExprSet.exists(_.semanticEquals(x)) =>
    +      originalOutput ++ groupByAttrs.map { attr =>
    +        if (nonSelectedGroupAttrSet.contains(attr)) {
               // if the input attribute in the Invalid Grouping Expression set 
of for this group
               // replace it with constant null
    -          Literal.create(null, expr.dataType)
    -        case x if x == gid =>
    -          // replace the groupingId with concrete value (the bit mask)
    -          Literal.create(bitmask, IntegerType)
    -      })
    -    }
    -    val output = child.output.map { attr =>
    -      if (groupByExprs.exists(_.semanticEquals(attr))) {
    -        attr.withNullability(true)
    -      } else {
    -        attr
    +          Literal.create(null, attr.dataType)
    +        } else {
    +          attr
    +        }
    +      } :+ {
    --- End diff --
    
    Nit: Funny brackets... Any reason for this besides this comment?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to