This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 26d1bb1 [SPARK-36023][SPARK-35735][SPARK-35768][SQL] Refactor code about parse string to DT/YM 26d1bb1 is described below commit 26d1bb16bc565dbcb1a3f536dc78cd87be6c2468 Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Tue Jul 6 13:51:06 2021 +0300 [SPARK-36023][SPARK-35735][SPARK-35768][SQL] Refactor code about parse string to DT/YM ### What changes were proposed in this pull request? Refactor code about parse string to DT/YM intervals. ### Why are the changes needed? Extracting the common code about parse string to DT/YM should improve code maintenance. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existed UT. Closes #33217 from AngersZhuuuu/SPARK-35735-35768. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../spark/sql/catalyst/util/IntervalUtils.scala | 201 ++++++++++----------- .../sql/catalyst/expressions/CastSuiteBase.scala | 28 ++- 2 files changed, 123 insertions(+), 106 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala index 30a2fa5..b174165 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToMicros import org.apache.spark.sql.catalyst.util.IntervalStringStyles.{ANSI_STYLE, HIVE_STYLE, IntervalStyle} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM} +import org.apache.spark.sql.types.{DataType, DayTimeIntervalType => DT, Decimal, YearMonthIntervalType => YM} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} // The style of textual representation of intervals @@ -110,7 +110,7 @@ object IntervalUtils { private val yearMonthIndividualLiteralRegex = (s"(?i)^INTERVAL\\s+([+|-])?'$yearMonthIndividualPatternString'\\s+(YEAR|MONTH)$$").r - private def getSign(firstSign: String, secondSign: String): Int = { + private def finalSign(firstSign: String, secondSign: String = null): Int = { (firstSign, secondSign) match { case ("-", "-") => 1 case ("-", _) => -1 @@ -119,6 +119,39 @@ object IntervalUtils { } } + private def throwIllegalIntervalFormatException( + input: UTF8String, + startFiled: Byte, + endField: Byte, + intervalStr: String, + typeName: String, + fallBackNotice: Option[String] = None) = { + throw new IllegalArgumentException( + s"Interval string does not match $intervalStr format of " + + s"${supportedFormat((startFiled, endField)).map(format => s"`$format`").mkString(", ")} " + + s"when cast to $typeName: ${input.toString}" + + s"${fallBackNotice.map(s => s", $s").getOrElse("")}") + } + + private def checkIntervalStringDataType( + input: UTF8String, + targetStartField: Byte, + targetEndField: Byte, + inputIntervalType: DataType, + fallBackNotice: Option[String] = None): Unit = { + val (intervalStr, typeName, inputStartField, inputEndField) = inputIntervalType match { + case DT(startField, endField) => + ("day-time", DT(targetStartField, targetEndField).typeName, startField, endField) + case YM(startField, endField) => + ("year-month", YM(targetStartField, targetEndField).typeName, startField, endField) + } + if (targetStartField != inputStartField || targetEndField != inputEndField) { + throwIllegalIntervalFormatException( + input, targetStartField, targetEndField, intervalStr, typeName, fallBackNotice) + } + } + + val supportedFormat = Map( (YM.YEAR, YM.MONTH) -> Seq("[+|-]y-m", "INTERVAL [+|-]'[+|-]y-m' YEAR TO MONTH"), (YM.YEAR, YM.YEAR) -> Seq("[+|-]y", "INTERVAL [+|-]'[+|-]y' YEAR"), @@ -140,56 +173,41 @@ object IntervalUtils { startField: Byte, endField: Byte): Int = { - def checkStringIntervalType(targetStartField: Byte, targetEndField: Byte): Unit = { - if (startField != targetStartField || endField != targetEndField) { - throw new IllegalArgumentException(s"Interval string does not match year-month format of " + - s"${supportedFormat((targetStartField, targetStartField)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${YM(startField, endField).typeName}: ${input.toString}") - } - } + def checkYMIntervalStringDataType(ym: YM): Unit = + checkIntervalStringDataType(input, startField, endField, ym) input.trimAll().toString match { - case yearMonthRegex("-", year, month) => - checkStringIntervalType(YM.YEAR, YM.MONTH) - toYMInterval(year, month, -1) - case yearMonthRegex(_, year, month) => - checkStringIntervalType(YM.YEAR, YM.MONTH) - toYMInterval(year, month, 1) + case yearMonthRegex(sign, year, month) => + checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH)) + toYMInterval(year, month, finalSign(sign)) case yearMonthLiteralRegex(firstSign, secondSign, year, month) => - checkStringIntervalType(YM.YEAR, YM.MONTH) - toYMInterval(year, month, getSign(firstSign, secondSign)) - case yearMonthIndividualRegex(secondSign, value) => - safeToInterval { - val sign = getSign("+", secondSign) + checkYMIntervalStringDataType(YM(YM.YEAR, YM.MONTH)) + toYMInterval(year, month, finalSign(firstSign, secondSign)) + case yearMonthIndividualRegex(firstSign, value) => + safeToInterval("year-month") { + val sign = finalSign(firstSign) if (endField == YM.YEAR) { sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR) } else if (startField == YM.MONTH) { Math.toIntExact(sign * value.toLong) } else { - throw new IllegalArgumentException( - s"Interval string does not match year-month format of " + - s"${supportedFormat((YM.YEAR, YM.MONTH)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${YM(startField, endField).typeName}: ${input.toString}") + throwIllegalIntervalFormatException( + input, startField, endField, "year-month", YM(startField, endField).typeName) } } case yearMonthIndividualLiteralRegex(firstSign, secondSign, value, suffix) => - safeToInterval { - val sign = getSign(firstSign, secondSign) + safeToInterval("year-month") { + val sign = finalSign(firstSign, secondSign) if ("YEAR".equalsIgnoreCase(suffix)) { - checkStringIntervalType(YM.YEAR, YM.YEAR) + checkYMIntervalStringDataType(YM(YM.YEAR, YM.YEAR)) sign * Math.toIntExact(value.toLong * MONTHS_PER_YEAR) } else { - checkStringIntervalType(YM.MONTH, YM.MONTH) + checkYMIntervalStringDataType(YM(YM.MONTH, YM.MONTH)) Math.toIntExact(sign * value.toLong) } } - case _ => throw new IllegalArgumentException( - s"Interval string does not match year-month format of " + - s"${supportedFormat((YM.YEAR, YM.MONTH)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${YM(startField, endField).typeName}: ${input.toString}") + case _ => throwIllegalIntervalFormatException(input, startField, endField, + "year-month", YM(startField, endField).typeName) } } @@ -201,28 +219,26 @@ object IntervalUtils { def fromYearMonthString(input: String): CalendarInterval = { require(input != null, "Interval year-month string must be not null") input.trim match { - case yearMonthRegex("-", yearStr, monthStr) => - new CalendarInterval(toYMInterval(yearStr, monthStr, -1), 0, 0) - case yearMonthRegex(_, yearStr, monthStr) => - new CalendarInterval(toYMInterval(yearStr, monthStr, 1), 0, 0) + case yearMonthRegex(sign, yearStr, monthStr) => + new CalendarInterval(toYMInterval(yearStr, monthStr, finalSign(sign)), 0, 0) case _ => throw new IllegalArgumentException( s"Interval string does not match year-month format of 'y-m': $input") } } - private def safeToInterval[T](f: => T): T = { + private def safeToInterval[T](interval: String)(f: => T): T = { try { f } catch { case NonFatal(e) => throw new IllegalArgumentException( - s"Error parsing interval year-month string: ${e.getMessage}", e) + s"Error parsing interval $interval string: ${e.getMessage}", e) } } private def toYMInterval(yearStr: String, monthStr: String, sign: Int): Int = { - safeToInterval { + safeToInterval("year-month") { val years = toLongWithRange(YEAR, yearStr, 0, Integer.MAX_VALUE / MONTHS_PER_YEAR) val totalMonths = sign * (years * MONTHS_PER_YEAR + toLongWithRange(MONTH, monthStr, 0, 11)) Math.toIntExact(totalMonths) @@ -279,15 +295,6 @@ object IntervalUtils { startField: Byte, endField: Byte): Long = { - def checkStringIntervalType(targetStartField: Byte, targetEndField: Byte): Unit = { - if (startField != targetStartField || endField != targetEndField) { - throw new IllegalArgumentException(s"Interval string does not match day-time format of " + - s"${supportedFormat((targetStartField, targetStartField)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${DT(startField, endField).typeName}: ${input.toString}") - } - } - def secondAndMicro(second: String, micro: String): String = { if (micro != null) { s"$second$micro" @@ -296,50 +303,53 @@ object IntervalUtils { } } + def checkDTIntervalStringDataType(dt: DT): Unit = + checkIntervalStringDataType(input, startField, endField, dt, Some(fallbackNotice)) + input.trimAll().toString match { case dayHourRegex(sign, day, hour) => - checkStringIntervalType(DT.DAY, DT.HOUR) - toDTInterval(day, hour, "0", "0", getSign(null, sign)) + checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR)) + toDTInterval(day, hour, "0", "0", finalSign(sign)) case dayHourLiteralRegex(firstSign, secondSign, day, hour) => - checkStringIntervalType(DT.DAY, DT.HOUR) - toDTInterval(day, hour, "0", "0", getSign(firstSign, secondSign)) + checkDTIntervalStringDataType(DT(DT.DAY, DT.HOUR)) + toDTInterval(day, hour, "0", "0", finalSign(firstSign, secondSign)) case dayMinuteRegex(sign, day, hour, minute) => - checkStringIntervalType(DT.DAY, DT.MINUTE) - toDTInterval(day, hour, minute, "0", getSign(null, sign)) + checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE)) + toDTInterval(day, hour, minute, "0", finalSign(sign)) case dayMinuteLiteralRegex(firstSign, secondSign, day, hour, minute) => - checkStringIntervalType(DT.DAY, DT.MINUTE) - toDTInterval(day, hour, minute, "0", getSign(firstSign, secondSign)) + checkDTIntervalStringDataType(DT(DT.DAY, DT.MINUTE)) + toDTInterval(day, hour, minute, "0", finalSign(firstSign, secondSign)) case daySecondRegex(sign, day, hour, minute, second, micro) => - checkStringIntervalType(DT.DAY, DT.SECOND) - toDTInterval(day, hour, minute, secondAndMicro(second, micro), getSign(null, sign)) + checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND)) + toDTInterval(day, hour, minute, secondAndMicro(second, micro), finalSign(sign)) case daySecondLiteralRegex(firstSign, secondSign, day, hour, minute, second, micro) => - checkStringIntervalType(DT.DAY, DT.SECOND) + checkDTIntervalStringDataType(DT(DT.DAY, DT.SECOND)) toDTInterval(day, hour, minute, secondAndMicro(second, micro), - getSign(firstSign, secondSign)) + finalSign(firstSign, secondSign)) case hourMinuteRegex(sign, hour, minute) => - checkStringIntervalType(DT.HOUR, DT.MINUTE) - toDTInterval(hour, minute, "0", getSign(null, sign)) + checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE)) + toDTInterval(hour, minute, "0", finalSign(sign)) case hourMinuteLiteralRegex(firstSign, secondSign, hour, minute) => - checkStringIntervalType(DT.HOUR, DT.MINUTE) - toDTInterval(hour, minute, "0", getSign(firstSign, secondSign)) + checkDTIntervalStringDataType(DT(DT.HOUR, DT.MINUTE)) + toDTInterval(hour, minute, "0", finalSign(firstSign, secondSign)) case hourSecondRegex(sign, hour, minute, second, micro) => - checkStringIntervalType(DT.HOUR, DT.SECOND) - toDTInterval(hour, minute, secondAndMicro(second, micro), getSign(null, sign)) + checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND)) + toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(sign)) case hourSecondLiteralRegex(firstSign, secondSign, hour, minute, second, micro) => - checkStringIntervalType(DT.HOUR, DT.SECOND) - toDTInterval(hour, minute, secondAndMicro(second, micro), getSign(firstSign, secondSign)) + checkDTIntervalStringDataType(DT(DT.HOUR, DT.SECOND)) + toDTInterval(hour, minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign)) case minuteSecondRegex(sign, minute, second, micro) => - checkStringIntervalType(DT.MINUTE, DT.SECOND) - toDTInterval(minute, secondAndMicro(second, micro), getSign(null, sign)) + checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND)) + toDTInterval(minute, secondAndMicro(second, micro), finalSign(sign)) case minuteSecondLiteralRegex(firstSign, secondSign, minute, second, micro) => - checkStringIntervalType(DT.MINUTE, DT.SECOND) - toDTInterval(minute, secondAndMicro(second, micro), getSign(firstSign, secondSign)) + checkDTIntervalStringDataType(DT(DT.MINUTE, DT.SECOND)) + toDTInterval(minute, secondAndMicro(second, micro), finalSign(firstSign, secondSign)) - case dayTimeIndividualRegex(secondSign, value, suffix) => - safeToInterval { - val sign = getSign("+", secondSign) + case dayTimeIndividualRegex(firstSign, value, suffix) => + safeToInterval("day-time") { + val sign = finalSign(firstSign) (startField, endField) match { case (DT.DAY, DT.DAY) if suffix == null && value.length <= 9 => sign * value.toLong * MICROS_PER_DAY @@ -352,46 +362,35 @@ object IntervalUtils { case 1 => parseSecondNano(secondAndMicro(value, suffix)) case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}") } - case (_, _) => throw new IllegalArgumentException( - s"Interval string does not match day-time format of " + - s"${supportedFormat((startField, endField)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${DT(startField, endField).typeName}: ${input.toString}") + case (_, _) => throwIllegalIntervalFormatException(input, startField, endField, + "day-time", DT(startField, endField).typeName, Some(fallbackNotice)) } } case dayTimeIndividualLiteralRegex(firstSign, secondSign, value, suffix, unit) => - safeToInterval { - val sign = getSign(firstSign, secondSign) + safeToInterval("day-time") { + val sign = finalSign(firstSign, secondSign) unit match { case "DAY" if suffix == null && value.length <= 9 => - checkStringIntervalType(DT.DAY, DT.DAY) + checkDTIntervalStringDataType(DT(DT.DAY, DT.DAY)) sign * value.toLong * MICROS_PER_DAY case "HOUR" if suffix == null && value.length <= 10 => - checkStringIntervalType(DT.HOUR, DT.HOUR) + checkDTIntervalStringDataType(DT(DT.HOUR, DT.HOUR)) sign * value.toLong * MICROS_PER_HOUR case "MINUTE" if suffix == null && value.length <= 12 => - checkStringIntervalType(DT.MINUTE, DT.MINUTE) + checkDTIntervalStringDataType(DT(DT.MINUTE, DT.MINUTE)) sign * value.toLong * MICROS_PER_MINUTE case "SECOND" if value.length <= 13 => - checkStringIntervalType(DT.SECOND, DT.SECOND) + checkDTIntervalStringDataType(DT(DT.SECOND, DT.SECOND)) sign match { case 1 => parseSecondNano(secondAndMicro(value, suffix)) case -1 => parseSecondNano(s"-${secondAndMicro(value, suffix)}") } - case _ => throw new IllegalArgumentException( - s"Interval string does not match day-time format of " + - s"${supportedFormat((startField, endField)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${DT(startField, endField).typeName}: ${input.toString}") + case _ => throwIllegalIntervalFormatException(input, startField, endField, + "day-time", DT(startField, endField).typeName, Some(fallbackNotice)) } } - case _ => - throw new IllegalArgumentException( - s"Interval string does not match day-time format of " + - s"${supportedFormat((startField, endField)) - .map(format => s"`$format`").mkString(", ")} " + - s"when cast to ${DT(startField, endField).typeName}: ${input.toString}, " + - s"$fallbackNotice") + case _ => throwIllegalIntervalFormatException(input, startField, endField, + "day-time", DT(startField, endField).typeName, Some(fallbackNotice)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index 66f5b50..8313242 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils.microsToDuration import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1113,10 +1114,14 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { if (!isTryCast) { Seq("INTERVAL '1-1' YEAR", "INTERVAL '1-1' MONTH").foreach { interval => + val dataType = YearMonthIntervalType() val e = intercept[IllegalArgumentException] { - cast(Literal.create(interval), YearMonthIntervalType()).eval() + cast(Literal.create(interval), dataType).eval() }.getMessage - assert(e.contains("Interval string does not match year-month format")) + assert(e.contains(s"Interval string does not match year-month format of " + + s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + .map(format => s"`$format`").mkString(", ")} " + + s"when cast to ${dataType.typeName}: $interval")) } Seq(("1", YearMonthIntervalType(YEAR, MONTH)), ("1", YearMonthIntervalType(YEAR, MONTH)), @@ -1132,7 +1137,10 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val e = intercept[IllegalArgumentException] { cast(Literal.create(interval), dataType).eval() }.getMessage - assert(e.contains("Interval string does not match year-month format")) + assert(e.contains(s"Interval string does not match year-month format of " + + s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + .map(format => s"`$format`").mkString(", ")} " + + s"when cast to ${dataType.typeName}: $interval")) } } } @@ -1249,7 +1257,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val e = intercept[IllegalArgumentException] { cast(Literal.create(interval), dataType).eval() }.getMessage - assert(e.contains("Interval string does not match day-time format")) + assert(e.contains(s"Interval string does not match day-time format of " + + s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + .map(format => s"`$format`").mkString(", ")} " + + s"when cast to ${dataType.typeName}: $interval, " + + s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " + + "to restore the behavior before Spark 3.0.")) } // Check first field outof bound @@ -1267,7 +1280,12 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { val e = intercept[IllegalArgumentException] { cast(Literal.create(interval), dataType).eval() }.getMessage - assert(e.contains("Interval string does not match day-time format")) + assert(e.contains(s"Interval string does not match day-time format of " + + s"${IntervalUtils.supportedFormat((dataType.startField, dataType.endField)) + .map(format => s"`$format`").mkString(", ")} " + + s"when cast to ${dataType.typeName}: $interval, " + + s"set ${SQLConf.LEGACY_FROM_DAYTIME_STRING.key} to true " + + "to restore the behavior before Spark 3.0.")) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org