This is an automated email from the ASF dual-hosted git repository. astitcher pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/qpid-proton.git
The following commit(s) were added to refs/heads/main by this push: new 53cc3940c PROTON-2879: [Python] Convenience iterators for sessions and links 53cc3940c is described below commit 53cc3940c547ff8224c81cdbf9f1035258100ff5 Author: Andrew Stitcher <astitc...@apache.org> AuthorDate: Wed Apr 30 22:16:17 2025 -0400 PROTON-2879: [Python] Convenience iterators for sessions and links --- python/proton/_endpoints.py | 32 ++++++++-- python/tests/proton_tests/engine.py | 122 ++++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+), 4 deletions(-) diff --git a/python/proton/_endpoints.py b/python/proton/_endpoints.py index 13b837444..11087055a 100644 --- a/python/proton/_endpoints.py +++ b/python/proton/_endpoints.py @@ -434,7 +434,7 @@ class Connection(Endpoint): else: return Session(ssn) - def session_head(self, mask: int) -> Optional['Session']: + def session_head(self, mask: EndpointState) -> Optional['Session']: """ Retrieve the first session from a given connection that matches the specified state mask. @@ -452,7 +452,19 @@ class Connection(Endpoint): """ return Session.wrap(pn_session_head(self._impl, mask)) - def link_head(self, mask: int) -> Optional[Union['Sender', 'Receiver']]: + def sessions(self, mask: EndpointState) -> Iterator['Session']: + """ + Returns a generator of sessions owned by the connection with the + given state mask. + + :return: Generator of sessions. + """ + session = self.session_head(mask) + while session: + yield session + session = session.next(mask) + + def link_head(self, mask: EndpointState) -> Optional['Link']: """ Retrieve the first link that matches the given state mask. @@ -469,6 +481,18 @@ class Connection(Endpoint): """ return Link.wrap(pn_link_head(self._impl, mask)) + def links(self, mask: EndpointState) -> Iterator['Link']: + """ + Returns a generator of links owned by this connection with the + given state mask. + + :return: Generator of links. + """ + link = self.link_head(mask) + while link: + yield link + link = link.next(mask) + @property def error(self): """ @@ -619,7 +643,7 @@ class Session(Endpoint): self._update_cond() pn_session_close(self._impl) - def next(self, mask): + def next(self, mask: EndpointState) -> Optional['Session']: """ Retrieve the next session for this connection that matches the specified state mask. @@ -935,7 +959,7 @@ class Link(Endpoint): """ return pn_link_queued(self._impl) - def next(self, mask: int) -> Optional[Union['Sender', 'Receiver']]: + def next(self, mask: EndpointState) -> Optional['Link']: """ Retrieve the next link that matches the given state mask. diff --git a/python/tests/proton_tests/engine.py b/python/tests/proton_tests/engine.py index 8b99539d3..6272af060 100644 --- a/python/tests/proton_tests/engine.py +++ b/python/tests/proton_tests/engine.py @@ -510,6 +510,62 @@ class SessionTest(Test): self.ssn.outgoing_window = 1024 assert self.ssn.outgoing_window == 1024 + def test_multiple_iterator(self): + ssn1 = self.ssn + ssn2 = self.c1.session() + ssn3 = self.c1.session() + + # Check that the iterator gets all sessions for no mask + ssns = [ssn1, ssn2, ssn3] + for ssn in self.c1.sessions(0): + assert ssn in ssns, ssn + ssns.remove(ssn) + assert not ssns, ssns + + # Check that every session starts uninitialized local and remote + ssns = [ssn1, ssn2, ssn3] + for ssn in self.c1.sessions(Endpoint.LOCAL_UNINIT | Endpoint.REMOTE_UNINIT): + assert ssn in ssns, ssn + ssns.remove(ssn) + assert not ssns, ssns + + for ssn in self.c1.sessions(0): + ssn.open() + + self.pump() + + ssns = [ssn1, ssn2, ssn3] + for ssn in self.c1.sessions(Endpoint.LOCAL_ACTIVE | Endpoint.REMOTE_UNINIT): + assert ssn in ssns, ssn + ssns.remove(ssn) + assert not ssns, ssns + + ssns = [ssn for ssn in self.c2.sessions(Endpoint.LOCAL_UNINIT | Endpoint.REMOTE_ACTIVE)] + assert len(ssns) == 3, ssns + + for ssn in self.c2.sessions(0): + ssn.open() + + self.pump() + + # Check that every session is now active local and remote + ssns = [ssn1, ssn2, ssn3] + for ssn in self.c1.sessions(Endpoint.LOCAL_ACTIVE | Endpoint.REMOTE_ACTIVE): + assert ssn in ssns, ssn + ssns.remove(ssn) + assert not ssns, ssns + + for ssn in self.c2.sessions(0): + ssn.close() + + self.pump() + + # Check that every session is now closed local and remote + ssns = [ssn1, ssn2, ssn3] + for ssn in self.c1.sessions(Endpoint.LOCAL_CLOSED | Endpoint.REMOTE_CLOSED): + assert ssn in ssns, ssn + ssns.remove(ssn) + class LinkTest(Test): @@ -621,6 +677,72 @@ class LinkTest(Test): conn.close() self.pump() + def test_multiple_iterator(self): + snd1 = self.snd + sess1 = self.snd.session + snd2 = sess1.sender('sender2') + snd3 = sess1.sender('sender3') + + # Check that the iterator gets all senders for no mask, and all senders + # are uninitialized local and remote + snds = [snd1, snd2, snd3] + for snd in sess1.connection.links(0): + assert snd.state == Endpoint.LOCAL_UNINIT | Endpoint.REMOTE_UNINIT, snd.state + assert snd in snds, snd + snds.remove(snd) + assert not snds, snds + + for snd in sess1.connection.links(0): + snd.open() + + self.pump() + + # Check that every sender starts uninitialized local and remote + snds = [snd1, snd2, snd3] + for snd in sess1.connection.links(Endpoint.LOCAL_ACTIVE | Endpoint.REMOTE_UNINIT): + assert snd in snds, f"{snd}, {snd.state} not in {snds}" + snds.remove(snd) + assert not snds, snds + + rcvs = [rcv for rcv in self.rcv.connection.links(Endpoint.LOCAL_UNINIT | Endpoint.REMOTE_ACTIVE)] + assert len(rcvs) == 3, rcvs + + for rcv in self.rcv.connection.links(0): + rcv.open() + + self.pump() + + # Check that every session is now active local and remote + snds = [snd1, snd2, snd3] + for snd in sess1.connection.links(Endpoint.LOCAL_ACTIVE | Endpoint.REMOTE_ACTIVE): + assert snd in snds, f"{snd}, {snd.state} not in {snds}" + snds.remove(snd) + assert not snds, snds + + for snd in sess1.connection.links(0): + snd.close() + + self.pump() + + # Check that every session is now closed local and active remote + snds = [snd1, snd2, snd3] + for snd in sess1.connection.links(Endpoint.LOCAL_CLOSED | Endpoint.REMOTE_ACTIVE): + assert snd in snds, f"{snd}, {snd.state} not in {snds}" + snds.remove(snd) + assert not snds, snds + + for rcv in self.rcv.connection.links(0): + rcv.close() + + self.pump() + + # Check that every session is now closed local and remote + snds = [snd1, snd2, snd3] + for snd in sess1.connection.links(Endpoint.LOCAL_CLOSED | Endpoint.REMOTE_CLOSED): + assert snd in snds, f"{snd}, {snd.state} not in {snds}" + snds.remove(snd) + assert not snds, snds + def test_closing_session(self): self.snd.open() self.rcv.open() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@qpid.apache.org For additional commands, e-mail: commits-h...@qpid.apache.org