sven-weber-db commented on code in PR #55716: URL: https://github.com/apache/spark/pull/55716#discussion_r3280917966
########## python/pyspark/tests/test_spark_message_receiver.py: ########## @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import io +import unittest +from typing import BinaryIO + +from pyspark.messages.spark_message_receiver import SparkMessageReceiver +from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream + + +class StubMessageReceiver(SparkMessageReceiver): + """Concrete stub for testing the state machine in SparkMessageReceiver.""" + + def __init__(self) -> None: + super().__init__() + self.stream_finished = True + + def _do_get_init_message(self) -> ZeroCopyByteStream: + return ZeroCopyByteStream(memoryview(b"init")) + + def _do_get_data_stream(self) -> BinaryIO: + return io.BytesIO(b"data") + + def _do_is_stream_finished(self) -> bool: + return self.stream_finished + + +class SparkMessageReceiverTests(unittest.TestCase): + """Tests for SparkMessageReceiver state transitions.""" + + def test_happy_path(self): + """Calling init -> data -> finish in order succeeds.""" + receiver = StubMessageReceiver() + init_msg = receiver.get_init_message() + self.assertIsInstance(init_msg, ZeroCopyByteStream) + data = receiver.get_data_stream() + self.assertEqual(data.read(), b"data") + self.assertTrue(receiver.is_stream_finished()) + + def test_invalid_transitions_fail(self): + """Calling methods out of order raises AssertionError.""" + # Each entry: (setup_calls, invalid_call, description) + # setup_calls are executed first to reach a certain state, + # then invalid_call is expected to raise AssertionError. + cases = [ Review Comment: Makes sense. I reduced the number of tests to only test that calling a method twice will fail. I do agree with earlier review comments that we should have at least some validation of the state machine inside the message receiver. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
