cloud-fan commented on code in PR #55982:
URL: https://github.com/apache/spark/pull/55982#discussion_r3280276035
##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRandSuite.scala:
##########
@@ -173,4 +176,95 @@ class OptimizeRandSuite extends PlanTest {
}
}
+ test("Optimize arithmetic expressions with rand") {
+ // rand() * 3 < 3 should be optimized to true
+ val plan1 = testRelation.select((rand5 * literal3d <
literal3d).as("flag")).analyze
+ val actual1 = Optimize.execute(plan1)
+ val correctAnswer1 = testRelation.select(Alias(TrueLiteral,
"flag")()).analyze
+ comparePlans(actual1, correctAnswer1)
+
+ // rand() + 1 < 2 should be optimized to true
+ val plan2 = testRelation.select((rand5 + literal1d <
literal2d).as("flag")).analyze
+ val actual2 = Optimize.execute(plan2)
+ val correctAnswer2 = testRelation.select(Alias(TrueLiteral,
"flag")()).analyze
+ comparePlans(actual2, correctAnswer2)
+
+ // rand() - 1 < 0 should be optimized to true
+ val plan3 = testRelation.select((rand5 - literal1d <
literal0d).as("flag")).analyze
+ val actual3 = Optimize.execute(plan3)
+ val correctAnswer3 = testRelation.select(Alias(TrueLiteral,
"flag")()).analyze
+ comparePlans(actual3, correctAnswer3)
+
+ // rand() / 2 < 1 should be optimized to true
+ val plan4 = testRelation.select((rand5 / literal2d <
literal1d).as("flag")).analyze
+ val actual4 = Optimize.execute(plan4)
+ val correctAnswer4 = testRelation.select(Alias(TrueLiteral,
"flag")()).analyze
+ comparePlans(actual4, correctAnswer4)
+
+ // rand() * 2 > 3 should be optimized to false
+ val plan5 = testRelation.select((rand5 * literal2d >
literal3d).as("flag")).analyze
+ val actual5 = Optimize.execute(plan5)
+ val correctAnswer5 = testRelation.select(Alias(FalseLiteral,
"flag")()).analyze
+ comparePlans(actual5, correctAnswer5)
+ }
+
+ test("Optimize equality comparison with rand") {
+ // rand() == 0.5 cannot be optimized (value is in [0, 1) range)
+ val plan1 = testRelation.select((rand5 === literalHalf).as("flag")).analyze
+ val actual1 = Optimize.execute(plan1)
+ comparePlans(actual1, plan1)
+
+ // rand() == 2 should be optimized to false (value outside [0, 1) range)
+ val plan2 = testRelation.select((rand5 === literal2d).as("flag")).analyze
+ val actual2 = Optimize.execute(plan2)
+ val correctAnswer2 = testRelation.select(Alias(FalseLiteral,
"flag")()).analyze
+ comparePlans(actual2, correctAnswer2)
+
+ // rand() == -1 should be optimized to false (value outside [0, 1) range)
+ val plan3 = testRelation.select((rand5 ===
negativeLiteral1d).as("flag")).analyze
+ val actual3 = Optimize.execute(plan3)
+ val correctAnswer3 = testRelation.select(Alias(FalseLiteral,
"flag")()).analyze
+ comparePlans(actual3, correctAnswer3)
+
+ // 2 == rand() should be optimized to false (literal on left side)
+ val plan4 = testRelation.select((literal2d === rand5).as("flag")).analyze
+ val actual4 = Optimize.execute(plan4)
+ val correctAnswer4 = testRelation.select(Alias(FalseLiteral,
"flag")()).analyze
+ comparePlans(actual4, correctAnswer4)
+
+ // -1 == rand() should be optimized to false (literal on left side)
+ val plan5 = testRelation.select((negativeLiteral1d ===
rand5).as("flag")).analyze
+ val actual5 = Optimize.execute(plan5)
+ val correctAnswer5 = testRelation.select(Alias(FalseLiteral,
"flag")()).analyze
+ comparePlans(actual5, correctAnswer5)
+ }
+
+ test("Benchmark: rand optimization performance benefit") {
Review Comment:
This test compares `analyze` only (the "unoptimized" loop) against `analyze
+ Optimize.execute` (the "optimized" loop). The optimized side does strictly
more work per iteration, so `(unoptimized − optimized) / unoptimized * 100` is
structurally non-positive — the printed "improvement" measures planning-time
noise around zero, not the benefit of the optimization. The test also has no
assertion, so it can't fail. Microbenchmarks for optimizer rules belong under
`dev/benchmarks/` using `org.apache.spark.benchmark.Benchmark` (see e.g.
`sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/`), not in
the unit suite. Suggest dropping this test; if you want to record perf numbers
for the JIRA, run a real benchmark and post the results in the PR description.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRand.scala:
##########
@@ -17,47 +17,201 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.expressions.{BinaryComparison,
DoubleLiteral, Expression, GreaterThan, GreaterThanOrEqual, LessThan,
LessThanOrEqual, Rand}
+import org.apache.spark.sql.catalyst.expressions.{Add, BinaryComparison,
Divide,
+ DoubleLiteral, EqualTo, Expression, GreaterThan, GreaterThanOrEqual,
LessThan,
+ LessThanOrEqual, Literal, Multiply, Rand, Subtract}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral,
TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON,
EXPRESSION_WITH_RANDOM_SEED, LITERAL}
+import
org.apache.spark.sql.catalyst.trees.TreePattern.EXPRESSION_WITH_RANDOM_SEED
-/**
- * Rand() generates a random column with i.i.d. uniformly distributed values
in [0, 1), so
- * compare double literal value with 1.0 or 0.0 could eliminate Rand() in
binary comparison.
- *
- * 1. Converts the binary comparison to true literal when the comparison value
must be true.
- * 2. Converts the binary comparison to false literal when the comparison
value must be false.
- */
object OptimizeRand extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan =
- plan.transformAllExpressionsWithPruning(_.containsAllPatterns(
- EXPRESSION_WITH_RANDOM_SEED, LITERAL, BINARY_COMPARISON), ruleId) {
- case op @ BinaryComparison(DoubleLiteral(_), _: Rand) =>
eliminateRand(swapComparison(op))
- case op @ BinaryComparison(_: Rand, DoubleLiteral(_)) =>
eliminateRand(op)
+ plan.transformAllExpressionsWithPruning(_.containsAnyPattern(
+ EXPRESSION_WITH_RANDOM_SEED), ruleId) {
+ case op @ EqualTo(DoubleLiteral(_), _: Rand) =>
+ eliminateRand(EqualTo(op.right, op.left))
+ case op @ BinaryComparison(DoubleLiteral(_), _: Rand)
+ if !op.isInstanceOf[EqualTo] =>
+ eliminateRand(swapComparison(op))
+ case op @ BinaryComparison(_: Rand, DoubleLiteral(_)) =>
+ eliminateRand(op)
+ case op: BinaryComparison
+ if isDirectRandChild(op.left) || isDirectRandChild(op.right) =>
+ optimizeArithmetic(op)
+ }
+
+ private def isDirectRandChild(expr: Expression): Boolean = expr match {
+ case _: Rand => true
+ case Add(l, r, _) => l.isInstanceOf[Rand] || r.isInstanceOf[Rand]
+ case Subtract(l, r, _) => l.isInstanceOf[Rand] || r.isInstanceOf[Rand]
+ case Multiply(l, r, _) => l.isInstanceOf[Rand] || r.isInstanceOf[Rand]
+ case Divide(l, r, _) => l.isInstanceOf[Rand] || r.isInstanceOf[Rand]
+ case _ => false
}
- /**
- * Swaps the left and right sides of some binary comparisons. e.g.,
transform "a < b" to "b > a"
- */
- private def swapComparison(comparison: BinaryComparison): BinaryComparison =
comparison match {
- case a LessThan b => GreaterThan(b, a)
- case a LessThanOrEqual b => GreaterThanOrEqual(b, a)
- case a GreaterThan b => LessThan(b, a)
- case a GreaterThanOrEqual b => LessThanOrEqual(b, a)
- case o => o
+ private def hasRand(expr: Expression): Boolean = expr match {
+ case _: Rand => true
+ case a: Add => hasRand(a.left) || hasRand(a.right)
+ case s: Subtract => hasRand(s.left) || hasRand(s.right)
+ case m: Multiply => hasRand(m.left) || hasRand(m.right)
+ case d: Divide => hasRand(d.left) || hasRand(d.right)
+ case _ => false
}
Review Comment:
Independent of the broader design discussion: these pattern matches use
`Add(l, r, _)` / `Subtract(l, r, _)` / `Multiply(l, r, _)` / `Divide(l, r, _)`,
which destructure the case class and discard the third field — `evalContext:
NumericEvalContext`. That field carries the eval mode (`LEGACY` / `ANSI` /
`TRY`); under `EvalMode.TRY`, the arithmetic node returns null on errors rather
than throwing or wrapping. Even though `rand()` is `DoubleType` and IEEE
arithmetic doesn't throw, `try_add(rand(), 1)` and friends are spellable in SQL
and the user would get `null` from the original expression but a folded
`True`/`False` from the rewrite. The rule needs to check `evalContext.evalMode
!= EvalMode.TRY` before applying. (Same gap if you split the logic into a new
rule — gate it there.)
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeRand.scala:
##########
@@ -17,47 +17,201 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.expressions.{BinaryComparison,
DoubleLiteral, Expression, GreaterThan, GreaterThanOrEqual, LessThan,
LessThanOrEqual, Rand}
+import org.apache.spark.sql.catalyst.expressions.{Add, BinaryComparison,
Divide,
+ DoubleLiteral, EqualTo, Expression, GreaterThan, GreaterThanOrEqual,
LessThan,
+ LessThanOrEqual, Literal, Multiply, Rand, Subtract}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral,
TrueLiteral}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON,
EXPRESSION_WITH_RANDOM_SEED, LITERAL}
+import
org.apache.spark.sql.catalyst.trees.TreePattern.EXPRESSION_WITH_RANDOM_SEED
-/**
- * Rand() generates a random column with i.i.d. uniformly distributed values
in [0, 1), so
- * compare double literal value with 1.0 or 0.0 could eliminate Rand() in
binary comparison.
- *
- * 1. Converts the binary comparison to true literal when the comparison value
must be true.
- * 2. Converts the binary comparison to false literal when the comparison
value must be false.
- */
object OptimizeRand extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan =
- plan.transformAllExpressionsWithPruning(_.containsAllPatterns(
- EXPRESSION_WITH_RANDOM_SEED, LITERAL, BINARY_COMPARISON), ruleId) {
- case op @ BinaryComparison(DoubleLiteral(_), _: Rand) =>
eliminateRand(swapComparison(op))
- case op @ BinaryComparison(_: Rand, DoubleLiteral(_)) =>
eliminateRand(op)
+ plan.transformAllExpressionsWithPruning(_.containsAnyPattern(
Review Comment:
The original `containsAllPatterns(EXPRESSION_WITH_RANDOM_SEED, LITERAL,
BINARY_COMPARISON)` is a tighter prefilter: every case in `apply` requires a
`BinaryComparison` and at least one `Literal`. The change to
`containsAnyPattern(EXPRESSION_WITH_RANDOM_SEED)` makes the rule visit any
subtree containing rand (e.g. `select rand() as r`), even when no comparison
can possibly match. Restoring `containsAllPatterns(EXPRESSION_WITH_RANDOM_SEED,
LITERAL, BINARY_COMPARISON)` keeps the rule cheap; if you take the "split into
a new rule" direction in the summary, the new rule's prefilter is
`containsAllPatterns(BINARY_COMPARISON, BINARY_ARITHMETIC, LITERAL)` and
`OptimizeRand`'s stays unchanged from master.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]