cloud-fan commented on a change in pull request #32351:
URL: https://github.com/apache/spark/pull/32351#discussion_r620824798



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
##########
@@ -29,67 +29,105 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
 
-abstract class ExtractIntervalPart(
-    child: Expression,
+abstract class ExtractIntervalPart[T](
     val dataType: DataType,
-    func: CalendarInterval => Any,
-    funcName: String)
-  extends UnaryExpression with ExpectsInputTypes with NullIntolerant with 
Serializable {
-
-  override def inputTypes: Seq[AbstractDataType] = Seq(CalendarIntervalType)
-
-  override protected def nullSafeEval(interval: Any): Any = {
-    func(interval.asInstanceOf[CalendarInterval])
-  }
-
+    func: T => Any,
+    funcName: String) extends UnaryExpression with NullIntolerant with 
Serializable {
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     val iu = IntervalUtils.getClass.getName.stripSuffix("$")
     defineCodeGen(ctx, ev, c => s"$iu.$funcName($c)")
   }
+
+  override protected def nullSafeEval(interval: Any): Any = {
+    func(interval.asInstanceOf[T])
+  }
 }
 
 case class ExtractIntervalYears(child: Expression)
-  extends ExtractIntervalPart(child, IntegerType, getYears, "getYears") {
+  extends ExtractIntervalPart[CalendarInterval](IntegerType, getYears, 
"getYears") {
   override protected def withNewChildInternal(newChild: Expression): 
ExtractIntervalYears =
     copy(child = newChild)
 }
 
 case class ExtractIntervalMonths(child: Expression)
-  extends ExtractIntervalPart(child, ByteType, getMonths, "getMonths") {
+  extends ExtractIntervalPart[CalendarInterval](ByteType, getMonths, 
"getMonths") {
   override protected def withNewChildInternal(newChild: Expression): 
ExtractIntervalMonths =
     copy(child = newChild)
 }
 
 case class ExtractIntervalDays(child: Expression)
-  extends ExtractIntervalPart(child, IntegerType, getDays, "getDays") {
+  extends ExtractIntervalPart[CalendarInterval](IntegerType, getDays, 
"getDays") {
   override protected def withNewChildInternal(newChild: Expression): 
ExtractIntervalDays =
     copy(child = newChild)
 }
 
 case class ExtractIntervalHours(child: Expression)
-  extends ExtractIntervalPart(child, LongType, getHours, "getHours") {
+  extends ExtractIntervalPart[CalendarInterval](ByteType, getHours, 
"getHours") {
   override protected def withNewChildInternal(newChild: Expression): 
ExtractIntervalHours =
     copy(child = newChild)
 }
 
 case class ExtractIntervalMinutes(child: Expression)
-  extends ExtractIntervalPart(child, ByteType, getMinutes, "getMinutes") {
+  extends ExtractIntervalPart[CalendarInterval](ByteType, getMinutes, 
"getMinutes") {
   override protected def withNewChildInternal(newChild: Expression): 
ExtractIntervalMinutes =
     copy(child = newChild)
 }
 
 case class ExtractIntervalSeconds(child: Expression)
-  extends ExtractIntervalPart(child, DecimalType(8, 6), getSeconds, 
"getSeconds") {
+  extends ExtractIntervalPart[CalendarInterval](DecimalType(8, 6), getSeconds, 
"getSeconds") {
   override protected def withNewChildInternal(newChild: Expression): 
ExtractIntervalSeconds =
     copy(child = newChild)
 }
 
+case class ExtractANSIIntervalYears(child: Expression)
+    extends ExtractIntervalPart[Int](IntegerType, getYears, "getYears") {
+  override protected def withNewChildInternal(newChild: Expression): 
ExtractANSIIntervalYears =
+    copy(child = newChild)
+}
+
+case class ExtractANSIIntervalMonths(child: Expression)
+    extends ExtractIntervalPart[Int](ByteType, getMonths, "getMonths") {
+  override protected def withNewChildInternal(newChild: Expression): 
ExtractANSIIntervalMonths =
+    copy(child = newChild)
+}
+
+case class ExtractANSIIntervalDays(child: Expression)
+    extends ExtractIntervalPart[Long](IntegerType, getDays, "getDays") {
+  override protected def withNewChildInternal(newChild: Expression): 
ExtractANSIIntervalDays = {
+    copy(child = newChild)
+  }
+}
+
+case class ExtractANSIIntervalHours(child: Expression)
+    extends ExtractIntervalPart[Long](ByteType, getHours, "getHours") {
+  override protected def withNewChildInternal(newChild: Expression): 
ExtractANSIIntervalHours =
+    copy(child = newChild)
+}
+
+case class ExtractANSIIntervalMinutes(child: Expression)
+    extends ExtractIntervalPart[Long](ByteType, getMinutes, "getMinutes") {
+  override protected def withNewChildInternal(newChild: Expression): 
ExtractANSIIntervalMinutes =
+    copy(child = newChild)
+}
+
+case class ExtractANSIIntervalSeconds(child: Expression)
+    extends ExtractIntervalPart[Long](DecimalType(8, 6), getSeconds, 
"getSeconds") {
+  override protected def withNewChildInternal(newChild: Expression): 
ExtractANSIIntervalSeconds =
+    copy(child = newChild)
+}
+
 object ExtractIntervalPart {
 
   def parseExtractField(
       extractField: String,
       source: Expression,
       errorHandleFunc: => Nothing): Expression = 
extractField.toUpperCase(Locale.ROOT) match {
+    case "YEAR" if source.dataType == YearMonthIntervalType => 
ExtractANSIIntervalYears(source)

Review comment:
       Why don't we support all the shortcuts `"YEAR" | "Y" | "YEARS" | "YR" | 
"YRS"`? Can we merge the case?
   ```
   case "YEAR" | "Y" | "YEARS" | "YR" | "YRS" => if (source.dataType == 
YearMonthIntervalType) ... else ...
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to