alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975687838


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, 
ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, 
UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import 
org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * 
[[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling 
`functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store 
for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a 
while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(

Review Comment:
   I wonder if this can be merged with the regular FlatMapGroupsWithStateExec. 
Maybe as a followup cleanup.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, 
ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, 
UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import 
org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * 
[[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling 
`functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store 
for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a 
while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(
+    functionExpr: Expression,
+    groupingAttributes: Seq[Attribute],
+    outAttributes: Seq[Attribute],
+    stateType: StructType,
+    stateInfo: Option[StatefulOperatorStateInfo],
+    stateFormatVersion: Int,
+    outputMode: OutputMode,
+    timeoutConf: GroupStateTimeout,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermark: Option[Long],
+    child: SparkPlan) extends UnaryExecNode with 
FlatMapGroupsWithStateExecBase {
+
+  // TODO(SPARK-40444): Add the support of initial state.
+  override protected val initialStateDeserializer: Expression = null
+  override protected val initialStateGroupAttrs: Seq[Attribute] = null
+  override protected val initialStateDataAttrs: Seq[Attribute] = null
+  override protected val initialState: SparkPlan = null
+  override protected val hasInitialState: Boolean = false
+
+  override protected val stateEncoder: ExpressionEncoder[Any] =
+    RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]
+
+  override def output: Seq[Attribute] = outAttributes
+
+  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+
+  private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func
+  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
+  private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(
+    groupingAttributes ++ child.output, groupingAttributes)
+  private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, 
child.output)
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    StatefulOperatorPartitioning.getCompatibleDistribution(
+      groupingAttributes, getStateInfo, conf) :: Nil
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
+    groupingAttributes.map(SortOrder(_, Ascending)))
+
+  override def shortName: String = "applyInPandasWithState"
+
+  override protected def withNewChildInternal(
+      newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = 
newChild)
+
+  override def createInputProcessor(
+      store: StateStore): InputProcessor = new InputProcessor(store: 
StateStore) {
+
+    override def processNewData(dataIter: Iterator[InternalRow]): 
Iterator[InternalRow] = {
+      val groupedIter = GroupedIterator(dataIter, groupingAttributes, 
child.output)
+      val processIter = groupedIter.map { case (keyRow, valueRowIter) =>
+        val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
+        val stateData = stateManager.getState(store, keyUnsafeRow)
+        (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj))
+      }
+
+      process(processIter, hasTimedOut = false)
+    }
+
+    override def processNewDataWithInitialState(
+        childDataIter: Iterator[InternalRow],
+        initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+      throw new UnsupportedOperationException("Should not reach here!")
+    }
+
+    override def processTimedOutState(): Iterator[InternalRow] = {
+      if (isTimeoutEnabled) {
+        val timeoutThreshold = timeoutConf match {
+          case ProcessingTimeTimeout => batchTimestampMs.get
+          case EventTimeTimeout => eventTimeWatermark.get
+          case _ =>
+            throw new IllegalStateException(
+              s"Cannot filter timed out keys for $timeoutConf")
+        }
+        val timingOutPairs = stateManager.getAllState(store).filter { state =>
+          state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < 
timeoutThreshold
+        }
+
+        val processIter = timingOutPairs.map { stateData =>
+          val joinedKeyRow = unsafeProj(
+            new JoinedRow(
+              stateData.keyRow,
+              new GenericInternalRow(Array.fill(dedupAttributes.length)(null: 
Any))))
+
+          (stateData.keyRow, stateData, Iterator.single(joinedKeyRow))
+        }
+
+        process(processIter, hasTimedOut = true)
+      } else Iterator.empty
+    }
+
+    private def process(
+        iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])],
+        hasTimedOut: Boolean): Iterator[InternalRow] = {
+      val runner = new ApplyInPandasWithStatePythonRunner(
+        chainedFunc,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        Array(argOffsets),
+        StructType.fromAttributes(dedupAttributes),
+        sessionLocalTimeZone,
+        pythonRunnerConf,
+        stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
+        groupingAttributes.toStructType,
+        child.output.toStructType,
+        stateType)
+
+      val context = TaskContext.get()
+
+      val processIter = iter.map { case (keyRow, stateData, valueIter) =>
+        val groupedState = GroupStateImpl.createForStreaming(
+          Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r 
},
+          batchTimestampMs.getOrElse(NO_TIMESTAMP),
+          eventTimeWatermark.getOrElse(NO_TIMESTAMP),
+          timeoutConf,
+          hasTimedOut = hasTimedOut,
+          watermarkPresent).asInstanceOf[GroupStateImpl[Row]]
+        (keyRow, groupedState, valueIter)
+      }
+      runner.compute(processIter, context.partitionId(), context).flatMap {
+        case (stateIter, outputIter) =>
+          // When the iterator is consumed, then write changes to state.
+          // state does not affect each others, hence when to update does not 
affect to the result.
+          def onIteratorCompletion: Unit = {
+            stateIter.foreach { case (keyRow, newGroupState, 
oldTimeoutTimestamp) =>
+              if (newGroupState.isRemoved && 
!newGroupState.getTimeoutTimestampMs.isPresent()) {
+                stateManager.removeState(store, keyRow)
+                numRemovedStateRows += 1
+              } else {
+                val currentTimeoutTimestamp = 
newGroupState.getTimeoutTimestampMs
+                  .orElse(NO_TIMESTAMP)
+                val hasTimeoutChanged = currentTimeoutTimestamp != 
oldTimeoutTimestamp
+                val shouldWriteState = newGroupState.isUpdated || 
newGroupState.isRemoved ||
+                  hasTimeoutChanged
+
+                if (shouldWriteState) {

Review Comment:
   what happens if 
   
   newGroupState.isRemoved && newGroupState.getTimeoutTimestampMs.isPresent()
   
   - basically if the state was removed but there is still timeout set? Will 
you keep the user state object around till the timeout fires?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, 
StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for 
data and state with
+ * bin-packing and chunking. The caller only need to call the proper public 
methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class 
will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk 
internally.
+ *
+ * This class requires that the parameter `root` has been initialized with the 
Arrow schema like
+ * below:
+ * - data fields
+ * - state field
+ *   - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA)
+ *
+ * Please refer the code comment in the implementation to see how the writes 
of data and state
+ * against Arrow RecordBatch work with consideration of bin-packing and 
chunking.
+ */
+class ApplyInPandasWithStateWriter(
+    root: VectorSchemaRoot,
+    writer: ArrowStreamWriter,
+    arrowMaxRecordsPerBatch: Int) {
+
+  import ApplyInPandasWithStateWriter._
+
+  // Unlike applyInPandas (and other PySpark operators), 
applyInPandasWithState requires to produce
+  // the additional data `state`, along with the input data.
+  //
+  // ArrowStreamWriter supports only single VectorSchemaRoot, which means all 
Arrow RecordBatches
+  // being sent out from ArrowStreamWriter should have same schema. That said, 
we have to construct
+  // "an" Arrow schema to contain both types of data, and also construct Arrow 
RecordBatches to

Review Comment:
   to contain both data and state, and also construct ArrowBatches to contain 
both data and state.



##########
python/pyspark/worker.py:
##########
@@ -207,6 +209,65 @@ def wrapped(key_series, value_series):
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+    def wrapped(key_series, value_series_gen, state):
+        import pandas as pd
+
+        key = tuple(s[0] for s in key_series)
+
+        if state.hasTimedOut:
+            # Timeout processing pass empty iterator. Here we return an empty 
DataFrame instead.
+            values = [
+                pd.DataFrame(columns=pd.concat(next(value_series_gen), 
axis=1).columns),
+            ]
+        else:
+            values = (pd.concat(x, axis=1) for x in value_series_gen)
+
+        result_iter = f(key, values, state)
+
+        def verify_element(result):
+            if not isinstance(result, pd.DataFrame):
+                raise TypeError(
+                    "The type of element in return iterator of the 
user-defined function "
+                    "should be pandas.DataFrame, but is 
{}".format(type(result))
+                )
+            # the number of columns of result have to match the return type
+            # but it is fine for result to have no columns at all if it is 
empty
+            if not (

Review Comment:
   ah, nevermind, I just misread the code.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, 
StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for 
data and state with
+ * bin-packing and chunking. The caller only need to call the proper public 
methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class 
will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk 
internally.
+ *
+ * This class requires that the parameter `root` has been initialized with the 
Arrow schema like
+ * below:
+ * - data fields
+ * - state field
+ *   - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA)
+ *
+ * Please refer the code comment in the implementation to see how the writes 
of data and state
+ * against Arrow RecordBatch work with consideration of bin-packing and 
chunking.
+ */
+class ApplyInPandasWithStateWriter(
+    root: VectorSchemaRoot,
+    writer: ArrowStreamWriter,
+    arrowMaxRecordsPerBatch: Int) {
+
+  import ApplyInPandasWithStateWriter._
+
+  // Unlike applyInPandas (and other PySpark operators), 
applyInPandasWithState requires to produce
+  // the additional data `state`, along with the input data.
+  //
+  // ArrowStreamWriter supports only single VectorSchemaRoot, which means all 
Arrow RecordBatches
+  // being sent out from ArrowStreamWriter should have same schema. That said, 
we have to construct
+  // "an" Arrow schema to contain both types of data, and also construct Arrow 
RecordBatches to
+  // contain both data.
+  //
+  // To achieve this, we extend the schema for input data to have a column for 
state at the end.
+  // But also, we logically group the columns by family (data vs state) and 
initialize writer
+  // separately, since it's lot more easier and probably performant to write 
the row directly
+  // rather than projecting the row to match up with the overall schema.
+  //
+  // Although Arrow RecordBatch enables to write the data as columnar, we 
figure out it gives
+  // strange outputs if we don't ensure that all columns have the same number 
of values. Since
+  // there are at least one data for a grouping key (we ensure this for the 
case of handling timed
+  // out state as well) whereas there is only one state for a grouping key, we 
have to fill up the
+  // empty rows in state side to ensure both have the same number of rows.
+  private val arrowWriterForData = createArrowWriter(
+    root.getFieldVectors.asScala.toSeq.dropRight(1))
+  private val arrowWriterForState = createArrowWriter(
+    root.getFieldVectors.asScala.toSeq.takeRight(1))
+
+  // - Bin-packing
+  //
+  // We apply bin-packing the data from multiple groups into one Arrow 
RecordBatch to
+  // gain the performance. In many cases, the amount of data per grouping key 
is quite
+  // small, which does not seem to maximize the benefits of using Arrow.
+  //
+  // We have to split the record batch down to each group in Python worker to 
convert the
+  // data for group to Pandas, but hopefully, Arrow RecordBatch provides the 
way to split
+  // the range of data and give a view, say, "zero-copy". To help splitting 
the range for
+  // data, we provide the "start offset" and the "number of data" in the state 
metadata.
+  //
+  // We don't bin-pack all groups into a single record batch - we have a limit 
on the number
+  // of rows in the current Arrow RecordBatch to stop adding next group.
+  //
+  // - Chunking
+  //
+  // We also chunk the data from single group into multiple Arrow RecordBatch 
to ensure
+  // scalability. Note that we don't know the volume (number of rows, overall 
size) of data for
+  // specific group key before we read the entire data. The easiest approach 
to address both
+  // bin-pack and chunk is to check the number of rows in the current Arrow 
RecordBatch for each
+  // write of row.
+  //
+  // - Consideration
+  //
+  // Since the number of rows in Arrow RecordBatch does not represent the 
actual size (bytes),
+  // the limit should be set very conservatively. Using a small number of 
limit does not introduce
+  // correctness issues.
+
+  private var numRowsForCurGroup = 0
+  private var startOffsetForCurGroup = 0
+  private var totalNumRowsForBatch = 0
+  private var totalNumStatesForBatch = 0
+
+  private var currentGroupKeyRow: UnsafeRow = _
+  private var currentGroupState: GroupStateImpl[Row] = _
+
+  /**
+   * Indicates writer to start with new grouping key.
+   *
+   * @param keyRow The grouping key row for current group.
+   * @param groupState The instance of GroupStateImpl for current group.
+   */
+  def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit 
= {
+    currentGroupKeyRow = keyRow
+    currentGroupState = groupState
+  }
+
+  /**
+   * Indicates writer to write a row in the current group.
+   *
+   * @param dataRow The row to write in the current group.
+   */
+  def writeRow(dataRow: InternalRow): Unit = {
+    // If it exceeds the condition of batch (number of records) and there is 
more data for the
+    // same group, finalize and construct a new batch.
+
+    if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+      // Provide state metadata row as intermediate
+      val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, 
currentGroupState,
+        startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = false)
+      arrowWriterForState.write(stateInfoRow)
+      totalNumStatesForBatch += 1
+
+      finalizeCurrentArrowBatch()
+    }
+
+    arrowWriterForData.write(dataRow)
+    numRowsForCurGroup += 1
+    totalNumRowsForBatch += 1
+  }
+
+  /**
+   * Indicates writer that current group has finalized and there will be no 
further row bound to
+   * the current group.
+   */
+  def finalizeGroup(): Unit = {
+    // Provide state metadata row
+    val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState,
+      startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = true)
+    arrowWriterForState.write(stateInfoRow)
+    totalNumStatesForBatch += 1
+
+    // The start offset for next group would be same as the total number of 
rows for batch,
+    // unless the next group starts with new batch.
+    startOffsetForCurGroup = totalNumRowsForBatch
+  }
+
+  /**
+   * Indicates writer that all groups have been processed.
+   */
+  def finalizeData(): Unit = {
+    if (numRowsForCurGroup > 0) {
+      // We still have some rows in the current record batch. Need to finalize 
them as well.
+      finalizeCurrentArrowBatch()
+    }
+  }
+
+  private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter = 
{
+    val children = fieldVectors.map { vector =>
+      vector.allocateNew()
+      createFieldWriter(vector)
+    }
+
+    new ArrowWriter(root, children.toArray)
+  }
+
+  private def buildStateInfoRow(
+      keyRow: UnsafeRow,
+      groupState: GroupStateImpl[Row],
+      startOffset: Int,
+      numRows: Int,
+      isLastChunk: Boolean): InternalRow = {
+    // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+    val stateUnderlyingRow = new GenericInternalRow(
+      Array[Any](
+        UTF8String.fromString(groupState.json()),
+        keyRow.getBytes,
+        groupState.getOption.map(PythonSQLUtils.toPyRow).orNull,
+        startOffset,
+        numRows,
+        isLastChunk
+      )
+    )
+    new GenericInternalRow(Array[Any](stateUnderlyingRow))
+  }
+
+  private def finalizeCurrentArrowBatch(): Unit = {
+    val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch
+    (0 until remainingEmptyStateRows).foreach { _ =>
+      arrowWriterForState.write(EMPTY_STATE_METADATA_ROW)
+    }
+
+    arrowWriterForState.finish()
+    arrowWriterForData.finish()
+    writer.writeBatch()
+    arrowWriterForState.reset()
+    arrowWriterForData.reset()
+
+    startOffsetForCurGroup = 0
+    numRowsForCurGroup = 0
+    totalNumRowsForBatch = 0
+    totalNumStatesForBatch = 0
+  }
+}
+
+object ApplyInPandasWithStateWriter {
+  val STATE_METADATA_SCHEMA: StructType = StructType(

Review Comment:
   please comment on the semantics of each column. Specifically isLastChunk is 
not obvious but important for the operation of the protocol.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to