Repository: spark Updated Branches: refs/heads/master 7bf09433f -> 9bf4e2baa
[SPARK-19497][SS] Implement streaming deduplication ## What changes were proposed in this pull request? This PR adds a special streaming deduplication operator to support `dropDuplicates` with `aggregation` and watermark. It reuses the `dropDuplicates` API but creates new logical plan `Deduplication` and new physical plan `DeduplicationExec`. The following cases are supported: - one or multiple `dropDuplicates()` without aggregation (with or without watermark) - `dropDuplicates` before aggregation Not supported cases: - `dropDuplicates` after aggregation Breaking changes: - `dropDuplicates` without aggregation doesn't work with `complete` or `update` mode. ## How was this patch tested? The new unit tests. Author: Shixiong Zhu <shixi...@databricks.com> Closes #16970 from zsxwing/dedup. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9bf4e2ba Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9bf4e2ba Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9bf4e2ba Branch: refs/heads/master Commit: 9bf4e2baad0e2851da554d85223ffaa029cfa490 Parents: 7bf0943 Author: Shixiong Zhu <shixi...@databricks.com> Authored: Thu Feb 23 11:25:39 2017 -0800 Committer: Tathagata Das <tathagata.das1...@gmail.com> Committed: Thu Feb 23 11:25:39 2017 -0800 ---------------------------------------------------------------------- python/pyspark/sql/dataframe.py | 6 + .../analysis/UnsupportedOperationChecker.scala | 6 +- .../sql/catalyst/optimizer/Optimizer.scala | 21 +- .../plans/logical/basicLogicalOperators.scala | 9 + .../analysis/UnsupportedOperationsSuite.scala | 56 ++++- .../optimizer/ReplaceOperatorSuite.scala | 33 ++- .../scala/org/apache/spark/sql/Dataset.scala | 39 ++- .../spark/sql/execution/SparkStrategies.scala | 15 +- .../streaming/IncrementalExecution.scala | 10 + .../execution/streaming/statefulOperators.scala | 140 ++++++++--- .../spark/sql/streaming/DeduplicateSuite.scala | 252 +++++++++++++++++++ .../sql/streaming/MapGroupsWithStateSuite.scala | 9 +- .../sql/streaming/StateStoreMetricsTest.scala | 36 +++ .../spark/sql/streaming/StreamSuite.scala | 2 +- .../streaming/StreamingAggregationSuite.scala | 2 +- 15 files changed, 578 insertions(+), 58 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 70efeaf..bb6df22 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1158,6 +1158,12 @@ class DataFrame(object): """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. + For a static batch :class:`DataFrame`, it just drops duplicate rows. For a streaming + :class:`DataFrame`, it will keep all data across triggers as intermediate state to drop + duplicates rows. You can use :func:`withWatermark` to limit how late the duplicate data can + be and system will accordingly limit the state. In addition, too late data older than + watermark will be dropped to avoid any possibility of duplicates. + :func:`drop_duplicates` is an alias for :func:`dropDuplicates`. >>> from pyspark.sql import Row http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 07b3558..397f5cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -75,7 +75,7 @@ object UnsupportedOperationChecker { if (watermarkAttributes.isEmpty) { throwError( s"$outputMode output mode not supported when there are streaming aggregations on " + - s"streaming DataFrames/DataSets")(plan) + s"streaming DataFrames/DataSets without watermark")(plan) } case InternalOutputModes.Complete if aggregates.isEmpty => @@ -120,6 +120,10 @@ object UnsupportedOperationChecker { throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + "streaming DataFrame/Dataset") + case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => + throwError("dropDuplicates is not supported after aggregation on a " + + "streaming DataFrame/Dataset") + case Join(left, right, joinType, _) => joinType match { http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index af846a0..036da3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -56,7 +56,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ReplaceExpressions, ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), - RewriteDistinctAggregates) :: + RewriteDistinctAggregates, + ReplaceDeduplicateWithAggregate) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// @@ -1143,6 +1144,24 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { } /** + * Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator. + */ +object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Deduplicate(keys, child, streaming) if !streaming => + val keyExprIds = keys.map(_.exprId) + val aggCols = child.output.map { attr => + if (keyExprIds.contains(attr.exprId)) { + attr + } else { + Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) + } + } + Aggregate(keys, aggCols, child) + } +} + +/** * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator. * {{{ * SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2 http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d17d12c..ce1c55d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -864,3 +864,12 @@ case object OneRowRelation extends LeafNode { override def output: Seq[Attribute] = Nil override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1) } + +/** A logical plan for `dropDuplicates`. */ +case class Deduplicate( + keys: Seq[Attribute], + child: LogicalPlan, + streaming: Boolean) extends UnaryNode { + + override def output: Seq[Attribute] = child.output +} http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 3b756e8..82be69a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{IntegerType, LongType} +import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} +import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -36,6 +37,11 @@ case class DummyCommand() extends Command class UnsupportedOperationsSuite extends SparkFunSuite { val attribute = AttributeReference("a", IntegerType, nullable = true)() + val watermarkMetadata = new MetadataBuilder() + .withMetadata(attribute.metadata) + .putLong(EventTimeWatermark.delayKey, 1000L) + .build() + val attributeWithWatermark = attribute.withMetadata(watermarkMetadata) val batchRelation = LocalRelation(attribute) val streamRelation = new TestStreamingRelation(attribute) @@ -98,6 +104,27 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Update, expectedMsgs = Seq("multiple streaming aggregations")) + assertSupportedInStreamingPlan( + "aggregate - streaming aggregations in update mode", + Aggregate(Nil, aggExprs("d"), streamRelation), + outputMode = Update) + + assertSupportedInStreamingPlan( + "aggregate - streaming aggregations in complete mode", + Aggregate(Nil, aggExprs("d"), streamRelation), + outputMode = Complete) + + assertSupportedInStreamingPlan( + "aggregate - streaming aggregations with watermark in append mode", + Aggregate(Seq(attributeWithWatermark), aggExprs("d"), streamRelation), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "aggregate - streaming aggregations without watermark in append mode", + Aggregate(Nil, aggExprs("d"), streamRelation), + outputMode = Append, + expectedMsgs = Seq("streaming aggregations", "without watermark")) + // Aggregation: Distinct aggregates not supported on streaming relation val distinctAggExprs = Seq(Count("*").toAggregateExpression(isDistinct = true).as("c")) assertSupportedInStreamingPlan( @@ -129,6 +156,33 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("(map/flatMap)GroupsWithState")) + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on batch relation inside streaming relation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation), + outputMode = Append + ) + + // Deduplicate + assertSupportedInStreamingPlan( + "Deduplicate - Deduplicate on streaming relation before aggregation", + Aggregate( + Seq(attributeWithWatermark), + aggExprs("c"), + Deduplicate(Seq(att), streamRelation, streaming = true)), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "Deduplicate - Deduplicate on streaming relation after aggregation", + Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation), streaming = true), + outputMode = Complete, + expectedMsgs = Seq("dropDuplicates")) + + assertSupportedInStreamingPlan( + "Deduplicate - Deduplicate on batch relation inside a streaming query", + Deduplicate(Seq(att), batchRelation, streaming = false), + outputMode = Append + ) + // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( "inner join", http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index f23e262..e68423f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -30,7 +32,8 @@ class ReplaceOperatorSuite extends PlanTest { Batch("Replace Operators", FixedPoint(100), ReplaceDistinctWithAggregate, ReplaceExceptWithAntiJoin, - ReplaceIntersectWithSemiJoin) :: Nil + ReplaceIntersectWithSemiJoin, + ReplaceDeduplicateWithAggregate) :: Nil } test("replace Intersect with Left-semi Join") { @@ -71,4 +74,32 @@ class ReplaceOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("replace batch Deduplicate with Aggregate") { + val input = LocalRelation('a.int, 'b.int) + val attrA = input.output(0) + val attrB = input.output(1) + val query = Deduplicate(Seq(attrA), input, streaming = false) // dropDuplicates("a") + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate( + Seq(attrA), + Seq( + attrA, + Alias(new First(attrB).toAggregateExpression(), attrB.name)(attrB.exprId) + ), + input) + + comparePlans(optimized, correctAnswer) + } + + test("don't replace streaming Deduplicate") { + val input = LocalRelation('a.int, 'b.int) + val attrA = input.output(0) + val query = Deduplicate(Seq(attrA), input, streaming = true) // dropDuplicates("a") + val optimized = Optimize.execute(query.analyze) + + comparePlans(optimized, query) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1ebc53d..3c212d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -557,7 +557,8 @@ class Dataset[T] private[sql]( * Spark will use this watermark for several purposes: * - To know when a given time window aggregation can be finalized and thus can be emitted when * using output modes that do not allow updates. - * - To minimize the amount of state that we need to keep for on-going aggregations. + * - To minimize the amount of state that we need to keep for on-going aggregations, + * `mapGroupsWithState` and `dropDuplicates` operators. * * The current watermark is computed by looking at the `MAX(eventTime)` seen across * all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost @@ -1981,6 +1982,12 @@ class Dataset[T] private[sql]( * Returns a new Dataset that contains only the unique rows from this Dataset. * This is an alias for `distinct`. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ @@ -1990,13 +1997,19 @@ class Dataset[T] private[sql]( * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { val resolver = sparkSession.sessionState.analyzer.resolver val allColumns = queryExecution.analyzed.output - val groupCols = colNames.flatMap { colName => + val groupCols = colNames.toSet.toSeq.flatMap { (colName: String) => // It is possibly there are more than one columns with the same name, // so we call filter instead of find. val cols = allColumns.filter(col => resolver(col.name, colName)) @@ -2006,21 +2019,19 @@ class Dataset[T] private[sql]( } cols } - val groupColExprIds = groupCols.map(_.exprId) - val aggCols = logicalPlan.output.map { attr => - if (groupColExprIds.contains(attr.exprId)) { - attr - } else { - Alias(new First(attr).toAggregateExpression(), attr.name)() - } - } - Aggregate(groupCols, aggCols, logicalPlan) + Deduplicate(groupCols, logicalPlan, isStreaming) } /** * Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ @@ -2030,6 +2041,12 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] with duplicate rows removed, considering only * the subset of columns. * + * For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it + * will keep all data across triggers as intermediate state to drop duplicates rows. You can use + * [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit + * the state. In addition, too late data older than watermark will be dropped to avoid any + * possibility of duplicates. + * * @group typedrel * @since 2.0.0 */ http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0e3d559..027b148 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,9 +22,10 @@ import org.apache.spark.sql.{SaveMode, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -245,6 +246,18 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } /** + * Used to plan the streaming deduplicate operator. + */ + object StreamingDeduplicationStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case Deduplicate(keys, child, true) => + StreamingDeduplicateExec(keys, planLater(child)) :: Nil + + case _ => Nil + } + } + + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a3e108b..ffdcd9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -45,6 +45,7 @@ class IncrementalExecution( sparkSession.sessionState.planner.StatefulAggregationStrategy +: sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: + sparkSession.sessionState.planner.StreamingDeduplicationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies // Modified planner with stateful operations. @@ -93,6 +94,15 @@ class IncrementalExecution( keys, Some(stateId), child) :: Nil)) + case StreamingDeduplicateExec(keys, child, None, None) => + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) + + StreamingDeduplicateExec( + keys, + child, + Some(stateId), + Some(currentEventTimeWatermark)) case MapGroupsWithStateExec( f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => val stateId = http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 1292452..d925297 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, NullType, StructType} import org.apache.spark.util.CompletionIterator @@ -68,6 +67,40 @@ trait StateStoreWriter extends StatefulOperator { "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) } +/** An operator that supports watermark. */ +trait WatermarkSupport extends SparkPlan { + + /** The keys that may have a watermark attribute. */ + def keyExpressions: Seq[Attribute] + + /** The watermark value. */ + def eventTimeWatermark: Option[Long] + + /** Generate a predicate that matches data older than the watermark */ + lazy val watermarkPredicate: Option[Predicate] = { + val optionalWatermarkAttribute = + keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) + + optionalWatermarkAttribute.map { watermarkAttribute => + // If we are evicting based on a window, use the end of the window. Otherwise just + // use the attribute itself. + val evictionExpression = + if (watermarkAttribute.dataType.isInstanceOf[StructType]) { + LessThanOrEqual( + GetStructField(watermarkAttribute, 1), + Literal(eventTimeWatermark.get * 1000)) + } else { + LessThanOrEqual( + watermarkAttribute, + Literal(eventTimeWatermark.get * 1000)) + } + + logInfo(s"Filtering state store on: $evictionExpression") + newPredicate(evictionExpression, keyExpressions) + } + } +} + /** * For each input tuple, the key is calculated and the value from the [[StateStore]] is added * to the stream (in addition to the input tuple) if present. @@ -76,7 +109,7 @@ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateId: Option[OperatorStateId], child: SparkPlan) - extends execution.UnaryExecNode with StateStoreReader { + extends UnaryExecNode with StateStoreReader { override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -113,31 +146,7 @@ case class StateStoreSaveExec( outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) - extends execution.UnaryExecNode with StateStoreWriter { - - /** Generate a predicate that matches data older than the watermark */ - private lazy val watermarkPredicate: Option[Predicate] = { - val optionalWatermarkAttribute = - keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) - - optionalWatermarkAttribute.map { watermarkAttribute => - // If we are evicting based on a window, use the end of the window. Otherwise just - // use the attribute itself. - val evictionExpression = - if (watermarkAttribute.dataType.isInstanceOf[StructType]) { - LessThanOrEqual( - GetStructField(watermarkAttribute, 1), - Literal(eventTimeWatermark.get * 1000)) - } else { - LessThanOrEqual( - watermarkAttribute, - Literal(eventTimeWatermark.get * 1000)) - } - - logInfo(s"Filtering state store on: $evictionExpression") - newPredicate(evictionExpression, keyExpressions) - } - } + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -146,8 +155,8 @@ case class StateStoreSaveExec( child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, + getStateId.operatorId, + getStateId.batchId, keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, @@ -262,8 +271,8 @@ case class MapGroupsWithStateExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, + getStateId.operatorId, + getStateId.batchId, groupingAttributes.toStructType, child.output.toStructType, sqlContext.sessionState, @@ -321,3 +330,70 @@ case class MapGroupsWithStateExec( } } } + + +/** Physical operator for executing streaming Deduplicate. */ +case class StreamingDeduplicateExec( + keyExpressions: Seq[Attribute], + child: SparkPlan, + stateId: Option[OperatorStateId] = None, + eventTimeWatermark: Option[Long] = None) + extends UnaryExecNode with StateStoreWriter with WatermarkSupport { + + /** Distribute by grouping attributes */ + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(keyExpressions) :: Nil + + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + getStateId.operatorId, + getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + + val baseIterator = watermarkPredicate match { + case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case None => iter + } + + val result = baseIterator.filter { r => + val row = r.asInstanceOf[UnsafeRow] + val key = getKey(row) + val value = store.get(key) + if (value.isEmpty) { + store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW) + numUpdatedStateRows += 1 + numOutputRows += 1 + true + } else { + // Drop duplicated rows + false + } + } + + CompletionIterator[InternalRow, Iterator[InternalRow]](result, { + watermarkPredicate.foreach(f => store.remove(f.eval _)) + store.commit() + numTotalStateRows += store.numKeys() + }) + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} + +object StreamingDeduplicateExec { + private val EMPTY_ROW = + UnsafeProjection.create(Array[DataType](NullType)).apply(InternalRow.apply(null)) +} http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala new file mode 100644 index 0000000..7ea7162 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.functions._ + +class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { + + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("deduplicate with all columns") { + val inputData = MemoryStream[String] + val result = inputData.toDS().dropDuplicates() + + testStream(result, Append)( + AddData(inputData, "a"), + CheckLastBatch("a"), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a"), + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + AddData(inputData, "b"), + CheckLastBatch("b"), + assertNumStateRows(total = 2, updated = 1) + ) + } + + test("deduplicate with some columns") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS().dropDuplicates("_1") + + testStream(result, Append)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a" -> 2), // Dropped + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + AddData(inputData, "b" -> 1), + CheckLastBatch("b" -> 1), + assertNumStateRows(total = 2, updated = 1) + ) + } + + test("multiple deduplicates") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS().dropDuplicates().dropDuplicates("_1") + + testStream(result, Append)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)), + + AddData(inputData, "a" -> 2), // Dropped from the second `dropDuplicates` + CheckLastBatch(), + assertNumStateRows(total = Seq(1L, 2L), updated = Seq(0L, 1L)), + + AddData(inputData, "b" -> 1), + CheckLastBatch("b" -> 1), + assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L)) + ) + } + + test("deduplicate with watermark") { + val inputData = MemoryStream[Int] + val result = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .select($"eventTime".cast("long").as[Long]) + + testStream(result, Append)( + AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), + CheckLastBatch(10 to 15: _*), + assertNumStateRows(total = 6, updated = 6), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch(25), + assertNumStateRows(total = 7, updated = 1), + + AddData(inputData, 25), // Drop states less than watermark + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0), + + AddData(inputData, 45), // Advance watermark to 35 seconds + CheckLastBatch(45), + assertNumStateRows(total = 2, updated = 1), + + AddData(inputData, 45), // Drop states less than watermark + CheckLastBatch(), + assertNumStateRows(total = 1, updated = 0) + ) + } + + test("deduplicate with aggregate - append mode") { + val inputData = MemoryStream[Int] + val windowedaggregate = inputData.toDS() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .dropDuplicates() + .withWatermark("eventTime", "10 seconds") + .groupBy(window($"eventTime", "5 seconds") as 'window) + .agg(count("*") as 'count) + .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) + + testStream(windowedaggregate)( + AddData(inputData, (1 to 5).flatMap(_ => (10 to 15)): _*), + CheckLastBatch(), + // states in aggregate in [10, 14), [15, 20) (2 windows) + // states in deduplicate is 10 to 15 + assertNumStateRows(total = Seq(2L, 6L), updated = Seq(2L, 6L)), + + AddData(inputData, 25), // Advance watermark to 15 seconds + CheckLastBatch(), + // states in aggregate in [10, 14), [15, 20) and [25, 30) (3 windows) + // states in deduplicate is 10 to 15 and 25 + assertNumStateRows(total = Seq(3L, 7L), updated = Seq(1L, 1L)), + + AddData(inputData, 25), // Emit items less than watermark and drop their state + CheckLastBatch((10 -> 5)), // 5 items (10 to 14) after deduplicate + // states in aggregate in [15, 20) and [25, 30) (2 windows, note aggregate uses the end of + // window to evict items, so [15, 20) is still in the state store) + // states in deduplicate is 25 + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + + AddData(inputData, 10), // Should not emit anything as data less than watermark + CheckLastBatch(), + assertNumStateRows(total = Seq(2L, 1L), updated = Seq(0L, 0L)), + + AddData(inputData, 40), // Advance watermark to 30 seconds + CheckLastBatch(), + // states in aggregate in [15, 20), [25, 30) and [40, 45) + // states in deduplicate is 25 and 40, + assertNumStateRows(total = Seq(3L, 2L), updated = Seq(1L, 1L)), + + AddData(inputData, 40), // Emit items less than watermark and drop their state + CheckLastBatch((15 -> 1), (25 -> 1)), + // states in aggregate in [40, 45) + // states in deduplicate is 40, + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)) + ) + } + + test("deduplicate with aggregate - update mode") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS() + .select($"_1" as "str", $"_2" as "num") + .dropDuplicates() + .groupBy("str") + .agg(sum("num")) + .as[(String, Long)] + + testStream(result, Update)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1L), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)), + AddData(inputData, "a" -> 1), // Dropped + CheckLastBatch(), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)), + AddData(inputData, "a" -> 2), + CheckLastBatch("a" -> 3L), + assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)), + AddData(inputData, "b" -> 1), + CheckLastBatch("b" -> 1L), + assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L)) + ) + } + + test("deduplicate with aggregate - complete mode") { + val inputData = MemoryStream[(String, Int)] + val result = inputData.toDS() + .select($"_1" as "str", $"_2" as "num") + .dropDuplicates() + .groupBy("str") + .agg(sum("num")) + .as[(String, Long)] + + testStream(result, Complete)( + AddData(inputData, "a" -> 1), + CheckLastBatch("a" -> 1L), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(1L, 1L)), + AddData(inputData, "a" -> 1), // Dropped + CheckLastBatch("a" -> 1L), + assertNumStateRows(total = Seq(1L, 1L), updated = Seq(0L, 0L)), + AddData(inputData, "a" -> 2), + CheckLastBatch("a" -> 3L), + assertNumStateRows(total = Seq(1L, 2L), updated = Seq(1L, 1L)), + AddData(inputData, "b" -> 1), + CheckLastBatch("a" -> 3L, "b" -> 1L), + assertNumStateRows(total = Seq(2L, 3L), updated = Seq(1L, 1L)) + ) + } + + test("deduplicate with file sink") { + withTempDir { output => + withTempDir { checkpointDir => + val outputPath = output.getAbsolutePath + val inputData = MemoryStream[String] + val result = inputData.toDS().dropDuplicates() + val q = result.writeStream + .format("parquet") + .outputMode(Append) + .option("checkpointLocation", checkpointDir.getPath) + .start(outputPath) + try { + inputData.addData("a") + q.processAllAvailable() + checkDataset(spark.read.parquet(outputPath).as[String], "a") + + inputData.addData("a") // Dropped + q.processAllAvailable() + checkDataset(spark.read.parquet(outputPath).as[String], "a") + + inputData.addData("b") + q.processAllAvailable() + checkDataset(spark.read.parquet(outputPath).as[String], "a", "b") + } finally { + q.stop() + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 0524898..6cf4d51 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.streaming.state.StateStore /** Class to check custom state types */ case class RunningCount(count: Long) -class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { +class MapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { import testImplicits._ @@ -321,13 +321,6 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count ) } - - private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows") - assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows") - true - } } object MapGroupsWithStateSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala new file mode 100644 index 0000000..894786c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +trait StateStoreMetricsTest extends StreamTest { + + def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery = + AssertOnQuery { q => + val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + assert( + progressWithData.stateOperators.map(_.numRowsTotal) === total, + "incorrect total rows") + assert( + progressWithData.stateOperators.map(_.numRowsUpdated) === updated, + "incorrect updates rows") + true + } + + def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = + assertNumStateRows(Seq(total), Seq(updated)) +} http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 0296a2a..f44cfad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -338,7 +338,7 @@ class StreamSuite extends StreamTest { .writeStream .format("memory") .queryName("testquery") - .outputMode("complete") + .outputMode("append") .start() try { query.processAllAvailable() http://git-wip-us.apache.org/repos/asf/spark/blob/9bf4e2ba/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index eca2647..0c80156 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -35,7 +35,7 @@ object FailureSinglton { var firstTime = true } -class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { +class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfterAll { override def afterAll(): Unit = { super.afterAll() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org