bogao007 commented on code in PR #47878: URL: https://github.com/apache/spark/pull/47878#discussion_r1764092513
########## python/pyspark/sql/streaming/stateful_processor_api_client.py: ########## @@ -124,6 +129,130 @@ def get_value_state( # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}") + def register_timer(self, expiry_time_stamp_ms: int) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + register_call = stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms) + state_call_command = stateMessage.TimerStateCallCommand(register=register_call) + call = stateMessage.StatefulProcessorCall(timerStateCall=state_call_command) + message = stateMessage.StateRequest(statefulProcessorCall=call) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error register timer: " f"{response_message[1]}") + + def delete_timer(self, expiry_time_stamp_ms: int) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + delete_call = stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms) + state_call_command = stateMessage.TimerStateCallCommand(delete=delete_call) + call = stateMessage.StatefulProcessorCall(timerStateCall=state_call_command) + message = stateMessage.StateRequest(statefulProcessorCall=call) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error delete timers: " f"{response_message[1]}") + + def list_timers(self) -> Iterator[list[int]]: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + while True: + list_call = stateMessage.ListTimers() + state_call_command = stateMessage.TimerStateCallCommand(list=list_call) + call = stateMessage.StatefulProcessorCall(timerStateCall=state_call_command) + message = stateMessage.StateRequest(statefulProcessorCall=call) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status == 1: + break + elif status == 0: + iterator = self._read_arrow_state() + batch = next(iterator) + result_list = [] + batch_df = batch.to_pandas() + for i in range(batch.num_rows): + timestamp = batch_df.at[i, 'timestamp'].item() + result_list.append(timestamp) + yield result_list + else: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error getting expiry timers: " f"{response_message[1]}") + + def get_expiry_timers_iterator(self, expiry_timestamp: int) -> Iterator[list[Any, int]]: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + while True: + expiry_timer_call = stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp) + timer_request = stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call) + message = stateMessage.StateRequest(timerRequest=timer_request) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status == 1: + break + elif status == 0: + iterator = self._read_arrow_state() + batch = next(iterator) + result_list = [] + key_fields = [field.name for field in self.key_schema.fields] + # TODO any better way to restore a grouping object from a batch? Review Comment: Maybe take a look at how `load_stream` is implemented in [ApplyInPandasWithStateSerializer](https://github.com/apache/spark/blob/6fc176f4f34d73d6f6975836951562243343ba9a/python/pyspark/sql/pandas/serializers.py#L808-L948) and [TransformWithStateInPandasSerializer](https://github.com/apache/spark/blob/6fc176f4f34d73d6f6975836951562243343ba9a/python/pyspark/sql/pandas/serializers.py#L1152-L1184) in `pyspark/sql/pandas/serializers.py`. (and maybe some other customize serializers in the same file) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org