uros-db commented on code in PR #51826:
URL: https://github.com/apache/spark/pull/51826#discussion_r2282222452


##########
sql/api/src/main/scala/org/apache/spark/sql/functions.scala:
##########
@@ -6292,6 +6292,29 @@ object functions {
   def timestamp_add(unit: String, quantity: Column, ts: Column): Column =
     Column.internalFn("timestampadd", lit(unit), quantity, ts)
 
+  /**
+   * Returns the difference between two times, measured in specified units.
+   *
+   * @param unit
+   *   A STRING representing the unit of the time difference. Supported units 
are: "HOUR",
+   *   "MINUTE", "SECOND", "MILLISECOND", and "MICROSECOND". The unit is 
case-insensitive.
+   * @param start
+   *   A starting TIME.
+   * @param end
+   *   An ending TIME.
+   * @return
+   *   The difference between two times, measured in specified units.
+   * @note
+   *   If any of the inputs is `NULL`, the result is `NULL`.
+   * @throws IllegalArgumentException

Review Comment:
   Actually, it is a `SparkIllegalArgumentException`. Fixing now.



##########
sql/core/src/test/scala/org/apache/spark/sql/TimeFunctionsSuiteBase.scala:
##########
@@ -241,6 +241,57 @@ abstract class TimeFunctionsSuiteBase extends QueryTest 
with SharedSparkSession
     checkAnswer(result2, expected)
   }
 
+  test("SPARK-53108: time_diff function") {
+    // Input data for the function.
+    val schema = StructType(Seq(
+      StructField("unit", StringType, nullable = false),
+      StructField("start", TimeType(), nullable = false),
+      StructField("end", TimeType(), nullable = false)
+    ))
+    val data = Seq(
+      Row("HOUR", LocalTime.parse("20:30:29"), LocalTime.parse("21:30:28")),
+      Row("second", LocalTime.parse("09:32:05.359123"), 
LocalTime.parse("17:23:49.906152")),
+      Row("MicroSecond", LocalTime.parse("09:32:05.359123"), 
LocalTime.parse("17:23:49.906152"))
+    )
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(data), 
schema)
+
+    // Test the function using both `selectExpr` and `select`.
+    val result1 = df.selectExpr(
+      "time_diff(unit, start, end)"
+    )
+    val result2 = df.select(
+      time_diff(col("unit"), col("start"), col("end"))
+    )
+    // Check that both methods produce the same result.
+    checkAnswer(result1, result2)
+
+    // Expected output of the function.
+    val expected = Seq(
+      0,
+      28304,
+      28304547029L
+    ).toDF("diff").select(col("diff"))
+    // Check that the results match the expected output.
+    checkAnswer(result1, expected)
+    checkAnswer(result2, expected)
+
+    // Error is thrown for malformed input.
+    val invalidUnitDF = Seq(
+      ("invalid_unit", LocalTime.parse("01:02:03"), 
LocalTime.parse("01:02:03"))
+    ).toDF("unit", "start", "end")
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {

Review Comment:
   Updating.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to