This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2605b87990c [SPARK-45655][SQL][SS] Allow non-deterministic expressions 
inside AggregateFunctions in CollectMetrics
2605b87990c is described below

commit 2605b87990c9826d05ad0943045e8dfa79af13e9
Author: Bhuwan Sahni <bhuwan.sa...@databricks.com>
AuthorDate: Sun Nov 12 16:50:01 2023 +0900

    [SPARK-45655][SQL][SS] Allow non-deterministic expressions inside 
AggregateFunctions in CollectMetrics
    
    ### What changes were proposed in this pull request?
    
    This PR allows non-deterministic expressions wrapped inside an 
`AggregateFunction` such as `count` inside `CollectMetrics` node. 
`CollectMetrics` is used to collect arbitrary metrics from the query, in 
certain scenarios user would like to collect metrics for filtering based on 
non-deterministic expressions (see query example below).
    
    Currently, Analyzer does not allow non-deterministic expressions inside a 
`AggregateFunction` for `CollectMetrics`. This constraint is relaxed to allow 
collection of such metrics. Note that the metrics are relevant for a completed 
batch, and can change if the batch is replayed (because non-deterministic 
expression can behave differently for different runs).
    
    While working on this feature, I found a issue with `checkMetric` logic to 
validate non-deterministic expressions inside an AggregateExpression. An 
expression is determined as non-deterministic if any of its children is 
non-deterministic, hence we need to match the case for `!e.deterministic && 
!seenAggregate` after we have matched if the current expression is a 
AggregateExpression. If the current expression is a AggregateExpression, we 
should validate further down in the tree recursi [...]
    
    ```
    
    val inputData = MemoryStream[Timestamp]
    
    inputData.toDF()
          .filter("value < current_date()")
          .observe("metrics", count(expr("value >= 
current_date()")).alias("dropped"))
          .writeStream
          .queryName("ts_metrics_test")
          .format("memory")
          .outputMode("append")
          .start()
    
    ```
    
    ### Why are the changes needed?
    
    1. Added a testcase to calculate dropped rows (by `CurrentBatchTimestamp`) 
and ensure the query is successful.
    
    As an example, the query below fails (without this change) due to observe 
call on the DataFrame.
    
    ```
    
    val inputData = MemoryStream[Timestamp]
    
    inputData.toDF()
          .filter("value < current_date()")
          .observe("metrics", count(expr("value >= 
current_date()")).alias("dropped"))
          .writeStream
          .queryName("ts_metrics_test")
          .format("memory")
          .outputMode("append")
          .start()
    
    ```
    2. Added testing in AnalysisSuite for non-deterministic expressions inside 
a AggregateFunction.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test cases added.
    
    ```
    
    [warn] 20 warnings found
    WARNING: Using incubator modules: jdk.incubator.vector, 
jdk.incubator.foreign
    [info] StreamingQueryStatusAndProgressSuite:
    09:14:39.684 WARN org.apache.hadoop.util.NativeCodeLoader: Unable to load 
native-hadoop library for your platform... using builtin-java classes where 
applicable
    [info] Passed: Total 0, Failed 0, Errors 0, Passed 0
    [info] No tests to run for hive / Test / testOnly
    [info] - StreamingQueryProgress - prettyJson (436 milliseconds)
    [info] - StreamingQueryProgress - json (3 milliseconds)
    [info] - StreamingQueryProgress - toString (5 milliseconds)
    [info] - StreamingQueryProgress - jsonString and fromJson (163 milliseconds)
    [info] - StreamingQueryStatus - prettyJson (1 millisecond)
    [info] - StreamingQueryStatus - json (1 millisecond)
    [info] - StreamingQueryStatus - toString (2 milliseconds)
    09:14:41.674 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary 
checkpoint location created which is deleted normally when the query didn't 
fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-34d2749f-f4d0
    -46d8-bc51-29da6411e1c5. If it's required to delete it under any 
circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation 
to true. Important to know deleting temp checkpoint folder is best effort.
    09:14:41.710 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - progress classes should be Serializable (5 seconds, 552 
milliseconds)
    09:14:46.345 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary 
checkpoint location created which is deleted normally when the query didn't 
fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-3a41d397-c3c1
    -490b-9cc7-d775b0c42208. If it's required to delete it under any 
circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation 
to true. Important to know deleting temp checkpoint folder is best effort.
    09:14:46.345 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - SPARK-19378: Continue reporting stateOp metrics even if there is 
no active trigger (1 second, 337 milliseconds)
    09:14:47.677 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - SPARK-29973: Make `processedRowsPerSecond` calculated more 
accurately and meaningfully (455 milliseconds)
    09:14:48.174 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: Temporary 
checkpoint location created which is deleted normally when the query didn't 
fail: /Users/bhuwan.sahni/workspace/spark/target/tmp/temporary-360fc3b9-a2c5
    -430c-a892-c9869f1f8339. If it's required to delete it under any 
circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation 
to true. Important to know deleting temp checkpoint folder is best effort.
    09:14:48.174 WARN 
org.apache.spark.sql.execution.streaming.ResolveWriteToStream: 
spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets 
and will be disabled.
    [info] - SPARK-45655: Use current batch timestamp in observe API (587 
milliseconds)
    09:14:48.768 WARN 
org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite:
    
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43517 from sahnib/SPARK-45655.
    
    Authored-by: Bhuwan Sahni <bhuwan.sa...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      | 25 +++++++-----
 .../sql/catalyst/analysis/AnalysisSuite.scala      | 17 +++++++--
 .../StreamingQueryStatusAndProgressSuite.scala     | 44 ++++++++++++++++++++++
 3 files changed, 74 insertions(+), 12 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 352b3124a86..d41345f38c2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.ExtendedAnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
-import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Median, PercentileCont, PercentileDisc}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
AggregateFunction, Median, PercentileCont, PercentileDisc}
 import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, 
DecorrelateInnerQuery, InlineCTE}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -476,10 +476,6 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
                   e.failAnalysis(
                     "INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED",
                     Map("expr" -> toSQLExpr(s)))
-                case _ if !e.deterministic && !seenAggregate =>
-                  e.failAnalysis(
-                    
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC",
-                    Map("expr" -> toSQLExpr(s)))
                 case a: AggregateExpression if seenAggregate =>
                   e.failAnalysis(
                     "INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED",
@@ -492,12 +488,18 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
                   e.failAnalysis(
                     
"INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED",
                     Map("expr" -> toSQLExpr(s)))
+                case _: AggregateExpression | _: AggregateFunction =>
+                  e.children.foreach(checkMetric (s, _, seenAggregate = true))
                 case _: Attribute if !seenAggregate =>
                   e.failAnalysis(
                     
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
                     Map("expr" -> toSQLExpr(s)))
-                case _: AggregateExpression =>
-                  e.children.foreach(checkMetric (s, _, seenAggregate = true))
+                case a: Alias =>
+                  checkMetric(s, a.child, seenAggregate)
+                case a if !e.deterministic && !seenAggregate =>
+                  e.failAnalysis(
+                    
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC",
+                    Map("expr" -> toSQLExpr(s)))
                 case _ =>
                   e.children.foreach(checkMetric (s, _, seenAggregate))
               }
@@ -734,8 +736,13 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
                 "dataType" -> toSQLType(mapCol.dataType)))
 
           case o if o.expressions.exists(!_.deterministic) &&
-            !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
-            !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] &&
+            !o.isInstanceOf[Project] &&
+            // non-deterministic expressions inside CollectMetrics have been
+            // already validated inside checkMetric function
+            !o.isInstanceOf[CollectMetrics] &&
+            !o.isInstanceOf[Filter] &&
+            !o.isInstanceOf[Aggregate] &&
+            !o.isInstanceOf[Window] &&
             !o.isInstanceOf[Expand] &&
             !o.isInstanceOf[Generate] &&
             !o.isInstanceOf[CreateVariable] &&
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ca22c55b49e..8e514e245cb 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -794,9 +794,20 @@ class AnalysisSuite extends AnalysisTest with Matchers {
     // No columns
     assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved)
 
-    def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit 
= {
-      assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0), 
errors)
-    }
+    // non-deterministic expression inside an aggregate function is valid
+    val tsLiteral = Literal.create(java.sql.Timestamp.valueOf("2023-11-30 
21:05:00.000000"),
+      TimestampType)
+
+    assertAnalysisSuccess(
+      CollectMetrics(
+        "invalid",
+        Count(
+          GreaterThan(tsLiteral, CurrentBatchTimestamp(1699485296000L, 
TimestampType))
+        ).as("count") :: Nil,
+        testRelation,
+        0
+      )
+    )
 
     // Unwrapped attribute
     assertAnalysisErrorClass(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
index 8fe4ef39b25..8ff71473f27 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.sql.streaming
 
+import java.sql.Timestamp
+import java.time.Instant
+import java.time.temporal.ChronoUnit
 import java.util.UUID
 
 import scala.jdk.CollectionConverters._
@@ -355,6 +358,47 @@ class StreamingQueryStatusAndProgressSuite extends 
StreamTest with Eventually {
     )
   }
 
+  test("SPARK-45655: Use current batch timestamp in observe API") {
+    import testImplicits._
+
+    val inputData = MemoryStream[Timestamp]
+
+    // current_date() internally uses current batch timestamp on streaming 
query
+    val query = inputData.toDF()
+      .filter("value < current_date()")
+      .observe("metrics", count(expr("value >= 
current_date()")).alias("dropped"))
+      .writeStream
+      .queryName("ts_metrics_test")
+      .format("memory")
+      .outputMode("append")
+      .start()
+
+    val timeNow = Instant.now().truncatedTo(ChronoUnit.SECONDS)
+
+    // this value would be accepted by the filter and would not count towards
+    // dropped metrics.
+    val validValue = Timestamp.from(timeNow.minus(2, ChronoUnit.DAYS))
+    inputData.addData(validValue)
+
+    // would be dropped by the filter and count towards dropped metrics
+    inputData.addData(Timestamp.from(timeNow.plus(2, ChronoUnit.DAYS)))
+
+    query.processAllAvailable()
+    query.stop()
+
+    val dropped = query.recentProgress.map { p =>
+      val metricVal = Option(p.observedMetrics.get("metrics"))
+      metricVal.map(_.getLong(0)).getOrElse(0L)
+    }.sum
+    // ensure dropped metrics are correct
+    assert(dropped == 1)
+
+    val data = spark.read.table("ts_metrics_test").collect()
+
+    // ensure valid value ends up in output
+    assert(data(0).getAs[Timestamp](0).equals(validValue))
+  }
+
   def waitUntilBatchProcessed: AssertOnQuery = Execute { q =>
     eventually(Timeout(streamingTimeout)) {
       if (q.exception.isEmpty) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to