anishshri-db commented on code in PR #43425: URL: https://github.com/apache/spark/pull/43425#discussion_r1376953649
########## sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala: ########## @@ -0,0 +1,670 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.state + +import java.io.{File, FileWriter} + +import org.scalatest.Assertions + +import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil +import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, OffsetSeqLog} +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.{IntegerType, StructType} + +class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase { + import testImplicits._ + + test("ERROR: read the state from stateless query") { + withTempDir { tempDir => + val inputData = MemoryStream[Int] + val df = inputData.toDF() + .selectExpr("value", "value % 2 AS value2") + + testStream(df)( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, 1, 2, 3, 4, 5), + CheckLastBatch((1, 1), (2, 0), (3, 1), (4, 0), (5, 1)), + AddData(inputData, 6, 7, 8), + CheckLastBatch((6, 0), (7, 1), (8, 0)) + ) + + intercept[IllegalArgumentException] { + spark.read.format("statestore").load(tempDir.getAbsolutePath) + } + } + } + + test("ERROR: no committed batch on default batch ID") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + val offsetLog = new OffsetSeqLog(spark, + new File(tempDir.getAbsolutePath, "offsets").getAbsolutePath) + val commitLog = new CommitLog(spark, + new File(tempDir.getAbsolutePath, "commits").getAbsolutePath) + + offsetLog.purgeAfter(0) + commitLog.purgeAfter(-1) + + intercept[IllegalStateException] { + spark.read.format("statestore").load(tempDir.getAbsolutePath) + } + } + } + + test("ERROR: corrupted state schema file") { + withTempDir { tempDir => + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + def rewriteStateSchemaFileToDummy(): Unit = { + // Refer to the StateSchemaCompatibilityChecker for the path of state schema file + val pathForSchema = Seq( + "state", "0", StateStore.PARTITION_ID_TO_CHECK_SCHEMA.toString, + "_metadata", "schema" + ).foldLeft(tempDir) { case (file, dirName) => + new File(file, dirName) + } + + assert(pathForSchema.exists()) + assert(pathForSchema.delete()) + + val fileWriter = new FileWriter(pathForSchema) + fileWriter.write("lol dummy corrupted schema file") + fileWriter.close() + + assert(pathForSchema.exists()) + } + + rewriteStateSchemaFileToDummy() + + intercept[IllegalArgumentException] { + spark.read.format("statestore").load(tempDir.getAbsolutePath) + } + } + } + + test("ERROR: path is not specified") { + intercept[IllegalArgumentException] { + spark.read.format("statestore").load() + } + } + + test("ERROR: operator ID specified to negative") { + withTempDir { tempDir => + intercept[IllegalArgumentException] { + spark.read.format("statestore") + .option(StateDataSource.PARAM_OPERATOR_ID, -1) + // trick to bypass getting the last committed batch before validating operator ID + .option(StateDataSource.PARAM_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + } + } + } + + test("ERROR: batch ID specified to negative") { + withTempDir { tempDir => + intercept[IllegalArgumentException] { + spark.read.format("statestore") + .option(StateDataSource.PARAM_BATCH_ID, -1) + .load(tempDir.getAbsolutePath) + } + } + } + + test("ERROR: invalid value for joinSide option") { + withTempDir { tempDir => + intercept[IllegalArgumentException] { + spark.read.format("statestore") + .option(StateDataSource.PARAM_JOIN_SIDE, "both") + // trick to bypass getting the last committed batch before validating operator ID + .option(StateDataSource.PARAM_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + } + } + } + + test("ERROR: both options `joinSide` and `storeName` are specified") { + withTempDir { tempDir => + intercept[IllegalArgumentException] { + spark.read.format("statestore") + .option(StateDataSource.PARAM_JOIN_SIDE, "right") + .option(StateDataSource.PARAM_STORE_NAME, "right-keyToNumValues") + // trick to bypass getting the last committed batch before validating operator ID + .option(StateDataSource.PARAM_BATCH_ID, 0) + .load(tempDir.getAbsolutePath) + } + } + } +} + +class StateDataSourceSQLConfigSuite extends StateDataSourceTestBase { + // Here we build a combination of test criteria for + // 1) number of shuffle partitions + // 2) state store provider + // 3) compression codec + // and run one of the test to verify that above configs work. + // We are building 3 x 2 x 4 = 24 different test criteria, and it's probably waste of time + // and resource to run all combinations for all times, hence we will randomly pick 5 tests + // per run. + + private val TEST_SHUFFLE_PARTITIONS = Seq(1, 3, 5) + private val TEST_PROVIDERS = Seq( + classOf[HDFSBackedStateStoreProvider].getName, + classOf[RocksDBStateStoreProvider].getName + ) + private val TEST_COMPRESSION_CODECS = CompressionCodec.ALL_COMPRESSION_CODECS + + private val ALL_COMBINATIONS = { + val comb = for ( + part <- TEST_SHUFFLE_PARTITIONS; + provider <- TEST_PROVIDERS; + codec <- TEST_COMPRESSION_CODECS + ) yield { + (part, provider, codec) + } + scala.util.Random.shuffle(comb) + } + + ALL_COMBINATIONS.take(5).foreach { case (part, provider, codec) => + val testName = s"Verify the read with config [part=$part][provider=$provider][codec=$codec]" + test(testName) { + withTempDir { tempDir => + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> part.toString, + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> provider, + SQLConf.STATE_STORE_COMPRESSION_CODEC.key -> codec) { + + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + + verifyLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + } + } + } + } + + test("Use different configs than session config") { + withTempDir { tempDir => + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName, + SQLConf.STATE_STORE_COMPRESSION_CODEC.key -> "zstd") { + + runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + } + + // Set the different values in session config, to validate whether state data source refers + // to the config in offset log. + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[HDFSBackedStateStoreProvider].getName, + SQLConf.STATE_STORE_COMPRESSION_CODEC.key -> "lz4") { + + verifyLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath) + } + } + } + + private def verifyLargeDataStreamingAggregationQuery(checkpointLocation: String): Unit = { + val operatorId = 0 + val batchId = 2 + + val stateReadDf = spark.read + .format("statestore") + .option(StateDataSource.PARAM_PATH, checkpointLocation) + // explicitly specifying batch ID and operator ID to test out the functionality + .option(StateDataSource.PARAM_BATCH_ID, batchId) + .option(StateDataSource.PARAM_OPERATOR_ID, operatorId) + .load() + + val resultDf = stateReadDf + .selectExpr("key.groupKey AS key_groupKey", "value.count AS value_cnt", + "value.sum AS value_sum", "value.max AS value_max", "value.min AS value_min") + + checkAnswer( + resultDf, + Seq( + Row(0, 5, 60, 30, 0), // 0, 10, 20, 30 + Row(1, 5, 65, 31, 1), // 1, 11, 21, 31 + Row(2, 5, 70, 32, 2), // 2, 12, 22, 32 + Row(3, 4, 72, 33, 3), // 3, 13, 23, 33 + Row(4, 4, 76, 34, 4), // 4, 14, 24, 34 + Row(5, 4, 80, 35, 5), // 5, 15, 25, 35 + Row(6, 4, 84, 36, 6), // 6, 16, 26, 36 + Row(7, 4, 88, 37, 7), // 7, 17, 27, 37 + Row(8, 4, 92, 38, 8), // 8, 18, 28, 38 + Row(9, 4, 96, 39, 9) // 9, 19, 29, 39 + ) + ) + } +} + +class HDFSBackedStateDataSourceReadSuite extends StateDataSourceReadSuite { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + classOf[HDFSBackedStateStoreProvider].getName) + } +} + +class RocksDBStateDataSourceReadSuite extends StateDataSourceReadSuite { + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, Review Comment: Should we also add a couple of tests with changelog checkpointing enabled ? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org