This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch branch-3.2 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push: new d611d1f6676 [SPARK-39259][SQL][3.2] Evaluate timestamps consistently in subqueries d611d1f6676 is described below commit d611d1f66761bd39fee850ca3f435027f9fc1e3c Author: Ole Sasse <ole.sa...@databricks.com> AuthorDate: Tue Jun 7 13:54:46 2022 +0300 [SPARK-39259][SQL][3.2] Evaluate timestamps consistently in subqueries ### What changes were proposed in this pull request? Apply the optimizer rule ComputeCurrentTime consistently across subqueries. This is a backport of https://github.com/apache/spark/pull/36654. ### Why are the changes needed? At the moment timestamp functions like now() can return different values within a query if subqueries are involved ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? A new unit test was added Closes #36753 from olaky/SPARK-39259-spark_3_2. Lead-authored-by: Ole Sasse <ole.sa...@databricks.com> Co-authored-by: Josh Rosen <joshro...@databricks.com> Co-authored-by: Dongjoon Hyun <dongj...@apache.org> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../sql/catalyst/optimizer/finishAnalysis.scala | 41 +++++----- .../spark/sql/catalyst/plans/QueryPlan.scala | 17 ++++- .../optimizer/ComputeCurrentTimeSuite.scala | 89 ++++++++++++++++------ 3 files changed, 103 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 802e0b4ef7c..897751e99fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.collection.mutable +import java.time.{Instant, LocalDateTime} import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.expressions._ @@ -25,7 +25,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ} +import org.apache.spark.sql.catalyst.trees.TreePatternBits +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -76,29 +78,30 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] { */ object ComputeCurrentTime extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val currentDates = mutable.Map.empty[String, Literal] - val timeExpr = CurrentTimestamp() - val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long] - val currentTime = Literal.create(timestamp, timeExpr.dataType) + val instant = Instant.now() + val currentTimestampMicros = instantToMicros(instant) + val currentTime = Literal.create(currentTimestampMicros, TimestampType) val timezone = Literal.create(conf.sessionLocalTimeZone, StringType) - val localTimestamps = mutable.Map.empty[String, Literal] - plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) { - case currentDate @ CurrentDate(Some(timeZoneId)) => - currentDates.getOrElseUpdate(timeZoneId, { - Literal.create(currentDate.eval().asInstanceOf[Int], DateType) - }) - case CurrentTimestamp() | Now() => currentTime - case CurrentTimeZone() => timezone - case localTimestamp @ LocalTimestamp(Some(timeZoneId)) => - localTimestamps.getOrElseUpdate(timeZoneId, { - Literal.create(localTimestamp.eval().asInstanceOf[Long], TimestampNTZType) - }) + def transformCondition(treePatternbits: TreePatternBits): Boolean = { + treePatternbits.containsPattern(CURRENT_LIKE) + } + + plan.transformDownWithSubqueriesAndPruning(transformCondition) { + case subQuery => + subQuery.transformAllExpressionsWithPruning(transformCondition) { + case cd: CurrentDate => + Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType) + case CurrentTimestamp() | Now() => currentTime + case CurrentTimeZone() => timezone + case localTimestamp: LocalTimestamp => + val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId) + Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType) + } } } } - /** * Replaces the expression of CurrentDatabase with the current database name. * Replaces the expression of CurrentCatalog with the current catalog name. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a67f8d3a3b6..9087916a4a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -473,6 +473,19 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] * When the partial function does not apply to a given node, it is left unchanged. */ def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = { + transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) + } + + /** + * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. + * Returns a copy of this node where the given partial function has been recursively applied + * first to this node, then this node's subqueries and finally this node's children. + * When the partial function does not apply to a given node, it is left unchanged. + */ + def transformDownWithSubqueriesAndPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId) + (f: PartialFunction[PlanType, PlanType]): PlanType = { val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { override def isDefinedAt(x: PlanType): Boolean = true @@ -480,13 +493,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] val transformed = f.applyOrElse[PlanType, PlanType](plan, identity) transformed transformExpressionsDown { case planExpression: PlanExpression[PlanType] => - val newPlan = planExpression.plan.transformDownWithSubqueries(f) + val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f) planExpression.withNewPlan(newPlan) } } } - transformDown(g) + transformDownWithPruning(cond, ruleId)(g) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala index 9b04dcddfb2..86461522f74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer import java.time.{LocalDateTime, ZoneId} +import scala.collection.JavaConverters.mapAsScalaMap +import scala.concurrent.duration._ + import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal, LocalTimestamp} +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf @@ -41,11 +44,7 @@ class ComputeCurrentTimeSuite extends PlanTest { val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = (System.currentTimeMillis() + 1) * 1000 - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } + val lits = literals[Long](plan) assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) @@ -59,11 +58,7 @@ class ComputeCurrentTimeSuite extends PlanTest { val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.currentDate(ZoneId.systemDefault()) - val lits = new scala.collection.mutable.ArrayBuffer[Int] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Int] - e - } + val lits = literals[Int](plan) assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) @@ -73,13 +68,9 @@ class ComputeCurrentTimeSuite extends PlanTest { test("SPARK-33469: Add current_timezone function") { val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation()) val plan = Optimize.execute(in.analyze).asInstanceOf[Project] - val lits = new scala.collection.mutable.ArrayBuffer[String] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[UTF8String].toString - e - } + val lits = literals[UTF8String](plan) assert(lits.size == 1) - assert(lits.head == SQLConf.get.sessionLocalTimeZone) + assert(lits.head == UTF8String.fromString(SQLConf.get.sessionLocalTimeZone)) } test("analyzer should replace localtimestamp with literals") { @@ -92,14 +83,66 @@ class ComputeCurrentTimeSuite extends PlanTest { val plan = Optimize.execute(in.analyze).asInstanceOf[Project] val max = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId)) - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } + val lits = literals[Long](plan) assert(lits.size == 2) assert(lits(0) >= min && lits(0) <= max) assert(lits(1) >= min && lits(1) <= max) assert(lits(0) == lits(1)) } + + test("analyzer should use equal timestamps across subqueries") { + val timestampInSubQuery = Project(Seq(Alias(LocalTimestamp(), "timestamp1")()), LocalRelation()) + val listSubQuery = ListQuery(timestampInSubQuery) + val valueSearchedInSubQuery = Seq(Alias(LocalTimestamp(), "timestamp2")()) + val inFilterWithSubQuery = InSubquery(valueSearchedInSubQuery, listSubQuery) + val input = Project(Nil, Filter(inFilterWithSubQuery, LocalRelation())) + + val plan = Optimize.execute(input.analyze).asInstanceOf[Project] + + val lits = literals[Long](plan) + assert(lits.size == 3) // transformDownWithSubqueries covers the inner timestamp twice + assert(lits.toSet.size == 1) + } + + test("analyzer should use consistent timestamps for different timezones") { + val localTimestamps = mapAsScalaMap(ZoneId.SHORT_IDS) + .map { case (zoneId, _) => Alias(LocalTimestamp(Some(zoneId)), zoneId)() }.toSeq + val input = Project(localTimestamps, LocalRelation()) + + val plan = Optimize.execute(input).asInstanceOf[Project] + + val lits = literals[Long](plan) + assert(lits.size === localTimestamps.size) + // there are timezones with a 30 or 45 minute offset + val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet + assert(offsetsFromQuarterHour.size == 1) + } + + test("analyzer should use consistent timestamps for different timestamp functions") { + val differentTimestamps = Seq( + Alias(CurrentTimestamp(), "currentTimestamp")(), + Alias(Now(), "now")(), + Alias(LocalTimestamp(Some("PLT")), "localTimestampWithTimezone")() + ) + val input = Project(differentTimestamps, LocalRelation()) + + val plan = Optimize.execute(input).asInstanceOf[Project] + + val lits = literals[Long](plan) + assert(lits.size === differentTimestamps.size) + // there are timezones with a 30 or 45 minute offset + val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet + assert(offsetsFromQuarterHour.size == 1) + } + + private def literals[T](plan: LogicalPlan): scala.collection.mutable.ArrayBuffer[T] = { + val literals = new scala.collection.mutable.ArrayBuffer[T] + plan.transformWithSubqueries { case subQuery => + subQuery.transformAllExpressions { case expression: Literal => + literals += expression.value.asInstanceOf[T] + expression + } + } + literals + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org