anishshri-db commented on code in PR #43425:
URL: https://github.com/apache/spark/pull/43425#discussion_r1376953649


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala:
##########
@@ -0,0 +1,670 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.io.{File, FileWriter}
+
+import org.scalatest.Assertions
+
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, Row}
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, 
GenericInternalRow}
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, 
OffsetSeqLog}
+import 
org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, 
RocksDBStateStoreProvider, StateStore}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase {
+  import testImplicits._
+
+  test("ERROR: read the state from stateless query") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[Int]
+      val df = inputData.toDF()
+        .selectExpr("value", "value % 2 AS value2")
+
+      testStream(df)(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, 1, 2, 3, 4, 5),
+        CheckLastBatch((1, 1), (2, 0), (3, 1), (4, 0), (5, 1)),
+        AddData(inputData, 6, 7, 8),
+        CheckLastBatch((6, 0), (7, 1), (8, 0))
+      )
+
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore").load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: no committed batch on default batch ID") {
+    withTempDir { tempDir =>
+      runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+      val offsetLog = new OffsetSeqLog(spark,
+        new File(tempDir.getAbsolutePath, "offsets").getAbsolutePath)
+      val commitLog = new CommitLog(spark,
+        new File(tempDir.getAbsolutePath, "commits").getAbsolutePath)
+
+      offsetLog.purgeAfter(0)
+      commitLog.purgeAfter(-1)
+
+      intercept[IllegalStateException] {
+        spark.read.format("statestore").load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: corrupted state schema file") {
+    withTempDir { tempDir =>
+      runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+      def rewriteStateSchemaFileToDummy(): Unit = {
+        // Refer to the StateSchemaCompatibilityChecker for the path of state 
schema file
+        val pathForSchema = Seq(
+          "state", "0", StateStore.PARTITION_ID_TO_CHECK_SCHEMA.toString,
+          "_metadata", "schema"
+        ).foldLeft(tempDir) { case (file, dirName) =>
+          new File(file, dirName)
+        }
+
+        assert(pathForSchema.exists())
+        assert(pathForSchema.delete())
+
+        val fileWriter = new FileWriter(pathForSchema)
+        fileWriter.write("lol dummy corrupted schema file")
+        fileWriter.close()
+
+        assert(pathForSchema.exists())
+      }
+
+      rewriteStateSchemaFileToDummy()
+
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore").load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: path is not specified") {
+    intercept[IllegalArgumentException] {
+      spark.read.format("statestore").load()
+    }
+  }
+
+  test("ERROR: operator ID specified to negative") {
+    withTempDir { tempDir =>
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore")
+          .option(StateDataSource.PARAM_OPERATOR_ID, -1)
+          // trick to bypass getting the last committed batch before 
validating operator ID
+          .option(StateDataSource.PARAM_BATCH_ID, 0)
+          .load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: batch ID specified to negative") {
+    withTempDir { tempDir =>
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore")
+          .option(StateDataSource.PARAM_BATCH_ID, -1)
+          .load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: invalid value for joinSide option") {
+    withTempDir { tempDir =>
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore")
+          .option(StateDataSource.PARAM_JOIN_SIDE, "both")
+          // trick to bypass getting the last committed batch before 
validating operator ID
+          .option(StateDataSource.PARAM_BATCH_ID, 0)
+          .load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  test("ERROR: both options `joinSide` and `storeName` are specified") {
+    withTempDir { tempDir =>
+      intercept[IllegalArgumentException] {
+        spark.read.format("statestore")
+          .option(StateDataSource.PARAM_JOIN_SIDE, "right")
+          .option(StateDataSource.PARAM_STORE_NAME, "right-keyToNumValues")
+          // trick to bypass getting the last committed batch before 
validating operator ID
+          .option(StateDataSource.PARAM_BATCH_ID, 0)
+          .load(tempDir.getAbsolutePath)
+      }
+    }
+  }
+}
+
+class StateDataSourceSQLConfigSuite extends StateDataSourceTestBase {
+  // Here we build a combination of test criteria for
+  // 1) number of shuffle partitions
+  // 2) state store provider
+  // 3) compression codec
+  // and run one of the test to verify that above configs work.
+  // We are building 3 x 2 x 4 = 24 different test criteria, and it's probably 
waste of time
+  // and resource to run all combinations for all times, hence we will 
randomly pick 5 tests
+  // per run.
+
+  private val TEST_SHUFFLE_PARTITIONS = Seq(1, 3, 5)
+  private val TEST_PROVIDERS = Seq(
+    classOf[HDFSBackedStateStoreProvider].getName,
+    classOf[RocksDBStateStoreProvider].getName
+  )
+  private val TEST_COMPRESSION_CODECS = CompressionCodec.ALL_COMPRESSION_CODECS
+
+  private val ALL_COMBINATIONS = {
+    val comb = for (
+      part <- TEST_SHUFFLE_PARTITIONS;
+      provider <- TEST_PROVIDERS;
+      codec <- TEST_COMPRESSION_CODECS
+    ) yield {
+      (part, provider, codec)
+    }
+    scala.util.Random.shuffle(comb)
+  }
+
+  ALL_COMBINATIONS.take(5).foreach { case (part, provider, codec) =>
+    val testName = s"Verify the read with config 
[part=$part][provider=$provider][codec=$codec]"
+    test(testName) {
+      withTempDir { tempDir =>
+        withSQLConf(
+          SQLConf.SHUFFLE_PARTITIONS.key -> part.toString,
+          SQLConf.STATE_STORE_PROVIDER_CLASS.key -> provider,
+          SQLConf.STATE_STORE_COMPRESSION_CODEC.key -> codec) {
+
+          runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+
+          verifyLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+        }
+      }
+    }
+  }
+
+  test("Use different configs than session config") {
+    withTempDir { tempDir =>
+      withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.STATE_STORE_COMPRESSION_CODEC.key -> "zstd") {
+
+        runLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+      }
+
+      // Set the different values in session config, to validate whether state 
data source refers
+      // to the config in offset log.
+      withSQLConf(
+        SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[HDFSBackedStateStoreProvider].getName,
+        SQLConf.STATE_STORE_COMPRESSION_CODEC.key -> "lz4") {
+
+        verifyLargeDataStreamingAggregationQuery(tempDir.getAbsolutePath)
+      }
+    }
+  }
+
+  private def verifyLargeDataStreamingAggregationQuery(checkpointLocation: 
String): Unit = {
+    val operatorId = 0
+    val batchId = 2
+
+    val stateReadDf = spark.read
+      .format("statestore")
+      .option(StateDataSource.PARAM_PATH, checkpointLocation)
+      // explicitly specifying batch ID and operator ID to test out the 
functionality
+      .option(StateDataSource.PARAM_BATCH_ID, batchId)
+      .option(StateDataSource.PARAM_OPERATOR_ID, operatorId)
+      .load()
+
+    val resultDf = stateReadDf
+      .selectExpr("key.groupKey AS key_groupKey", "value.count AS value_cnt",
+        "value.sum AS value_sum", "value.max AS value_max", "value.min AS 
value_min")
+
+    checkAnswer(
+      resultDf,
+      Seq(
+        Row(0, 5, 60, 30, 0), // 0, 10, 20, 30
+        Row(1, 5, 65, 31, 1), // 1, 11, 21, 31
+        Row(2, 5, 70, 32, 2), // 2, 12, 22, 32
+        Row(3, 4, 72, 33, 3), // 3, 13, 23, 33
+        Row(4, 4, 76, 34, 4), // 4, 14, 24, 34
+        Row(5, 4, 80, 35, 5), // 5, 15, 25, 35
+        Row(6, 4, 84, 36, 6), // 6, 16, 26, 36
+        Row(7, 4, 88, 37, 7), // 7, 17, 27, 37
+        Row(8, 4, 92, 38, 8), // 8, 18, 28, 38
+        Row(9, 4, 96, 39, 9) // 9, 19, 29, 39
+      )
+    )
+  }
+}
+
+class HDFSBackedStateDataSourceReadSuite extends StateDataSourceReadSuite {
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+      classOf[HDFSBackedStateStoreProvider].getName)
+  }
+}
+
+class RocksDBStateDataSourceReadSuite extends StateDataSourceReadSuite {
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,

Review Comment:
   Should we also add a couple of tests with changelog checkpointing enabled ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to