Github user zsxwing commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19416#discussion_r142583033
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
 ---
    @@ -0,0 +1,143 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.streaming.state
    +
    +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
    +import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, 
GetStructField, IsNull, Literal, UnsafeRow}
    +import org.apache.spark.sql.execution.ObjectOperator
    +import org.apache.spark.sql.execution.streaming.GroupStateImpl
    +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
    +import org.apache.spark.sql.types.{IntegerType, LongType, StructType}
    +
    +
    +class FlatMapGroupsWithState_StateManager(
    +    stateEncoder: ExpressionEncoder[Any],
    +    shouldStoreTimestamp: Boolean) extends Serializable {
    +
    +  val stateSchema = {
    +    val schema = new StructType().add("groupState", stateEncoder.schema, 
nullable = true)
    +    if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) 
else schema
    +  }
    +
    +  def getState(store: StateStore, keyRow: UnsafeRow): 
FlatMapGroupsWithState_StateData = {
    +    val stateRow = store.get(keyRow)
    +    stateDataForGets.withNew(
    +      keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow))
    +  }
    +
    +  def putState(store: StateStore, keyRow: UnsafeRow, state: Any, 
timestamp: Long): Unit = {
    +    val stateRow = getStateRow(state)
    +    setTimestamp(stateRow, timestamp)
    +    store.put(keyRow, stateRow)
    +  }
    +
    +  def removeState(store: StateStore, keyRow: UnsafeRow): Unit = {
    +    store.remove(keyRow)
    +  }
    +
    +  def getAllState(store: StateStore): 
Iterator[FlatMapGroupsWithState_StateData] = {
    +    val stateDataForGetAllState = FlatMapGroupsWithState_StateData()
    +    store.getRange(None, None).map { pair =>
    +      stateDataForGetAllState.withNew(
    +        pair.key, pair.value, getStateObjFromRow(pair.value), 
getTimestamp(pair.value))
    +    }
    +  }
    +
    +  private val stateAttributes: Seq[Attribute] = stateSchema.toAttributes
    +
    +  // Get the serializer for the state, taking into account whether we need 
to save timestamps
    +  private val stateSerializer = {
    +    val nestedStateExpr = CreateNamedStruct(
    +      stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e)))
    +    if (shouldStoreTimestamp) {
    +      Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP))
    +    } else {
    +      Seq(nestedStateExpr)
    +    }
    +  }
    +
    +  // Get the deserializer for the state. Note that this must be done in 
the driver, as
    +  // resolving and binding of deserializer expressions to the encoded type 
can be safely done
    +  // only in the driver.
    +  private val stateDeserializer = {
    +    val boundRefToNestedState = BoundReference(nestedStateOrdinal, 
stateEncoder.schema, true)
    +    val deser = stateEncoder.resolveAndBind().deserializer.transformUp {
    +      case BoundReference(ordinal, _, _) => 
GetStructField(boundRefToNestedState, ordinal)
    +    }
    +    CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), 
elseValue = deser).toCodegen()
    +  }
    +
    +  private lazy val nestedStateOrdinal = 0
    +  private lazy val timeoutTimestampOrdinal = 1
    +
    +  // Converters for translating state between rows and Java objects
    +  private lazy val getStateObjFromRow = 
ObjectOperator.deserializeRowToObject(
    +    stateDeserializer, stateAttributes)
    +  private lazy val getStateRowFromObj = 
ObjectOperator.serializeObjectToRow(stateSerializer)
    +
    +  private lazy val stateDataForGets = FlatMapGroupsWithState_StateData()
    +
    +  /** Returns the state as Java object if defined */
    +  private def getStateObj(stateRow: UnsafeRow): Any = {
    +    if (stateRow == null) null
    +    // else if (stateRow.isNullAt(nestedStateOrdinal)) null
    --- End diff --
    
    nit: remove this


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to