Github user hvanhovell commented on a diff in the pull request: https://github.com/apache/spark/pull/11485#discussion_r54940996 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala --- @@ -473,41 +473,44 @@ private[sql] object Expand { * multiple output rows for a input row. * * @param bitmasks The bitmask set represents the grouping sets - * @param groupByExprs The grouping by expressions + * @param groupByAttrs The attributes of aliased group by expressions * @param gid Attribute of the grouping id - * @param child Child operator + * @param project The child project operator */ def apply( bitmasks: Seq[Int], - groupByExprs: Seq[Expression], + groupByAttrs: Seq[Attribute], gid: Attribute, - child: LogicalPlan): Expand = { + project: Project): Expand = { + + val originalOutput = project.child.output + assert(project.output.length == (originalOutput ++ groupByAttrs).length) + assert(project.output.zip(originalOutput ++ groupByAttrs).forall { + case (attr1, attr2) => attr1 semanticEquals attr2 + }) + // Create an array of Projections for the child projection, and replace the projections' // expressions which equal GroupBy expressions with Literal(null), if those expressions // are not set for this grouping set (according to the bit mask). val projections = bitmasks.map { bitmask => // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) + val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs) - (child.output :+ gid).map(expr => expr transformDown { - // TODO this causes a problem when a column is used both for grouping and aggregation. - case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) => + originalOutput ++ groupByAttrs.map { attr => + if (nonSelectedGroupAttrSet.contains(attr)) { // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null - Literal.create(null, expr.dataType) - case x if x == gid => - // replace the groupingId with concrete value (the bit mask) - Literal.create(bitmask, IntegerType) - }) - } - val output = child.output.map { attr => - if (groupByExprs.exists(_.semanticEquals(attr))) { - attr.withNullability(true) - } else { - attr + Literal.create(null, attr.dataType) + } else { + attr + } + } :+ { --- End diff -- Nit: Funny brackets... Any reason for this besides this comment?
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org