This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8a16aed9a17 [SPARK-43511][CONNECT][SS] Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect 8a16aed9a17 is described below commit 8a16aed9a17269b4c8111779229507e3c28ba945 Author: bogao007 <bo....@databricks.com> AuthorDate: Wed Jun 21 15:35:34 2023 +0900 [SPARK-43511][CONNECT][SS] Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect ### What changes were proposed in this pull request? Implemented MapGroupsWithState and FlatMapGroupsWithState APIs for Spark Connect ### Why are the changes needed? To support streaming state APIs in Spark Connect ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? Added unit test Closes #41558 from bogao007/sc-state-api. Authored-by: bogao007 <bo....@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../apache/spark/sql/KeyValueGroupedDataset.scala | 398 +++++++++++++++++++++ .../sql/KeyValueGroupedDatasetE2ETestSuite.scala | 107 ++++++ .../CheckConnectJvmClientCompatibility.scala | 6 - .../FlatMapGroupsWithStateStreamingSuite.scala | 224 ++++++++++++ .../function/FlatMapGroupsWithStateFunction.java | 39 ++ .../java/function/MapGroupsWithStateFunction.java | 38 ++ .../main/protobuf/spark/connect/relations.proto | 16 + .../apache/spark/sql/connect/common/UdfUtils.scala | 26 ++ .../apache/spark/sql/streaming/GroupState.scala | 336 +++++++++++++++++ .../sql/connect/planner/SparkConnectPlanner.scala | 92 ++++- python/pyspark/sql/connect/proto/relations_pb2.py | 24 +- python/pyspark/sql/connect/proto/relations_pb2.pyi | 84 ++++- 12 files changed, 1359 insertions(+), 31 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 7b2fa3b52be..20c130b83cb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.ScalarUserDefinedFunction import org.apache.spark.sql.functions.col +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} /** * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -460,6 +461,356 @@ abstract class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable cogroupSorted(other)(thisSortExprs: _*)(otherSortExprs: _*)( UdfUtils.coGroupFunctionToScalaFunc(f))(encoder) } + + protected def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder]( + outputMode: Option[OutputMode], + timeoutConf: GroupStateTimeout, + initialState: Option[KeyValueGroupedDataset[K, S]], + isMapGroupWithState: Boolean)( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { + throw new UnsupportedOperationException + } + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See + * [[org.apache.spark.sql.streaming.GroupState]] for more details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.5.0 + */ + def mapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + mapGroupsWithState(GroupStateTimeout.NoTimeout)(func) + } + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See + * [[org.apache.spark.sql.streaming.GroupState]] for more details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.5.0 + */ + def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + flatMapGroupsWithStateHelper(None, timeoutConf, None, isMapGroupWithState = true)( + UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func)) + } + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See + * [[org.apache.spark.sql.streaming.GroupState]] for more details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param timeoutConf + * Timeout Conf, see GroupStateTimeout for more details + * @param initialState + * The user provided state that will be initialized when the first batch of data is processed + * in the streaming query. The user defined function will be called on the state data even if + * there are no other values in the group. To convert a Dataset ds of type Dataset[(K, S)] to + * a KeyValueGroupedDataset[K, S] do {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.5.0 + */ + def mapGroupsWithState[S: Encoder, U: Encoder]( + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S])( + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { + flatMapGroupsWithStateHelper( + None, + timeoutConf, + Some(initialState), + isMapGroupWithState = true)(UdfUtils.mapGroupsWithStateFuncToFlatMapAdaptor(func)) + } + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.5.0 + */ + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + mapGroupsWithState[S, U](UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))( + stateEncoder, + outputEncoder) + } + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.5.0 + */ + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = { + mapGroupsWithState[S, U](timeoutConf)(UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))( + stateEncoder, + outputEncoder) + } + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * @param initialState + * The user provided state that will be initialized when the first batch of data is processed + * in the streaming query. The user defined function will be called on the state data even if + * there are no other values in the group. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.5.0 + */ + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { + mapGroupsWithState[S, U](timeoutConf, initialState)( + UdfUtils.mapGroupsWithStateFuncToScalaFunc(func))(stateEncoder, outputEncoder) + } + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param outputMode + * The output mode of the function. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.5.0 + */ + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + outputMode: OutputMode, + timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { + flatMapGroupsWithStateHelper( + Some(outputMode), + timeoutConf, + None, + isMapGroupWithState = false)(func) + } + + /** + * (Scala-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param outputMode + * The output mode of the function. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * @param initialState + * The user provided state that will be initialized when the first batch of data is processed + * in the streaming query. The user defined function will be called on the state data even if + * there are no other values in the group. To covert a Dataset `ds` of type of type + * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use + * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} See [[Encoder]] for more details on what + * types are encodable to Spark SQL. + * @since 3.5.0 + */ + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S])( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { + flatMapGroupsWithStateHelper( + Some(outputMode), + timeoutConf, + Some(initialState), + isMapGroupWithState = false)(func) + } + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param outputMode + * The output mode of the function. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.5.0 + */ + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout): Dataset[U] = { + val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func) + flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) + } + + /** + * (Java-specific) Applies the given function to each group of data, while maintaining a + * user-defined per-group state. The result Dataset will represent the objects returned by the + * function. For a static batch Dataset, the function will be invoked once per group. For a + * streaming Dataset, the function will be invoked for each group repeatedly in every trigger, + * and updates to each group's state will be saved across invocations. See `GroupState` for more + * details. + * + * @tparam S + * The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U + * The type of the output objects. Must be encodable to Spark SQL types. + * @param func + * Function to be called on every group. + * @param outputMode + * The output mode of the function. + * @param stateEncoder + * Encoder for the state type. + * @param outputEncoder + * Encoder for the output type. + * @param timeoutConf + * Timeout configuration for groups that do not receive data for a while. + * @param initialState + * The user provided state that will be initialized when the first batch of data is processed + * in the streaming query. The user defined function will be called on the state data even if + * there are no other values in the group. To covert a Dataset `ds` of type of type + * `Dataset[(K, S)]` to a `KeyValueGroupedDataset[K, S]`, use + * {{{ds.groupByKey(x => x._1).mapValues(_._2)}}} + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 3.5.0 + */ + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + outputMode: OutputMode, + stateEncoder: Encoder[S], + outputEncoder: Encoder[U], + timeoutConf: GroupStateTimeout, + initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = { + val f = UdfUtils.flatMapGroupsWithStateFuncToScalaFunc(func) + flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(f)( + stateEncoder, + outputEncoder) + } } /** @@ -572,6 +923,53 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( agg(aggregator) } + override protected def flatMapGroupsWithStateHelper[S: Encoder, U: Encoder]( + outputMode: Option[OutputMode], + timeoutConf: GroupStateTimeout, + initialState: Option[KeyValueGroupedDataset[K, S]], + isMapGroupWithState: Boolean)( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] = { + if (outputMode.isDefined && outputMode.get != OutputMode.Append && + outputMode.get != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + + if (initialState.isDefined) { + assert(initialState.get.isInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]]) + } + + val initialStateImpl = if (initialState.isDefined) { + initialState.get.asInstanceOf[KeyValueGroupedDatasetImpl[K, S, _, _]] + } else { + null + } + + val outputEncoder = encoderFor[U] + val nf = if (valueMapFunc == UdfUtils.identical()) { + func + } else { + UdfUtils.mapValuesAdaptor(func, valueMapFunc) + } + + sparkSession.newDataset[U](outputEncoder) { builder => + val groupMapBuilder = builder.getGroupMapBuilder + groupMapBuilder + .setInput(plan.getRoot) + .addAllGroupingExpressions(groupingExprs) + .setFunc(getUdf(nf, outputEncoder)(ivEncoder)) + .setIsMapGroupsWithState(isMapGroupWithState) + .setOutputMode(if (outputMode.isEmpty) OutputMode.Update.toString + else outputMode.get.toString) + .setTimeoutConf(timeoutConf.toString) + + if (initialStateImpl != null) { + groupMapBuilder + .addAllInitialGroupingExpressions(initialStateImpl.groupingExprs) + .setInitialInput(initialStateImpl.plan.getRoot) + } + } + } + private def getUdf[U: Encoder](nf: AnyRef, outputEncoder: AgnosticEncoder[U])( inEncoders: AgnosticEncoder[_]*): proto.CommonInlineUserDefinedFunction = { val inputEncoders = kEncoder +: inEncoders // Apply keyAs changes by setting kEncoder diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index e7a77eed70d..404239f7e14 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -16,14 +16,21 @@ */ package org.apache.spark.sql +import java.sql.Timestamp import java.util.Arrays import io.grpc.StatusRuntimeException +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} import org.apache.spark.sql.types._ +case class ClickEvent(id: String, timestamp: Timestamp) + +case class ClickState(id: String, count: Int) + /** * All tests in this class requires client UDF artifacts synced with the server. */ @@ -447,4 +454,104 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { checkDataset(keys, "1", "2", "10", "20") } + + test("flatMapGroupsWithState") { + val stateFunc = + (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + Iterator(ClickState(key, values.size)) + } + + val session: SparkSession = spark + import session.implicits._ + val values = Seq( + ClickEvent("a", new Timestamp(12345)), + ClickEvent("a", new Timestamp(12346)), + ClickEvent("a", new Timestamp(12347)), + ClickEvent("b", new Timestamp(12348)), + ClickEvent("b", new Timestamp(12349)), + ClickEvent("c", new Timestamp(12350))) + .toDS() + .groupByKey(_.id) + .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc) + + checkDataset(values, ClickState("a", 3), ClickState("b", 2), ClickState("c", 1)) + } + + test("flatMapGroupsWithState - with initial state") { + val stateFunc = + (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => { + val currState = state.getOption.getOrElse(ClickState(key, 0)) + Iterator(ClickState(key, currState.count + values.size)) + } + + val session: SparkSession = spark + import session.implicits._ + + val initialStateDS = Seq(ClickState("a", 2), ClickState("b", 1)).toDS() + val initialState = initialStateDS.groupByKey(_.id).mapValues(x => x) + + val values = Seq( + ClickEvent("a", new Timestamp(12345)), + ClickEvent("a", new Timestamp(12346)), + ClickEvent("a", new Timestamp(12347)), + ClickEvent("b", new Timestamp(12348)), + ClickEvent("b", new Timestamp(12349)), + ClickEvent("c", new Timestamp(12350))) + .toDS() + .groupByKey(_.id) + .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout, initialState)(stateFunc) + + checkDataset(values, ClickState("a", 5), ClickState("b", 3), ClickState("c", 1)) + } + + test("mapGroupsWithState") { + val stateFunc = + (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + ClickState(key, values.size) + } + + val session: SparkSession = spark + import session.implicits._ + val values = Seq( + ClickEvent("a", new Timestamp(12345)), + ClickEvent("a", new Timestamp(12346)), + ClickEvent("a", new Timestamp(12347)), + ClickEvent("b", new Timestamp(12348)), + ClickEvent("b", new Timestamp(12349)), + ClickEvent("c", new Timestamp(12350))) + .toDS() + .groupByKey(_.id) + .mapGroupsWithState(GroupStateTimeout.NoTimeout)(stateFunc) + + checkDataset(values, ClickState("a", 3), ClickState("b", 2), ClickState("c", 1)) + } + + test("mapGroupsWithState - with initial state") { + val stateFunc = + (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => { + val currState = state.getOption.getOrElse(ClickState(key, 0)) + ClickState(key, currState.count + values.size) + } + + val session: SparkSession = spark + import session.implicits._ + + val initialStateDS = Seq(ClickState("a", 2), ClickState("b", 1)).toDS() + val initialState = initialStateDS.groupByKey(_.id).mapValues(x => x) + + val values = Seq( + ClickEvent("a", new Timestamp(12345)), + ClickEvent("a", new Timestamp(12346)), + ClickEvent("a", new Timestamp(12347)), + ClickEvent("b", new Timestamp(12348)), + ClickEvent("b", new Timestamp(12349)), + ClickEvent("c", new Timestamp(12350))) + .toDS() + .groupByKey(_.id) + .mapGroupsWithState(GroupStateTimeout.NoTimeout, initialState)(stateFunc) + + checkDataset(values, ClickState("a", 5), ClickState("b", 3), ClickState("c", 1)) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 7a9a889706d..6b648fd152b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -201,12 +201,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), // KeyValueGroupedDataset - ProblemFilters.exclude[Problem]( - "org.apache.spark.sql.KeyValueGroupedDataset.mapGroupsWithState" - ), // streaming - ProblemFilters.exclude[Problem]( - "org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState" - ), // streaming ProblemFilters.exclude[Problem]( "org.apache.spark.sql.KeyValueGroupedDataset.queryExecution"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.KeyValueGroupedDataset.this"), diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala new file mode 100644 index 00000000000..cdb6b9a2e9c --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.sql.Timestamp + +import org.scalatest.concurrent.Eventually.eventually +import org.scalatest.concurrent.Futures.timeout +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.{SparkSession, SQLHelper} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append +import org.apache.spark.sql.connect.client.util.QueryTest +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + +case class ClickEvent(id: String, timestamp: Timestamp) + +case class ClickState(id: String, count: Int) + +class FlatMapGroupsWithStateStreamingSuite extends QueryTest with SQLHelper { + + val flatMapGroupsWithStateSchema: StructType = StructType( + Array(StructField("id", StringType), StructField("timestamp", TimestampType))) + + val flatMapGroupsWithStateData: Seq[ClickEvent] = Seq( + ClickEvent("a", new Timestamp(12345)), + ClickEvent("a", new Timestamp(12346)), + ClickEvent("a", new Timestamp(12347)), + ClickEvent("b", new Timestamp(12348)), + ClickEvent("b", new Timestamp(12349)), + ClickEvent("c", new Timestamp(12350))) + + val flatMapGroupsWithStateInitialStateData: Seq[ClickState] = + Seq(ClickState("a", 2), ClickState("b", 1)) + + test("flatMapGroupsWithState - streaming") { + val session: SparkSession = spark + import session.implicits._ + + val stateFunc = + (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + Iterator(ClickState(key, values.size)) + } + spark.sql("DROP TABLE IF EXISTS my_sink") + + withTempPath { dir => + val path = dir.getCanonicalPath + flatMapGroupsWithStateData.toDS().write.parquet(path) + val q = spark.readStream + .schema(flatMapGroupsWithStateSchema) + .parquet(path) + .as[ClickEvent] + .groupByKey(_.id) + .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout)(stateFunc) + .writeStream + .format("memory") + .queryName("my_sink") + .start() + + try { + q.processAllAvailable() + eventually(timeout(30.seconds)) { + checkDataset( + spark.table("my_sink").toDF().as[ClickState], + ClickState("c", 1), + ClickState("b", 2), + ClickState("a", 3)) + } + } finally { + q.stop() + spark.sql("DROP TABLE IF EXISTS my_sink") + } + } + } + + test("flatMapGroupsWithState - streaming - with initial state") { + val session: SparkSession = spark + import session.implicits._ + + val stateFunc = + (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => { + val currState = state.getOption.getOrElse(ClickState(key, 0)) + Iterator(ClickState(key, currState.count + values.size)) + } + val initialState = flatMapGroupsWithStateInitialStateData + .toDS() + .groupByKey(_.id) + .mapValues(x => x) + spark.sql("DROP TABLE IF EXISTS my_sink") + + withTempPath { dir => + val path = dir.getCanonicalPath + flatMapGroupsWithStateData.toDS().write.parquet(path) + val q = spark.readStream + .schema(flatMapGroupsWithStateSchema) + .parquet(path) + .as[ClickEvent] + .groupByKey(_.id) + .flatMapGroupsWithState(Append, GroupStateTimeout.NoTimeout, initialState)(stateFunc) + .writeStream + .format("memory") + .queryName("my_sink") + .start() + + try { + q.processAllAvailable() + eventually(timeout(30.seconds)) { + checkDataset( + spark.table("my_sink").toDF().as[ClickState], + ClickState("c", 1), + ClickState("b", 3), + ClickState("a", 5)) + } + } finally { + q.stop() + spark.sql("DROP TABLE IF EXISTS my_sink") + } + } + } + + test("mapGroupsWithState - streaming") { + val session: SparkSession = spark + import session.implicits._ + + val stateFunc = + (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + ClickState(key, values.size) + } + spark.sql("DROP TABLE IF EXISTS my_sink") + + withTempPath { dir => + val path = dir.getCanonicalPath + flatMapGroupsWithStateData.toDS().write.parquet(path) + val q = spark.readStream + .schema(flatMapGroupsWithStateSchema) + .parquet(path) + .as[ClickEvent] + .groupByKey(_.id) + .mapGroupsWithState(GroupStateTimeout.NoTimeout)(stateFunc) + .writeStream + .format("memory") + .queryName("my_sink") + .outputMode("update") + .start() + + try { + q.processAllAvailable() + eventually(timeout(30.seconds)) { + checkDataset( + spark.table("my_sink").toDF().as[ClickState], + ClickState("c", 1), + ClickState("b", 2), + ClickState("a", 3)) + } + } finally { + q.stop() + spark.sql("DROP TABLE IF EXISTS my_sink") + } + } + } + + test("mapGroupsWithState - streaming - with initial state") { + val session: SparkSession = spark + import session.implicits._ + + val stateFunc = + (key: String, values: Iterator[ClickEvent], state: GroupState[ClickState]) => { + val currState = state.getOption.getOrElse(ClickState(key, 0)) + ClickState(key, currState.count + values.size) + } + val initialState = flatMapGroupsWithStateInitialStateData + .toDS() + .groupByKey(_.id) + .mapValues(x => x) + spark.sql("DROP TABLE IF EXISTS my_sink") + + withTempPath { dir => + val path = dir.getCanonicalPath + flatMapGroupsWithStateData.toDS().write.parquet(path) + val q = spark.readStream + .schema(flatMapGroupsWithStateSchema) + .parquet(path) + .as[ClickEvent] + .groupByKey(_.id) + .mapGroupsWithState(GroupStateTimeout.NoTimeout, initialState)(stateFunc) + .writeStream + .format("memory") + .queryName("my_sink") + .outputMode("update") + .start() + + try { + q.processAllAvailable() + eventually(timeout(30.seconds)) { + checkDataset( + spark.table("my_sink").toDF().as[ClickState], + ClickState("c", 1), + ClickState("b", 3), + ClickState("a", 5)) + } + } finally { + q.stop() + spark.sql("DROP TABLE IF EXISTS my_sink") + } + } + } +} diff --git a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java new file mode 100644 index 00000000000..c917c8d28be --- /dev/null +++ b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.streaming.GroupState; + +/** + * ::Experimental:: + * Base interface for a map function used in + * {@code org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState( + * FlatMapGroupsWithStateFunction, org.apache.spark.sql.streaming.OutputMode, + * org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} + * @since 3.5.0 + */ +@Experimental +@Evolving +public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable { + Iterator<R> call(K key, Iterator<V> values, GroupState<S> state) throws Exception; +} diff --git a/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java new file mode 100644 index 00000000000..ae179ad7d27 --- /dev/null +++ b/connector/connect/common/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.streaming.GroupState; + +/** + * ::Experimental:: + * Base interface for a map function used in + * {@code org.apache.spark.sql.KeyValueGroupedDataset.mapGroupsWithState( + * MapGroupsWithStateFunction, org.apache.spark.sql.Encoder, org.apache.spark.sql.Encoder)} + * @since 3.5.0 + */ +@Experimental +@Evolving +public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable { + R call(K key, Iterator<V> values, GroupState<S> state) throws Exception; +} diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 6347bd7bc56..ea432bb48fc 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -857,6 +857,22 @@ message GroupMap { // (Optional) Expressions for sorting. Only used by Scala Sorted Group Map API. repeated Expression sorting_expressions = 4; + + // Below fields are only used by (Flat)MapGroupsWithState + // (Optional) Input relation for initial State. + Relation initial_input = 5; + + // (Optional) Expressions for grouping keys of the initial state input relation. + repeated Expression initial_grouping_expressions = 6; + + // (Optional) True if MapGroupsWithState, false if FlatMapGroupsWithState. + optional bool is_map_groups_with_state = 7; + + // (Optional) The output mode of the function. + optional string output_mode = 8; + + // (Optional) Timeout configuration for groups that do not receive data for a while. + optional string timeout_conf = 9; } message CoGroupMap { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala index 06a6c74f268..883637ff86c 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/UdfUtils.scala @@ -20,6 +20,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.api.java.function._ import org.apache.spark.sql.Row +import org.apache.spark.sql.streaming.GroupState /** * Util functions to help convert input functions between typed filter, map, flatMap, @@ -95,6 +96,31 @@ private[sql] object UdfUtils extends Serializable { } } + def mapValuesAdaptor[K, V, S, U, IV]( + f: (K, Iterator[V], GroupState[S]) => Iterator[U], + valueMapFunc: IV => V): (K, Iterator[IV], GroupState[S]) => Iterator[U] = { + (k: K, itr: Iterator[IV], s: GroupState[S]) => + { + f(k, itr.map(v => valueMapFunc(v)), s) + } + } + + def mapGroupsWithStateFuncToFlatMapAdaptor[K, V, S, U]( + f: (K, Iterator[V], GroupState[S]) => U): (K, Iterator[V], GroupState[S]) => Iterator[U] = { + (k: K, itr: Iterator[V], s: GroupState[S]) => Iterator(f(k, itr, s)) + } + + def mapGroupsWithStateFuncToScalaFunc[K, V, S, U]( + f: MapGroupsWithStateFunction[K, V, S, U]): (K, Iterator[V], GroupState[S]) => U = { + (key, data, groupState) => f.call(key, data.asJava, groupState) + } + + def flatMapGroupsWithStateFuncToScalaFunc[K, V, S, U]( + f: FlatMapGroupsWithStateFunction[K, V, S, U]) + : (K, Iterator[V], GroupState[S]) => Iterator[U] = { (key, data, groupState) => + f.call(key, data.asJava, groupState).asScala + } + def mapReduceFuncToScalaFunc[T](func: ReduceFunction[T]): (T, T) => T = func.call def identical[T](): T => T = t => t diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala new file mode 100644 index 00000000000..bd418a89534 --- /dev/null +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.sql.catalyst.plans.logical.LogicalGroupState + +/** + * :: Experimental :: + * + * Wrapper class for interacting with per-group state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on `KeyValueGroupedDataset`. + * + * Detail description on `[map/flatMap]GroupsWithState` operation + * -------------------------------------------------------------- Both, `mapGroupsWithState` and + * `flatMapGroupsWithState` in `KeyValueGroupedDataset` will invoke the user-given function on + * each group (defined by the grouping function in `Dataset.groupByKey()`) while maintaining a + * user-defined per-group state between invocations. For a static batch Dataset, the function will + * be invoked once per group. For a streaming Dataset, the function will be invoked for each group + * repeatedly in every trigger. That is, in every batch of the `StreamingQuery`, the function will + * be invoked once for each group that has data in the trigger. Furthermore, if timeout is set, + * then the function will be invoked on timed-out groups (more detail below). + * + * The function is invoked with the following parameters. + * - The key of the group. + * - An iterator containing all the values for this group. + * - A user-defined state object set by previous invocations of the given function. + * + * In case of a batch Dataset, there is only one invocation and the state object will be empty as + * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` is + * equivalent to `[map/flatMap]Groups` and any updates to the state and/or timeouts have no + * effect. + * + * The major difference between `mapGroupsWithState` and `flatMapGroupsWithState` is that the + * former allows the function to return one and only one record, whereas the latter allows the + * function to return any number of records (including no records). Furthermore, the + * `flatMapGroupsWithState` is associated with an operation output mode, which can be either + * `Append` or `Update`. Semantically, this defines whether the output records of one trigger is + * effectively replacing the previously output records (from previous triggers) or is appending to + * the list of previously output records. Essentially, this defines how the Result Table (refer to + * the semantics in the programming guide) is updated, and allows us to reason about the semantics + * of later operations. + * + * Important points to note about the function (both mapGroupsWithState and + * flatMapGroupsWithState). + * - In a trigger, the function will be called only the groups present in the batch. So do not + * assume that the function will be called in every trigger for every group that has state. + * - There is no guaranteed ordering of values in the iterator in the function, neither with + * batch, nor with streaming Datasets. + * - All the data will be shuffled before applying the function. + * - If timeout is set, then the function will also be called with no values. See more details + * on `GroupStateTimeout` below. + * + * Important points to note about using `GroupState`. + * - The value of the state cannot be null. So updating state with null will throw + * `IllegalArgumentException`. + * - Operations on `GroupState` are not thread-safe. This is to avoid memory barriers. + * - If `remove()` is called, then `exists()` will return `false`, `get()` will throw + * `NoSuchElementException` and `getOption()` will return `None` + * - After that, if `update(newState)` is called, then `exists()` will again return `true`, + * `get()` and `getOption()`will return the updated value. + * + * Important points to note about using `GroupStateTimeout`. + * - The timeout type is a global param across all the groups (set as `timeout` param in + * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable + * per group by calling `setTimeout...()` in `GroupState`. + * - Timeouts can be either based on processing time (i.e. + * `GroupStateTimeout.ProcessingTimeTimeout`) or event time (i.e. + * `GroupStateTimeout.EventTimeTimeout`). + * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling + * `GroupState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the + * set duration. Guarantees provided by this timeout with a duration of D ms are as follows: + * - Timeout will never occur before the clock time has advanced by D ms + * - Timeout will occur eventually when there is a trigger in the query (i.e. after D ms). So + * there is no strict upper bound on when the timeout would occur. For example, the trigger + * interval of the query will affect when the timeout actually occurs. If there is no data + * in the stream (for any group) for a while, then there will not be any trigger and timeout + * function call will not occur until there is data. + * - Since the processing time timeout is based on the clock time, it is affected by the + * variations in the system clock (i.e. time zone changes, clock skew, etc.). + * - With `EventTimeTimeout`, the user also has to specify the event time watermark in the query + * using `Dataset.withWatermark()`. With this setting, data that is older than the watermark + * is filtered out. The timeout can be set for a group by setting a timeout timestamp + * using`GroupState.setTimeoutTimestamp()`, and the timeout would occur when the watermark + * advances beyond the set timestamp. You can control the timeout delay by two parameters - + * (i) watermark delay and an additional duration beyond the timestamp in the event (which is + * guaranteed to be newer than watermark due to the filtering). Guarantees provided by this + * timeout are as follows: + * - Timeout will never occur before the watermark has exceeded the set timeout. + * - Similar to processing time timeouts, there is no strict upper bound on the delay when the + * timeout actually occurs. The watermark can advance only when there is data in the stream + * and the event time of the data has actually advanced. + * - When the timeout occurs for a group, the function is called for that group with no values, + * and `GroupState.hasTimedOut()` set to true. + * - The timeout is reset every time the function is called on a group, that is, when the group + * has new data, or the group has timed out. So the user has to set the timeout duration every + * time the function is called, otherwise, there will not be any timeout set. + * + * `[map/flatMap]GroupsWithState` can take a user defined initial state as an additional argument. + * This state will be applied when the first batch of the streaming query is processed. If there + * are no matching rows in the data for the keys present in the initial state, the state is still + * applied and the function will be invoked with the values being an empty iterator. + * + * Scala example of using GroupState in `mapGroupsWithState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. + * def mappingFunction(key: String, value: Iterator[Int], state: GroupState[Int]): String = { + * + * if (state.hasTimedOut) { // If called when timing out, remove the state + * state.remove() + * + * } else if (state.exists) { // If state exists, use it for processing + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * state.setTimeoutDuration("1 hour") // Set the timeout + * } + * + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * state.setTimeoutDuration("1 hour") // Set the timeout + * } + * ... + * // return something + * } + * + * dataset + * .groupByKey(...) + * .mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout)(mappingFunction) + * }}} + * + * Java example of using `GroupState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * // Additionally, it sets a timeout to remove the state if it has not received data for an hour. + * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction = + * new MapGroupsWithStateFunction<String, Integer, Integer, String>() { + * + * @Override + * public String call(String key, Iterator<Integer> value, GroupState<Integer> state) { + * if (state.hasTimedOut()) { // If called when timing out, remove the state + * state.remove(); + * + * } else if (state.exists()) { // If state exists, use it for processing + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * state.setTimeoutDuration("1 hour"); // Set the timeout + * } + * + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * state.setTimeoutDuration("1 hour"); // Set the timeout + * } + * ... + * // return something + * } + * }; + * + * dataset + * .groupByKey(...) + * .mapGroupsWithState( + * mappingFunction, Encoders.INT, Encoders.STRING, GroupStateTimeout.ProcessingTimeTimeout); + * }}} + * + * @tparam S + * User-defined type of the state to be stored for each group. Must be encodable into Spark SQL + * types (see `Encoder` for more details). + * @since 3.5.0 + */ +@Experimental +@Evolving +trait GroupState[S] extends LogicalGroupState[S] { + + /** Whether state exists or not. */ + def exists: Boolean + + /** Get the state value if it exists, or throw NoSuchElementException. */ + @throws[NoSuchElementException]("when state does not exist") + def get: S + + /** Get the state value as a scala Option. */ + def getOption: Option[S] + + /** Update the value of the state. */ + def update(newState: S): Unit + + /** Remove this state. */ + def remove(): Unit + + /** + * Whether the function has been called because the key has timed out. + * @note + * This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`. + */ + def hasTimedOut: Boolean + + /** + * Set the timeout duration in ms for this key. + * + * @note + * [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note + * This method has no effect when used in a batch query. + */ + @throws[IllegalArgumentException]("if 'durationMs' is not positive") + @throws[UnsupportedOperationException]( + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") + def setTimeoutDuration(durationMs: Long): Unit + + /** + * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. + * + * @note + * [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note + * This method has no effect when used in a batch query. + */ + @throws[IllegalArgumentException]("if 'duration' is not a valid duration") + @throws[UnsupportedOperationException]( + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") + def setTimeoutDuration(duration: String): Unit + + /** + * Set the timeout timestamp for this key as milliseconds in epoch time. This timestamp cannot + * be older than the current watermark. + * + * @note + * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` + * for calling this method. + * @note + * This method has no effect when used in a batch query. + */ + @throws[IllegalArgumentException]( + "if 'timestampMs' is not positive or less than the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") + def setTimeoutTimestamp(timestampMs: Long): Unit + + /** + * Set the timeout timestamp for this key as milliseconds in epoch time and an additional + * duration as a string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the + * additional duration) cannot be older than the current watermark. + * + * @note + * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` + * for calling this method. + * @note + * This method has no side effect when used in a batch query. + */ + @throws[IllegalArgumentException]( + "if 'additionalDuration' is invalid or the final timeout timestamp is less than " + + "the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") + def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit + + /** + * Set the timeout timestamp for this key as a java.sql.Date. This timestamp cannot be older + * than the current watermark. + * + * @note + * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` + * for calling this method. + * @note + * This method has no side effect when used in a batch query. + */ + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") + def setTimeoutTimestamp(timestamp: java.sql.Date): Unit + + /** + * Set the timeout timestamp for this key as a java.sql.Date and an additional duration as a + * string (e.g. "1 hour", "2 days", etc.). The final timestamp (including the additional + * duration) cannot be older than the current watermark. + * + * @note + * [[GroupStateTimeout Event time timeout]] must be enabled in `[map/flatMap]GroupsWithState` + * for calling this method. + * @note + * This method has no side effect when used in a batch query. + */ + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") + def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit + + /** + * Get the current event time watermark as milliseconds in epoch time. + * + * @note + * In a streaming query, this can be called only when watermark is set before calling + * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1. + */ + @throws[UnsupportedOperationException]( + "if watermark has not been set before in [map|flatMap]GroupsWithState") + def getCurrentWatermarkMs(): Long + + /** + * Get the current processing time as milliseconds in epoch time. + * @note + * In a streaming query, this will return a constant value throughout the duration of a + * trigger, even if the trigger is re-executed. + */ + def getCurrentProcessingTimeMs(): Long +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index dc819fb4020..6ee252d1a58 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -47,7 +47,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, Intersect, LocalRelation, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint} +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket} @@ -64,11 +65,12 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation} import org.apache.spark.sql.execution.python.{PythonForeachWriter, UserDefinedPythonFunction} import org.apache.spark.sql.execution.stat.StatFunctions +import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils} import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst} -import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryProgress, Trigger} +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryProgress, Trigger} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.CacheId @@ -570,16 +572,82 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { rel.getGroupingExpressionsList, rel.getSortingExpressionsList) - val mapped = new MapGroups( - udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], - udf.inputDeserializer(ds.groupingAttributes), - ds.valueDeserializer, - ds.groupingAttributes, - ds.dataAttributes, - ds.sortOrder, - udf.outputObjAttr, - ds.analyzed) - SerializeFromObject(udf.outputNamedExpression, mapped) + if (rel.hasIsMapGroupsWithState) { + val hasInitialState = !rel.getInitialGroupingExpressionsList.isEmpty && rel.hasInitialInput + val initialDs = if (hasInitialState) { + UntypedKeyValueGroupedDataset( + rel.getInitialInput, + rel.getInitialGroupingExpressionsList, + rel.getSortingExpressionsList) + } else { + UntypedKeyValueGroupedDataset( + rel.getInput, + rel.getGroupingExpressionsList, + rel.getSortingExpressionsList) + } + val timeoutConf = if (!rel.hasTimeoutConf) { + GroupStateTimeout.NoTimeout + } else { + groupStateTimeoutFromString(rel.getTimeoutConf) + } + val outputMode = if (!rel.hasOutputMode) { + OutputMode.Update + } else { + InternalOutputModes(rel.getOutputMode) + } + + val flatMapGroupsWithState = if (hasInitialState) { + new FlatMapGroupsWithState( + udf.function + .asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], + udf.inputDeserializer(ds.groupingAttributes), + ds.valueDeserializer, + ds.groupingAttributes, + ds.dataAttributes, + udf.outputObjAttr, + initialDs.vEncoder.asInstanceOf[ExpressionEncoder[Any]], + outputMode, + rel.getIsMapGroupsWithState, + timeoutConf, + hasInitialState, + initialDs.groupingAttributes, + initialDs.dataAttributes, + initialDs.valueDeserializer, + initialDs.analyzed, + ds.analyzed) + } else { + new FlatMapGroupsWithState( + udf.function + .asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], + udf.inputDeserializer(ds.groupingAttributes), + ds.valueDeserializer, + ds.groupingAttributes, + ds.dataAttributes, + udf.outputObjAttr, + initialDs.vEncoder.asInstanceOf[ExpressionEncoder[Any]], + outputMode, + rel.getIsMapGroupsWithState, + timeoutConf, + hasInitialState, + ds.groupingAttributes, + ds.dataAttributes, + udf.inputDeserializer(ds.groupingAttributes), + LocalRelation(initialDs.vEncoder.schema.toAttributes), // empty data set + ds.analyzed) + } + SerializeFromObject(udf.outputNamedExpression, flatMapGroupsWithState) + } else { + val mapped = new MapGroups( + udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], + udf.inputDeserializer(ds.groupingAttributes), + ds.valueDeserializer, + ds.groupingAttributes, + ds.dataAttributes, + ds.sortOrder, + udf.outputObjAttr, + ds.analyzed) + SerializeFromObject(udf.outputNamedExpression, mapped) + } } private def transformCoGroupMap(rel: proto.CoGroupMap): LogicalPlan = { diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 7b1c55408be..20e0a13c5e4 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf3\x16\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf3\x16\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -886,17 +886,17 @@ if _descriptor._USE_C_DESCRIPTORS == False: _MAPPARTITIONS._serialized_start = 10801 _MAPPARTITIONS._serialized_end = 10982 _GROUPMAP._serialized_start = 10985 - _GROUPMAP._serialized_end = 11264 - _COGROUPMAP._serialized_start = 11267 - _COGROUPMAP._serialized_end = 11793 - _APPLYINPANDASWITHSTATE._serialized_start = 11796 - _APPLYINPANDASWITHSTATE._serialized_end = 12153 - _COLLECTMETRICS._serialized_start = 12156 - _COLLECTMETRICS._serialized_end = 12292 - _PARSE._serialized_start = 12295 - _PARSE._serialized_end = 12683 + _GROUPMAP._serialized_end = 11620 + _COGROUPMAP._serialized_start = 11623 + _COGROUPMAP._serialized_end = 12149 + _APPLYINPANDASWITHSTATE._serialized_start = 12152 + _APPLYINPANDASWITHSTATE._serialized_end = 12509 + _COLLECTMETRICS._serialized_start = 12512 + _COLLECTMETRICS._serialized_end = 12648 + _PARSE._serialized_start = 12651 + _PARSE._serialized_end = 13039 _PARSE_OPTIONSENTRY._serialized_start = 3687 _PARSE_OPTIONSENTRY._serialized_end = 3745 - _PARSE_PARSEFORMAT._serialized_start = 12584 - _PARSE_PARSEFORMAT._serialized_end = 12672 + _PARSE_PARSEFORMAT._serialized_start = 12940 + _PARSE_PARSEFORMAT._serialized_end = 13028 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 69a4d6b9ccc..bd6460519a4 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -2986,6 +2986,11 @@ class GroupMap(google.protobuf.message.Message): GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int FUNC_FIELD_NUMBER: builtins.int SORTING_EXPRESSIONS_FIELD_NUMBER: builtins.int + INITIAL_INPUT_FIELD_NUMBER: builtins.int + INITIAL_GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int + IS_MAP_GROUPS_WITH_STATE_FIELD_NUMBER: builtins.int + OUTPUT_MODE_FIELD_NUMBER: builtins.int + TIMEOUT_CONF_FIELD_NUMBER: builtins.int @property def input(self) -> global___Relation: """(Required) Input relation for Group Map API: apply, applyInPandas.""" @@ -3006,6 +3011,24 @@ class GroupMap(google.protobuf.message.Message): pyspark.sql.connect.proto.expressions_pb2.Expression ]: """(Optional) Expressions for sorting. Only used by Scala Sorted Group Map API.""" + @property + def initial_input(self) -> global___Relation: + """Below fields are only used by (Flat)MapGroupsWithState + (Optional) Input relation for initial State. + """ + @property + def initial_grouping_expressions( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Optional) Expressions for grouping keys of the initial state input relation.""" + is_map_groups_with_state: builtins.bool + """(Optional) True if MapGroupsWithState, false if FlatMapGroupsWithState.""" + output_mode: builtins.str + """(Optional) The output mode of the function.""" + timeout_conf: builtins.str + """(Optional) Timeout configuration for groups that do not receive data for a while.""" def __init__( self, *, @@ -3020,23 +3043,82 @@ class GroupMap(google.protobuf.message.Message): pyspark.sql.connect.proto.expressions_pb2.Expression ] | None = ..., + initial_input: global___Relation | None = ..., + initial_grouping_expressions: collections.abc.Iterable[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ] + | None = ..., + is_map_groups_with_state: builtins.bool | None = ..., + output_mode: builtins.str | None = ..., + timeout_conf: builtins.str | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"] + self, + field_name: typing_extensions.Literal[ + "_is_map_groups_with_state", + b"_is_map_groups_with_state", + "_output_mode", + b"_output_mode", + "_timeout_conf", + b"_timeout_conf", + "func", + b"func", + "initial_input", + b"initial_input", + "input", + b"input", + "is_map_groups_with_state", + b"is_map_groups_with_state", + "output_mode", + b"output_mode", + "timeout_conf", + b"timeout_conf", + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_is_map_groups_with_state", + b"_is_map_groups_with_state", + "_output_mode", + b"_output_mode", + "_timeout_conf", + b"_timeout_conf", "func", b"func", "grouping_expressions", b"grouping_expressions", + "initial_grouping_expressions", + b"initial_grouping_expressions", + "initial_input", + b"initial_input", "input", b"input", + "is_map_groups_with_state", + b"is_map_groups_with_state", + "output_mode", + b"output_mode", "sorting_expressions", b"sorting_expressions", + "timeout_conf", + b"timeout_conf", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, + oneof_group: typing_extensions.Literal[ + "_is_map_groups_with_state", b"_is_map_groups_with_state" + ], + ) -> typing_extensions.Literal["is_map_groups_with_state"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_output_mode", b"_output_mode"] + ) -> typing_extensions.Literal["output_mode"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_timeout_conf", b"_timeout_conf"] + ) -> typing_extensions.Literal["timeout_conf"] | None: ... global___GroupMap = GroupMap --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org