This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f0061dbe856a [SPARK-47449][SS] Refactor and split list/timer unit tests f0061dbe856a is described below commit f0061dbe856a55295cc95835aff5dc717aa19431 Author: jingz-db <jing.z...@databricks.com> AuthorDate: Wed Mar 20 09:21:04 2024 +0900 [SPARK-47449][SS] Refactor and split list/timer unit tests ### What changes were proposed in this pull request? Refactor StatefulProcessorHandle unit test suites. Add List state and timer state unit tests. As planned in test plan for state-v2, list/timer should be tested in both integration and unit tests. Currently StatefulProcessorHandle related tests could be refactored to use base suite class in `ValueStateSuite`, and list/timer state unit tests are needed in addition to integration tests. ### Why are the changes needed? Compliance with test plan for state-v2 project. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test suites refactored and added. ### Was this patch authored or co-authored using generative AI tooling? No Closes #45573 from jingz-db/split-timer-list-state-v2. Authored-by: jingz-db <jing.z...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../execution/streaming/state/ListStateSuite.scala | 163 +++++++++++++++++++++ .../execution/streaming/state/MapStateSuite.scala | 2 +- .../state/StatefulProcessorHandleSuite.scala | 69 +-------- .../sql/execution/streaming/state/TimerSuite.scala | 113 ++++++++++++++ .../streaming/state/ValueStateSuite.scala | 8 +- 5 files changed, 289 insertions(+), 66 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala new file mode 100644 index 000000000000..e895e475b74d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import java.util.UUID + +import org.apache.spark.SparkIllegalArgumentException +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl} +import org.apache.spark.sql.streaming.{ListState, TimeoutMode, ValueState} + +/** + * Class that adds unit tests for ListState types used in arbitrary stateful + * operators such as transformWithState + */ +class ListStateSuite extends StateVariableSuiteBase { + // overwrite useMultipleValuesPerKey in base suite to be true for list state + override def useMultipleValuesPerKey: Boolean = true + + private def testMapStateWithNullUserKey()(runListOps: ListState[Long] => Unit): Unit = { + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + + val listState: ListState[Long] = handle.getListState[Long]("listState", Encoders.scalaLong) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + val e = intercept[SparkIllegalArgumentException] { + runListOps(listState) + } + + checkError( + exception = e.asInstanceOf[SparkIllegalArgumentException], + errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE", + sqlState = Some("42601"), + parameters = Map("stateName" -> "listState") + ) + } + } + + Seq("appendList", "put").foreach { listImplFunc => + test(s"Test list operation($listImplFunc) with null") { + testMapStateWithNullUserKey() { listState => + listImplFunc match { + case "appendList" => listState.appendList(null) + case "put" => listState.put(null) + } + } + } + } + + test("List state operations for single instance") { + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + + val testState: ListState[Long] = handle.getListState[Long]("testState", Encoders.scalaLong) + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + + // simple put and get test + testState.appendValue(123) + assert(testState.get().toSeq === Seq(123)) + testState.clear() + assert(!testState.exists()) + assert(testState.get().toSeq === Seq.empty[Long]) + + // put list test + testState.appendList(Array(123, 456)) + assert(testState.get().toSeq === Seq(123, 456)) + testState.appendValue(789) + assert(testState.get().toSeq === Seq(123, 456, 789)) + + testState.clear() + assert(!testState.exists()) + assert(testState.get().toSeq === Seq.empty[Long]) + } + } + + test("List state operations for multiple instance") { + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + + val testState1: ListState[Long] = handle.getListState[Long]("testState1", Encoders.scalaLong) + val testState2: ListState[Long] = handle.getListState[Long]("testState2", Encoders.scalaLong) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + + // simple put and get test + testState1.appendValue(123) + testState2.appendValue(456) + assert(testState1.get().toSeq === Seq(123)) + assert(testState2.get().toSeq === Seq(456)) + testState1.clear() + assert(!testState1.exists()) + assert(testState2.exists()) + assert(testState1.get().toSeq === Seq.empty[Long]) + + // put list test + testState1.appendList(Array(123, 456)) + assert(testState1.get().toSeq === Seq(123, 456)) + testState2.appendList(Array(123)) + assert(testState2.get().toSeq === Seq(456, 123)) + + testState1.appendValue(789) + assert(testState1.get().toSeq === Seq(123, 456, 789)) + assert(testState2.get().toSeq === Seq(456, 123)) + + testState2.clear() + assert(!testState2.exists()) + assert(testState1.exists()) + assert(testState2.get().toSeq === Seq.empty[Long]) + } + } + + test("List state operations with list, value, another list instances") { + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]], TimeoutMode.NoTimeouts()) + + val listState1: ListState[Long] = handle.getListState[Long]("listState1", Encoders.scalaLong) + val listState2: ListState[Long] = handle.getListState[Long]("listState2", Encoders.scalaLong) + val valueState: ValueState[Long] = handle.getValueState[Long]( + "valueState", Encoders.scalaLong) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + // simple put and get test + valueState.update(123) + listState1.appendValue(123) + listState2.appendValue(456) + assert(listState1.get().toSeq === Seq(123)) + assert(listState2.get().toSeq === Seq(456)) + assert(valueState.get() === 123) + + listState1.clear() + valueState.clear() + assert(!listState1.exists()) + assert(listState2.exists()) + assert(!valueState.exists()) + assert(listState1.get().toSeq === Seq.empty[Long]) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala index f7aed2045793..ce72061d39ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{BinaryType, StructType} */ class MapStateSuite extends StateVariableSuiteBase { // Overwrite Key schema as MapState use composite key - schemaForKeyRow = new StructType() + override def schemaForKeyRow: StructType = new StructType() .add("key", BinaryType) .add("userKey", BinaryType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala index 5d9a9cbcaae0..662a5dbfaac4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala @@ -19,76 +19,21 @@ package org.apache.spark.sql.execution.streaming.state import java.util.UUID -import scala.util.Random - -import org.apache.hadoop.conf.Configuration -import org.scalatest.BeforeAndAfter - import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.TimeoutMode -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types._ /** * Class that adds tests to verify operations based on stateful processor handle * used primarily in queries based on the `transformWithState` operator. */ -class StatefulProcessorHandleSuite extends SharedSparkSession - with BeforeAndAfter { - - before { - StateStore.stop() - require(!StateStore.isMaintenanceRunning) - } - - after { - StateStore.stop() - require(!StateStore.isMaintenanceRunning) - } - - import StateStoreTestsHelper._ - - val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) - - val schemaForValueRow: StructType = new StructType().add("value", BinaryType) +class StatefulProcessorHandleSuite extends StateVariableSuiteBase { private def keyExprEncoder: ExpressionEncoder[Any] = Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]] - private def newStoreProviderWithHandle(useColumnFamilies: Boolean): - RocksDBStateStoreProvider = { - newStoreProviderWithHandle(StateStoreId(newDir(), Random.nextInt(), 0), - numColsPrefixKey = 0, - useColumnFamilies = useColumnFamilies) - } - - private def newStoreProviderWithHandle( - storeId: StateStoreId, - numColsPrefixKey: Int, - sqlConf: Option[SQLConf] = None, - conf: Configuration = new Configuration, - useColumnFamilies: Boolean = false): RocksDBStateStoreProvider = { - val provider = new RocksDBStateStoreProvider() - provider.init( - storeId, schemaForKeyRow, schemaForValueRow, numColsPrefixKey = numColsPrefixKey, - useColumnFamilies, - new StateStoreConf(sqlConf.getOrElse(SQLConf.get)), conf) - provider - } - - private def tryWithProviderResource[T]( - provider: StateStoreProvider)(f: StateStoreProvider => T): T = { - try { - f(provider) - } finally { - provider.close() - } - } - private def getTimeoutMode(timeoutMode: String): TimeoutMode = { timeoutMode match { case "NoTimeouts" => TimeoutMode.NoTimeouts() @@ -100,7 +45,7 @@ class StatefulProcessorHandleSuite extends SharedSparkSession Seq("NoTimeouts", "ProcessingTime", "EventTime").foreach { timeoutMode => test(s"value state creation with timeoutMode=$timeoutMode should succeed") { - tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) @@ -141,7 +86,7 @@ class StatefulProcessorHandleSuite extends SharedSparkSession Seq("NoTimeouts", "ProcessingTime", "EventTime").foreach { timeoutMode => test(s"value state creation with timeoutMode=$timeoutMode " + "and invalid state should fail") { - tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) @@ -159,7 +104,7 @@ class StatefulProcessorHandleSuite extends SharedSparkSession } test("registering processing/event time timeouts with NoTimeout mode should fail") { - tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), keyExprEncoder, TimeoutMode.NoTimeouts()) @@ -195,7 +140,7 @@ class StatefulProcessorHandleSuite extends SharedSparkSession Seq("ProcessingTime", "EventTime").foreach { timeoutMode => test(s"registering timeouts with timeoutMode=$timeoutMode should succeed") { - tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) @@ -216,7 +161,7 @@ class StatefulProcessorHandleSuite extends SharedSparkSession Seq("ProcessingTime", "EventTime").foreach { timeoutMode => test(s"verify listing of registered timers with timeoutMode=$timeoutMode") { - tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) @@ -256,7 +201,7 @@ class StatefulProcessorHandleSuite extends SharedSparkSession Seq("ProcessingTime", "EventTime").foreach { timeoutMode => test(s"registering timeouts with timeoutMode=$timeoutMode and invalid state should fail") { - tryWithProviderResource(newStoreProviderWithHandle(true)) { provider => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => val store = provider.getStore(0) val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(), keyExprEncoder, getTimeoutMode(timeoutMode)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala new file mode 100644 index 000000000000..1aae0e0498aa --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, TimerStateImpl} +import org.apache.spark.sql.streaming.TimeoutMode + +/** + * Class that adds unit tests for Timer State used in arbitrary stateful + * operators such as transformWithState + */ +class TimerSuite extends StateVariableSuiteBase { + private def testWithTimeOutMode(testName: String) + (testFunc: TimeoutMode => Unit): Unit = { + Seq("Processing", "Event").foreach { timeoutMode => + test(s"$timeoutMode timer - " + testName) { + timeoutMode match { + case "Processing" => testFunc(TimeoutMode.ProcessingTime()) + case "Event" => testFunc(TimeoutMode.EventTime()) + } + } + } + } + + testWithTimeOutMode("single instance with single key") { timeoutMode => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + val timerState = new TimerStateImpl(store, timeoutMode, + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + timerState.registerTimer(1L * 1000) + assert(timerState.listTimers().toSet === Set(1000L)) + assert(timerState.getExpiredTimers().toSet === Set(("test_key", 1000L))) + + timerState.registerTimer(20L * 1000) + assert(timerState.listTimers().toSet === Set(20000L, 1000L)) + timerState.deleteTimer(20000L) + assert(timerState.listTimers().toSet === Set(1000L)) + } + } + + testWithTimeOutMode("multiple instances with single key") { timeoutMode => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key") + val timerState1 = new TimerStateImpl(store, timeoutMode, + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + val timerState2 = new TimerStateImpl(store, timeoutMode, + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + timerState1.registerTimer(1L * 1000) + timerState2.registerTimer(15L * 1000) + assert(timerState1.listTimers().toSet === Set(15000L, 1000L)) + assert(timerState1.getExpiredTimers().toSet === + Set(("test_key", 15000L), ("test_key", 1000L))) + assert(timerState1.listTimers().toSet === Set(15000L, 1000L)) + + timerState1.registerTimer(20L * 1000) + assert(timerState1.listTimers().toSet === Set(20000L, 15000L, 1000L)) + timerState1.deleteTimer(20000L) + assert(timerState1.listTimers().toSet === Set(15000L, 1000L)) + } + } + + testWithTimeOutMode("multiple instances with multiple keys") { timeoutMode => + tryWithProviderResource(newStoreProviderWithStateVariable(true)) { provider => + val store = provider.getStore(0) + + ImplicitGroupingKeyTracker.setImplicitKey("test_key1") + val timerState1 = new TimerStateImpl(store, timeoutMode, + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + timerState1.registerTimer(1L * 1000) + timerState1.registerTimer(2L * 1000) + assert(timerState1.listTimers().toSet === Set(1000L, 2000L)) + ImplicitGroupingKeyTracker.removeImplicitKey() + + ImplicitGroupingKeyTracker.setImplicitKey("test_key2") + val timerState2 = new TimerStateImpl(store, timeoutMode, + Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]]) + timerState2.registerTimer(15L * 1000) + ImplicitGroupingKeyTracker.removeImplicitKey() + + ImplicitGroupingKeyTracker.setImplicitKey("test_key1") + assert(timerState1.getExpiredTimers().toSet === + Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L))) + assert(timerState1.listTimers().toSet === Set(1000L, 2000L)) + ImplicitGroupingKeyTracker.removeImplicitKey() + + ImplicitGroupingKeyTracker.setImplicitKey("test_key2") + assert(timerState2.listTimers().toSet === Set(15000L)) + assert(timerState2.getExpiredTimers().toSet === + Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala index e423f9e7385a..e86ac03b70d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala @@ -326,8 +326,10 @@ abstract class StateVariableSuiteBase extends SharedSparkSession import StateStoreTestsHelper._ - protected var schemaForKeyRow: StructType = new StructType().add("key", BinaryType) - protected var schemaForValueRow: StructType = new StructType().add("value", BinaryType) + protected def schemaForKeyRow: StructType = new StructType().add("key", BinaryType) + protected def schemaForValueRow: StructType = new StructType().add("value", BinaryType) + + protected def useMultipleValuesPerKey = false protected def newStoreProviderWithStateVariable( useColumnFamilies: Boolean): RocksDBStateStoreProvider = { @@ -346,7 +348,7 @@ abstract class StateVariableSuiteBase extends SharedSparkSession provider.init( storeId, schemaForKeyRow, schemaForValueRow, numColsPrefixKey = numColsPrefixKey, useColumnFamilies, - new StateStoreConf(sqlConf), conf) + new StateStoreConf(sqlConf), conf, useMultipleValuesPerKey) provider } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org