bogao007 commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1764092513


##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -124,6 +129,130 @@ def get_value_state(
             # TODO(SPARK-49233): Classify user facing errors.
             raise PySparkRuntimeError(f"Error initializing value state: " 
f"{response_message[1]}")
 
+    def register_timer(self, expiry_time_stamp_ms: int) -> None:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        register_call = 
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+        state_call_command = 
stateMessage.TimerStateCallCommand(register=register_call)
+        call = 
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+        message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error register timer: " 
f"{response_message[1]}")
+
+    def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        delete_call = 
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+        state_call_command = 
stateMessage.TimerStateCallCommand(delete=delete_call)
+        call = 
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+        message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error delete timers: " 
f"{response_message[1]}")
+
+    def list_timers(self) -> Iterator[list[int]]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        while True:
+            list_call = stateMessage.ListTimers()
+            state_call_command = 
stateMessage.TimerStateCallCommand(list=list_call)
+            call = 
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+            message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+            self._send_proto_message(message.SerializeToString())
+            response_message = self._receive_proto_message()
+            status = response_message[0]
+            if status == 1:
+                break
+            elif status == 0:
+                iterator = self._read_arrow_state()
+                batch = next(iterator)
+                result_list = []
+                batch_df = batch.to_pandas()
+                for i in range(batch.num_rows):
+                    timestamp = batch_df.at[i, 'timestamp'].item()
+                    result_list.append(timestamp)
+                yield result_list
+            else:
+                # TODO(SPARK-49233): Classify user facing errors.
+                raise PySparkRuntimeError(f"Error getting expiry timers: " 
f"{response_message[1]}")
+
+    def get_expiry_timers_iterator(self, expiry_timestamp: int) -> 
Iterator[list[Any, int]]:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        while True:
+            expiry_timer_call = 
stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
+            timer_request = 
stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
+            message = stateMessage.StateRequest(timerRequest=timer_request)
+
+            self._send_proto_message(message.SerializeToString())
+            response_message = self._receive_proto_message()
+            status = response_message[0]
+            if status == 1:
+                break
+            elif status == 0:
+                iterator = self._read_arrow_state()
+                batch = next(iterator)
+                result_list = []
+                key_fields = [field.name for field in self.key_schema.fields]
+                # TODO any better way to restore a grouping object from a 
batch?

Review Comment:
   Maybe take a look at how `load_stream` is implemented in 
[ApplyInPandasWithStateSerializer](https://github.com/apache/spark/blob/6fc176f4f34d73d6f6975836951562243343ba9a/python/pyspark/sql/pandas/serializers.py#L808-L948)
 and 
[TransformWithStateInPandasSerializer](https://github.com/apache/spark/blob/6fc176f4f34d73d6f6975836951562243343ba9a/python/pyspark/sql/pandas/serializers.py#L1152-L1184)
 in `pyspark/sql/pandas/serializers.py`. (and maybe some other customize 
serializers in the same file)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to