bogao007 commented on code in PR #48005:
URL: https://github.com/apache/spark/pull/48005#discussion_r1779172618


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -158,11 +175,70 @@ class TransformWithStateInPandasStateServer(
           Some(message.getGetValueState.getTtl.getDurationMs)
         } else None
         initializeValueState(stateName, schema, ttlDurationMs)
+      case StatefulProcessorCall.MethodCase.UTILSCALL =>
+        handleStatefulProcessorUtilRequest(message.getUtilsCall)
       case _ =>
         throw new IllegalArgumentException("Invalid method call")
     }
   }
 
+  private def handleStatefulProcessorUtilRequest(message: UtilsCallCommand): 
Unit = {

Review Comment:
   Should we add some scala unit tests for these 2 new APIs?



##########
sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto:
##########
@@ -106,3 +107,17 @@ message SetHandleState {
 message TTLConfig {
   int32 durationMs = 1;
 }
+
+message UtilsCallCommand {
+  oneof method {
+    IsFirstBatch isFirstBatch = 1;
+    GetInitialState getInitialState = 2;
+  }
+}
+
+message IsFirstBatch {
+}
+
+message GetInitialState {
+  bytes value = 1;

Review Comment:
   Should we name it `key`?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -158,11 +175,70 @@ class TransformWithStateInPandasStateServer(
           Some(message.getGetValueState.getTtl.getDurationMs)
         } else None
         initializeValueState(stateName, schema, ttlDurationMs)
+      case StatefulProcessorCall.MethodCase.UTILSCALL =>
+        handleStatefulProcessorUtilRequest(message.getUtilsCall)
       case _ =>
         throw new IllegalArgumentException("Invalid method call")
     }
   }
 
+  private def handleStatefulProcessorUtilRequest(message: UtilsCallCommand): 
Unit = {
+    message.getMethodCase match {
+      case UtilsCallCommand.MethodCase.ISFIRSTBATCH =>
+        if (!hasInitialState) {
+          // In physical planning, hasInitialState will always be flipped
+          // if it is not first batch
+          sendResponse(1)
+        } else {
+          sendResponse(0)
+        }
+
+      case UtilsCallCommand.MethodCase.GETINITIALSTATE =>
+        if (!hasInitialState || initialStateKeyToRowMap.isEmpty) {
+          sendResponse(1)
+        } else {
+          sendResponse(0)
+
+          outputStream.flush()
+          val arrowStreamWriter = {
+            val outputSchema = initialStateSchema
+            val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, 
timeZoneId,
+              errorOnDuplicatedFieldNames, largeVarTypes)
+            val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+              s"stdout writer for transformWithStateInPandas state socket", 0, 
Long.MaxValue)
+            val root = VectorSchemaRoot.create(arrowSchema, allocator)
+            new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, 
null, outputStream),
+              arrowTransformWithStateInPandasMaxRecordsPerBatch)
+          }
+
+          val keyBytes = message.getGetInitialState.getValue.toByteArray
+          // The key row is serialized as a byte array, we need to convert it 
back to a Row
+          val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, 
keyRowDeserializer)
+          val groupingKeyToInternalRow =
+            
ExpressionEncoder(groupingKeySchema).createSerializer().apply(keyRow)
+          val iter = initialStateKeyToRowMap
+            .get(groupingKeyToInternalRow).getOrElse(Iterator.empty)
+
+          var seenInitStateOnKey = false
+          while (iter.hasNext) {

Review Comment:
   IIUC, we only process the first item in the iterator. If that's the case, 
why do we need to use arrow to send that single row?
   
   Also, it seems the initial state map is `Map[InternalRow, 
Iterator[InternalRow]]`. Why do we have multiple value rows mapping to a single 
key row? Is that because we allow a pandas dataframe as an input of initial 
state on python side?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to