This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3953754 [SPARK-35735][SQL][FOLLOWUP] Fix case minute to second regex can cover by hour to minute and unit case-sensitive issue 3953754 is described below commit 3953754f36656e1a0bee16b89fae0142f172a91a Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Wed Jul 7 12:37:19 2021 +0300 [SPARK-35735][SQL][FOLLOWUP] Fix case minute to second regex can cover by hour to minute and unit case-sensitive issue ### What changes were proposed in this pull request? When cast `10:10` to interval minute to second, it can be catch by hour to minute regex, here to fix this. ### Why are the changes needed? Fix bug ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UT Closes #33242 from AngersZhuuuu/SPARK-35735-FOLLOWUP. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../spark/sql/catalyst/util/IntervalUtils.scala | 92 +++++++++++----------- .../sql/catalyst/expressions/CastSuiteBase.scala | 4 + 2 files changed, 48 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index b174165..ad87f2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.time.{Duration, Period} import java.time.temporal.ChronoUnit +import java.util.Locale import java.util.concurrent.TimeUnit import scala.collection.mutable @@ -173,15 +174,14 @@ object IntervalUtils { startField: Byte, endField: Byte): Int = { - def checkYMIntervalStringDataType(ym: YM): Unit = - checkIntervalStringDataType(input, startField, endField, ym) + def checkTargetType(targetStartField: Byte, targetEndField: Byte): Boolean = + startField == targetStartField && endField == targetEndField input.trimAll().toString match { - case yearMonthRegex(sign, year, month) => - checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH)) + case yearMonthRegex(sign, year, month) if checkTargetType(YM.YEAR, YM.MONTH) => toYMInterval(year, month, finalSign(sign)) - case yearMonthLiteralRegex(firstSign, secondSign, year, month) => - checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH)) + case yearMonthLiteralRegex(firstSign, secondSign, year, month) + if checkTargetType(YM.YEAR, YM.MONTH) => toYMInterval(year, month, finalSign(firstSign, secondSign)) case yearMonthIndividualRegex(firstSign, value) => safeToInterval("year-month") { @@ -195,15 +195,16 @@ object IntervalUtils { input, startField, endField, "year-month", YM(startField, endField).typeName) } } - case yearMonthIndividualLiteralRegex(firstSign, secondSign, value, suffix) => + case yearMonthIndividualLiteralRegex(firstSign, secondSign, value, unit) => safeToInterval("year-month") { val sign = finalSign(firstSign, secondSign) - if ("YEAR".equalsIgnoreCase(suffix)) { - checkYMIntervalStringDataType(YM(YM.YEAR, YM.YEAR)) - sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR) - } else { - checkYMIntervalStringDataType(YM(YM.MONTH, YM.MONTH)) - Math.toIntExact(sign * value.toLong) + unit.toUpperCase(Locale.ROOT) match { + case "YEAR" if checkTargetType(YM.YEAR, YM.YEAR) => + sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR) + case "MONTH" if checkTargetType(YM.MONTH, YM.MONTH) => + Math.toIntExact(sign * value.toLong) + case _ => throwIllegalIntervalFormatException(input, startField, endField, + "year-month", YM(startField, endField).typeName) } } case _ => throwIllegalIntervalFormatException(input, startField, endField, @@ -303,48 +304,45 @@ object IntervalUtils { } } - def checkDTIntervalStringDataType(dt: DT): Unit = - checkIntervalStringDataType(input, startField, endField, dt, Some(fallbackNotice)) + def checkTargetType(targetStartField: Byte, targetEndField: Byte): Boolean = + startField == targetStartField && endField == targetEndField input.trimAll().toString match { - case dayHourRegex(sign, day, hour) => - checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR)) + case dayHourRegex(sign, day, hour) if checkTargetType(DT.DAY, DT.HOUR) => toDTInterval(day, hour, "0", "0", finalSign(sign)) - case dayHourLiteralRegex(firstSign, secondSign, day, hour) => - checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR)) + case dayHourLiteralRegex(firstSign, secondSign, day, hour) + if checkTargetType(DT.DAY, DT.HOUR) => toDTInterval(day, hour, "0", "0", finalSign(firstSign, secondSign)) - case dayMinuteRegex(sign, day, hour, minute) => - checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE)) + case dayMinuteRegex(sign, day, hour, minute) if checkTargetType(DT.DAY, DT.MINUTE) => toDTInterval(day, hour, minute, "0", finalSign(sign)) - case dayMinuteLiteralRegex(firstSign, secondSign, day, hour, minute) => - checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE)) + case dayMinuteLiteralRegex(firstSign, secondSign, day, hour, minute) + if checkTargetType(DT.DAY, DT.MINUTE) => toDTInterval(day, hour, minute, "0", finalSign(firstSign, secondSign)) - case daySecondRegex(sign, day, hour, minute, second, micro) => - checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND)) + case daySecondRegex(sign, day, hour, minute, second, micro) + if checkTargetType(DT.DAY, DT.SECOND) => toDTInterval(day, hour, minute, secondAndMicro(second, micro), finalSign(sign)) - case daySecondLiteralRegex(firstSign, secondSign, day, hour, minute, second, micro) => - checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND)) + case daySecondLiteralRegex(firstSign, secondSign, day, hour, minute, second, micro) + if checkTargetType(DT.DAY, DT.SECOND) => toDTInterval(day, hour, minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign)) - case hourMinuteRegex(sign, hour, minute) => - checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE)) + case hourMinuteRegex(sign, hour, minute) if checkTargetType(DT.HOUR, DT.MINUTE) => toDTInterval(hour, minute, "0", finalSign(sign)) - case hourMinuteLiteralRegex(firstSign, secondSign, hour, minute) => - checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE)) + case hourMinuteLiteralRegex(firstSign, secondSign, hour, minute) + if checkTargetType(DT.HOUR, DT.MINUTE) => toDTInterval(hour, minute, "0", finalSign(firstSign, secondSign)) - case hourSecondRegex(sign, hour, minute, second, micro) => - checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND)) + case hourSecondRegex(sign, hour, minute, second, micro) + if checkTargetType(DT.HOUR, DT.SECOND) => toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(sign)) - case hourSecondLiteralRegex(firstSign, secondSign, hour, minute, second, micro) => - checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND)) + case hourSecondLiteralRegex(firstSign, secondSign, hour, minute, second, micro) + if checkTargetType(DT.HOUR, DT.SECOND) => toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign)) - case minuteSecondRegex(sign, minute, second, micro) => - checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND)) + case minuteSecondRegex(sign, minute, second, micro) + if checkTargetType(DT.MINUTE, DT.SECOND) => toDTInterval(minute, secondAndMicro(second, micro), finalSign(sign)) - case minuteSecondLiteralRegex(firstSign, secondSign, minute, second, micro) => - checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND)) + case minuteSecondLiteralRegex(firstSign, secondSign, minute, second, micro) + if checkTargetType(DT.MINUTE, DT.SECOND) => toDTInterval(minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign)) case dayTimeIndividualRegex(firstSign, value, suffix) => @@ -369,18 +367,16 @@ object IntervalUtils { case dayTimeIndividualLiteralRegex(firstSign, secondSign, value, suffix, unit) => safeToInterval("day-time") { val sign = finalSign(firstSign, secondSign) - unit match { - case "DAY" if suffix == null && value.length <= 9 => - checkDTIntervalStringDataType(DT(DT.DAY, DT.DAY)) + unit.toUpperCase(Locale.ROOT) match { + case "DAY" if suffix == null && value.length <= 9 && checkTargetType(DT.DAY, DT.DAY) => sign * value.toLong * MICROS_PER_DAY - case "HOUR" if suffix == null && value.length <= 10 => - checkDTIntervalStringDataType(DT(DT.HOUR, DT.HOUR)) + case "HOUR" if suffix == null && value.length <= 10 + && checkTargetType(DT.HOUR, DT.HOUR) => sign * value.toLong * MICROS_PER_HOUR - case "MINUTE" if suffix == null && value.length <= 12 => - checkDTIntervalStringDataType(DT(DT.MINUTE, DT.MINUTE)) + case "MINUTE" if suffix == null && value.length <= 12 + && checkTargetType(DT.MINUTE, DT.MINUTE) => sign * value.toLong * MICROS_PER_MINUTE - case "SECOND" if value.length <= 13 => - checkDTIntervalStringDataType(DT(DT.SECOND, DT.SECOND)) + case "SECOND" if value.length <= 13 && checkTargetType(DT.SECOND, DT.SECOND) => sign match { case 1 => parseSecondNano(secondAndMicro(value, suffix)) case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 8313242..38e012c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -1107,8 +1107,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Literal.create(str), dataType), value) if (dataType == YearMonthIntervalType(YEAR)) { checkEvaluation(cast(Literal.create(s"INTERVAL '$str' YEAR"), dataType), value) + checkEvaluation(cast(Literal.create(s"INTERVAL '$str' year"), dataType), value) } else { checkEvaluation(cast(Literal.create(s"INTERVAL '$str' MONTH"), dataType), value) + checkEvaluation(cast(Literal.create(s"INTERVAL '$str' month"), dataType), value) } } @@ -1177,6 +1179,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { ("01", DayTimeIntervalType(MINUTE, MINUTE), (60) * MICROS_PER_SECOND), ("-01", DayTimeIntervalType(MINUTE, MINUTE), -(60) * MICROS_PER_SECOND), + ("01:01", DayTimeIntervalType(MINUTE, SECOND), ((60 + 1) * MICROS_PER_SECOND)), ("01:01.12345", DayTimeIntervalType(MINUTE, SECOND), ((60 + 1.12345) * MICROS_PER_SECOND).toLong), ("-01:01.12345", DayTimeIntervalType(MINUTE, SECOND), @@ -1209,6 +1212,7 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { ("INTERVAL '9223372036854.775807' SECOND", DayTimeIntervalType(SECOND), Long.MaxValue)) .foreach { case (interval, dataType, dt) => checkEvaluation(cast(Literal.create(interval), dataType), dt) + checkEvaluation(cast(Literal.create(interval.toLowerCase(Locale.ROOT)), dataType), dt) } Seq(("INTERVAL '-106751991' DAY", DayTimeIntervalType(DAY), -106751991L * MICROS_PER_DAY), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org