This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 21548a8cc5c5 [SPARK-47545][CONNECT] Dataset `observe` support for the Scala client 21548a8cc5c5 is described below commit 21548a8cc5c527d4416a276a852f967b4410bd4b Author: Paddy Xu <xupa...@gmail.com> AuthorDate: Wed May 8 15:44:02 2024 -0400 [SPARK-47545][CONNECT] Dataset `observe` support for the Scala client ### What changes were proposed in this pull request? This PR adds support for `Dataset.observe` to the Spark Connect Scala client. Note that the support here does not include listener support as it runs on the serve side. This PR includes a small refactoring to the `Observation` helper class. We extracted methods that are not bound to the SparkSession to `spark-api`, and added two subclasses on both `spark-core` and `spark-jvm-client`. ### Why are the changes needed? Before this PR, the `DF.observe` method is only supported in the Python client. ### Does this PR introduce _any_ user-facing change? Yes. The user can now issue `DF.observe(name, metrics...)` or `DF.observe(observationObject, metrics...)` to get stats of columns of a dataframe. ### How was this patch tested? Added new e2e tests. ### Was this patch authored or co-authored using generative AI tooling? Nope. Closes #45701 from xupefei/scala-observe. Authored-by: Paddy Xu <xupa...@gmail.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 63 ++++++- .../scala/org/apache/spark/sql/Observation.scala | 46 +++++ .../scala/org/apache/spark/sql/SparkSession.scala | 31 +++- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 43 +++++ .../CheckConnectJvmClientCompatibility.scala | 3 - .../src/main/protobuf/spark/connect/base.proto | 1 + .../spark/sql/connect/client/SparkResult.scala | 44 ++++- .../common/LiteralValueProtoConverter.scala | 2 +- .../connect/execution/ExecuteThreadRunner.scala | 1 + .../execution/SparkConnectPlanExecution.scala | 12 +- python/pyspark/sql/connect/proto/base_pb2.py | 188 ++++++++++----------- python/pyspark/sql/connect/proto/base_pb2.pyi | 5 +- .../org/apache/spark/sql/ObservationBase.scala | 113 +++++++++++++ .../scala/org/apache/spark/sql/Observation.scala | 62 +------ 14 files changed, 448 insertions(+), 166 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 9a42afebf8f2..37f770319b69 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3337,8 +3337,69 @@ class Dataset[T] private[sql] ( } } + /** + * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset + * that returns the same result as the input, with the following guarantees: <ul> <li>It will + * compute the defined aggregates (metrics) on all the data that is flowing through the Dataset + * at that point.</li> <li>It will report the value of the defined aggregate columns as soon as + * we reach a completion point. A completion point is currently defined as the end of a + * query.</li> </ul> Please note that continuous execution is currently not supported. + * + * The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or + * more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that + * contain references to the input Dataset's columns must always be wrapped in an aggregate + * function. + * + * A user can retrieve the metrics by calling + * `org.apache.spark.sql.Dataset.collectResult().getObservedMetrics`. + * + * {{{ + * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it + * val observed_ds = ds.observe("my_metrics", count(lit(1)).as("rows"), max($"id").as("maxid")) + * observed_ds.write.parquet("ds.parquet") + * val metrics = observed_ds.collectResult().getObservedMetrics + * }}} + * + * @group typedrel + * @since 4.0.0 + */ + @scala.annotation.varargs def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = { - throw new UnsupportedOperationException("observe is not implemented.") + sparkSession.newDataset(agnosticEncoder) { builder => + builder.getCollectMetricsBuilder + .setInput(plan.getRoot) + .setName(name) + .addAllMetrics((expr +: exprs).map(_.expr).asJava) + } + } + + /** + * Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. This is + * equivalent to calling `observe(String, Column, Column*)` but does not require to collect all + * results before returning the metrics - the metrics are filled during iterating the results, + * as soon as they are available. This method does not support streaming datasets. + * + * A user can retrieve the metrics by accessing `org.apache.spark.sql.Observation.get`. + * + * {{{ + * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it + * val observation = Observation("my_metrics") + * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid")) + * observed_ds.write.parquet("ds.parquet") + * val metrics = observation.get + * }}} + * + * @throws IllegalArgumentException + * If this is a streaming Dataset (this.isStreaming == true) + * + * @group typedrel + * @since 4.0.0 + */ + @scala.annotation.varargs + def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = { + val df = observe(observation.name, expr, exprs: _*) + sparkSession.registerObservation(df.getPlanId.get, observation) + df } def checkpoint(): Dataset[T] = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala new file mode 100644 index 000000000000..75629b6000f9 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Observation.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.UUID + +class Observation(name: String) extends ObservationBase(name) { + + /** + * Create an Observation instance without providing a name. This generates a random name. + */ + def this() = this(UUID.randomUUID().toString) +} + +/** + * (Scala-specific) Create instances of Observation via Scala `apply`. + * @since 4.0.0 + */ +object Observation { + + /** + * Observation constructor for creating an anonymous observation. + */ + def apply(): Observation = new Observation() + + /** + * Observation constructor for creating a named observation. + */ + def apply(name: String): Observation = new Observation(name) + +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 22bb62803fac..1188fba60a2f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.Closeable import java.net.URI +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicLong, AtomicReference} @@ -36,7 +37,7 @@ import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder} -import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit @@ -80,6 +81,8 @@ class SparkSession private[sql] ( client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION).getSparkVersion.getVersion } + private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]() + /** * Runtime configuration interface for Spark. * @@ -532,8 +535,12 @@ class SparkSession private[sql] ( private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = { val value = client.execute(plan) - val result = new SparkResult(value, allocator, encoder, timeZoneId) - result + new SparkResult( + value, + allocator, + encoder, + timeZoneId, + Some(setMetricsAndUnregisterObservation)) } private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = { @@ -554,6 +561,9 @@ class SparkSession private[sql] ( client.execute(plan).filter(!_.hasExecutionProgress).toSeq } + private[sql] def execute(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] = + client.execute(plan) + private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = { val command = proto.Command.newBuilder().setRegisterFunction(udf).build() execute(command) @@ -779,6 +789,21 @@ class SparkSession private[sql] ( * Set to false to prevent client.releaseSession on close() (testing only) */ private[sql] var releaseSessionOnClose = true + + private[sql] def registerObservation(planId: Long, observation: Observation): Unit = { + if (observationRegistry.putIfAbsent(planId, observation) != null) { + throw new IllegalArgumentException("An Observation can be used with a Dataset only once") + } + } + + private[sql] def setMetricsAndUnregisterObservation( + planId: Long, + metrics: Map[String, Any]): Unit = { + val observationOrNull = observationRegistry.remove(planId) + if (observationOrNull != null) { + observationOrNull.setMetricsAndNotify(Some(metrics)) + } + } } // The minimal builder needed to create a spark session. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index a0729adb8960..73a2f6d4f88e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -22,6 +22,8 @@ import java.time.DateTimeException import java.util.Properties import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ import org.apache.commons.io.FileUtils @@ -41,6 +43,7 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, SQLHelper} import org.apache.spark.sql.test.SparkConnectServerUtils.port import org.apache.spark.sql.types._ +import org.apache.spark.util.SparkThreadUtils class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { @@ -1511,6 +1514,46 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM (0 until 5).foreach(i => assert(row.get(i * 2) === row.get(i * 2 + 1))) } } + + test("Observable metrics") { + val df = spark.range(99).withColumn("extra", col("id") - 1) + val ob1 = new Observation("ob1") + val observedDf = df.observe(ob1, min("id"), avg("id"), max("id")) + val observedObservedDf = observedDf.observe("ob2", min("extra"), avg("extra"), max("extra")) + + val ob1Schema = new StructType() + .add("min(id)", LongType) + .add("avg(id)", DoubleType) + .add("max(id)", LongType) + val ob2Schema = new StructType() + .add("min(extra)", LongType) + .add("avg(extra)", DoubleType) + .add("max(extra)", LongType) + val ob1Metrics = Map("ob1" -> new GenericRowWithSchema(Array(0, 49, 98), ob1Schema)) + val ob2Metrics = Map("ob2" -> new GenericRowWithSchema(Array(-1, 48, 97), ob2Schema)) + + assert(df.collectResult().getObservedMetrics === Map.empty) + assert(observedDf.collectResult().getObservedMetrics === ob1Metrics) + assert(observedObservedDf.collectResult().getObservedMetrics === ob1Metrics ++ ob2Metrics) + } + + test("Observation.get is blocked until the query is finished") { + val df = spark.range(99).withColumn("extra", col("id") - 1) + val observation = new Observation("ob1") + val observedDf = df.observe(observation, min("id"), avg("id"), max("id")) + + // Start a new thread to get the observation + val future = Future(observation.get)(ExecutionContext.global) + // make sure the thread is blocked right now + val e = intercept[java.util.concurrent.TimeoutException] { + SparkThreadUtils.awaitResult(future, 2.seconds) + } + assert(e.getMessage.contains("Future timed out")) + observedDf.collect() + // make sure the thread is unblocked after the query is finished + val metrics = SparkThreadUtils.awaitResult(future, 2.seconds) + assert(metrics === Map("min(id)" -> 0, "avg(id)" -> 49, "max(id)" -> 98)) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index c89dba03ed69..7be5e2ecd172 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -196,9 +196,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.COL_POS_KEY"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.DATASET_ID_KEY"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.curId"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.observe"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ObservationListener$"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.queryExecution"), diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 49a33d3419b6..77dda277602a 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -434,6 +434,7 @@ message ExecutePlanResponse { string name = 1; repeated Expression.Literal values = 2; repeated string keys = 3; + int64 plan_id = 4; } message ResultComplete { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 93d1075aea02..0905ee76c3f3 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -27,10 +27,13 @@ import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch} import org.apache.arrow.vector.types.pojo import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator} -import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ArrowUtils @@ -38,7 +41,8 @@ private[sql] class SparkResult[T]( responses: CloseableIterator[proto.ExecutePlanResponse], allocator: BufferAllocator, encoder: AgnosticEncoder[T], - timeZoneId: String) + timeZoneId: String, + setObservationMetricsOpt: Option[(Long, Map[String, Any]) => Unit] = None) extends AutoCloseable { self => case class StageInfo( @@ -79,6 +83,7 @@ private[sql] class SparkResult[T]( private[this] var arrowSchema: pojo.Schema = _ private[this] var nextResultIndex: Int = 0 private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])] + private val observedMetrics = mutable.Map.empty[String, Row] private val cleanable = SparkResult.cleaner.register(this, new SparkResultCloseable(resultMap, responses)) @@ -117,6 +122,9 @@ private[sql] class SparkResult[T]( while (!stop && responses.hasNext) { val response = responses.next() + // Collect metrics for this response + observedMetrics ++= processObservedMetrics(response.getObservedMetricsList) + // Save and validate operationId if (opId == null) { opId = response.getOperationId @@ -198,6 +206,29 @@ private[sql] class SparkResult[T]( nonEmpty } + private def processObservedMetrics( + metrics: java.util.List[ObservedMetrics]): Iterable[(String, Row)] = { + metrics.asScala.map { metric => + assert(metric.getKeysCount == metric.getValuesCount) + var schema = new StructType() + val keys = mutable.ListBuffer.empty[String] + val values = mutable.ListBuffer.empty[Any] + (0 until metric.getKeysCount).map { i => + val key = metric.getKeys(i) + val value = LiteralValueProtoConverter.toCatalystValue(metric.getValues(i)) + schema = schema.add(key, LiteralValueProtoConverter.toDataType(value.getClass)) + keys += key + values += value + } + // If the metrics is registered by an Observation object, attach them and unblock any + // blocked thread. + setObservationMetricsOpt.foreach { setObservationMetrics => + setObservationMetrics(metric.getPlanId, keys.zip(values).toMap) + } + metric.getName -> new GenericRowWithSchema(values.toArray, schema) + } + } + /** * Returns the number of elements in the result. */ @@ -248,6 +279,15 @@ private[sql] class SparkResult[T]( result } + /** + * Returns all observed metrics in the result. + */ + def getObservedMetrics: Map[String, Row] = { + // We need to process all responses to get all metrics. + processResponses() + observedMetrics.toMap + } + /** * Returns an iterator over the contents of the result. */ diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index ce42cc797bf3..1f3496fa8984 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -204,7 +204,7 @@ object LiteralValueProtoConverter { def toLiteralProto(literal: Any, dataType: DataType): proto.Expression.Literal = toLiteralProtoBuilder(literal, dataType).build() - private def toDataType(clz: Class[_]): DataType = clz match { + private[sql] def toDataType(clz: Class[_]): DataType = clz match { // primitive types case JShort.TYPE => ShortType case JInteger.TYPE => IntegerType diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 0a6d12cbb191..4ef4f632204b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -220,6 +220,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends .createObservedMetricsResponse( executeHolder.sessionHolder.sessionId, executeHolder.sessionHolder.serverSessionId, + executeHolder.request.getPlan.getRoot.getCommon.getPlanId, observedMetrics ++ accumulatedInPython)) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 4f2b8c945127..660951f22984 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -264,8 +264,14 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) name -> values } if (observedMetrics.nonEmpty) { - Some(SparkConnectPlanExecution - .createObservedMetricsResponse(sessionId, sessionHolder.serverSessionId, observedMetrics)) + val planId = executeHolder.request.getPlan.getRoot.getCommon.getPlanId + Some( + SparkConnectPlanExecution + .createObservedMetricsResponse( + sessionId, + sessionHolder.serverSessionId, + planId, + observedMetrics)) } else None } } @@ -274,11 +280,13 @@ object SparkConnectPlanExecution { def createObservedMetricsResponse( sessionId: String, serverSessionId: String, + planId: Long, metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse = { val observedMetrics = metrics.map { case (name, values) => val metrics = ExecutePlanResponse.ObservedMetrics .newBuilder() .setName(name) + .setPlanId(planId) values.foreach { case (key, value) => metrics.addValues(toLiteralProto(value)) key.foreach(metrics.addKeys) diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 2a30ffe60a9f..a39396db4ff1 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17 [...] + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -120,7 +120,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_start = 5196 _EXECUTEPLANREQUEST_REQUESTOPTION._serialized_end = 5361 _EXECUTEPLANRESPONSE._serialized_start = 5440 - _EXECUTEPLANRESPONSE._serialized_end = 8230 + _EXECUTEPLANRESPONSE._serialized_end = 8256 _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 7030 _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 7101 _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 7103 @@ -133,96 +133,96 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 7651 _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 7653 _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 7741 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7743 - _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7859 - _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7861 - _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7877 - _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_start = 7880 - _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_end = 8213 - _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_start = 8036 - _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_end = 8213 - _KEYVALUE._serialized_start = 8232 - _KEYVALUE._serialized_end = 8297 - _CONFIGREQUEST._serialized_start = 8300 - _CONFIGREQUEST._serialized_end = 9459 - _CONFIGREQUEST_OPERATION._serialized_start = 8608 - _CONFIGREQUEST_OPERATION._serialized_end = 9106 - _CONFIGREQUEST_SET._serialized_start = 9108 - _CONFIGREQUEST_SET._serialized_end = 9160 - _CONFIGREQUEST_GET._serialized_start = 9162 - _CONFIGREQUEST_GET._serialized_end = 9187 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 9189 - _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 9252 - _CONFIGREQUEST_GETOPTION._serialized_start = 9254 - _CONFIGREQUEST_GETOPTION._serialized_end = 9285 - _CONFIGREQUEST_GETALL._serialized_start = 9287 - _CONFIGREQUEST_GETALL._serialized_end = 9335 - _CONFIGREQUEST_UNSET._serialized_start = 9337 - _CONFIGREQUEST_UNSET._serialized_end = 9364 - _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 9366 - _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 9400 - _CONFIGRESPONSE._serialized_start = 9462 - _CONFIGRESPONSE._serialized_end = 9637 - _ADDARTIFACTSREQUEST._serialized_start = 9640 - _ADDARTIFACTSREQUEST._serialized_end = 10642 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 10115 - _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 10168 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 10170 - _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 10281 - _ADDARTIFACTSREQUEST_BATCH._serialized_start = 10283 - _ADDARTIFACTSREQUEST_BATCH._serialized_end = 10376 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 10379 - _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10572 - _ADDARTIFACTSRESPONSE._serialized_start = 10645 - _ADDARTIFACTSRESPONSE._serialized_end = 10917 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10836 - _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10917 - _ARTIFACTSTATUSESREQUEST._serialized_start = 10920 - _ARTIFACTSTATUSESREQUEST._serialized_end = 11246 - _ARTIFACTSTATUSESRESPONSE._serialized_start = 11249 - _ARTIFACTSTATUSESRESPONSE._serialized_end = 11601 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11444 - _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11559 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11561 - _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11601 - _INTERRUPTREQUEST._serialized_start = 11604 - _INTERRUPTREQUEST._serialized_end = 12207 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 12007 - _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 12135 - _INTERRUPTRESPONSE._serialized_start = 12210 - _INTERRUPTRESPONSE._serialized_end = 12354 - _REATTACHOPTIONS._serialized_start = 12356 - _REATTACHOPTIONS._serialized_end = 12409 - _REATTACHEXECUTEREQUEST._serialized_start = 12412 - _REATTACHEXECUTEREQUEST._serialized_end = 12818 - _RELEASEEXECUTEREQUEST._serialized_start = 12821 - _RELEASEEXECUTEREQUEST._serialized_end = 13406 - _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 13275 - _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 13287 - _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 13289 - _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 13336 - _RELEASEEXECUTERESPONSE._serialized_start = 13409 - _RELEASEEXECUTERESPONSE._serialized_end = 13574 - _RELEASESESSIONREQUEST._serialized_start = 13577 - _RELEASESESSIONREQUEST._serialized_end = 13748 - _RELEASESESSIONRESPONSE._serialized_start = 13750 - _RELEASESESSIONRESPONSE._serialized_end = 13858 - _FETCHERRORDETAILSREQUEST._serialized_start = 13861 - _FETCHERRORDETAILSREQUEST._serialized_end = 14193 - _FETCHERRORDETAILSRESPONSE._serialized_start = 14196 - _FETCHERRORDETAILSRESPONSE._serialized_end = 15751 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 14425 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14599 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14602 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14970 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 14933 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14970 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14973 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 15382 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 15284 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 15352 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 15385 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15732 - _SPARKCONNECTSERVICE._serialized_start = 15754 - _SPARKCONNECTSERVICE._serialized_end = 16700 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 7744 + _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 7885 + _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_start = 7887 + _EXECUTEPLANRESPONSE_RESULTCOMPLETE._serialized_end = 7903 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_start = 7906 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS._serialized_end = 8239 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_start = 8062 + _EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO._serialized_end = 8239 + _KEYVALUE._serialized_start = 8258 + _KEYVALUE._serialized_end = 8323 + _CONFIGREQUEST._serialized_start = 8326 + _CONFIGREQUEST._serialized_end = 9485 + _CONFIGREQUEST_OPERATION._serialized_start = 8634 + _CONFIGREQUEST_OPERATION._serialized_end = 9132 + _CONFIGREQUEST_SET._serialized_start = 9134 + _CONFIGREQUEST_SET._serialized_end = 9186 + _CONFIGREQUEST_GET._serialized_start = 9188 + _CONFIGREQUEST_GET._serialized_end = 9213 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 9215 + _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 9278 + _CONFIGREQUEST_GETOPTION._serialized_start = 9280 + _CONFIGREQUEST_GETOPTION._serialized_end = 9311 + _CONFIGREQUEST_GETALL._serialized_start = 9313 + _CONFIGREQUEST_GETALL._serialized_end = 9361 + _CONFIGREQUEST_UNSET._serialized_start = 9363 + _CONFIGREQUEST_UNSET._serialized_end = 9390 + _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 9392 + _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 9426 + _CONFIGRESPONSE._serialized_start = 9488 + _CONFIGRESPONSE._serialized_end = 9663 + _ADDARTIFACTSREQUEST._serialized_start = 9666 + _ADDARTIFACTSREQUEST._serialized_end = 10668 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 10141 + _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 10194 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 10196 + _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 10307 + _ADDARTIFACTSREQUEST_BATCH._serialized_start = 10309 + _ADDARTIFACTSREQUEST_BATCH._serialized_end = 10402 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 10405 + _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 10598 + _ADDARTIFACTSRESPONSE._serialized_start = 10671 + _ADDARTIFACTSRESPONSE._serialized_end = 10943 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 10862 + _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 10943 + _ARTIFACTSTATUSESREQUEST._serialized_start = 10946 + _ARTIFACTSTATUSESREQUEST._serialized_end = 11272 + _ARTIFACTSTATUSESRESPONSE._serialized_start = 11275 + _ARTIFACTSTATUSESRESPONSE._serialized_end = 11627 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 11470 + _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 11585 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 11587 + _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 11627 + _INTERRUPTREQUEST._serialized_start = 11630 + _INTERRUPTREQUEST._serialized_end = 12233 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_start = 12033 + _INTERRUPTREQUEST_INTERRUPTTYPE._serialized_end = 12161 + _INTERRUPTRESPONSE._serialized_start = 12236 + _INTERRUPTRESPONSE._serialized_end = 12380 + _REATTACHOPTIONS._serialized_start = 12382 + _REATTACHOPTIONS._serialized_end = 12435 + _REATTACHEXECUTEREQUEST._serialized_start = 12438 + _REATTACHEXECUTEREQUEST._serialized_end = 12844 + _RELEASEEXECUTEREQUEST._serialized_start = 12847 + _RELEASEEXECUTEREQUEST._serialized_end = 13432 + _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_start = 13301 + _RELEASEEXECUTEREQUEST_RELEASEALL._serialized_end = 13313 + _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_start = 13315 + _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 13362 + _RELEASEEXECUTERESPONSE._serialized_start = 13435 + _RELEASEEXECUTERESPONSE._serialized_end = 13600 + _RELEASESESSIONREQUEST._serialized_start = 13603 + _RELEASESESSIONREQUEST._serialized_end = 13774 + _RELEASESESSIONRESPONSE._serialized_start = 13776 + _RELEASESESSIONRESPONSE._serialized_end = 13884 + _FETCHERRORDETAILSREQUEST._serialized_start = 13887 + _FETCHERRORDETAILSREQUEST._serialized_end = 14219 + _FETCHERRORDETAILSRESPONSE._serialized_start = 14222 + _FETCHERRORDETAILSRESPONSE._serialized_end = 15777 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 14451 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 14625 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 14628 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 14996 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 14959 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 14996 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 14999 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 15408 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 15310 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 15378 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 15411 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 15758 + _SPARKCONNECTSERVICE._serialized_start = 15780 + _SPARKCONNECTSERVICE._serialized_end = 16726 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index d22502f8839d..b76f2a7f4de3 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -1406,6 +1406,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): NAME_FIELD_NUMBER: builtins.int VALUES_FIELD_NUMBER: builtins.int KEYS_FIELD_NUMBER: builtins.int + PLAN_ID_FIELD_NUMBER: builtins.int name: builtins.str @property def values( @@ -1417,6 +1418,7 @@ class ExecutePlanResponse(google.protobuf.message.Message): def keys( self, ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ... + plan_id: builtins.int def __init__( self, *, @@ -1426,11 +1428,12 @@ class ExecutePlanResponse(google.protobuf.message.Message): ] | None = ..., keys: collections.abc.Iterable[builtins.str] | None = ..., + plan_id: builtins.int = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ - "keys", b"keys", "name", b"name", "values", b"values" + "keys", b"keys", "name", b"name", "plan_id", b"plan_id", "values", b"values" ], ) -> None: ... diff --git a/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala b/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala new file mode 100644 index 000000000000..4789ae8975d1 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/ObservationBase.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.jdk.CollectionConverters.MapHasAsJava + +/** + * Helper class to simplify usage of `Dataset.observe(String, Column, Column*)`: + * + * {{{ + * // Observe row count (rows) and highest id (maxid) in the Dataset while writing it + * val observation = Observation("my metrics") + * val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid")) + * observed_ds.write.parquet("ds.parquet") + * val metrics = observation.get + * }}} + * + * This collects the metrics while the first action is executed on the observed dataset. Subsequent + * actions do not modify the metrics returned by [[get]]. Retrieval of the metric via [[get]] + * blocks until the first action has finished and metrics become available. + * + * This class does not support streaming datasets. + * + * @param name name of the metric + * @since 3.3.0 + */ +abstract class ObservationBase(val name: String) { + + if (name.isEmpty) throw new IllegalArgumentException("Name must not be empty") + + @volatile protected var metrics: Option[Map[String, Any]] = None + + /** + * (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish + * its first action. Only the result of the first action is available. Subsequent actions do not + * modify the result. + * + * @return the observed metrics as a `Map[String, Any]` + * @throws InterruptedException interrupted while waiting + */ + @throws[InterruptedException] + def get: Map[String, _] = { + synchronized { + // we need to loop as wait might return without us calling notify + // https://en.wikipedia.org/w/index.php?title=Spurious_wakeup&oldid=992601610 + while (this.metrics.isEmpty) { + wait() + } + } + + this.metrics.get + } + + /** + * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish + * its first action. Only the result of the first action is available. Subsequent actions do not + * modify the result. + * + * @return the observed metrics as a `java.util.Map[String, Object]` + * @throws InterruptedException interrupted while waiting + */ + @throws[InterruptedException] + def getAsJava: java.util.Map[String, AnyRef] = { + get.map { case (key, value) => (key, value.asInstanceOf[Object]) }.asJava + } + + /** + * Get the observed metrics. This returns the metrics if they are available, otherwise an empty. + * + * @return the observed metrics as a `Map[String, Any]` + */ + @throws[InterruptedException] + private[sql] def getOrEmpty: Map[String, _] = { + synchronized { + if (metrics.isEmpty) { + wait(100) // Wait for 100ms to see if metrics are available + } + metrics.getOrElse(Map.empty) + } + } + + /** + * Set the observed metrics and notify all waiting threads to resume. + * + * @return `true` if all waiting threads were notified, `false` if otherwise. + */ + private[spark] def setMetricsAndNotify(metrics: Option[Map[String, Any]]): Boolean = { + synchronized { + this.metrics = metrics + if(metrics.isDefined) { + notifyAll() + true + } else { + false + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala index 104e7c101fd1..30d5943c6092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql import java.util.UUID -import scala.jdk.CollectionConverters.MapHasAsJava - import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.util.QueryExecutionListener @@ -47,9 +45,7 @@ import org.apache.spark.util.ArrayImplicits._ * @param name name of the metric * @since 3.3.0 */ -class Observation(val name: String) { - - if (name.isEmpty) throw new IllegalArgumentException("Name must not be empty") +class Observation(name: String) extends ObservationBase(name) { /** * Create an Observation instance without providing a name. This generates a random name. @@ -60,8 +56,6 @@ class Observation(val name: String) { @volatile private var dataframeId: Option[(SparkSession, Long)] = None - @volatile private var metrics: Option[Map[String, Any]] = None - /** * Attach this observation to the given [[Dataset]] to observe aggregation expressions. * @@ -83,55 +77,6 @@ class Observation(val name: String) { ds.observe(name, expr, exprs: _*) } - /** - * (Scala-specific) Get the observed metrics. This waits for the observed dataset to finish - * its first action. Only the result of the first action is available. Subsequent actions do not - * modify the result. - * - * @return the observed metrics as a `Map[String, Any]` - * @throws InterruptedException interrupted while waiting - */ - @throws[InterruptedException] - def get: Map[String, _] = { - synchronized { - // we need to loop as wait might return without us calling notify - // https://en.wikipedia.org/w/index.php?title=Spurious_wakeup&oldid=992601610 - while (this.metrics.isEmpty) { - wait() - } - } - - this.metrics.get - } - - /** - * (Java-specific) Get the observed metrics. This waits for the observed dataset to finish - * its first action. Only the result of the first action is available. Subsequent actions do not - * modify the result. - * - * @return the observed metrics as a `java.util.Map[String, Object]` - * @throws InterruptedException interrupted while waiting - */ - @throws[InterruptedException] - def getAsJava: java.util.Map[String, AnyRef] = { - get.map { case (key, value) => (key, value.asInstanceOf[Object])}.asJava - } - - /** - * Get the observed metrics. This returns the metrics if they are available, otherwise an empty. - * - * @return the observed metrics as a `Map[String, Any]` - */ - @throws[InterruptedException] - private[sql] def getOrEmpty: Map[String, _] = { - synchronized { - if (metrics.isEmpty) { - wait(100) // Wait for 100ms to see if metrics are available - } - metrics.getOrElse(Map.empty) - } - } - private[sql] def register(sparkSession: SparkSession, dataframeId: Long): Unit = { // makes this class thread-safe: // only the first thread entering this block can set sparkSession @@ -158,9 +103,8 @@ class Observation(val name: String) { case _ => false }) { val row = qe.observedMetrics.get(name) - this.metrics = row.map(r => r.getValuesMap[Any](r.schema.fieldNames.toImmutableArraySeq)) - if (metrics.isDefined) { - notifyAll() + val metrics = row.map(r => r.getValuesMap[Any](r.schema.fieldNames.toImmutableArraySeq)) + if (setMetricsAndNotify(metrics)) { unregister() } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org