Repository: spark
Updated Branches:
  refs/heads/master 72ecfd095 -> 6c5cb8585


[SPARK-24763][SS] Remove redundant key data from value in streaming aggregation

## What changes were proposed in this pull request?

This patch proposes a new flag option for stateful aggregation: remove 
redundant key data from value.
Enabling new option runs similar with current, and uses less memory for state 
according to key/value fields of state operator.

Please refer below link to see detailed perf. test result:
https://issues.apache.org/jira/browse/SPARK-24763?focusedCommentId=16536539&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16536539

Since the state between enabling the option and disabling the option is not 
compatible, the option is set to 'disable' by default (to ensure backward 
compatibility), and OffsetSeqMetadata would prevent modifying the option after 
executing query.

## How was this patch tested?

Modify unit tests to cover both disabling option and enabling option.
Also did manual tests to see whether propose patch improves state memory usage.

Closes #21733 from HeartSaVioR/SPARK-24763.

Authored-by: Jungtaek Lim <kabh...@gmail.com>
Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6c5cb858
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6c5cb858
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6c5cb858

Branch: refs/heads/master
Commit: 6c5cb85856235efd464b109558896f81ae2c4c75
Parents: 72ecfd0
Author: Jungtaek Lim <kabh...@gmail.com>
Authored: Tue Aug 21 15:22:42 2018 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Tue Aug 21 15:22:42 2018 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/internal/SQLConf.scala |  10 +
 .../spark/sql/execution/SparkStrategies.scala   |   3 +
 .../sql/execution/aggregate/AggUtils.scala      |   5 +-
 .../streaming/IncrementalExecution.scala        |   6 +-
 .../sql/execution/streaming/OffsetSeq.scala     |   8 +-
 .../StreamingAggregationStateManager.scala      | 205 +++++++++++++++++++
 .../execution/streaming/statefulOperators.scala |  61 +++---
 .../commits/0                                   |   2 +
 .../commits/1                                   |   2 +
 .../metadata                                    |   1 +
 .../offsets/0                                   |   3 +
 .../offsets/1                                   |   3 +
 .../state/0/0/1.delta                           | Bin 0 -> 46 bytes
 .../state/0/0/2.delta                           | Bin 0 -> 46 bytes
 .../state/0/1/1.delta                           | Bin 0 -> 77 bytes
 .../state/0/1/2.delta                           | Bin 0 -> 77 bytes
 .../state/0/2/1.delta                           | Bin 0 -> 46 bytes
 .../state/0/2/2.delta                           | Bin 0 -> 46 bytes
 .../state/0/3/1.delta                           | Bin 0 -> 46 bytes
 .../state/0/3/2.delta                           | Bin 0 -> 46 bytes
 .../state/0/4/1.delta                           | Bin 0 -> 46 bytes
 .../state/0/4/2.delta                           | Bin 0 -> 77 bytes
 .../streaming/state/MemoryStateStore.scala      |  49 +++++
 .../StreamingAggregationStateManagerSuite.scala | 126 ++++++++++++
 .../streaming/FlatMapGroupsWithStateSuite.scala |  24 +--
 .../streaming/StreamingAggregationSuite.scala   | 150 +++++++++++---
 26 files changed, 573 insertions(+), 85 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index bffdddc..b44bfe7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -888,6 +888,16 @@ object SQLConf {
     .intConf
     .createWithDefault(2)
 
+  val STREAMING_AGGREGATION_STATE_FORMAT_VERSION =
+    buildConf("spark.sql.streaming.aggregation.stateFormatVersion")
+      .internal()
+      .doc("State format version used by streaming aggregation operations in a 
streaming query. " +
+        "State between versions are tend to be incompatible, so state format 
version shouldn't " +
+        "be modified after running.")
+      .intConf
+      .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
+      .createWithDefault(2)
+
   val UNSUPPORTED_OPERATION_CHECK_ENABLED =
     buildConf("spark.sql.streaming.unsupportedOperationCheck")
       .internal()

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index b4179f4..4c39990 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -328,10 +328,13 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
             "Streaming aggregation doesn't support group aggregate pandas UDF")
         }
 
+        val stateVersion = 
conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
+
         aggregate.AggUtils.planStreamingAggregation(
           namedGroupingExpressions,
           aggregateExpressions.map(expr => 
expr.asInstanceOf[AggregateExpression]),
           rewrittenResultExpressions,
+          stateVersion,
           planLater(child))
 
       case _ => Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index c8ef2b3..6be88c4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -260,6 +260,7 @@ object AggUtils {
       groupingExpressions: Seq[NamedExpression],
       functionsWithoutDistinct: Seq[AggregateExpression],
       resultExpressions: Seq[NamedExpression],
+      stateFormatVersion: Int,
       child: SparkPlan): Seq[SparkPlan] = {
 
     val groupingAttributes = groupingExpressions.map(_.toAttribute)
@@ -291,7 +292,8 @@ object AggUtils {
         child = partialAggregate)
     }
 
-    val restored = StateStoreRestoreExec(groupingAttributes, None, 
partialMerged1)
+    val restored = StateStoreRestoreExec(groupingAttributes, None, 
stateFormatVersion,
+      partialMerged1)
 
     val partialMerged2: SparkPlan = {
       val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = 
PartialMerge))
@@ -315,6 +317,7 @@ object AggUtils {
         stateInfo = None,
         outputMode = None,
         eventTimeWatermark = None,
+        stateFormatVersion = stateFormatVersion,
         partialMerged2)
 
     val finalAndCompleteAggregate: SparkPlan = {

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 725abb3..fad287e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -102,19 +102,21 @@ class IncrementalExecution(
   val state = new Rule[SparkPlan] {
 
     override def apply(plan: SparkPlan): SparkPlan = plan transform {
-      case StateStoreSaveExec(keys, None, None, None,
+      case StateStoreSaveExec(keys, None, None, None, stateFormatVersion,
              UnaryExecNode(agg,
-               StateStoreRestoreExec(_, None, child))) =>
+               StateStoreRestoreExec(_, None, _, child))) =>
         val aggStateInfo = nextStatefulOperationStateInfo
         StateStoreSaveExec(
           keys,
           Some(aggStateInfo),
           Some(outputMode),
           Some(offsetSeqMetadata.batchWatermarkMs),
+          stateFormatVersion,
           agg.withNewChildren(
             StateStoreRestoreExec(
               keys,
               Some(aggStateInfo),
+              stateFormatVersion,
               child) :: Nil))
 
       case StreamingDeduplicateExec(keys, child, None, None) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
index 9847756..73cf355 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
@@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.RuntimeConfig
-import 
org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper
+import 
org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper,
 StreamingAggregationStateManager}
 import 
org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION,
 _}
 
 /**
@@ -89,7 +89,7 @@ object OffsetSeqMetadata extends Logging {
   private implicit val format = Serialization.formats(NoTypeHints)
   private val relevantSQLConfs = Seq(
     SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, 
STREAMING_MULTIPLE_WATERMARK_POLICY,
-    FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
+    FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, 
STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
 
   /**
    * Default values of relevant configurations that are used for backward 
compatibility.
@@ -104,7 +104,9 @@ object OffsetSeqMetadata extends Logging {
   private val relevantSQLConfDefaultValues = Map[String, String](
     STREAMING_MULTIPLE_WATERMARK_POLICY.key -> 
MultipleWatermarkPolicy.DEFAULT_POLICY_NAME,
     FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
-      FlatMapGroupsWithStateExecHelper.legacyVersion.toString
+      FlatMapGroupsWithStateExecHelper.legacyVersion.toString,
+    STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
+      StreamingAggregationStateManager.legacyVersion.toString
   )
 
   def apply(json: String): OffsetSeqMetadata = 
Serialization.read[OffsetSeqMetadata](json)

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
new file mode 100644
index 0000000..9bfb956
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
+import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, 
GenerateUnsafeRowJoiner}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Base trait for state manager purposed to be used from streaming 
aggregations.
+ */
+sealed trait StreamingAggregationStateManager extends Serializable {
+
+  /** Extract columns consisting key from input row, and return the new row 
for key columns. */
+  def getKey(row: UnsafeRow): UnsafeRow
+
+  /** Calculate schema for the value of state. The schema is mainly passed to 
the StateStoreRDD. */
+  def getStateValueSchema: StructType
+
+  /** Get the current value of a non-null key from the target state store. */
+  def get(store: StateStore, key: UnsafeRow): UnsafeRow
+
+  /**
+   * Put a new value for a non-null key to the target state store. Note that 
key will be
+   * extracted from the input row, and the key would be same as the result of 
getKey(inputRow).
+   */
+  def put(store: StateStore, row: UnsafeRow): Unit
+
+  /**
+   * Commit all the updates that have been made to the target state store, and 
return the
+   * new version.
+   */
+  def commit(store: StateStore): Long
+
+  /** Remove a single non-null key from the target state store. */
+  def remove(store: StateStore, key: UnsafeRow): Unit
+
+  /** Return an iterator containing all the key-value pairs in target state 
store. */
+  def iterator(store: StateStore): Iterator[UnsafeRowPair]
+
+  /** Return an iterator containing all the keys in target state store. */
+  def keys(store: StateStore): Iterator[UnsafeRow]
+
+  /** Return an iterator containing all the values in target state store. */
+  def values(store: StateStore): Iterator[UnsafeRow]
+}
+
+object StreamingAggregationStateManager extends Logging {
+  val supportedVersions = Seq(1, 2)
+  val legacyVersion = 1
+
+  def createStateManager(
+      keyExpressions: Seq[Attribute],
+      inputRowAttributes: Seq[Attribute],
+      stateFormatVersion: Int): StreamingAggregationStateManager = {
+    stateFormatVersion match {
+      case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, 
inputRowAttributes)
+      case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, 
inputRowAttributes)
+      case _ => throw new IllegalArgumentException(s"Version 
$stateFormatVersion is invalid")
+    }
+  }
+}
+
+abstract class StreamingAggregationStateManagerBaseImpl(
+    protected val keyExpressions: Seq[Attribute],
+    protected val inputRowAttributes: Seq[Attribute]) extends 
StreamingAggregationStateManager {
+
+  @transient protected lazy val keyProjector =
+    GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes)
+
+  override def getKey(row: UnsafeRow): UnsafeRow = keyProjector(row)
+
+  override def commit(store: StateStore): Long = store.commit()
+
+  override def remove(store: StateStore, key: UnsafeRow): Unit = 
store.remove(key)
+
+  override def keys(store: StateStore): Iterator[UnsafeRow] = {
+    // discard and don't convert values to avoid computation
+    store.getRange(None, None).map(_.key)
+  }
+}
+
+/**
+ * The implementation of StreamingAggregationStateManager for state version 1.
+ * In state version 1, the schema of key and value in state are follow:
+ *
+ * - key: Same as key expressions.
+ * - value: Same as input row attributes. The schema of value contains key 
expressions as well.
+ *
+ * @param keyExpressions The attributes of keys.
+ * @param inputRowAttributes The attributes of input row.
+ */
+class StreamingAggregationStateManagerImplV1(
+    keyExpressions: Seq[Attribute],
+    inputRowAttributes: Seq[Attribute])
+  extends StreamingAggregationStateManagerBaseImpl(keyExpressions, 
inputRowAttributes) {
+
+  override def getStateValueSchema: StructType = 
inputRowAttributes.toStructType
+
+  override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
+    store.get(key)
+  }
+
+  override def put(store: StateStore, row: UnsafeRow): Unit = {
+    store.put(getKey(row), row)
+  }
+
+  override def iterator(store: StateStore): Iterator[UnsafeRowPair] = {
+    store.iterator()
+  }
+
+  override def values(store: StateStore): Iterator[UnsafeRow] = {
+    store.iterator().map(_.value)
+  }
+}
+
+/**
+ * The implementation of StreamingAggregationStateManager for state version 2.
+ * In state version 2, the schema of key and value in state are follow:
+ *
+ * - key: Same as key expressions.
+ * - value: The diff between input row attributes and key expressions.
+ *
+ * The schema of value is changed to optimize the memory/space usage in state, 
via removing
+ * duplicated columns in key-value pair. Hence key columns are excluded from 
the schema of value.
+ *
+ * @param keyExpressions The attributes of keys.
+ * @param inputRowAttributes The attributes of input row.
+ */
+class StreamingAggregationStateManagerImplV2(
+    keyExpressions: Seq[Attribute],
+    inputRowAttributes: Seq[Attribute])
+  extends StreamingAggregationStateManagerBaseImpl(keyExpressions, 
inputRowAttributes) {
+
+  private val valueExpressions: Seq[Attribute] = 
inputRowAttributes.diff(keyExpressions)
+  private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ 
valueExpressions
+
+  // flag to check whether the row needs to be project into input row 
attributes after join
+  // e.g. if the fields in the joined row are not in the expected order
+  private val needToProjectToRestoreValue: Boolean =
+    keyValueJoinedExpressions != inputRowAttributes
+
+  @transient private lazy val valueProjector =
+    GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes)
+
+  @transient private lazy val joiner =
+    GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
+      StructType.fromAttributes(valueExpressions))
+  @transient private lazy val restoreValueProjector = 
GenerateUnsafeProjection.generate(
+    inputRowAttributes, keyValueJoinedExpressions)
+
+  override def getStateValueSchema: StructType = valueExpressions.toStructType
+
+  override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
+    val savedState = store.get(key)
+    if (savedState == null) {
+      return savedState
+    }
+
+    restoreOriginalRow(key, savedState)
+  }
+
+  override def put(store: StateStore, row: UnsafeRow): Unit = {
+    val key = keyProjector(row)
+    val value = valueProjector(row)
+    store.put(key, value)
+  }
+
+  override def iterator(store: StateStore): Iterator[UnsafeRowPair] = {
+    store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, 
restoreOriginalRow(rowPair)))
+  }
+
+  override def values(store: StateStore): Iterator[UnsafeRow] = {
+    store.iterator().map(rowPair => restoreOriginalRow(rowPair))
+  }
+
+  private def restoreOriginalRow(rowPair: UnsafeRowPair): UnsafeRow = {
+    restoreOriginalRow(rowPair.key, rowPair.value)
+  }
+
+  private def restoreOriginalRow(key: UnsafeRow, value: UnsafeRow): UnsafeRow 
= {
+    val joinedRow = joiner.join(key, value)
+    if (needToProjectToRestoreValue) {
+      restoreValueProjector(joinedRow)
+    } else {
+      joinedRow
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 6759fb4..34e26d8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -20,8 +20,6 @@ package org.apache.spark.sql.execution.streaming
 import java.util.UUID
 import java.util.concurrent.TimeUnit._
 
-import scala.collection.JavaConverters._
-
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
@@ -167,6 +165,18 @@ trait WatermarkSupport extends UnaryExecNode {
       }
     }
   }
+
+  protected def removeKeysOlderThanWatermark(
+      storeManager: StreamingAggregationStateManager,
+      store: StateStore): Unit = {
+    if (watermarkPredicateForKeys.nonEmpty) {
+      storeManager.keys(store).foreach { keyRow =>
+        if (watermarkPredicateForKeys.get.eval(keyRow)) {
+          storeManager.remove(store, keyRow)
+        }
+      }
+    }
+  }
 }
 
 object WatermarkSupport {
@@ -201,20 +211,23 @@ object WatermarkSupport {
 case class StateStoreRestoreExec(
     keyExpressions: Seq[Attribute],
     stateInfo: Option[StatefulOperatorStateInfo],
+    stateFormatVersion: Int,
     child: SparkPlan)
   extends UnaryExecNode with StateStoreReader {
 
+  private[sql] val stateManager = 
StreamingAggregationStateManager.createStateManager(
+    keyExpressions, child.output, stateFormatVersion)
+
   override protected def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
     child.execute().mapPartitionsWithStateStore(
       getStateInfo,
       keyExpressions.toStructType,
-      child.output.toStructType,
+      stateManager.getStateValueSchema,
       indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
-        val getKey = GenerateUnsafeProjection.generate(keyExpressions, 
child.output)
         val hasInput = iter.hasNext
         if (!hasInput && keyExpressions.isEmpty) {
           // If our `keyExpressions` are empty, we're getting a global 
aggregation. In that case
@@ -224,10 +237,10 @@ case class StateStoreRestoreExec(
           store.iterator().map(_.value)
         } else {
           iter.flatMap { row =>
-            val key = getKey(row)
-            val savedState = store.get(key)
+            val key = stateManager.getKey(row.asInstanceOf[UnsafeRow])
+            val restoredRow = stateManager.get(store, key)
             numOutputRows += 1
-            Option(savedState).toSeq :+ row
+            Option(restoredRow).toSeq :+ row
           }
         }
     }
@@ -254,9 +267,13 @@ case class StateStoreSaveExec(
     stateInfo: Option[StatefulOperatorStateInfo] = None,
     outputMode: Option[OutputMode] = None,
     eventTimeWatermark: Option[Long] = None,
+    stateFormatVersion: Int,
     child: SparkPlan)
   extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
 
+  private[sql] val stateManager = 
StreamingAggregationStateManager.createStateManager(
+    keyExpressions, child.output, stateFormatVersion)
+
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver
     assert(outputMode.nonEmpty,
@@ -265,11 +282,10 @@ case class StateStoreSaveExec(
     child.execute().mapPartitionsWithStateStore(
       getStateInfo,
       keyExpressions.toStructType,
-      child.output.toStructType,
+      stateManager.getStateValueSchema,
       indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
-        val getKey = GenerateUnsafeProjection.generate(keyExpressions, 
child.output)
         val numOutputRows = longMetric("numOutputRows")
         val numUpdatedStateRows = longMetric("numUpdatedStateRows")
         val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
@@ -282,19 +298,18 @@ case class StateStoreSaveExec(
             allUpdatesTimeMs += timeTakenMs {
               while (iter.hasNext) {
                 val row = iter.next().asInstanceOf[UnsafeRow]
-                val key = getKey(row)
-                store.put(key, row)
+                stateManager.put(store, row)
                 numUpdatedStateRows += 1
               }
             }
             allRemovalsTimeMs += 0
             commitTimeMs += timeTakenMs {
-              store.commit()
+              stateManager.commit(store)
             }
             setStoreMetrics(store)
-            store.iterator().map { rowPair =>
+            stateManager.values(store).map { valueRow =>
               numOutputRows += 1
-              rowPair.value
+              valueRow
             }
 
           // Update and output only rows being evicted from the StateStore
@@ -304,14 +319,13 @@ case class StateStoreSaveExec(
               val filteredIter = iter.filter(row => 
!watermarkPredicateForData.get.eval(row))
               while (filteredIter.hasNext) {
                 val row = filteredIter.next().asInstanceOf[UnsafeRow]
-                val key = getKey(row)
-                store.put(key, row)
+                stateManager.put(store, row)
                 numUpdatedStateRows += 1
               }
             }
 
             val removalStartTimeNs = System.nanoTime
-            val rangeIter = store.getRange(None, None)
+            val rangeIter = stateManager.iterator(store)
 
             new NextIterator[InternalRow] {
               override protected def getNext(): InternalRow = {
@@ -319,7 +333,7 @@ case class StateStoreSaveExec(
                 while(rangeIter.hasNext && removedValueRow == null) {
                   val rowPair = rangeIter.next()
                   if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
-                    store.remove(rowPair.key)
+                    stateManager.remove(store, rowPair.key)
                     removedValueRow = rowPair.value
                   }
                 }
@@ -333,7 +347,7 @@ case class StateStoreSaveExec(
 
               override protected def close(): Unit = {
                 allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
removalStartTimeNs)
-                commitTimeMs += timeTakenMs { store.commit() }
+                commitTimeMs += timeTakenMs { stateManager.commit(store) }
                 setStoreMetrics(store)
               }
             }
@@ -352,8 +366,7 @@ case class StateStoreSaveExec(
               override protected def getNext(): InternalRow = {
                 if (baseIterator.hasNext) {
                   val row = baseIterator.next().asInstanceOf[UnsafeRow]
-                  val key = getKey(row)
-                  store.put(key, row)
+                  stateManager.put(store, row)
                   numOutputRows += 1
                   numUpdatedStateRows += 1
                   row
@@ -367,8 +380,10 @@ case class StateStoreSaveExec(
                 allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
updatesStartTimeNs)
 
                 // Remove old aggregates if watermark specified
-                allRemovalsTimeMs += timeTakenMs { 
removeKeysOlderThanWatermark(store) }
-                commitTimeMs += timeTakenMs { store.commit() }
+                allRemovalsTimeMs += timeTakenMs {
+                  removeKeysOlderThanWatermark(stateManager, store)
+                }
+                commitTimeMs += timeTakenMs { stateManager.commit(store) }
                 setStoreMetrics(store)
               }
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0
new file mode 100644
index 0000000..83321cd
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/0
@@ -0,0 +1,2 @@
+v1
+{}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1
new file mode 100644
index 0000000..83321cd
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/commits/1
@@ -0,0 +1,2 @@
+v1
+{}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata
new file mode 100644
index 0000000..c160d73
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/metadata
@@ -0,0 +1 @@
+{"id":"2f32aca2-1b97-458f-a48f-109328724f09"}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0
new file mode 100644
index 0000000..acdc6e6
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/0
@@ -0,0 +1,3 @@
+v1
+{"batchWatermarkMs":0,"batchTimestampMs":1533784347136,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}}
+0
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1
new file mode 100644
index 0000000..27353e8
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/offsets/1
@@ -0,0 +1,3 @@
+v1
+{"batchWatermarkMs":0,"batchTimestampMs":1533784349160,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}}
+1
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/0/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta
new file mode 100644
index 0000000..281b21e
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta
new file mode 100644
index 0000000..b701841
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/1/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/2/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/3/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta
new file mode 100644
index 0000000..f4fb252
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/state/0/4/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
new file mode 100644
index 0000000..98586d6
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+class MemoryStateStore extends StateStore() {
+  import scala.collection.JavaConverters._
+  private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
+
+  override def iterator(): Iterator[UnsafeRowPair] = {
+    map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, 
e.getValue) }
+  }
+
+  override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+
+  override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = 
map.put(key.copy(), newValue.copy())
+
+  override def remove(key: UnsafeRow): Unit = map.remove(key)
+
+  override def commit(): Long = version + 1
+
+  override def abort(): Unit = {}
+
+  override def id: StateStoreId = null
+
+  override def version: Long = 0
+
+  override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, 
Map.empty)
+
+  override def hasCommitted: Boolean = true
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala
new file mode 100644
index 0000000..daacdfd
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
SpecificInternalRow, UnsafeProjection, UnsafeRow}
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+class StreamingAggregationStateManagerSuite extends StreamTest {
+  // ============================ fields and method for test data 
============================
+
+  val testKeys: Seq[String] = Seq("key1", "key2")
+  val testValues: Seq[String] = Seq("sum(key1)", "sum(key2)")
+
+  val testOutputSchema: StructType = StructType(
+    testKeys.map(createIntegerField) ++ testValues.map(createIntegerField))
+
+  val testOutputAttributes: Seq[Attribute] = testOutputSchema.toAttributes
+  val testKeyAttributes: Seq[Attribute] = testOutputAttributes.filter { p =>
+    testKeys.contains(p.name)
+  }
+  val testValuesAttributes: Seq[Attribute] = testOutputAttributes.filter { p =>
+    testValues.contains(p.name)
+  }
+  val expectedTestValuesSchema: StructType = testValuesAttributes.toStructType
+
+  val testRow: UnsafeRow = {
+    val unsafeRowProjection = UnsafeProjection.create(testOutputSchema)
+    val row = unsafeRowProjection(new SpecificInternalRow(testOutputSchema))
+    (testKeys ++ testValues).zipWithIndex.foreach { case (_, index) => 
row.setInt(index, index) }
+    row
+  }
+
+  val expectedTestKeyRow: UnsafeRow = {
+    val keyProjector = GenerateUnsafeProjection.generate(testKeyAttributes, 
testOutputAttributes)
+    keyProjector(testRow)
+  }
+
+  val expectedTestValueRowForV2: UnsafeRow = {
+    val valueProjector = 
GenerateUnsafeProjection.generate(testValuesAttributes,
+      testOutputAttributes)
+    valueProjector(testRow)
+  }
+
+  private def createIntegerField(name: String): StructField = {
+    StructField(name, IntegerType, nullable = false)
+  }
+
+  // ============================ StateManagerImplV1 
============================
+
+  test("StateManager v1 - get, put, iter") {
+    val stateManager = 
StreamingAggregationStateManager.createStateManager(testKeyAttributes,
+      testOutputAttributes, 1)
+
+    // in V1, input row is stored as value
+    testGetPutIterOnStateManager(stateManager, testOutputSchema, testRow,
+      expectedTestKeyRow, expectedStateValue = testRow)
+  }
+
+  // ============================ StateManagerImplV2 
============================
+  test("StateManager v2 - get, put, iter") {
+    val stateManager = 
StreamingAggregationStateManager.createStateManager(testKeyAttributes,
+      testOutputAttributes, 2)
+
+    // in V2, row for values itself (excluding keys from input row) is stored 
as value
+    // so that stored value doesn't have key part, but state manager V2 will 
provide same output
+    // as V1 when getting row for key
+    testGetPutIterOnStateManager(stateManager, expectedTestValuesSchema, 
testRow,
+      expectedTestKeyRow, expectedTestValueRowForV2)
+  }
+
+  private def testGetPutIterOnStateManager(
+      stateManager: StreamingAggregationStateManager,
+      expectedValueSchema: StructType,
+      inputRow: UnsafeRow,
+      expectedStateKey: UnsafeRow,
+      expectedStateValue: UnsafeRow): Unit = {
+
+    assert(stateManager.getStateValueSchema === expectedValueSchema)
+
+    val memoryStateStore = new MemoryStateStore()
+    stateManager.put(memoryStateStore, inputRow)
+
+    assert(memoryStateStore.iterator().size === 1)
+    assert(stateManager.iterator(memoryStateStore).size === 
memoryStateStore.iterator().size)
+
+    val keyRow = stateManager.getKey(inputRow)
+    assert(keyRow === expectedStateKey)
+
+    // iterate state store and verify whether expected format of key and value 
are stored
+    val pair = memoryStateStore.iterator().next()
+    assert(pair.key === keyRow)
+    assert(pair.value === expectedStateValue)
+
+    // iterate with state manager and see whether original rows are returned 
as values
+    val pairFromStateManager = stateManager.iterator(memoryStateStore).next()
+    assert(pairFromStateManager.key === keyRow)
+    assert(pairFromStateManager.value === inputRow)
+
+    // following as keys and values
+    assert(stateManager.keys(memoryStateStore).next() === keyRow)
+    assert(stateManager.values(memoryStateStore).next() === inputRow)
+
+    // verify the stored value once again via get
+    assert(memoryStateStore.get(keyRow) === expectedStateValue)
+
+    // state manager should return row which is same as input row regardless 
of format version
+    assert(inputRow === stateManager.get(memoryStateStore, keyRow))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 82d7755..76511ae 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming
 
 import java.io.File
 import java.sql.Date
-import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.commons.io.FileUtils
 import org.scalatest.BeforeAndAfterAll
@@ -34,7 +33,7 @@ import 
org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution.RDDScanExec
 import org.apache.spark.sql.execution.streaming._
-import 
org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper,
 StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
+import 
org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper,
 MemoryStateStore, StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -1286,27 +1285,6 @@ object FlatMapGroupsWithStateSuite {
 
   var failInTask = true
 
-  class MemoryStateStore extends StateStore() {
-    import scala.collection.JavaConverters._
-    private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
-
-    override def iterator(): Iterator[UnsafeRowPair] = {
-      map.entrySet.iterator.asScala.map { case e => new 
UnsafeRowPair(e.getKey, e.getValue) }
-    }
-
-    override def get(key: UnsafeRow): UnsafeRow = map.get(key)
-    override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = {
-      map.put(key.copy(), newValue.copy())
-    }
-    override def remove(key: UnsafeRow): Unit = { map.remove(key) }
-    override def commit(): Long = version + 1
-    override def abort(): Unit = { }
-    override def id: StateStoreId = null
-    override def version: Long = 0
-    override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 
0, Map.empty)
-    override def hasCommitted: Boolean = true
-  }
-
   def assertCanGetProcessingTime(predicate: => Boolean): Unit = {
     if (!predicate) throw new TestFailedException("Could not get processing 
time", 20)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5cb858/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 382da13..1ae6ff3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql.streaming
 
+import java.io.File
 import java.util.{Locale, TimeZone}
 
-import org.scalatest.Assertions
-import org.scalatest.BeforeAndAfterAll
+import org.apache.commons.io.FileUtils
+import org.scalatest.{Assertions, BeforeAndAfterAll}
 
 import org.apache.spark.{SparkEnv, SparkException}
 import org.apache.spark.rdd.BlockRDD
@@ -31,13 +32,15 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.exchange.Exchange
 import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StreamingAggregationStateManager}
 import org.apache.spark.sql.expressions.scalalang.typed
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.OutputMode._
 import org.apache.spark.sql.streaming.util.{MockSourceProvider, 
StreamManualClock}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
+import org.apache.spark.util.Utils
 
 object FailureSingleton {
   var firstTime = true
@@ -53,7 +56,35 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
 
   import testImplicits._
 
-  test("simple count, update mode") {
+  def executeFuncWithStateVersionSQLConf(
+      stateVersion: Int,
+      confPairs: Seq[(String, String)],
+      func: => Any): Unit = {
+    withSQLConf(confPairs ++
+      Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> 
stateVersion.toString): _*) {
+      func
+    }
+  }
+
+  def testWithAllStateVersions(name: String, confPairs: (String, String)*)
+                              (func: => Any): Unit = {
+    for (version <- StreamingAggregationStateManager.supportedVersions) {
+      test(s"$name - state format version $version") {
+        executeFuncWithStateVersionSQLConf(version, confPairs, func)
+      }
+    }
+  }
+
+  def testQuietlyWithAllStateVersions(name: String, confPairs: (String, 
String)*)
+                                     (func: => Any): Unit = {
+    for (version <- StreamingAggregationStateManager.supportedVersions) {
+      testQuietly(s"$name - state format version $version") {
+        executeFuncWithStateVersionSQLConf(version, confPairs, func)
+      }
+    }
+  }
+
+  testWithAllStateVersions("simple count, update mode") {
     val inputData = MemoryStream[Int]
 
     val aggregated =
@@ -77,7 +108,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
     )
   }
 
-  test("count distinct") {
+  testWithAllStateVersions("count distinct") {
     val inputData = MemoryStream[(Int, Seq[Int])]
 
     val aggregated =
@@ -93,7 +124,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
     )
   }
 
-  test("simple count, complete mode") {
+  testWithAllStateVersions("simple count, complete mode") {
     val inputData = MemoryStream[Int]
 
     val aggregated =
@@ -116,7 +147,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("simple count, append mode") {
+  testWithAllStateVersions("simple count, append mode") {
     val inputData = MemoryStream[Int]
 
     val aggregated =
@@ -133,7 +164,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     }
   }
 
-  test("sort after aggregate in complete mode") {
+  testWithAllStateVersions("sort after aggregate in complete mode") {
     val inputData = MemoryStream[Int]
 
     val aggregated =
@@ -158,7 +189,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("state metrics") {
+  testWithAllStateVersions("state metrics") {
     val inputData = MemoryStream[Int]
 
     val aggregated =
@@ -211,7 +242,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("multiple keys") {
+  testWithAllStateVersions("multiple keys") {
     val inputData = MemoryStream[Int]
 
     val aggregated =
@@ -228,7 +259,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     )
   }
 
-  testQuietly("midbatch failure") {
+  testQuietlyWithAllStateVersions("midbatch failure") {
     val inputData = MemoryStream[Int]
     FailureSingleton.firstTime = true
     val aggregated =
@@ -254,7 +285,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("typed aggregators") {
+  testWithAllStateVersions("typed aggregators") {
     val inputData = MemoryStream[(String, Int)]
     val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2))
 
@@ -264,7 +295,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("prune results by current_time, complete mode") {
+  testWithAllStateVersions("prune results by current_time, complete mode") {
     import testImplicits._
     val clock = new StreamManualClock
     val inputData = MemoryStream[Long]
@@ -316,7 +347,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("prune results by current_date, complete mode") {
+  testWithAllStateVersions("prune results by current_date, complete mode") {
     import testImplicits._
     val clock = new StreamManualClock
     val tz = TimeZone.getDefault.getID
@@ -365,7 +396,8 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("SPARK-19690: do not convert batch aggregation in streaming query to 
streaming") {
+  testWithAllStateVersions("SPARK-19690: do not convert batch aggregation in 
streaming query " +
+    "to streaming") {
     val streamInput = MemoryStream[Int]
     val batchDF = Seq(1, 2, 3, 4, 5)
         .toDF("value")
@@ -429,7 +461,8 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     true
   }
 
-  test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned 
to 1") {
+  testWithAllStateVersions("SPARK-21977: coalesce(1) with 0 partition RDD 
should be " +
+    "repartitioned to 1") {
     val inputSource = new BlockRDDBackedSource(spark)
     MockSourceProvider.withMockSources(inputSource) {
       // `coalesce(1)` changes the partitioning of data to `SinglePartition` 
which by default
@@ -467,8 +500,8 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     }
   }
 
-  test("SPARK-21977: coalesce(1) with aggregation should still be 
repartitioned when it " +
-    "has non-empty grouping keys") {
+  testWithAllStateVersions("SPARK-21977: coalesce(1) with aggregation should 
still be " +
+    "repartitioned when it has non-empty grouping keys") {
     val inputSource = new BlockRDDBackedSource(spark)
     MockSourceProvider.withMockSources(inputSource) {
       withTempDir { tempDir =>
@@ -520,7 +553,7 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     }
   }
 
-  test("SPARK-22230: last should change with new batches") {
+  testWithAllStateVersions("SPARK-22230: last should change with new batches") 
{
     val input = MemoryStream[Int]
 
     val aggregated = input.toDF().agg(last('value))
@@ -536,7 +569,8 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("SPARK-23004: Ensure that TypedImperativeAggregate functions do not 
throw errors") {
+  testWithAllStateVersions("SPARK-23004: Ensure that TypedImperativeAggregate 
functions " +
+    "do not throw errors", SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
     // See the JIRA SPARK-23004 for more details. In short, this test 
reproduces the error
     // by ensuring the following.
     // - A streaming query with a streaming aggregation.
@@ -545,22 +579,72 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
     //   ObjectHashAggregateExec falls back to sort-based aggregation). This 
is done by having a
     //   micro-batch with 128 records that shuffle to a single partition.
     // This test throws the exact error reported in SPARK-23004 without the 
corresponding fix.
-    withSQLConf("spark.sql.shuffle.partitions" -> "1") {
-      val input = MemoryStream[Int]
-      val df = input.toDF().toDF("value")
-        .selectExpr("value as group", "value")
-        .groupBy("group")
-        .agg(collect_list("value"))
-      testStream(df, outputMode = OutputMode.Update)(
-        AddData(input, (1 to 
spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*),
-        AssertOnQuery { q =>
-          q.processAllAvailable()
-          true
+    val input = MemoryStream[Int]
+    val df = input.toDF().toDF("value")
+      .selectExpr("value as group", "value")
+      .groupBy("group")
+      .agg(collect_list("value"))
+    testStream(df, outputMode = OutputMode.Update)(
+      AddData(input, (1 to 
spark.sqlContext.conf.objectAggSortBasedFallbackThreshold): _*),
+      AssertOnQuery { q =>
+        q.processAllAvailable()
+        true
+      }
+    )
+  }
+
+
+  test("simple count, update mode - recovery from checkpoint uses state format 
version 1") {
+    val inputData = MemoryStream[Int]
+
+    val aggregated =
+      inputData.toDF()
+        .groupBy($"value")
+        .agg(count("*"))
+        .as[(Int, Long)]
+
+    val resourceUri = this.getClass.getResource(
+      
"/structured-streaming/checkpoint-version-2.3.1-streaming-aggregate-state-format-1/").toURI
+
+    val checkpointDir = Utils.createTempDir().getCanonicalFile
+    // Copy the checkpoint to a temp dir to prevent changes to the original.
+    // Not doing this will lead to the test passing on the first run, but fail 
subsequent runs.
+    FileUtils.copyDirectory(new File(resourceUri), checkpointDir)
+
+    inputData.addData(3)
+    inputData.addData(3, 2)
+
+    testStream(aggregated, Update)(
+      StartStream(checkpointLocation = checkpointDir.getAbsolutePath,
+        additionalConfs = 
Map(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> "2")),
+      /*
+        Note: The checkpoint was generated using the following input in Spark 
version 2.3.1
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1))
+       */
+
+      AddData(inputData, 3, 2, 1),
+      CheckLastBatch((3, 3), (2, 2), (1, 1)),
+
+      Execute { query =>
+        // Verify state format = 1
+        val stateVersions = query.lastExecution.executedPlan.collect {
+          case f: StateStoreSaveExec => f.stateFormatVersion
+          case f: StateStoreRestoreExec => f.stateFormatVersion
         }
-      )
-    }
+        assert(stateVersions.size == 2)
+        assert(stateVersions.forall(_ == 1))
+      },
+
+      // By default we run in new tuple mode.
+      AddData(inputData, 4, 4, 4, 4),
+      CheckLastBatch((4, 4))
+    )
   }
 
+
   /** Add blocks of data to the `BlockRDDBackedSource`. */
   case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) 
extends AddData {
     override def addData(query: Option[StreamExecution]): (Source, Offset) = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to