HyukjinKwon commented on a change in pull request #33142: URL: https://github.com/apache/spark/pull/33142#discussion_r667633052
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala ########## @@ -170,65 +167,71 @@ class EquivalentExpressions { // can cause error like NPE. (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null) - if (!skip && !addFunc(expr)) { - childrenToRecurse(expr).foreach(addExprTree(_, addFunc)) - commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc)) + if (!skip && !addExprToMap(expr, map)) { + childrenToRecurse(expr).foreach(addExprTree(_, map)) + commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map)) } } /** - * Returns all of the expression trees that are equivalent to `e`. Returns - * an empty collection if there are none. + * Returns the state of the given expression in the `equivalenceMap`. Returns None if there is no + * equivalent expressions. */ - def getEquivalentExprs(e: Expression): Seq[Expression] = { - equivalenceMap.getOrElse(Expr(e), Seq.empty).toSeq + def getExprState(e: Expression): Option[ExpressionStats] = { + equivalenceMap.get(ExpressionEquals(e)) + } + + // Exposed for testing. + private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = { + equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height) } /** - * Returns all the equivalent sets of expressions which appear more than given `repeatTimes` - * times. + * Returns a sequence of expressions that more than one equivalent expressions. */ - def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = { - equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq - .sortBy(_.head)(new ExpressionContainmentOrdering) + def getCommonSubexpressions: Seq[Expression] = { + getAllExprStates(1).map(_.expr) } /** * Returns the state of the data structure as a string. If `all` is false, skips sets of * equivalent expressions with cardinality 1. */ def debugString(all: Boolean = false): String = { - val sb: mutable.StringBuilder = new StringBuilder() + val sb = new java.lang.StringBuilder() sb.append("Equivalent expressions:\n") - equivalenceMap.foreach { case (k, v) => - if (all || v.length > 1) { - sb.append(" " + v.mkString(", ")).append("\n") - } + equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats => + sb.append(" ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n') } sb.toString() } } /** - * Orders `Expression` by parent/child relations. The child expression is smaller - * than parent expression. If there is child-parent relationships among the subexpressions, - * we want the child expressions come first than parent expressions, so we can replace - * child expressions in parent expressions with subexpression evaluation. Note that - * this is not for general expression ordering. For example, two irrelevant or semantically-equal - * expressions will be considered as equal by this ordering. But for the usage here, the order of - * irrelevant expressions does not matter. + * Wrapper around an Expression that provides semantic equality. */ -class ExpressionContainmentOrdering extends Ordering[Expression] { - override def compare(x: Expression, y: Expression): Int = { - if (x.find(_.semanticEquals(y)).isDefined) { - // `y` is child expression of `x`. - 1 - } else if (y.find(_.semanticEquals(x)).isDefined) { - // `x` is child expression of `y`. - -1 - } else { - // Irrelevant or semantically-equal expressions - 0 - } +case class ExpressionEquals(e: Expression) { + override def equals(o: Any): Boolean = o match { + case other: ExpressionEquals => e.semanticEquals(other.e) + case _ => false + } + + override def hashCode: Int = e.semanticHash() +} + +/** + * A wrapper in place of using Seq[Expression] to record a group of equivalent expressions. + * + * This saves a lot of memory when there are a lot of expressions in a same equivalence group. + * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" + * useCount in this wrapper in-place. + */ +case class ExpressionStats(expr: Expression)(var useCount: Int = 1) { + // This is used to do a fast pre-check for child-parent relationship. For example, expr1 can + // only be a parent of expr2 if expr1.height is larger than expr2.height. Review comment: ohh it's correct! sorry for a false alarm -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org