Reviewed-by: Patrick Robb <pr...@iol.unh.edu>

I don't have any comments beyond Luca's suggestions, but saw the typo below.

On Tue, May 14, 2024 at 4:15 PM <jspew...@iol.unh.edu> wrote:
> +    def __exit__(self, type: BaseException, value: BaseException, traceback: 
> TracebackType) -> None:
> +        """Exit the context block.
> +
> +        Upon exiting a context block with this class, we want to ensure that 
> the instance of the
> +        application is explicitly closed and properly cleaned up using it's 
> close method. Note that

it's -> its

On Tue, May 14, 2024 at 4:15 PM <jspew...@iol.unh.edu> wrote:
>
> From: Jeremy Spewock <jspew...@iol.unh.edu>
>
> Interactive shells are managed in a way currently where they are closed
> and cleaned up at the time of garbage collection. Due to there being no
> guarantee of when this garbage collection happens in Python, there is no
> way to consistently know when an application will be closed without
> manually closing the application yourself when you are done with it.
> This doesn't cause a problem in cases where you can start another
> instance of the same application multiple times on a server, but this
> isn't the case for primary applications in DPDK. The introduction of
> primary applications, such as testpmd, adds a need for knowing previous
> instances of the application have been stopped and cleaned up before
> starting a new one, which the garbage collector does not provide.
>
> To solve this problem, a new class is added which acts as a wrapper
> around the interactive shell that enforces that instances of the
> application be managed using a context manager. Using a context manager
> guarantees that once you leave the scope of the block where the
> application is being used for any reason, the application will be closed
> immediately. This avoids the possibility of the shell not being closed
> due to an exception being raised or user error.
>
> depends-on: patch-139227 ("dts: skip test cases based on capabilities")
>
> Signed-off-by: Jeremy Spewock <jspew...@iol.unh.edu>
> ---
>  .../critical_interactive_shell.py             | 98 +++++++++++++++++++
>  .../remote_session/interactive_shell.py       | 13 ++-
>  dts/framework/remote_session/testpmd_shell.py |  4 +-
>  dts/framework/testbed_model/sut_node.py       |  8 +-
>  dts/tests/TestSuite_pmd_buffer_scatter.py     | 28 +++---
>  dts/tests/TestSuite_smoke_tests.py            |  3 +-
>  6 files changed, 130 insertions(+), 24 deletions(-)
>  create mode 100644 dts/framework/remote_session/critical_interactive_shell.py
>
> diff --git a/dts/framework/remote_session/critical_interactive_shell.py 
> b/dts/framework/remote_session/critical_interactive_shell.py
> new file mode 100644
> index 0000000000..d61b203954
> --- /dev/null
> +++ b/dts/framework/remote_session/critical_interactive_shell.py
> @@ -0,0 +1,98 @@
> +r"""Wrapper around :class:`~.interactive_shell.InteractiveShell` that 
> handles critical applications.
> +
> +Critical applications are defined as applications that require explicit 
> clean-up before another
> +instance of some application can be started. In DPDK these are referred to 
> as "primary
> +applications" and these applications take out a lock which stops other 
> primary applications from
> +running. Much like :class:`~.interactive_shell.InteractiveShell`\s,
> +:class:`CriticalInteractiveShell` is meant to be extended by subclasses that 
> implement application
> +specific functionality and should never be instantiated directly.
> +"""
> +
> +from types import TracebackType
> +from typing import Callable, TypeVar
> +
> +from paramiko import SSHClient  # type: ignore[import]
> +
> +from framework.logger import DTSLogger
> +from framework.settings import SETTINGS
> +
> +from .interactive_shell import InteractiveShell
> +
> +CriticalInteractiveShellType = TypeVar(
> +    "CriticalInteractiveShellType", bound="CriticalInteractiveShell"
> +)
> +
> +
> +class CriticalInteractiveShell(InteractiveShell):
> +    """The base class for interactive critical applications.
> +
> +    This class is a wrapper around 
> :class:`~.interactive_shell.InteractiveShell` and should always
> +    implement the exact same functionality with the primary difference being 
> how the application
> +    is started and stopped. In contrast to normal interactive shells, this 
> class does not start the
> +    application upon initialization of the class. Instead, the application 
> is handled through a
> +    context manager. This allows for more explicit starting and stopping of 
> the application, and
> +    more guarantees for when the application is cleaned up which are not 
> present with normal
> +    interactive shells that get cleaned up upon garbage collection.
> +    """
> +
> +    _get_priviledged_command: Callable[[str], str] | None
> +
> +    def __init__(
> +        self,
> +        interactive_session: SSHClient,
> +        logger: DTSLogger,
> +        get_privileged_command: Callable[[str], str] | None,
> +        app_args: str = "",
> +        timeout: float = SETTINGS.timeout,
> +    ) -> None:
> +        """Store parameters for creating an interactive shell, but do not 
> start the application.
> +
> +        Note that this method also does not create the channel for the 
> application, as this is
> +        something that isn't needed until the application starts.
> +
> +        Args:
> +            interactive_session: The SSH session dedicated to interactive 
> shells.
> +            logger: The logger instance this session will use.
> +            get_privileged_command: A method for modifying a command to 
> allow it to use
> +                elevated privileges. If :data:`None`, the application will 
> not be started
> +                with elevated privileges.
> +            app_args: The command line arguments to be passed to the 
> application on startup.
> +            timeout: The timeout used for the SSH channel that is dedicated 
> to this interactive
> +                shell. This timeout is for collecting output, so if reading 
> from the buffer
> +                and no output is gathered within the timeout, an exception 
> is thrown. The default
> +                value for this argument may be modified using the 
> :option:`--timeout` command-line
> +                argument or the :envvar:`DTS_TIMEOUT` environment variable.
> +        """
> +        self._interactive_session = interactive_session
> +        self._logger = logger
> +        self._timeout = timeout
> +        self._app_args = app_args
> +        self._get_priviledged_command = get_privileged_command
> +
> +    def __enter__(self: CriticalInteractiveShellType) -> 
> CriticalInteractiveShellType:
> +        """Enter the context block.
> +
> +        Upon entering a context block with this class, the desired behavior 
> is to create the
> +        channel for the application to use, and then start the application.
> +
> +        Returns:
> +            Reference to the object for the application after it has been 
> started.
> +        """
> +        self._init_channel()
> +        self._start_application(self._get_priviledged_command)
> +        return self
> +
> +    def __exit__(self, type: BaseException, value: BaseException, traceback: 
> TracebackType) -> None:
> +        """Exit the context block.
> +
> +        Upon exiting a context block with this class, we want to ensure that 
> the instance of the
> +        application is explicitly closed and properly cleaned up using it's 
> close method. Note that
> +        because this method returns :data:`None` if an exception was raised 
> within the block, it is
> +        not handled and will be re-raised after the application is closed.
> +
> +        Args:
> +            type: Type of exception that was thrown in the context block if 
> there was one.
> +            value: Value of the exception thrown in the context block if 
> there was one.
> +            traceback: Traceback of the exception thrown in the context 
> block if there was one.
> +        """
> +        self.close()
> diff --git a/dts/framework/remote_session/interactive_shell.py 
> b/dts/framework/remote_session/interactive_shell.py
> index d1a9d8a6d2..08b8ba6a3e 100644
> --- a/dts/framework/remote_session/interactive_shell.py
> +++ b/dts/framework/remote_session/interactive_shell.py
> @@ -89,16 +89,19 @@ def __init__(
>                  and no output is gathered within the timeout, an exception 
> is thrown.
>          """
>          self._interactive_session = interactive_session
> -        self._ssh_channel = self._interactive_session.invoke_shell()
> -        self._stdin = self._ssh_channel.makefile_stdin("w")
> -        self._stdout = self._ssh_channel.makefile("r")
> -        self._ssh_channel.settimeout(timeout)
> -        self._ssh_channel.set_combine_stderr(True)  # combines stdout and 
> stderr streams
>          self._logger = logger
>          self._timeout = timeout
>          self._app_args = app_args
> +        self._init_channel()
>          self._start_application(get_privileged_command)
>
> +    def _init_channel(self):
> +        self._ssh_channel = self._interactive_session.invoke_shell()
> +        self._stdin = self._ssh_channel.makefile_stdin("w")
> +        self._stdout = self._ssh_channel.makefile("r")
> +        self._ssh_channel.settimeout(self._timeout)
> +        self._ssh_channel.set_combine_stderr(True)  # combines stdout and 
> stderr streams
> +
>      def _start_application(self, get_privileged_command: Callable[[str], 
> str] | None) -> None:
>          """Starts a new interactive application based on the path to the app.
>
> diff --git a/dts/framework/remote_session/testpmd_shell.py 
> b/dts/framework/remote_session/testpmd_shell.py
> index cb4642bf3d..33b3e7c5a3 100644
> --- a/dts/framework/remote_session/testpmd_shell.py
> +++ b/dts/framework/remote_session/testpmd_shell.py
> @@ -26,7 +26,7 @@
>  from framework.settings import SETTINGS
>  from framework.utils import StrEnum
>
> -from .interactive_shell import InteractiveShell
> +from .critical_interactive_shell import CriticalInteractiveShell
>
>
>  class TestPmdDevice(object):
> @@ -82,7 +82,7 @@ class TestPmdForwardingModes(StrEnum):
>      recycle_mbufs = auto()
>
>
> -class TestPmdShell(InteractiveShell):
> +class TestPmdShell(CriticalInteractiveShell):
>      """Testpmd interactive shell.
>
>      The testpmd shell users should never use
> diff --git a/dts/framework/testbed_model/sut_node.py 
> b/dts/framework/testbed_model/sut_node.py
> index 1fb536735d..7dd39fd735 100644
> --- a/dts/framework/testbed_model/sut_node.py
> +++ b/dts/framework/testbed_model/sut_node.py
> @@ -243,10 +243,10 @@ def get_supported_capabilities(
>          unsupported_capas: set[NicCapability] = set()
>          self._logger.debug(f"Checking which capabilities from {capabilities} 
> NIC are supported.")
>          testpmd_shell = self.create_interactive_shell(TestPmdShell, 
> privileged=True)
> -        for capability in capabilities:
> -            if capability not in supported_capas or capability not in 
> unsupported_capas:
> -                capability.value(testpmd_shell, supported_capas, 
> unsupported_capas)
> -        del testpmd_shell
> +        with testpmd_shell as running_testpmd:
> +            for capability in capabilities:
> +                if capability not in supported_capas or capability not in 
> unsupported_capas:
> +                    capability.value(running_testpmd, supported_capas, 
> unsupported_capas)
>          return supported_capas
>
>      def _set_up_build_target(self, build_target_config: 
> BuildTargetConfiguration) -> None:
> diff --git a/dts/tests/TestSuite_pmd_buffer_scatter.py 
> b/dts/tests/TestSuite_pmd_buffer_scatter.py
> index 3701c47408..41f6090a7e 100644
> --- a/dts/tests/TestSuite_pmd_buffer_scatter.py
> +++ b/dts/tests/TestSuite_pmd_buffer_scatter.py
> @@ -101,7 +101,7 @@ def pmd_scatter(self, mbsize: int) -> None:
>          Test:
>              Start testpmd and run functional test with preset mbsize.
>          """
> -        testpmd = self.sut_node.create_interactive_shell(
> +        testpmd_shell = self.sut_node.create_interactive_shell(
>              TestPmdShell,
>              app_parameters=(
>                  "--mbcache=200 "
> @@ -112,17 +112,21 @@ def pmd_scatter(self, mbsize: int) -> None:
>              ),
>              privileged=True,
>          )
> -        testpmd.set_forward_mode(TestPmdForwardingModes.mac)
> -        testpmd.start()
> -
> -        for offset in [-1, 0, 1, 4, 5]:
> -            recv_payload = self.scatter_pktgen_send_packet(mbsize + offset)
> -            self._logger.debug(f"Payload of scattered packet after 
> forwarding: \n{recv_payload}")
> -            self.verify(
> -                ("58 " * 8).strip() in recv_payload,
> -                f"Payload of scattered packet did not match expected payload 
> with offset {offset}.",
> -            )
> -        testpmd.stop()
> +        with testpmd_shell as testpmd:
> +            testpmd.set_forward_mode(TestPmdForwardingModes.mac)
> +            testpmd.start()
> +
> +            for offset in [-1, 0, 1, 4, 5]:
> +                recv_payload = self.scatter_pktgen_send_packet(mbsize + 
> offset)
> +                self._logger.debug(
> +                    f"Payload of scattered packet after forwarding: 
> \n{recv_payload}"
> +                )
> +                self.verify(
> +                    ("58 " * 8).strip() in recv_payload,
> +                    "Payload of scattered packet did not match expected 
> payload with offset "
> +                    f"{offset}.",
> +                )
> +            testpmd.stop()
>
>      def test_scatter_mbuf_2048(self) -> None:
>          """Run the :meth:`pmd_scatter` test with `mbsize` set to 2048."""
> diff --git a/dts/tests/TestSuite_smoke_tests.py 
> b/dts/tests/TestSuite_smoke_tests.py
> index a553e89662..360e64eb5a 100644
> --- a/dts/tests/TestSuite_smoke_tests.py
> +++ b/dts/tests/TestSuite_smoke_tests.py
> @@ -100,7 +100,8 @@ def test_devices_listed_in_testpmd(self) -> None:
>              List all devices found in testpmd and verify the configured 
> devices are among them.
>          """
>          testpmd_driver = 
> self.sut_node.create_interactive_shell(TestPmdShell, privileged=True)
> -        dev_list = [str(x) for x in testpmd_driver.get_devices()]
> +        with testpmd_driver as testpmd:
> +            dev_list = [str(x) for x in testpmd.get_devices()]
>          for nic in self.nics_in_node:
>              self.verify(
>                  nic.pci in dev_list,
> --
> 2.44.0
>

Reply via email to