This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch branch-3.2 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push: new 4bfcdf3 [SPARK-34893][SS] Support session window natively 4bfcdf3 is described below commit 4bfcdf38cf1a6e98b9677ceca7f32edc3f628d18 Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Fri Jul 16 20:38:16 2021 +0900 [SPARK-34893][SS] Support session window natively Introduction: this PR is the last part of SPARK-10816 (EventTime based sessionization (session window)). Please refer #31937 to see the overall view of the code change. (Note that code diff could be diverged a bit.) ### What changes were proposed in this pull request? This PR proposes to support native session window. Please refer the comments/design doc in SPARK-10816 for more details on the rationalization and design (could be outdated a bit compared to the PR). The definition of the boundary of "session window" is [the timestamp of start event ~ the timestamp of last event + gap duration). That said, unlike time window, session window is a dynamic window which can expand if new input row is added to the session. To handle expansion of session window, Spark defines session window per input row, and "merge" windows if they can be merged (boundaries are overlapped). This PR leverages two different approaches on merging session windows: 1. merging session windows with Spark's aggregation logic (a variant of sort aggregation) 2. updating session window for all rows bound to the same session, and applying aggregation logic afterwards First one is preferable as it outperforms compared to the second one, though it can be only used if merging session window can be applied altogether with aggregation. It is not applicable on all the cases, so second one is used to cover the remaining cases. This PR also applies the optimization on merging input rows and existing sessions with retaining the order (group keys + start timestamp of session window), leveraging the fact the number of existing sessions per group key won't be huge. The state format is versioned, so that we can bring a new state format if we find a better one. ### Why are the changes needed? For now, to deal with sessionization, Spark requires end users to play with (flat)MapGroupsWithState directly which has a couple of major drawbacks: 1. (flat)MapGroupsWithState is lower level API and end users have to code everything in details for defining session window and merging windows 2. built-in aggregate functions cannot be used and end users have to deal with aggregation by themselves 3. (flat)MapGroupsWithState is only available in Scala/Java. With native support of session window, end users simply use "session_window" like they use "window" for tumbling/sliding window, and leverage built-in aggregate functions as well as UDAFs to simply define aggregations. Quoting the query example from test suite: ``` val inputData = MemoryStream[(String, Long)] // Split the lines into words, treat words as sessionId of events val events = inputData.toDF() .select($"_1".as("value"), $"_2".as("timestamp")) .withColumn("eventTime", $"timestamp".cast("timestamp")) .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") .withWatermark("eventTime", "30 seconds") val sessionUpdates = events .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) .agg(count("*").as("numEvents")) .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", "numEvents") ``` which is same as StructuredSessionization (native session window is shorter and clearer even ignoring model classes). https://github.com/apache/spark/blob/39542bb81f8570219770bb6533c077f44f6cbd2a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredSessionization.scala#L66-L105 (Worth noting that the code in StructuredSessionization only works with processing time. The code doesn't consider old event can update the start time of old session.) ### Does this PR introduce _any_ user-facing change? Yes. This PR brings the new feature to support session window on both batch and streaming query, which adds a new function "session_window" which usage is similar with "window". ### How was this patch tested? New test suites. Also tested with benchmark code. Closes #33081 from HeartSaVioR/SPARK-34893-SPARK-10816-PR-31570-part-5. Lead-authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Co-authored-by: Liang-Chi Hsieh <vii...@gmail.com> Co-authored-by: Yuanjian Li <yuanjian...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> (cherry picked from commit f2bf8b051beb6a3f4b714c4fd6d1a5c5ac942d8e) Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- python/pyspark/sql/functions.py | 35 ++ python/pyspark/sql/functions.pyi | 1 + .../spark/sql/catalyst/analysis/Analyzer.scala | 86 +++- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/SessionWindow.scala | 77 ++++ .../sql/catalyst/expressions/TimeWindow.scala | 4 +- .../spark/sql/errors/QueryCompilationErrors.scala | 5 +- .../org/apache/spark/sql/internal/SQLConf.scala | 24 ++ .../spark/sql/execution/SparkStrategies.scala | 33 +- .../spark/sql/execution/aggregate/AggUtils.scala | 190 ++++++++- .../aggregate/UpdatingSessionsIterator.scala | 13 +- .../execution/python/AggregateInPandasExec.scala | 49 ++- .../execution/streaming/IncrementalExecution.scala | 20 + .../state/StreamingSessionWindowStateManager.scala | 17 +- .../execution/streaming/statefulOperators.scala | 289 +++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 30 ++ .../sql-functions/sql-expression-schema.md | 9 +- .../spark/sql/DataFrameSessionWindowingSuite.scala | 290 +++++++++++++ .../spark/sql/DataFrameTimeWindowingSuite.scala | 2 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 3 +- .../streaming/UpdatingSessionsIteratorSuite.scala | 8 +- .../sql/expressions/ExpressionInfoSuite.scala | 1 + .../streaming/StreamingSessionWindowSuite.scala | 460 +++++++++++++++++++++ 23 files changed, 1608 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d4f527d..06d58b8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2330,6 +2330,41 @@ def window(timeColumn, windowDuration, slideDuration=None, startTime=None): return Column(res) +def session_window(timeColumn, gapDuration): + """ + Generates session window given a timestamp specifying column. + Session window is one of dynamic windows, which means the length of window is varying + according to the given inputs. The length of session window is defined as "the timestamp + of latest input of the session + gap duration", so when the new inputs are bound to the + current session window, the end time of session window can be expanded according to the new + inputs. + Windows can support microsecond precision. Windows in the order of months are not supported. + For a streaming query, you may use the function `current_timestamp` to generate windows on + processing time. + gapDuration is provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid + interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. + The output column will be a struct called 'session_window' by default with the nested columns + 'start' and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`. + .. versionadded:: 3.2.0 + Examples + -------- + >>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> w = df.groupBy(session_window("date", "5 seconds")).agg(sum("val").alias("sum")) + >>> w.select(w.session_window.start.cast("string").alias("start"), + ... w.session_window.end.cast("string").alias("end"), "sum").collect() + [Row(start='2016-03-11 09:00:07', end='2016-03-11 09:00:12', sum=1)] + """ + def check_string_field(field, fieldName): + if not field or type(field) is not str: + raise TypeError("%s should be provided as a string" % fieldName) + + sc = SparkContext._active_spark_context + time_col = _to_java_column(timeColumn) + check_string_field(gapDuration, "gapDuration") + res = sc._jvm.functions.session_window(time_col, gapDuration) + return Column(res) + + # ---------------------------- misc functions ---------------------------------- def crc32(col): diff --git a/python/pyspark/sql/functions.pyi b/python/pyspark/sql/functions.pyi index 0a4aabf..051a6f1 100644 --- a/python/pyspark/sql/functions.pyi +++ b/python/pyspark/sql/functions.pyi @@ -135,6 +135,7 @@ def window( slideDuration: Optional[str] = ..., startTime: Optional[str] = ..., ) -> Column: ... +def session_window(timeColumn: ColumnOrName, gapDuration: str) -> Column: ... def crc32(col: ColumnOrName) -> Column: ... def md5(col: ColumnOrName) -> Column: ... def sha1(col: ColumnOrName) -> Column: ... diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3bab58d..e8ab874 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -296,6 +296,7 @@ class Analyzer(override val catalogManager: CatalogManager) GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: + SessionWindowing :: ResolveInlineTables :: ResolveHigherOrderFunctions(catalogManager) :: ResolveLambdaVariables :: @@ -3856,9 +3857,13 @@ object TimeWindowing extends Rule[LogicalPlan] { val windowExpressions = p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet - val numWindowExpr = windowExpressions.size + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + // Only support a single window expression for now - if (numWindowExpr == 1 && + if (numWindowExpr == 1 && windowExpressions.nonEmpty && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { @@ -3933,6 +3938,83 @@ object TimeWindowing extends Rule[LogicalPlan] { } } +/** Maps a time column to a session window. */ +object SessionWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val SESSION_COL_NAME = "session_window" + private final val SESSION_START = "start" + private final val SESSION_END = "end" + + /** + * Generates the logical plan for generating session window on a timestamp column. + * Each session window is initially defined as [timestamp, timestamp + gap). + * + * This also adds a marker to the session column so that downstream can easily find the column + * on session window. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val sessionExpressions = + p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet + + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + + // Only support a single session expression for now + if (numWindowExpr == 1 && sessionExpressions.nonEmpty && + sessionExpressions.head.timeColumn.resolved && + sessionExpressions.head.checkInputDataTypes().isSuccess) { + + val session = sessionExpressions.head + + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() + + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() + + val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType) + val sessionEnd = sessionStart + session.gapDuration + + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, TimestampType) :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, TimestampType) :: + Nil) + + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } + + // As same as tumbling window, we add a filter to filter out nulls. + val filterExpr = IsNotNull(session.timeColumn) + + replacedPlan.withNewChildren( + Project(sessionStruct +: child.output, + Filter(filterExpr, child)) :: Nil) + } else if (numWindowExpr > 1) { + throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p) + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} + /** * Resolve expressions if they contains [[NamePlaceholder]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 60ca1e9..234da76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -552,6 +552,7 @@ object FunctionRegistry { expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), + expression[SessionWindow]("session_window"), expression[MakeDate]("make_date"), expression[MakeTimestamp]("make_timestamp"), expression[MakeTimestampNTZ]("make_timestamp_ntz", true), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala new file mode 100644 index 0000000..60b0744 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.types._ + +/** + * Represent the session window. + * + * @param timeColumn the start time of session window + * @param gapDuration the duration of session gap, meaning the session will close if there is + * no new element appeared within "the last element in session + gap". + */ +case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + ////////////////////////// + // SQL Constructors + ////////////////////////// + + def this(timeColumn: Expression, gapDuration: Expression) = { + this(timeColumn, TimeWindow.parseExpression(gapDuration)) + } + + override def child: Expression = timeColumn + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)) + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + /** Validate the inputs for the gap duration in addition to the input data type. */ + override def checkInputDataTypes(): TypeCheckResult = { + val dataTypeCheck = super.checkInputDataTypes() + if (dataTypeCheck.isSuccess) { + if (gapDuration <= 0) { + return TypeCheckFailure(s"The window duration ($gapDuration) must be greater than 0.") + } + } + dataTypeCheck + } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(timeColumn = newChild) +} + +object SessionWindow { + val marker = "spark.sessionWindow" + + def apply( + timeColumn: Expression, + gapDuration: String): SessionWindow = { + SessionWindow(timeColumn, + TimeWindow.getIntervalInMicroSeconds(gapDuration)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 5b13872..e79e8d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -109,7 +109,7 @@ object TimeWindow { * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond * precision. */ - private def getIntervalInMicroSeconds(interval: String): Long = { + def getIntervalInMicroSeconds(interval: String): Long = { val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval)) if (cal.months != 0) { throw new IllegalArgumentException( @@ -122,7 +122,7 @@ object TimeWindow { * Parses the duration expression to generate the long value for the original constructor so * that we can use `window` in SQL. */ - private def parseExpression(expr: Expression): Long = expr match { + def parseExpression(expr: Expression): Long = expr match { case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) case IntegerLiteral(i) => i.toLong case NonNullLiteral(l, LongType) => l.toString.toLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 2cee614..7a33d52a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -366,8 +366,9 @@ private[spark] object QueryCompilationErrors { } def multiTimeWindowExpressionsNotSupportedError(t: TreeNode[_]): Throwable = { - new AnalysisException("Multiple time window expressions would result in a cartesian product " + - "of rows, therefore they are currently not supported.", t.origin.line, t.origin.startPosition) + new AnalysisException("Multiple time/session window expressions would result in a cartesian " + + "product of rows, therefore they are currently not supported.", t.origin.line, + t.origin.startPosition) } def viewOutputNumberMismatchQueryColumnNamesError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cc53d92..4061965 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1610,6 +1610,27 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = + buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") + .internal() + .doc("When true, streaming session window sorts and merge sessions in local partition " + + "prior to shuffle. This is to reduce the rows to shuffle, but only beneficial when " + + "there're lots of rows in a batch being assigned to same sessions.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + + val STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION = + buildConf("spark.sql.streaming.sessionWindow.stateFormatVersion") + .internal() + .doc("State format version used by streaming session window in a streaming query. " + + "State between versions are tend to be incompatible, so state format version shouldn't " + + "be modified after running.") + .version("3.2.0") + .intConf + .checkValue(v => Set(1).contains(v), "Valid version is 1") + .createWithDefault(1) + val UNSUPPORTED_OPERATION_CHECK_ENABLED = buildConf("spark.sql.streaming.unsupportedOperationCheck") .internal() @@ -3678,6 +3699,9 @@ class SQLConf extends Serializable with Logging { def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) + def streamingSessionWindowMergeSessionInLocalPartition: Boolean = + getConf(STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION) + def datetimeJava8ApiEnabled: Boolean = getConf(DATETIME_JAVA8API_ENABLED) def uiExplainMode: String = getConf(UI_EXPLAIN_MODE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 65a5923..6d10fa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -324,7 +324,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw QueryCompilationErrors.groupAggPandasUDFUnsupportedByStreamingAggError() } - val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + val sessionWindowOption = namedGroupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because // `groupingExpressions` is not extracted during logical phase. @@ -335,12 +337,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - AggUtils.planStreamingAggregation( - normalizedGroupingExpressions, - aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), - rewrittenResultExpressions, - stateVersion, - planLater(child)) + sessionWindowOption match { + case Some(sessionWindow) => + val stateVersion = conf.getConf(SQLConf.STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION) + + AggUtils.planStreamingAggregationForSession( + normalizedGroupingExpressions, + sessionWindow, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + stateVersion, + conf.streamingSessionWindowMergeSessionInLocalPartition, + planLater(child)) + + case None => + val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + + AggUtils.planStreamingAggregation( + normalizedGroupingExpressions, + aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), + rewrittenResultExpressions, + stateVersion, + planLater(child)) + } case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 58d3411..0f239b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} +import org.apache.spark.sql.execution.streaming._ /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -113,6 +114,11 @@ object AggUtils { resultExpressions = partialResultExpressions, child = child) + // If we have session window expression in aggregation, we add MergingSessionExec to + // merge sessions with calculating aggregation values. + val interExec: SparkPlan = mayAppendMergingSessionExec(groupingExpressions, + aggregateExpressions, partialAggregate) + // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result @@ -126,7 +132,7 @@ object AggUtils { aggregateAttributes = finalAggregateAttributes, initialInputBufferOffset = groupingExpressions.length, resultExpressions = resultExpressions, - child = partialAggregate) + child = interExec) finalAggregate :: Nil } @@ -140,6 +146,11 @@ object AggUtils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + // If we have session window expression in aggregation, we add UpdatingSessionsExec to + // calculate sessions for input rows and update rows' session column, so that further + // aggregations can aggregate input rows for the same session. + val maySessionChild = mayAppendUpdatingSessionExec(groupingExpressions, child) + val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute) val groupingAttributes = groupingExpressions.map(_.toAttribute) @@ -156,7 +167,7 @@ object AggUtils { aggregateAttributes = aggregateAttributes, resultExpressions = groupingAttributes ++ distinctAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), - child = child) + child = maySessionChild) } // 2. Create an Aggregate Operator for partial merge aggregations. @@ -345,4 +356,177 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + + /** + * Plans a streaming session aggregation using the following progression: + * + * - Partial Aggregation + * - all tuples will have aggregated columns with initial value + * - (If "spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition" is enabled) + * - Sort within partition (sort: all keys) + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) + * - SessionWindowStateStoreRestore (group: keys "without" session) + * - merge input tuples with stored tuples (sessions) respecting sort order + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - NOTE: it leverages the fact that the output of SessionWindowStateStoreRestore is sorted + * - now there is at most 1 tuple per group, key with session + * - SessionWindowStateStoreSave (group: keys "without" session) + * - saves tuple(s) for the next batch (multiple sessions could co-exist at the same time) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregationForSession( + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + stateFormatVersion: Int, + mergeSessionsInLocalPartition: Boolean, + child: SparkPlan): Seq[SparkPlan] = { + + val groupWithoutSessionExpression = groupingExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + if (groupWithoutSessionExpression.isEmpty) { + throw new AnalysisException("Global aggregation with session window in streaming query" + + " is not supported.") + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + // Here doing partial merge is to have aggregated columns with default value for each row. + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = if (mergeSessionsInLocalPartition) { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + + // sort happens here to merge sessions on each partition + // this is to reduce amount of rows to shuffle + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + } else { + partialAggregate + } + + // shuffle & sort happens here: most of details are also handled in this physical plan + val restored = SessionWindowStateStoreRestoreExec(groupingWithoutSessionAttributes, + sessionExpression.toAttribute, stateInfo = None, eventTimeWatermark = None, + stateFormatVersion, partialMerged1) + + val mergedSessions = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + MergingSessionsExec( + requiredChildDistributionExpressions = None, + requiredChildDistributionOption = Some(restored.requiredChildDistribution), + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored + ) + } + + // Note: stateId and returnAllStates are filled in later with preparation rules + // in IncrementalExecution. + val saved = SessionWindowStateStoreSaveExec( + groupingWithoutSessionAttributes, + sessionExpression.toAttribute, + stateInfo = None, + outputMode = None, + eventTimeWatermark = None, + stateFormatVersion, mergedSessions) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } + + private def mayAppendUpdatingSessionExec( + groupingExpressions: Seq[NamedExpression], + maybeChildPlan: SparkPlan): SparkPlan = { + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + UpdatingSessionsExec( + groupingExpressions.map(_.toAttribute), + sessionExpression.toAttribute, + maybeChildPlan) + + case None => maybeChildPlan + } + } + + private def mayAppendMergingSessionExec( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + partialAggregate: SparkPlan): SparkPlan = { + groupingExpressions.find(_.metadata.contains(SessionWindow.marker)) match { + case Some(sessionExpression) => + val aggExpressions = aggregateExpressions.map(_.copy(mode = PartialMerge)) + val aggAttributes = aggregateExpressions.map(_.resultAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val groupingWithoutSessionExpressions = groupingExpressions.diff(Seq(sessionExpression)) + val groupingWithoutSessionsAttributes = groupingWithoutSessionExpressions + .map(_.toAttribute) + + MergingSessionsExec( + requiredChildDistributionExpressions = Some(groupingWithoutSessionsAttributes), + requiredChildDistributionOption = None, + groupingExpressions = groupingAttributes, + sessionExpression = sessionExpression, + aggregateExpressions = aggExpressions, + aggregateAttributes = aggAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate + ) + + case None => partialAggregate + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala index bb474a1..0a60ddc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala @@ -181,7 +181,8 @@ class UpdatingSessionsIterator( private val valueProj = GenerateUnsafeProjection.generate(valuesExpressions, inputSchema) private val restoreProj = GenerateUnsafeProjection.generate(inputSchema, - groupingExpressions.map(_.toAttribute) ++ valuesExpressions.map(_.toAttribute)) + groupingWithoutSession.map(_.toAttribute) ++ Seq(sessionExpression.toAttribute) ++ + valuesExpressions.map(_.toAttribute)) private def generateGroupingKey(): InternalRow = { val newRow = new SpecificInternalRow(Seq(sessionExpression.toAttribute).toStructType) @@ -190,19 +191,21 @@ class UpdatingSessionsIterator( } private def closeCurrentSession(keyChanged: Boolean): Unit = { - assert(returnRowsIter == null || !returnRowsIter.hasNext) - returnRows = rowsForCurrentSession rowsForCurrentSession = null - val groupingKey = generateGroupingKey() + val groupingKey = generateGroupingKey().copy() val currentRowsIter = returnRows.generateIterator().map { internalRow => val valueRow = valueProj(internalRow) restoreProj(join2(groupingKey, valueRow)).copy() } - returnRowsIter = currentRowsIter + if (returnRowsIter != null && returnRowsIter.hasNext) { + returnRowsIter = returnRowsIter ++ currentRowsIter + } else { + returnRowsIter = currentRowsIter + } if (keyChanged) processedKeys.add(currentKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 5019008e..69802b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils @@ -53,6 +54,17 @@ case class AggregateInPandasExec( override def producedAttributes: AttributeSet = AttributeSet(output) + val sessionWindowOption = groupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + + val groupingWithoutSessionExpressions = sessionWindowOption match { + case Some(sessionExpression) => + groupingExpressions.filterNot { p => p.semanticEquals(sessionExpression) } + + case None => groupingExpressions + } + override def requiredChildDistribution: Seq[Distribution] = { if (groupingExpressions.isEmpty) { AllTuples :: Nil @@ -61,6 +73,14 @@ case class AggregateInPandasExec( } } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = sessionWindowOption match { + case Some(sessionExpression) => + Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + + case None => Seq(groupingExpressions.map(SortOrder(_, Ascending))) + } + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => @@ -73,9 +93,6 @@ case class AggregateInPandasExec( } } - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingExpressions.map(SortOrder(_, Ascending))) - override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() @@ -107,13 +124,18 @@ case class AggregateInPandasExec( // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + // If we have session window expression in aggregation, we wrap iterator with + // UpdatingSessionIterator to calculate sessions for input rows and update + // rows' session column, so that further aggregations can aggregate input rows + // for the same session. + val newIter: Iterator[InternalRow] = mayAppendUpdatingSessionIterator(iter) val prunedProj = UnsafeProjection.create(allInputs.toSeq, child.output) val grouped = if (groupingExpressions.isEmpty) { // Use an empty unsafe row as a place holder for the grouping key - Iterator((new UnsafeRow(), iter)) + Iterator((new UnsafeRow(), newIter)) } else { - GroupedIterator(iter, groupingExpressions, child.output) + GroupedIterator(newIter, groupingExpressions, child.output) }.map { case (key, rows) => (key, rows.map(prunedProj)) } @@ -157,4 +179,21 @@ case class AggregateInPandasExec( override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) + + + private def mayAppendUpdatingSessionIterator( + iter: Iterator[InternalRow]): Iterator[InternalRow] = { + val newIter = sessionWindowOption match { + case Some(sessionExpression) => + val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold + val spillThreshold = conf.windowExecBufferSpillThreshold + + new UpdatingSessionsIterator(iter, groupingWithoutSessionExpressions, sessionExpression, + child.output, inMemoryThreshold, spillThreshold) + + case None => iter + } + + newIter + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index e98996b..3e772e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -149,6 +149,26 @@ class IncrementalExecution( stateFormatVersion, child) :: Nil)) + case SessionWindowStateStoreSaveExec(keys, session, None, None, None, stateFormatVersion, + UnaryExecNode(agg, + SessionWindowStateStoreRestoreExec(_, _, None, None, _, child))) => + val aggStateInfo = nextStatefulOperationStateInfo + SessionWindowStateStoreSaveExec( + keys, + session, + Some(aggStateInfo), + Some(outputMode), + Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, + agg.withNewChildren( + SessionWindowStateStoreRestoreExec( + keys, + session, + Some(aggStateInfo), + Some(offsetSeqMetadata.batchWatermarkMs), + stateFormatVersion, + child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => StreamingDeduplicateExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala index 6561286..5130933 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManager.scala @@ -68,8 +68,11 @@ sealed trait StreamingSessionWindowStateManager extends Serializable { * {@code extractKeyWithoutSession}. * @param sessions The all sessions including existing sessions if it's active. * Existing sessions which aren't included in this parameter will be removed. + * @return A tuple having two elements + * 1. number of added/updated rows + * 2. number of deleted rows */ - def updateSessions(store: StateStore, key: UnsafeRow, sessions: Seq[UnsafeRow]): Unit + def updateSessions(store: StateStore, key: UnsafeRow, sessions: Seq[UnsafeRow]): (Long, Long) /** * Removes using a predicate on values, with returning removed values via iterator. @@ -168,7 +171,7 @@ class StreamingSessionWindowStateManagerImplV1( override def updateSessions( store: StateStore, key: UnsafeRow, - sessions: Seq[UnsafeRow]): Unit = { + sessions: Seq[UnsafeRow]): (Long, Long) = { // Below two will be used multiple times - need to make sure this is not a stream or iterator. val newValues = sessions.toList val savedStates = getSessionsWithKeys(store, key) @@ -225,7 +228,7 @@ class StreamingSessionWindowStateManagerImplV1( store: StateStore, key: UnsafeRow, oldValues: List[(UnsafeRow, UnsafeRow)], - values: List[UnsafeRow]): Unit = { + values: List[UnsafeRow]): (Long, Long) = { // Here the key doesn't represent the state key - we need to construct the key for state val keyAndValues = values.map { row => val sessionStart = helper.extractTimePair(row)._1 @@ -236,16 +239,24 @@ class StreamingSessionWindowStateManagerImplV1( val keysForValues = keyAndValues.map(_._1) val keysForOldValues = oldValues.map(_._1) + var upsertedRows = 0L + var deletedRows = 0L + // We should "replace" the value instead of "delete" and "put" if the start time // equals to. This will remove unnecessary tombstone being written to the delta, which is // implementation details on state store implementations. + keysForOldValues.filterNot(keysForValues.contains).foreach { oldKey => store.remove(oldKey) + deletedRows += 1 } keyAndValues.foreach { case (key, value) => store.put(key, value) + upsertedRows += 1 } + + (upsertedRows, deletedRows) } override def abortIfNeeded(store: StateStore): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3f6a7ba..2dd91de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID import java.util.concurrent.TimeUnit._ +import scala.annotation.tailrec import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD @@ -511,6 +513,293 @@ case class StateStoreSaveExec( copy(child = newChild) } +/** + * This class sorts input rows and existing sessions in state and provides output rows as + * sorted by "group keys + start time of session window". + * + * Refer [[MergingSortWithSessionWindowStateIterator]] for more details. + */ +case class SessionWindowStateStoreRestoreExec( + keyWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo], + eventTimeWatermark: Option[Long], + stateFormatVersion: Int, + child: SparkPlan) + extends UnaryExecNode with StateStoreReader with WatermarkSupport { + + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + + assert(keyExpressions.nonEmpty, "Grouping key must be specified when using sessionWindow") + + private val stateManager = StreamingSessionWindowStateManager.createStateManager( + keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) + + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + child.execute().mapPartitionsWithReadStateStore( + getStateInfo, + stateManager.getStateKeySchema, + stateManager.getStateValueSchema, + numColsPrefixKey = stateManager.getNumColsForPrefixKey, + session.sessionState, + Some(session.streams.stateStoreCoordinator)) { case (store, iter) => + + // We need to filter out outdated inputs + val filteredIterator = watermarkPredicateForData match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + new MergingSortWithSessionWindowStateIterator( + filteredIterator, + stateManager, + store, + keyWithoutSessionExpressions, + sessionExpression, + child.output).map { row => + numOutputRows += 1 + row + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = { + (keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending)) + } + + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + Seq((keyWithoutSessionExpressions ++ Seq(sessionExpression)).map(SortOrder(_, Ascending))) + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + +/** + * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. + */ +case class SessionWindowStateStoreSaveExec( + keyWithoutSessionExpressions: Seq[Attribute], + sessionExpression: Attribute, + stateInfo: Option[StatefulOperatorStateInfo] = None, + outputMode: Option[OutputMode] = None, + eventTimeWatermark: Option[Long] = None, + stateFormatVersion: Int, + child: SparkPlan) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions + + private val stateManager = StreamingSessionWindowStateManager.createStateManager( + keyWithoutSessionExpressions, sessionExpression, child.output, stateFormatVersion) + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + assert(outputMode.nonEmpty, + "Incorrect planning in IncrementalExecution, outputMode has not been set") + assert(keyExpressions.nonEmpty, + "Grouping key must be specified when using sessionWindow") + + child.execute().mapPartitionsWithStateStore( + getStateInfo, + stateManager.getStateKeySchema, + stateManager.getStateValueSchema, + numColsPrefixKey = stateManager.getNumColsForPrefixKey, + session.sessionState, + Some(session.streams.stateStoreCoordinator)) { case (store, iter) => + + val numOutputRows = longMetric("numOutputRows") + val numRemovedStateRows = longMetric("numRemovedStateRows") + val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") + val allRemovalsTimeMs = longMetric("allRemovalsTimeMs") + val commitTimeMs = longMetric("commitTimeMs") + + outputMode match { + // Update and output all rows in the StateStore. + case Some(Complete) => + allUpdatesTimeMs += timeTakenMs { + putToStore(iter, store) + } + commitTimeMs += timeTakenMs { + stateManager.commit(store) + } + setStoreMetrics(store) + stateManager.iterator(store).map { row => + numOutputRows += 1 + row + } + + // Update and output only rows being evicted from the StateStore + // Assumption: watermark predicates must be non-empty if append mode is allowed + case Some(Append) => + allUpdatesTimeMs += timeTakenMs { + val filteredIter = applyRemovingRowsOlderThanWatermark(iter, + watermarkPredicateForData.get) + putToStore(filteredIter, store) + } + + val removalStartTimeNs = System.nanoTime + new NextIterator[InternalRow] { + private val removedIter = stateManager.removeByValueCondition( + store, watermarkPredicateForData.get.eval) + + override protected def getNext(): InternalRow = { + if (!removedIter.hasNext) { + finished = true + null + } else { + numRemovedStateRows += 1 + numOutputRows += 1 + removedIter.next() + } + } + + override protected def close(): Unit = { + allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs) + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + setOperatorMetrics() + } + } + + case Some(Update) => + val baseIterator = watermarkPredicateForData match { + case Some(predicate) => applyRemovingRowsOlderThanWatermark(iter, predicate) + case None => iter + } + val iterPutToStore = iteratorPutToStore(baseIterator, store, + returnOnlyUpdatedRows = true) + new NextIterator[InternalRow] { + private val updatesStartTimeNs = System.nanoTime + + override protected def getNext(): InternalRow = { + if (iterPutToStore.hasNext) { + val row = iterPutToStore.next() + numOutputRows += 1 + row + } else { + finished = true + null + } + } + + override protected def close(): Unit = { + allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs) + + allRemovalsTimeMs += timeTakenMs { + if (watermarkPredicateForData.nonEmpty) { + val removedIter = stateManager.removeByValueCondition( + store, watermarkPredicateForData.get.eval) + while (removedIter.hasNext) { + numRemovedStateRows += 1 + removedIter.next() + } + } + } + commitTimeMs += timeTakenMs { store.commit() } + setStoreMetrics(store) + setOperatorMetrics() + } + } + + case _ => throw QueryExecutionErrors.invalidStreamingOutputModeError(outputMode) + } + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + } + + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + (outputMode.contains(Append) || outputMode.contains(Update)) && + eventTimeWatermark.isDefined && + newMetadata.batchWatermarkMs > eventTimeWatermark.get + } + + private def iteratorPutToStore( + iter: Iterator[InternalRow], + store: StateStore, + returnOnlyUpdatedRows: Boolean): Iterator[InternalRow] = { + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val numRemovedStateRows = longMetric("numRemovedStateRows") + + new NextIterator[InternalRow] { + var curKey: UnsafeRow = null + val curValuesOnKey = new mutable.ArrayBuffer[UnsafeRow]() + + private def applyChangesOnKey(): Unit = { + if (curValuesOnKey.nonEmpty) { + val (upserted, deleted) = stateManager.updateSessions(store, curKey, curValuesOnKey.toSeq) + numUpdatedStateRows += upserted + numRemovedStateRows += deleted + curValuesOnKey.clear + } + } + + @tailrec + override protected def getNext(): InternalRow = { + if (!iter.hasNext) { + applyChangesOnKey() + finished = true + return null + } + + val row = iter.next().asInstanceOf[UnsafeRow] + val key = stateManager.extractKeyWithoutSession(row) + + if (curKey == null || curKey != key) { + // new group appears + applyChangesOnKey() + curKey = key.copy() + } + + // must copy the row, for this row is a reference in iterator and + // will change when iter.next + curValuesOnKey += row.copy + + if (!returnOnlyUpdatedRows) { + row + } else { + if (stateManager.newOrModified(store, row)) { + row + } else { + // current row isn't the "updated" row, continue to the next row + getNext() + } + } + } + + override protected def close(): Unit = {} + } + } + + private def putToStore(baseIter: Iterator[InternalRow], store: StateStore): Unit = { + val iterPutToStore = iteratorPutToStore(baseIter, store, returnOnlyUpdatedRows = false) + while (iterPutToStore.hasNext) { + iterPutToStore.next() + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} + + /** Physical operator for executing streaming Deduplicate. */ case class StreamingDeduplicateExec( keyExpressions: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3b39d97..7db8e8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3631,6 +3631,36 @@ object functions { } /** + * Generates session window given a timestamp specifying column. + * + * Session window is one of dynamic windows, which means the length of window is varying + * according to the given inputs. The length of session window is defined as "the timestamp + * of latest input of the session + gap duration", so when the new inputs are bound to the + * current session window, the end time of session window can be expanded according to the new + * inputs. + * + * Windows can support microsecond precision. gapDuration in the order of months are not + * supported. + * + * For a streaming query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time column must be of TimestampType. + * @param gapDuration A string specifying the timeout of the session, e.g. `10 minutes`, + * `1 second`. Check `org.apache.spark.unsafe.types.CalendarInterval` for + * valid duration identifiers. + * + * @group datetime_funcs + * @since 3.2.0 + */ + def session_window(timeColumn: Column, gapDuration: String): Column = { + withExpr { + SessionWindow(timeColumn.expr, gapDuration) + }.as("session_window") + } + + /** * Creates timestamp from the number of seconds since UTC epoch. * @group datetime_funcs * @since 3.1.0 diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index c13a1d4..41692d2 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,8 +1,8 @@ <!-- Automatically generated by ExpressionsSchemaSuite --> ## Summary - - Number of queries: 360 - - Number of expressions that missing example: 13 - - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,window + - Number of queries: 361 + - Number of expressions that missing example: 14 + - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint,session_window,window ## Schema of Built-in Functions | Class name | Function name or alias | Query example | Output schema | | ---------- | ---------------------- | ------------- | ------------- | @@ -244,6 +244,7 @@ | org.apache.spark.sql.catalyst.expressions.SecondsToTimestamp | timestamp_seconds | SELECT timestamp_seconds(1230219000) | struct<timestamp_seconds(1230219000):timestamp> | | org.apache.spark.sql.catalyst.expressions.Sentences | sentences | SELECT sentences('Hi there! Good morning.') | struct<sentences(Hi there! Good morning., , ):array<array<string>>> | | org.apache.spark.sql.catalyst.expressions.Sequence | sequence | SELECT sequence(1, 5) | struct<sequence(1, 5):array<int>> | +| org.apache.spark.sql.catalyst.expressions.SessionWindow | session_window | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha | SELECT sha('Spark') | struct<sha(Spark):string> | | org.apache.spark.sql.catalyst.expressions.Sha1 | sha1 | SELECT sha1('Spark') | struct<sha1(Spark):string> | | org.apache.spark.sql.catalyst.expressions.Sha2 | sha2 | SELECT sha2('Spark', 256) | struct<sha2(Spark, 256):string> | @@ -365,4 +366,4 @@ | org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> | | org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> | | org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> | -| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> | \ No newline at end of file +| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala new file mode 100644 index 0000000..b70b2c6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.catalyst.plans.logical.Expand +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType + +class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession + with BeforeAndAfterEach { + + import testImplicits._ + + test("simple session window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + + test("session window groupBy statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"session_window.start".asc) + .select("counts"), + Seq(Row(2), Row(1)) + ) + } + + test("session window groupBy with multiple keys statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + + test("session window groupBy with multiple keys statement - one distinct") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:40:04", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum_distinct(col("value")).as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 2, 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 2) + ) + ) + } + + test("session window groupBy with multiple keys statement - two distinct") { + val df = Seq( + ("2016-03-27 19:39:34", 1, 2, "a"), + ("2016-03-27 19:39:39", 1, 2, "a"), + ("2016-03-27 19:39:56", 2, 4, "a"), + ("2016-03-27 19:40:04", 2, 4, "a"), + ("2016-03-27 19:39:27", 4, 8, "b")).toDF("time", "value", "value2", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // key "a" => (19:39:34 ~ 19:39:49) (19:39:56 ~ 19:40:14) + // key "b" => (19:39:27 ~ 19:39:37) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(sum_distinct(col("value")).as("sum"), sum_distinct(col("value2")).as("sum2")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "sum", "sum2"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 4, 8), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:49", "a", 1, 2), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:14", "a", 2, 4) + ) + ) + } + + test("session window groupBy with multiple keys statement - keys overlapped with sessions") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:39", 1, "b"), + ("2016-03-27 19:39:40", 2, "a"), + ("2016-03-27 19:39:45", 2, "b"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + // session window handles sort while applying group by + // whereas time window doesn't + + // expected sessions + // a => (19:39:34 ~ 19:39:50) + // b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55) + + checkAnswer( + df.groupBy(session_window($"time", "10 seconds"), 'id) + .agg(count("*").as("counts"), sum("value").as("sum")) + .orderBy($"session_window.start".asc) + .selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)", + "id", "counts", "sum"), + + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", "b", 1, 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:50", "a", 2, 3), + Row("2016-03-27 19:39:39", "2016-03-27 19:39:55", "b", 2, 3) + ) + ) + } + + test("session window with multi-column projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) + .select($"session_window.start".cast("string"), $"session_window.end".cast("string"), + $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Session windows shouldn't require expand") + + checkAnswer( + df, + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + + test("session window combined with explode expression") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session_window($"time", "10 seconds"), $"value", explode($"ids")) + .orderBy($"session_window.start".asc).select("value"), + // first window exploded to two rows for "a", and "b", second window exploded to 3 rows + Seq(Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("null timestamps") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + checkDataset( + df.select(session_window($"time", "10 seconds"), $"value") + .orderBy($"session_window.start".asc) + .select("value") + .as[Int], + 1, 2) // null columns are dropped + } + + // NOTE: unlike time window, joining session windows without grouping + // doesn't arrange session, so two rows will be joined only if session range is exactly same + + test("multiple session windows in a single operator throws nice exception") { + val df = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "value") + val e = intercept[AnalysisException] { + df.select(session_window($"time", "10 second"), session_window($"time", "15 second")) + .collect() + } + assert(e.getMessage.contains( + "Multiple time/session window expressions would result in a cartesian product")) + } + + test("aliased session windows") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(session_window($"time", "10 seconds").as("session_window"), $"value") + .orderBy($"session_window.start".asc) + .select("value"), + Seq(Row(1), Row(2)) + ) + } + + private def withTempTable(f: String => Unit): Unit = { + val tableName = "temp" + Seq( + ("2016-03-27 19:39:34", 1), + ("2016-03-27 19:39:56", 2), + ("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName) + try { + f(tableName) + } finally { + spark.catalog.dropTempView(tableName) + } + } + + test("time window in SQL with single string expression") { + withTempTable { table => + checkAnswer( + spark.sql(s"""select session_window(time, "10 seconds"), value from $table""") + .select($"session_window.start".cast(StringType), $"session_window.end".cast(StringType), + $"value"), + Seq( + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 4), + Row("2016-03-27 19:39:34", "2016-03-27 19:39:44", 1), + Row("2016-03-27 19:39:56", "2016-03-27 19:40:06", 2) + ) + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 4fdaeb5..2ef43dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -239,7 +239,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSparkSession { df.select(window($"time", "10 second"), window($"time", "15 second")).collect() } assert(e.getMessage.contains( - "Multiple time window expressions would result in a cartesian product")) + "Multiple time/session window expressions would result in a cartesian product")) } test("aliased windows") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b0d5c89..1e23c11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -136,7 +136,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark test("SPARK-14415: All functions should have own descriptions") { for (f <- spark.sessionState.functionRegistry.listFunction()) { - if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f.unquotedString)) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window", + "session_window").contains(f.unquotedString)) { checkKeywordsNotExist(sql(s"describe function $f"), "N/A.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala index 2a4245d..045901bc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala @@ -199,9 +199,9 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { val row6 = createRow("a", 2, 115, 125, 20, 1.2) val rows3 = List(row5, row6) + // This is to test the edge case that the last input row creates a new session. val row7 = createRow("a", 2, 127, 137, 30, 1.3) - val row8 = createRow("a", 2, 135, 145, 40, 1.4) - val rows4 = List(row7, row8) + val rows4 = List(row7) val rowsAll = rows1 ++ rows2 ++ rows3 ++ rows4 @@ -244,8 +244,8 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { } retRows4.zip(rows4).foreach { case (retRow, expectedRow) => - // session being expanded to (127 ~ 145) - assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 145) + // session being expanded to (127 ~ 137) + assertRowsEqualsWithNewSession(expectedRow, retRow, 127, 137) } assert(iterator.hasNext === false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 30ee97a..08e21d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -133,6 +133,7 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { val ignoreSet = Set( // Explicitly inherits NonSQLExpression, and has no ExpressionDescription "org.apache.spark.sql.catalyst.expressions.TimeWindow", + "org.apache.spark.sql.catalyst.expressions.SessionWindow", // Cast aliases do not need examples "org.apache.spark.sql.catalyst.expressions.Cast") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala new file mode 100644 index 0000000..a381d06 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -0,0 +1,460 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.util.Locale + +import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.must.Matchers + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider} +import org.apache.spark.sql.functions.{count, session_window, sum} +import org.apache.spark.sql.internal.SQLConf + +class StreamingSessionWindowSuite extends StreamTest + with BeforeAndAfter with Matchers with Logging { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + def testWithAllOptions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + val mergingSessionOptions = Seq(true, false).map { value => + (SQLConf.STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION.key, value) + } + val providerOptions = Seq( + classOf[HDFSBackedStateStoreProvider].getCanonicalName, + classOf[RocksDBStateStoreProvider].getCanonicalName + ).map { value => + (SQLConf.STATE_STORE_PROVIDER_CLASS.key, value.stripSuffix("$")) + } + + val availableOptions = for ( + opt1 <- mergingSessionOptions; + opt2 <- providerOptions + ) yield (opt1, opt2) + + for (option <- availableOptions) { + test(s"$name - merging sessions in local partition: ${option._1._2} / " + + s"provider: ${option._2._2}") { + withSQLConf(confPairs ++ + Seq( + option._1._1 -> option._1._2.toString, + option._2._1 -> option._2._2): _*) { + func + } + } + } + } + + testWithAllOptions("complete mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + // note that complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + sessionUpdates.explain() + + testStream(sessionUpdates, OutputMode.Complete())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + CheckNewAnswer( + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + CheckNewAnswer( + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ), + + AddData(inputData, ("structured streaming", 90L)), + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1), + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1), + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) + ) + ) + } + + testWithAllOptions("complete mode - session window - no key") { + // complete mode doesn't honor watermark: even it is specified, watermark will be + // always Unix timestamp 0 + + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + val e = intercept[StreamingQueryException] { + testStream(windowedAggregation, OutputMode.Complete())( + AddData(inputData, 40), + CheckAnswer() // this is just to trigger the exception + ) + } + Seq("Global aggregation with session window", "not supported").foreach { m => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) + } + } + + testWithAllOptions("append mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "30 seconds") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Append())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1) + CheckNewAnswer( + ), + + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4), + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("structured", 41, 51, 10, 1) + ) + ) + } + + testWithAllOptions("append mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + val e = intercept[StreamingQueryException] { + testStream(windowedAggregation)( + AddData(inputData, 40), + CheckAnswer() // this is just to trigger the exception + ) + } + Seq("Global aggregation with session window", "not supported").foreach { m => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) + } + } + + testWithAllOptions("update mode - session window") { + // Implements StructuredSessionization.scala leveraging "session" function + // as a test, to verify the sessionization works with simple example + + val inputData = MemoryStream[(String, Long)] + + // Split the lines into words, treat words as sessionId of events + val events = inputData.toDF() + .select($"_1".as("value"), $"_2".as("timestamp")) + .withColumn("eventTime", $"timestamp".cast("timestamp")) + .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") + .withWatermark("eventTime", "10 seconds") + + val sessionUpdates = events + .groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) + .agg(count("*").as("numEvents")) + .selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", + "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", + "numEvents") + + testStream(sessionUpdates, OutputMode.Update())( + AddData(inputData, + ("hello world spark streaming", 40L), + ("world hello structured streaming", 41L) + ), + // watermark: 11 + // current sessions + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("hello", 40, 51, 11, 2), + ("world", 40, 51, 11, 2), + ("streaming", 40, 51, 11, 2), + ("spark", 40, 50, 10, 1), + ("structured", 41, 51, 10, 1) + ), + + // placing new sessions "before" previous sessions + AddData(inputData, ("spark streaming", 25L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 35, 10, 1), + ("streaming", 25, 35, 10, 1) + ), + + // late event which session's end 10 would be later than watermark 11: should be dropped + AddData(inputData, ("spark streaming", 0L)), + // watermark: 11 + // current sessions + // ("spark", 25, 35, 10, 1), + // ("streaming", 25, 35, 10, 1), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("streaming", 40, 51, 11, 2), + // ("spark", 40, 50, 10, 1), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ), + + // concatenating multiple previous sessions into one + AddData(inputData, ("spark streaming", 30L)), + // watermark: 11 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("spark", 25, 50, 25, 3), + ("streaming", 25, 51, 26, 4) + ), + + // placing new sessions after previous sessions + AddData(inputData, ("hello apache spark", 60L)), + // watermark: 30 + // current sessions + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1), + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1) + CheckNewAnswer( + ("hello", 60, 70, 10, 1), + ("apache", 60, 70, 10, 1), + ("spark", 60, 70, 10, 1) + ), + + AddData(inputData, ("structured streaming", 90L)), + // watermark: 60 + // current sessions + // ("hello", 60, 70, 10, 1), + // ("apache", 60, 70, 10, 1), + // ("spark", 60, 70, 10, 1), + // ("structured", 90, 100, 10, 1), + // ("streaming", 90, 100, 10, 1) + // evicted + // ("spark", 25, 50, 25, 3), + // ("streaming", 25, 51, 26, 4), + // ("hello", 40, 51, 11, 2), + // ("world", 40, 51, 11, 2), + // ("structured", 41, 51, 10, 1) + CheckNewAnswer( + ("structured", 90, 100, 10, 1), + ("streaming", 90, 100, 10, 1) + ) + ) + } + + testWithAllOptions("update mode - session window - no key") { + val inputData = MemoryStream[Int] + + val windowedAggregation = inputData.toDF() + .selectExpr("*") + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .groupBy(session_window($"eventTime", "5 seconds") as 'session) + .agg(count("*") as 'count, sum("value") as 'sum) + .select($"session".getField("start").cast("long").as[Long], + $"session".getField("end").cast("long").as[Long], $"count".as[Long], $"sum".as[Long]) + + val e = intercept[StreamingQueryException] { + testStream(windowedAggregation, OutputMode.Update())( + AddData(inputData, 40), + CheckAnswer() // this is just to trigger the exception + ) + } + Seq("Global aggregation with session window", "not supported").foreach { m => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(m.toLowerCase(Locale.ROOT))) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org