beliefer commented on a change in pull request #27428:
URL: https://github.com/apache/spark/pull/27428#discussion_r450581405



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
##########
@@ -148,24 +204,105 @@ object RewriteDistinctAggregates extends 
Rule[LogicalPlan] {
     val distinctAggs = exprs.flatMap { _.collect {
       case ae: AggregateExpression if ae.isDistinct => ae
     }}
-    // We need at least two distinct aggregates for this rule because 
aggregation
-    // strategy can handle a single distinct group.
+    // We need at least two distinct aggregates or a single distinct aggregate 
with a filter for
+    // this rule because aggregation strategy can handle a single distinct 
group without a filter.
     // This check can produce false-positives, e.g., SUM(DISTINCT a) & 
COUNT(DISTINCT a).
-    distinctAggs.size > 1
+    distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined)
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-    case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) => rewrite(a)
+    case a: Aggregate if mayNeedtoRewrite(a.aggregateExpressions) =>
+      val expandAggregate = extractFiltersInDistinctAggregates(a)
+      rewriteDistinctAggregates(expandAggregate)
   }
 
-  def rewrite(a: Aggregate): Aggregate = {
+  private def extractFiltersInDistinctAggregates(a: Aggregate): Aggregate = {
+    val aggExpressions = collectAggregateExprs(a)
+    val (distinctAggExpressions, regularAggExpressions) = 
aggExpressions.partition(_.isDistinct)
+    if (distinctAggExpressions.exists(_.filter.isDefined)) {
+      // Constructs pairs between old and new expressions for regular 
aggregates. Because we
+      // will construct a new `Aggregate` and the children of the distinct 
aggregates will be
+      // changed to generated ones, we need to create new references to avoid 
collisions between
+      // distinct and regular aggregate children.
+      val regularAggExprs = 
regularAggExpressions.filter(_.children.exists(!_.foldable))
+      val regularFunChildren = regularAggExprs
+        .flatMap(_.aggregateFunction.children.filter(!_.foldable))
+      val regularFilterAttrs = regularAggExprs.flatMap(_.filterAttributes)
+      val regularAggChildren = (regularFunChildren ++ 
regularFilterAttrs).distinct
+      val regularAggChildrenMap = regularAggChildren.map {
+        case ne: NamedExpression => ne -> ne
+        case other => other -> Alias(other, other.toString)()
+      }
+      val namedRegularAggChildren = regularAggChildrenMap.map(_._2)
+      val regularAggChildAttrLookup = regularAggChildrenMap.map { kv =>
+        (kv._1, kv._2.toAttribute)
+      }.toMap
+      val regularAggPairs = regularAggExprs.map {
+        case ae @ AggregateExpression(af, _, _, filter, _) =>
+          val newChildren = af.children.map(c => 
regularAggChildAttrLookup.getOrElse(c, c))
+          val raf = 
af.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
+          val filterOpt = filter.map(_.transform {
+            case a: Attribute => regularAggChildAttrLookup.getOrElse(a, a)
+          })
+          val aggExpr = ae.copy(aggregateFunction = raf, filter = filterOpt)
+          (ae, aggExpr)
+      }
 
-    // Collect all aggregate expressions.
-    val aggExpressions = a.aggregateExpressions.flatMap { e =>
-      e.collect {
-        case ae: AggregateExpression => ae
+      // Constructs pairs between old and new expressions for distinct 
aggregates, too.
+      val distinctAggExprs = distinctAggExpressions.filter(e => 
e.children.exists(!_.foldable))
+      val (projections, distinctAggPairs) = distinctAggExprs.map {
+        case ae @ AggregateExpression(af, _, _, filter, _) =>
+          // First, In order to reduce costs, it is better to handle the 
filter clause locally.
+          // e.g. COUNT (DISTINCT a) FILTER (WHERE id > 1), evaluate expression
+          // If(id > 1) 'a else null first, and use the result as output.
+          // Second, If at least two DISTINCT aggregate expression which may 
references the
+          // same attributes. We need to construct the generated attributes so 
as the output not
+          // lost. e.g. SUM (DISTINCT a), COUNT (DISTINCT a) FILTER (WHERE id 
> 1) will output
+          // attribute '_gen_distinct-1 and attribute '_gen_distinct-2 instead 
of two 'a.
+          // Note: The illusionary mechanism may result in at least two 
distinct groups, so we
+          // still need to call `rewrite`.
+          val unfoldableChildren = af.children.filter(!_.foldable)
+          // Expand projection
+          val projectionMap = unfoldableChildren.map {
+            case e if filter.isDefined =>
+              val ife = If(filter.get, e, nullify(e))
+              e -> Alias(ife, 
s"_gen_distinct_${NamedExpression.newExprId.id}")()
+            case e => e -> Alias(e, 
s"_gen_distinct_${NamedExpression.newExprId.id}")()

Review comment:
       OK




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to