rednaxelafx commented on a change in pull request #28707: URL: https://github.com/apache/spark/pull/28707#discussion_r434457124
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala ########## @@ -94,6 +110,28 @@ abstract class StreamingAggregationStateManagerBaseImpl( // discard and don't convert values to avoid computation store.getRange(None, None).map(_.key) } + + override def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit = { + if (checkFormat && SQLConf.get.getConf( + SQLConf.STREAMING_STATE_FORMAT_CHECK_ENABLED) && row != null) { + if (schema.fields.length != row.numFields) { Review comment: I was hoping we could move the core validation logic to either `UnsafeRow` itself, or some sort of `UnsafeRowUtils`, maybe somewhere in `sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util`. This util function would either return a boolean indicating passed/failed integrity check, or it could return more details. I'd probably go with the former first. It would not do any conf checks -- that's the caller's responsibility. This utility is useful for debugging low-level stuff in general, and would come in handy in both Spark SQL and Structured Streaming debugging. Then we can call that util function from here, after checking the confs. And the exception throwing logic can be left here too. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala ########## @@ -94,6 +112,28 @@ abstract class StreamingAggregationStateManagerBaseImpl( // discard and don't convert values to avoid computation store.getRange(None, None).map(_.key) } + + override def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit = { + if (checkFormat && SQLConf.get.getConf( + SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_CHECK_ENABLED) && row != null) { + if (schema.fields.length != row.numFields) { + throw new InvalidUnsafeRowException + } + schema.fields.zipWithIndex + .filterNot(field => UnsafeRow.isFixedLength(field._1.dataType)).foreach { + case (_, index) => + val offsetAndSize = row.getLong(index) + val offset = (offsetAndSize >> 32).toInt + val size = offsetAndSize.toInt + if (size < 0 || + offset < UnsafeRow.calculateBitSetWidthInBytes(row.numFields) + 8 * row.numFields || Review comment: `UnsafeRow.calculateBitSetWidthInBytes(row.numFields) + 8 * row.numFields` this part is loop invariant. Please hoist it out of the loop manually here. It's the same kind of logic as `UnsafeRowWriter`'s ```java this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.fixedSize = nullBitsSize + 8 * numFields; ``` We may want to use the same or similar names for the hoisted variables. `row.getSizeInBytes` on the next line is also loop invariant. Let's also hoist that out. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala ########## @@ -59,6 +61,9 @@ sealed trait StreamingAggregationStateManager extends Serializable { /** Return an iterator containing all the values in target state store. */ def values(store: StateStore): Iterator[UnsafeRow] + + /** Check the UnsafeRow format with the expected schema */ + def unsafeRowFormatValidation(row: UnsafeRow, schema: StructType): Unit Review comment: Nit: I'd like use "verb + noun" names for actions, and "nouns" for properties. Here it'd be some form of "validate structural integrity". WDYT? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org