Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21739#discussion_r201597291
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
 ---
    @@ -0,0 +1,225 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.streaming.state
    +
    +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
    +import org.apache.spark.sql.catalyst.expressions._
    +import org.apache.spark.sql.execution.ObjectOperator
    +import org.apache.spark.sql.execution.streaming.GroupStateImpl
    +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
    +import org.apache.spark.sql.types._
    +
    +
    +object FlatMapGroupsWithStateExecHelper {
    +
    +  val supportedVersions = Seq(1, 2)
    +  val legacyVersion = 1
    +
    +  /**
    +   * Class to capture deserialized state and timestamp return by the state 
manager.
    +   * This is intended for reuse.
    +   */
    +  case class StateData(
    +      var keyRow: UnsafeRow = null,
    +      var stateRow: UnsafeRow = null,
    +      var stateObj: Any = null,
    +      var timeoutTimestamp: Long = -1) {
    +
    +    private[FlatMapGroupsWithStateExecHelper] def withNew(
    +        newKeyRow: UnsafeRow,
    +        newStateRow: UnsafeRow,
    +        newStateObj: Any,
    +        newTimeout: Long): this.type = {
    +      keyRow = newKeyRow
    +      stateRow = newStateRow
    +      stateObj = newStateObj
    +      timeoutTimestamp = newTimeout
    +      this
    +    }
    +  }
    +
    +  sealed trait StateManager extends Serializable {
    +    def stateSchema: StructType
    +    def getState(store: StateStore, keyRow: UnsafeRow): StateData
    +    def putState(store: StateStore, keyRow: UnsafeRow, state: Any, 
timeoutTimestamp: Long): Unit
    +    def removeState(store: StateStore, keyRow: UnsafeRow): Unit
    +    def getAllState(store: StateStore): Iterator[StateData]
    +    def version: Int
    +  }
    +
    +  def createStateManager(
    +      stateEncoder: ExpressionEncoder[Any],
    +      shouldStoreTimestamp: Boolean,
    +      stateFormatVersion: Int): StateManager = {
    +    stateFormatVersion match {
    +      case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp)
    +      case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp)
    +      case _ => throw new IllegalArgumentException(s"Version 
$stateFormatVersion is invalid")
    +    }
    +  }
    +
    +  // 
===============================================================================================
    +  // =========================== Private implementations of StateManager 
===========================
    +  // 
===============================================================================================
    +
    +  private abstract class StateManagerImplBase(val version: Int, 
shouldStoreTimestamp: Boolean)
    +    extends StateManager {
    +
    +    protected def stateSerializerExprs: Seq[Expression]
    +    protected def stateDeserializerExpr: Expression
    +    protected def timeoutTimestampOrdinalInRow: Int
    +
    +    /** Get deserialized state and corresponding timeout timestamp for a 
key */
    +    override def getState(store: StateStore, keyRow: UnsafeRow): StateData 
= {
    +      val stateRow = store.get(keyRow)
    +      stateDataForGets.withNew(keyRow, stateRow, getStateObject(stateRow), 
getTimestamp(stateRow))
    +    }
    +
    +    /** Put state and timeout timestamp for a key */
    +    override def putState(store: StateStore, key: UnsafeRow, state: Any, 
timestamp: Long): Unit = {
    +      val stateRow = getStateRow(state)
    +      setTimestamp(stateRow, timestamp)
    +      store.put(key, stateRow)
    +    }
    +
    +    override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = 
{
    +      store.remove(keyRow)
    +    }
    +
    +    override def getAllState(store: StateStore): Iterator[StateData] = {
    +      val stateData = StateData()
    +      store.getRange(None, None).map { p =>
    +        stateData.withNew(p.key, p.value, getStateObject(p.value), 
getTimestamp(p.value))
    +      }
    +    }
    +
    +    private lazy val stateSerializerFunc = 
ObjectOperator.serializeObjectToRow(stateSerializerExprs)
    +    private lazy val stateDeserializerFunc = {
    +      ObjectOperator.deserializeRowToObject(stateDeserializerExpr, 
stateSchema.toAttributes)
    +    }
    +    private lazy val stateDataForGets = StateData()
    +
    +    protected def getStateObject(row: UnsafeRow): Any = {
    +      if (row != null) stateDeserializerFunc(row) else null
    +    }
    +
    +    protected def getStateRow(obj: Any): UnsafeRow = {
    +      stateSerializerFunc(obj)
    +    }
    +
    +    /** Returns the timeout timestamp of a state row is set */
    +    private def getTimestamp(stateRow: UnsafeRow): Long = {
    +      if (shouldStoreTimestamp && stateRow != null) {
    +        stateRow.getLong(timeoutTimestampOrdinalInRow)
    +      } else NO_TIMESTAMP
    +    }
    +
    +    /** Set the timestamp in a state row */
    +    private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: 
Long): Unit = {
    +      if (shouldStoreTimestamp) 
stateRow.setLong(timeoutTimestampOrdinalInRow, timeoutTimestamps)
    +    }
    +  }
    +
    +
    +  private class StateManagerImplV1(
    +      stateEncoder: ExpressionEncoder[Any],
    +      shouldStoreTimestamp: Boolean) extends StateManagerImplBase(1, 
shouldStoreTimestamp) {
    +
    +    private val timestampTimeoutAttribute =
    +      AttributeReference("timeoutTimestamp", dataType = IntegerType, 
nullable = false)()
    +
    +    private val stateAttributes: Seq[Attribute] = {
    +      val encSchemaAttribs = stateEncoder.schema.toAttributes
    +      if (shouldStoreTimestamp) encSchemaAttribs :+ 
timestampTimeoutAttribute else encSchemaAttribs
    +    }
    +
    +    override val stateSchema: StructType = stateAttributes.toStructType
    +
    +    override val timeoutTimestampOrdinalInRow: Int = {
    +      stateAttributes.indexOf(timestampTimeoutAttribute)
    +    }
    +
    +    override val stateSerializerExprs: Seq[Expression] = {
    +      val encoderSerializer = stateEncoder.namedExpressions
    +      if (shouldStoreTimestamp) {
    +        encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
    +      } else {
    +        encoderSerializer
    +      }
    +    }
    +
    +    override val stateDeserializerExpr: Expression = {
    +      // Note that this must be done in the driver, as resolving and 
binding of deserializer
    +      // expressions to the encoded type can be safely done only in the 
driver.
    +      stateEncoder.resolveAndBind().deserializer
    +    }
    +
    +    override protected def getStateRow(obj: Any): UnsafeRow = {
    +      require(obj != null, "State object cannot be null")
    +      super.getStateRow(obj)
    +    }
    +  }
    +
    +
    +  private class StateManagerImplV2(
    --- End diff --
    
    Add docs explaining the state format


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to