This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f0061dbe856a [SPARK-47449][SS] Refactor and split list/timer unit tests
f0061dbe856a is described below

commit f0061dbe856a55295cc95835aff5dc717aa19431
Author: jingz-db <jing.z...@databricks.com>
AuthorDate: Wed Mar 20 09:21:04 2024 +0900

    [SPARK-47449][SS] Refactor and split list/timer unit tests
    
    ### What changes were proposed in this pull request?
    
    Refactor StatefulProcessorHandle unit test suites. Add List state and timer 
state unit tests.
    As planned in test plan for state-v2, list/timer should be tested in both 
integration and unit tests. Currently StatefulProcessorHandle related tests 
could be refactored to use base suite class in `ValueStateSuite`, and 
list/timer state unit tests are needed in addition to integration tests.
    
    ### Why are the changes needed?
    
    Compliance with test plan for state-v2 project.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Test suites refactored and added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45573 from jingz-db/split-timer-list-state-v2.
    
    Authored-by: jingz-db <jing.z...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../execution/streaming/state/ListStateSuite.scala | 163 +++++++++++++++++++++
 .../execution/streaming/state/MapStateSuite.scala  |   2 +-
 .../state/StatefulProcessorHandleSuite.scala       |  69 +--------
 .../sql/execution/streaming/state/TimerSuite.scala | 113 ++++++++++++++
 .../streaming/state/ValueStateSuite.scala          |   8 +-
 5 files changed, 289 insertions(+), 66 deletions(-)

diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
new file mode 100644
index 000000000000..e895e475b74d
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.UUID
+
+import org.apache.spark.SparkIllegalArgumentException
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl}
+import org.apache.spark.sql.streaming.{ListState, TimeoutMode, ValueState}
+
+/**
+ * Class that adds unit tests for ListState types used in arbitrary stateful
+ * operators such as transformWithState
+ */
+class ListStateSuite extends StateVariableSuiteBase {
+  // overwrite useMultipleValuesPerKey in base suite to be true for list state
+  override def useMultipleValuesPerKey: Boolean = true
+
+  private def testMapStateWithNullUserKey()(runListOps: ListState[Long] => 
Unit): Unit = {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+
+      val listState: ListState[Long] = handle.getListState[Long]("listState", 
Encoders.scalaLong)
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      val e = intercept[SparkIllegalArgumentException] {
+        runListOps(listState)
+      }
+
+      checkError(
+        exception = e.asInstanceOf[SparkIllegalArgumentException],
+        errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE",
+        sqlState = Some("42601"),
+        parameters = Map("stateName" -> "listState")
+      )
+    }
+  }
+
+  Seq("appendList", "put").foreach { listImplFunc =>
+    test(s"Test list operation($listImplFunc) with null") {
+      testMapStateWithNullUserKey() { listState =>
+        listImplFunc match {
+          case "appendList" => listState.appendList(null)
+          case "put" => listState.put(null)
+        }
+      }
+    }
+  }
+
+  test("List state operations for single instance") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+
+      val testState: ListState[Long] = handle.getListState[Long]("testState", 
Encoders.scalaLong)
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+
+      // simple put and get test
+      testState.appendValue(123)
+      assert(testState.get().toSeq === Seq(123))
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get().toSeq === Seq.empty[Long])
+
+      // put list test
+      testState.appendList(Array(123, 456))
+      assert(testState.get().toSeq === Seq(123, 456))
+      testState.appendValue(789)
+      assert(testState.get().toSeq === Seq(123, 456, 789))
+
+      testState.clear()
+      assert(!testState.exists())
+      assert(testState.get().toSeq === Seq.empty[Long])
+    }
+  }
+
+  test("List state operations for multiple instance") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+
+      val testState1: ListState[Long] = 
handle.getListState[Long]("testState1", Encoders.scalaLong)
+      val testState2: ListState[Long] = 
handle.getListState[Long]("testState2", Encoders.scalaLong)
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+
+      // simple put and get test
+      testState1.appendValue(123)
+      testState2.appendValue(456)
+      assert(testState1.get().toSeq === Seq(123))
+      assert(testState2.get().toSeq === Seq(456))
+      testState1.clear()
+      assert(!testState1.exists())
+      assert(testState2.exists())
+      assert(testState1.get().toSeq === Seq.empty[Long])
+
+      // put list test
+      testState1.appendList(Array(123, 456))
+      assert(testState1.get().toSeq === Seq(123, 456))
+      testState2.appendList(Array(123))
+      assert(testState2.get().toSeq === Seq(456, 123))
+
+      testState1.appendValue(789)
+      assert(testState1.get().toSeq === Seq(123, 456, 789))
+      assert(testState2.get().toSeq === Seq(456, 123))
+
+      testState2.clear()
+      assert(!testState2.exists())
+      assert(testState1.exists())
+      assert(testState2.get().toSeq === Seq.empty[Long])
+    }
+  }
+
+  test("List state operations with list, value, another list instances") {
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+      val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], 
TimeoutMode.NoTimeouts())
+
+      val listState1: ListState[Long] = 
handle.getListState[Long]("listState1", Encoders.scalaLong)
+      val listState2: ListState[Long] = 
handle.getListState[Long]("listState2", Encoders.scalaLong)
+      val valueState: ValueState[Long] = handle.getValueState[Long](
+        "valueState", Encoders.scalaLong)
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      // simple put and get test
+      valueState.update(123)
+      listState1.appendValue(123)
+      listState2.appendValue(456)
+      assert(listState1.get().toSeq === Seq(123))
+      assert(listState2.get().toSeq === Seq(456))
+      assert(valueState.get() === 123)
+
+      listState1.clear()
+      valueState.clear()
+      assert(!listState1.exists())
+      assert(listState2.exists())
+      assert(!valueState.exists())
+      assert(listState1.get().toSeq === Seq.empty[Long])
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
index f7aed2045793..ce72061d39ea 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{BinaryType, StructType}
  */
 class MapStateSuite extends StateVariableSuiteBase {
   // Overwrite Key schema as MapState use composite key
-  schemaForKeyRow = new StructType()
+  override def schemaForKeyRow: StructType = new StructType()
     .add("key", BinaryType)
     .add("userKey", BinaryType)
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
index 5d9a9cbcaae0..662a5dbfaac4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
@@ -19,76 +19,21 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.util.UUID
 
-import scala.util.Random
-
-import org.apache.hadoop.conf.Configuration
-import org.scalatest.BeforeAndAfter
-
 import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, StatefulProcessorHandleState}
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.TimeoutMode
-import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types._
 
 /**
  * Class that adds tests to verify operations based on stateful processor 
handle
  * used primarily in queries based on the `transformWithState` operator.
  */
-class StatefulProcessorHandleSuite extends SharedSparkSession
-  with BeforeAndAfter {
-
-  before {
-    StateStore.stop()
-    require(!StateStore.isMaintenanceRunning)
-  }
-
-  after {
-    StateStore.stop()
-    require(!StateStore.isMaintenanceRunning)
-  }
-
-  import StateStoreTestsHelper._
-
-  val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
-
-  val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
+class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
 
   private def keyExprEncoder: ExpressionEncoder[Any] =
     Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]
 
-  private def newStoreProviderWithHandle(useColumnFamilies: Boolean):
-    RocksDBStateStoreProvider = {
-    newStoreProviderWithHandle(StateStoreId(newDir(), Random.nextInt(), 0),
-      numColsPrefixKey = 0,
-      useColumnFamilies = useColumnFamilies)
-  }
-
-  private def newStoreProviderWithHandle(
-      storeId: StateStoreId,
-      numColsPrefixKey: Int,
-      sqlConf: Option[SQLConf] = None,
-      conf: Configuration = new Configuration,
-      useColumnFamilies: Boolean = false): RocksDBStateStoreProvider = {
-    val provider = new RocksDBStateStoreProvider()
-    provider.init(
-      storeId, schemaForKeyRow, schemaForValueRow, numColsPrefixKey = 
numColsPrefixKey,
-      useColumnFamilies,
-      new StateStoreConf(sqlConf.getOrElse(SQLConf.get)), conf)
-    provider
-  }
-
-  private def tryWithProviderResource[T](
-      provider: StateStoreProvider)(f: StateStoreProvider => T): T = {
-    try {
-      f(provider)
-    } finally {
-      provider.close()
-    }
-  }
-
   private def getTimeoutMode(timeoutMode: String): TimeoutMode = {
     timeoutMode match {
       case "NoTimeouts" => TimeoutMode.NoTimeouts()
@@ -100,7 +45,7 @@ class StatefulProcessorHandleSuite extends SharedSparkSession
 
   Seq("NoTimeouts", "ProcessingTime", "EventTime").foreach { timeoutMode =>
     test(s"value state creation with timeoutMode=$timeoutMode should succeed") 
{
-      tryWithProviderResource(newStoreProviderWithHandle(true)) { provider =>
+      tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
           UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
@@ -141,7 +86,7 @@ class StatefulProcessorHandleSuite extends SharedSparkSession
   Seq("NoTimeouts", "ProcessingTime", "EventTime").foreach { timeoutMode =>
     test(s"value state creation with timeoutMode=$timeoutMode " +
       "and invalid state should fail") {
-      tryWithProviderResource(newStoreProviderWithHandle(true)) { provider =>
+      tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
           UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
@@ -159,7 +104,7 @@ class StatefulProcessorHandleSuite extends 
SharedSparkSession
   }
 
   test("registering processing/event time timeouts with NoTimeout mode should 
fail") {
-    tryWithProviderResource(newStoreProviderWithHandle(true)) { provider =>
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val handle = new StatefulProcessorHandleImpl(store,
         UUID.randomUUID(), keyExprEncoder, TimeoutMode.NoTimeouts())
@@ -195,7 +140,7 @@ class StatefulProcessorHandleSuite extends 
SharedSparkSession
 
   Seq("ProcessingTime", "EventTime").foreach { timeoutMode =>
     test(s"registering timeouts with timeoutMode=$timeoutMode should succeed") 
{
-      tryWithProviderResource(newStoreProviderWithHandle(true)) { provider =>
+      tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
           UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
@@ -216,7 +161,7 @@ class StatefulProcessorHandleSuite extends 
SharedSparkSession
 
   Seq("ProcessingTime", "EventTime").foreach { timeoutMode =>
     test(s"verify listing of registered timers with timeoutMode=$timeoutMode") 
{
-      tryWithProviderResource(newStoreProviderWithHandle(true)) { provider =>
+      tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
           UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
@@ -256,7 +201,7 @@ class StatefulProcessorHandleSuite extends 
SharedSparkSession
 
   Seq("ProcessingTime", "EventTime").foreach { timeoutMode =>
     test(s"registering timeouts with timeoutMode=$timeoutMode and invalid 
state should fail") {
-      tryWithProviderResource(newStoreProviderWithHandle(true)) { provider =>
+      tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
         val store = provider.getStore(0)
         val handle = new StatefulProcessorHandleImpl(store,
           UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
new file mode 100644
index 000000000000..1aae0e0498aa
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
TimerStateImpl}
+import org.apache.spark.sql.streaming.TimeoutMode
+
+/**
+ * Class that adds unit tests for Timer State used in arbitrary stateful
+ * operators such as transformWithState
+ */
+class TimerSuite extends StateVariableSuiteBase {
+  private def testWithTimeOutMode(testName: String)
+      (testFunc: TimeoutMode => Unit): Unit = {
+    Seq("Processing", "Event").foreach { timeoutMode =>
+      test(s"$timeoutMode timer - " + testName) {
+        timeoutMode match {
+          case "Processing" => testFunc(TimeoutMode.ProcessingTime())
+          case "Event" => testFunc(TimeoutMode.EventTime())
+        }
+      }
+    }
+  }
+
+  testWithTimeOutMode("single instance with single key") { timeoutMode =>
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      val timerState = new TimerStateImpl(store, timeoutMode,
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+      timerState.registerTimer(1L * 1000)
+      assert(timerState.listTimers().toSet === Set(1000L))
+      assert(timerState.getExpiredTimers().toSet === Set(("test_key", 1000L)))
+
+      timerState.registerTimer(20L * 1000)
+      assert(timerState.listTimers().toSet === Set(20000L, 1000L))
+      timerState.deleteTimer(20000L)
+      assert(timerState.listTimers().toSet === Set(1000L))
+    }
+  }
+
+  testWithTimeOutMode("multiple instances with single key") { timeoutMode =>
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+      val timerState1 = new TimerStateImpl(store, timeoutMode,
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+      val timerState2 = new TimerStateImpl(store, timeoutMode,
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+      timerState1.registerTimer(1L * 1000)
+      timerState2.registerTimer(15L * 1000)
+      assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
+      assert(timerState1.getExpiredTimers().toSet ===
+        Set(("test_key", 15000L), ("test_key", 1000L)))
+      assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
+
+      timerState1.registerTimer(20L * 1000)
+      assert(timerState1.listTimers().toSet === Set(20000L, 15000L, 1000L))
+      timerState1.deleteTimer(20000L)
+      assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
+    }
+  }
+
+  testWithTimeOutMode("multiple instances with multiple keys") { timeoutMode =>
+    tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
+      val store = provider.getStore(0)
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
+      val timerState1 = new TimerStateImpl(store, timeoutMode,
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+      timerState1.registerTimer(1L * 1000)
+      timerState1.registerTimer(2L * 1000)
+      assert(timerState1.listTimers().toSet === Set(1000L, 2000L))
+      ImplicitGroupingKeyTracker.removeImplicitKey()
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key2")
+      val timerState2 = new TimerStateImpl(store, timeoutMode,
+        Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+      timerState2.registerTimer(15L * 1000)
+      ImplicitGroupingKeyTracker.removeImplicitKey()
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
+      assert(timerState1.getExpiredTimers().toSet ===
+        Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L)))
+      assert(timerState1.listTimers().toSet === Set(1000L, 2000L))
+      ImplicitGroupingKeyTracker.removeImplicitKey()
+
+      ImplicitGroupingKeyTracker.setImplicitKey("test_key2")
+      assert(timerState2.listTimers().toSet === Set(15000L))
+      assert(timerState2.getExpiredTimers().toSet ===
+        Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L)))
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index e423f9e7385a..e86ac03b70d9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -326,8 +326,10 @@ abstract class StateVariableSuiteBase extends 
SharedSparkSession
 
   import StateStoreTestsHelper._
 
-  protected var schemaForKeyRow: StructType = new StructType().add("key", 
BinaryType)
-  protected var schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)
+  protected def schemaForKeyRow: StructType = new StructType().add("key", 
BinaryType)
+  protected def schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)
+
+  protected def useMultipleValuesPerKey = false
 
   protected def newStoreProviderWithStateVariable(
       useColumnFamilies: Boolean): RocksDBStateStoreProvider = {
@@ -346,7 +348,7 @@ abstract class StateVariableSuiteBase extends 
SharedSparkSession
     provider.init(
       storeId, schemaForKeyRow, schemaForValueRow, numColsPrefixKey = 
numColsPrefixKey,
       useColumnFamilies,
-      new StateStoreConf(sqlConf), conf)
+      new StateStoreConf(sqlConf), conf, useMultipleValuesPerKey)
     provider
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to