This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 99431e28f95 [SPARK-42162] Introduce MultiCommutativeOp expression as a memory optimization for canonicalizing large trees of commutative expressions 99431e28f95 is described below commit 99431e28f950bb25c421abd51888a3f9f4b46685 Author: Supun Nakandala <supun.nakand...@databricks.com> AuthorDate: Fri Feb 10 23:56:34 2023 +0800 [SPARK-42162] Introduce MultiCommutativeOp expression as a memory optimization for canonicalizing large trees of commutative expressions ### What changes were proposed in this pull request? - This PR introduces a new expression called `MultiCommutativeOp` which is used by the commutative expressions (e.g., `Add`, `Multiply`, `And`, `Or`, `BitwiseOr`, `BitwiseAnd`, `BitwiseXor`) during canonicalization. - During canonicalization, when there is a list of consecutive commutative expressions, we now create a MultiCommutative expression with references to original operands, instead of creating new objects. - This new expression is added as a memory optimization to reduce generating a large number of intermediate objects during canonicalization. ### Why are the changes needed? - With the [recent changes](https://github.com/apache/spark/pull/37851) in the expression canonicalization, a complex query with a large number of commutative operations could end up consuming significantly more (sometimes > 10X) memory on the executors. - In our case, this issue happens for a specific complex query that has a huge expression tree containing Add operators interleaved by non Add operators. - The issue is related to canonicalization and why it is causing issues in the executors is because the codegen component relies on expression canonicalization to deduplicate expressions. - When we have a large number of Adds interleaved by non-Add operators, [this line](https://github.com/apache/spark/pull/37851/files#diff-7278f2db37934522ee7c74b71525153234cff245cefaf996957e4a9ff3dbaacdR1171) ends up materializing a new canonicalized expression tree at every non-Add operator. - In our case, analyzing the executor heap histogram shows that the additional memory is consumed by a large number of Add objects. - The high memory usage causes the executors to lose heartbeat signals and results in task failures. - The proposed `MultiCommutativeOp` expression avoids generating new Add expressions and keeps the extra memory usage to a minimum. ### Does this PR introduce _any_ user-facing change? - No ### How was this patch tested? - Existing unit tests and new unit tests. Closes #39722 from db-scnakandala/SPARK-42162. Authored-by: Supun Nakandala <supun.nakand...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/expressions/Expression.scala | 74 +++++++++++++ .../sql/catalyst/expressions/arithmetic.scala | 13 ++- .../catalyst/expressions/bitwiseExpressions.scala | 15 ++- .../sql/catalyst/expressions/predicates.scala | 10 +- .../org/apache/spark/sql/internal/SQLConf.scala | 9 ++ .../catalyst/expressions/CanonicalizeSuite.scala | 122 ++++++++++++++++++++- 6 files changed, 234 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 37ec4802993..2d2236a8a80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, Tre import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1335,4 +1336,77 @@ trait CommutativeExpression extends Expression { protected def orderCommutative( f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = gatherCommutative(this, f).sortBy(_.hashCode()) + + /** + * Helper method to generated a canonicalized plan. If the number of operands are + * greater than the MULTI_COMMUTATIVE_OP_OPT_THRESHOLD, this method creates a + * [[MultiCommutativeOp]] as the canonicalized plan. + */ + protected def buildCanonicalizedPlan( + collectOperands: PartialFunction[Expression, Seq[Expression]], + buildBinaryOp: (Expression, Expression) => Expression, + evalMode: Option[EvalMode.Value] = None): Expression = { + val operands = orderCommutative(collectOperands) + val reorderResult = + if (operands.length < SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)) { + operands.reduce(buildBinaryOp) + } else { + MultiCommutativeOp(operands, this.getClass, evalMode)(this) + } + reorderResult + } +} + +/** + * A helper class used by the Commutative expressions during canonicalization. During + * canonicalization, when we have a long tree of commutative operations, we use the MultiCommutative + * expression to represent that tree instead of creating new commutative objects. + * This class is added as a memory optimization for processing large commutative operation trees + * without creating a large number of new intermediate objects. + * The MultiCommutativeOp memory optimization is applied to the following commutative + * expressions: + * Add, Multiply, And, Or, BitwiseAnd, BitwiseOr, BitwiseXor. + * @param operands A sequence of operands that produces a commutative expression tree. + * @param opCls The class of the root operator of the expression tree. + * @param evalMode The optional expression evaluation mode. + * @param originalRoot Root operator of the commutative expression tree before canonicalization. + * This object reference is used to deduce the return dataType of Add and + * Multiply operations when the input datatype is decimal. + */ +case class MultiCommutativeOp( + operands: Seq[Expression], + opCls: Class[_], + evalMode: Option[EvalMode.Value])(originalRoot: Expression) extends Unevaluable { + // Helper method to deduce the data type of a single operation. + private def singleOpDataType(lType: DataType, rType: DataType): DataType = { + originalRoot match { + case add: Add => + (lType, rType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + add.resultDecimalType(p1, s1, p2, s2) + case _ => lType + } + case multiply: Multiply => + (lType, rType) match { + case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => + multiply.resultDecimalType(p1, s1, p2, s2) + case _ => lType + } + } + } + + override def dataType: DataType = { + originalRoot match { + case _: Add | _: Multiply => + operands.map(_.dataType).reduce((l, r) => singleOpDataType(l, r)) + case other => other.dataType + } + } + + override def nullable: Boolean = operands.exists(_.nullable) + + override def children: Seq[Expression] = operands + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = + this.copy(operands = newChildren)(originalRoot) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d5694e58cc9..88f7fabf121 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -479,8 +479,11 @@ case class Add( override lazy val canonicalized: Expression = { // TODO: do not reorder consecutive `Add`s with different `evalMode` - val reorderResult = - orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode)) + val reorderResult = buildCanonicalizedPlan( + { case Add(l, r, _) => Seq(l, r) }, + { case (l: Expression, r: Expression) => Add(l, r, evalMode)}, + Some(evalMode) + ) if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) { reorderResult } else { @@ -632,7 +635,11 @@ case class Multiply( override lazy val canonicalized: Expression = { // TODO: do not reorder consecutive `Multiply`s with different `evalMode` - orderCommutative({ case Multiply(l, r, _) => Seq(l, r) }).reduce(Multiply(_, _, evalMode)) + buildCanonicalizedPlan( + { case Multiply(l, r, _) => Seq(l, r) }, + { case (l: Expression, r: Expression) => Multiply(l, r, evalMode)}, + Some(evalMode) + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 70c3d11deda..6061f625ef0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -62,7 +62,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme newLeft: Expression, newRight: Expression): BitwiseAnd = copy(left = newLeft, right = newRight) override lazy val canonicalized: Expression = { - orderCommutative({ case BitwiseAnd(l, r) => Seq(l, r) }).reduce(BitwiseAnd) + buildCanonicalizedPlan( + { case BitwiseAnd(l, r) => Seq(l, r) }, + { case (l: Expression, r: Expression) => BitwiseAnd(l, r)} + ) } } @@ -106,7 +109,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet newLeft: Expression, newRight: Expression): BitwiseOr = copy(left = newLeft, right = newRight) override lazy val canonicalized: Expression = { - orderCommutative({ case BitwiseOr(l, r) => Seq(l, r) }).reduce(BitwiseOr) + buildCanonicalizedPlan( + { case BitwiseOr(l, r) => Seq(l, r) }, + { case (l: Expression, r: Expression) => BitwiseOr(l, r)} + ) } } @@ -150,7 +156,10 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme newLeft: Expression, newRight: Expression): BitwiseXor = copy(left = newLeft, right = newRight) override lazy val canonicalized: Expression = { - orderCommutative({ case BitwiseXor(l, r) => Seq(l, r) }).reduce(BitwiseXor) + buildCanonicalizedPlan( + { case BitwiseXor(l, r) => Seq(l, r) }, + { case (l: Expression, r: Expression) => BitwiseXor(l, r)} + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b7d7c5e700e..64bee643c86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -805,7 +805,10 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with copy(left = newLeft, right = newRight) override lazy val canonicalized: Expression = { - orderCommutative({ case And(l, r) => Seq(l, r) }).reduce(And) + buildCanonicalizedPlan( + { case And(l, r) => Seq(l, r) }, + { case (l: Expression, r: Expression) => And(l, r)} + ) } } @@ -899,7 +902,10 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P copy(left = newLeft, right = newRight) override lazy val canonicalized: Expression = { - orderCommutative({ case Or(l, r) => Seq(l, r) }).reduce(Or) + buildCanonicalizedPlan( + { case Or(l, r) => Seq(l, r) }, + { case (l: Expression, r: Expression) => Or(l, r)} + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 34295d1c42a..8d8aacbc9cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -240,6 +240,15 @@ object SQLConf { .intConf .createWithDefault(100) + val MULTI_COMMUTATIVE_OP_OPT_THRESHOLD = + buildConf("spark.sql.analyzer.canonicalization.multiCommutativeOpMemoryOptThreshold") + .internal() + .doc("The minimum number of operands in a commutative expression tree to" + + " invoke the MultiCommutativeOp memory optimization during canonicalization.") + .version("3.4.0") + .intConf + .createWithDefault(3) + val OPTIMIZER_EXCLUDED_RULES = buildConf("spark.sql.optimizer.excludedRules") .doc("Configures a list of rules to be disabled in the optimizer, in which the rules are " + "specified by their rule names and separated by comma. It is not guaranteed that all the " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 057fb98c239..f2a9eac8216 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -23,7 +23,9 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.Range -import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.MULTI_COMMUTATIVE_OP_OPT_THRESHOLD +import org.apache.spark.sql.types.{BooleanType, Decimal, DecimalType, IntegerType, LongType, StringType, StructField, StructType} class CanonicalizeSuite extends SparkFunSuite { @@ -206,4 +208,122 @@ class CanonicalizeSuite extends SparkFunSuite { assert(!Add(Add(literal4, literal5), literal1).semanticEquals( Add(Add(literal1, literal5), literal4))) } + + test("SPARK-42162: Commutative expression canonicalization should work" + + " with the MultiCommutativeOp memory optimization") { + val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD) + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, "3") + + // Add + val d = Decimal(1.2) + val literal1 = Literal.create(d, DecimalType(2, 1)) + val literal2 = Literal.create(d, DecimalType(2, 1)) + val literal3 = Literal.create(d, DecimalType(3, 2)) + assert(Add(literal1, Add(literal2, literal3)) + .semanticEquals(Add(Add(literal1, literal2), literal3))) + assert(Add(literal1, Add(literal2, literal3)).canonicalized.isInstanceOf[MultiCommutativeOp]) + + // Multiply + assert(Multiply(literal1, Multiply(literal2, literal3)) + .semanticEquals(Multiply(Multiply(literal1, literal2), literal3))) + assert(Multiply(literal1, Multiply(literal2, literal3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // And + val literalBool1 = Literal.create(true, BooleanType) + val literalBool2 = Literal.create(true, BooleanType) + val literalBool3 = Literal.create(true, BooleanType) + assert(And(literalBool1, And(literalBool2, literalBool3)) + .semanticEquals(And(And(literalBool1, literalBool2), literalBool3))) + assert(And(literalBool1, And(literalBool2, literalBool3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // Or + assert(Or(literalBool1, Or(literalBool2, literalBool3)) + .semanticEquals(Or(Or(literalBool1, literalBool2), literalBool3))) + assert(Or(literalBool1, Or(literalBool2, literalBool3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // BitwiseAnd + val literalBit1 = Literal(1) + val literalBit2 = Literal(2) + val literalBit3 = Literal(3) + assert(BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3)) + .semanticEquals(BitwiseAnd(BitwiseAnd(literalBit1, literalBit2), literalBit3))) + assert(BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // BitwiseOr + assert(BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3)) + .semanticEquals(BitwiseOr(BitwiseOr(literalBit1, literalBit2), literalBit3))) + assert(BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // BitwiseXor + assert(BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3)) + .semanticEquals(BitwiseXor(BitwiseXor(literalBit1, literalBit2), literalBit3))) + assert(BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString) + } + + test("SPARK-42162: Commutative expression canonicalization should not use" + + " MultiCommutativeOp memory optimization when threshold is not met") { + val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD) + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, "100") + + // Add + val d = Decimal(1.2) + val literal1 = Literal.create(d, DecimalType(2, 1)) + val literal2 = Literal.create(d, DecimalType(2, 1)) + val literal3 = Literal.create(d, DecimalType(3, 2)) + assert(Add(literal1, Add(literal2, literal3)) + .semanticEquals(Add(Add(literal1, literal2), literal3))) + assert(!Add(literal1, Add(literal2, literal3)).canonicalized.isInstanceOf[MultiCommutativeOp]) + + // Multiply + assert(Multiply(literal1, Multiply(literal2, literal3)) + .semanticEquals(Multiply(Multiply(literal1, literal2), literal3))) + assert(!Multiply(literal1, Multiply(literal2, literal3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // And + val literalBool1 = Literal.create(true, BooleanType) + val literalBool2 = Literal.create(true, BooleanType) + val literalBool3 = Literal.create(true, BooleanType) + assert(And(literalBool1, And(literalBool2, literalBool3)) + .semanticEquals(And(And(literalBool1, literalBool2), literalBool3))) + assert(!And(literalBool1, And(literalBool2, literalBool3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // Or + assert(Or(literalBool1, Or(literalBool2, literalBool3)) + .semanticEquals(Or(Or(literalBool1, literalBool2), literalBool3))) + assert(!Or(literalBool1, Or(literalBool2, literalBool3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // BitwiseAnd + val literalBit1 = Literal(1) + val literalBit2 = Literal(2) + val literalBit3 = Literal(3) + assert(BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3)) + .semanticEquals(BitwiseAnd(BitwiseAnd(literalBit1, literalBit2), literalBit3))) + assert(!BitwiseAnd(literalBit1, BitwiseAnd(literalBit2, literalBit3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // BitwiseOr + assert(BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3)) + .semanticEquals(BitwiseOr(BitwiseOr(literalBit1, literalBit2), literalBit3))) + assert(!BitwiseOr(literalBit1, BitwiseOr(literalBit2, literalBit3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + // BitwiseXor + assert(BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3)) + .semanticEquals(BitwiseXor(BitwiseXor(literalBit1, literalBit2), literalBit3))) + assert(!BitwiseXor(literalBit1, BitwiseXor(literalBit2, literalBit3)) + .canonicalized.isInstanceOf[MultiCommutativeOp]) + + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org