This is an automated email from the ASF dual-hosted git repository.
astitcher pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/qpid-proton.git
The following commit(s) were added to refs/heads/main by this push:
new 53cc3940c PROTON-2879: [Python] Convenience iterators for sessions and
links
53cc3940c is described below
commit 53cc3940c547ff8224c81cdbf9f1035258100ff5
Author: Andrew Stitcher <[email protected]>
AuthorDate: Wed Apr 30 22:16:17 2025 -0400
PROTON-2879: [Python] Convenience iterators for sessions and links
---
python/proton/_endpoints.py | 32 ++++++++--
python/tests/proton_tests/engine.py | 122 ++++++++++++++++++++++++++++++++++++
2 files changed, 150 insertions(+), 4 deletions(-)
diff --git a/python/proton/_endpoints.py b/python/proton/_endpoints.py
index 13b837444..11087055a 100644
--- a/python/proton/_endpoints.py
+++ b/python/proton/_endpoints.py
@@ -434,7 +434,7 @@ class Connection(Endpoint):
else:
return Session(ssn)
- def session_head(self, mask: int) -> Optional['Session']:
+ def session_head(self, mask: EndpointState) -> Optional['Session']:
"""
Retrieve the first session from a given connection that matches the
specified state mask.
@@ -452,7 +452,19 @@ class Connection(Endpoint):
"""
return Session.wrap(pn_session_head(self._impl, mask))
- def link_head(self, mask: int) -> Optional[Union['Sender', 'Receiver']]:
+ def sessions(self, mask: EndpointState) -> Iterator['Session']:
+ """
+ Returns a generator of sessions owned by the connection with the
+ given state mask.
+
+ :return: Generator of sessions.
+ """
+ session = self.session_head(mask)
+ while session:
+ yield session
+ session = session.next(mask)
+
+ def link_head(self, mask: EndpointState) -> Optional['Link']:
"""
Retrieve the first link that matches the given state mask.
@@ -469,6 +481,18 @@ class Connection(Endpoint):
"""
return Link.wrap(pn_link_head(self._impl, mask))
+ def links(self, mask: EndpointState) -> Iterator['Link']:
+ """
+ Returns a generator of links owned by this connection with the
+ given state mask.
+
+ :return: Generator of links.
+ """
+ link = self.link_head(mask)
+ while link:
+ yield link
+ link = link.next(mask)
+
@property
def error(self):
"""
@@ -619,7 +643,7 @@ class Session(Endpoint):
self._update_cond()
pn_session_close(self._impl)
- def next(self, mask):
+ def next(self, mask: EndpointState) -> Optional['Session']:
"""
Retrieve the next session for this connection that matches the
specified state mask.
@@ -935,7 +959,7 @@ class Link(Endpoint):
"""
return pn_link_queued(self._impl)
- def next(self, mask: int) -> Optional[Union['Sender', 'Receiver']]:
+ def next(self, mask: EndpointState) -> Optional['Link']:
"""
Retrieve the next link that matches the given state mask.
diff --git a/python/tests/proton_tests/engine.py
b/python/tests/proton_tests/engine.py
index 8b99539d3..6272af060 100644
--- a/python/tests/proton_tests/engine.py
+++ b/python/tests/proton_tests/engine.py
@@ -510,6 +510,62 @@ class SessionTest(Test):
self.ssn.outgoing_window = 1024
assert self.ssn.outgoing_window == 1024
+ def test_multiple_iterator(self):
+ ssn1 = self.ssn
+ ssn2 = self.c1.session()
+ ssn3 = self.c1.session()
+
+ # Check that the iterator gets all sessions for no mask
+ ssns = [ssn1, ssn2, ssn3]
+ for ssn in self.c1.sessions(0):
+ assert ssn in ssns, ssn
+ ssns.remove(ssn)
+ assert not ssns, ssns
+
+ # Check that every session starts uninitialized local and remote
+ ssns = [ssn1, ssn2, ssn3]
+ for ssn in self.c1.sessions(Endpoint.LOCAL_UNINIT |
Endpoint.REMOTE_UNINIT):
+ assert ssn in ssns, ssn
+ ssns.remove(ssn)
+ assert not ssns, ssns
+
+ for ssn in self.c1.sessions(0):
+ ssn.open()
+
+ self.pump()
+
+ ssns = [ssn1, ssn2, ssn3]
+ for ssn in self.c1.sessions(Endpoint.LOCAL_ACTIVE |
Endpoint.REMOTE_UNINIT):
+ assert ssn in ssns, ssn
+ ssns.remove(ssn)
+ assert not ssns, ssns
+
+ ssns = [ssn for ssn in self.c2.sessions(Endpoint.LOCAL_UNINIT |
Endpoint.REMOTE_ACTIVE)]
+ assert len(ssns) == 3, ssns
+
+ for ssn in self.c2.sessions(0):
+ ssn.open()
+
+ self.pump()
+
+ # Check that every session is now active local and remote
+ ssns = [ssn1, ssn2, ssn3]
+ for ssn in self.c1.sessions(Endpoint.LOCAL_ACTIVE |
Endpoint.REMOTE_ACTIVE):
+ assert ssn in ssns, ssn
+ ssns.remove(ssn)
+ assert not ssns, ssns
+
+ for ssn in self.c2.sessions(0):
+ ssn.close()
+
+ self.pump()
+
+ # Check that every session is now closed local and remote
+ ssns = [ssn1, ssn2, ssn3]
+ for ssn in self.c1.sessions(Endpoint.LOCAL_CLOSED |
Endpoint.REMOTE_CLOSED):
+ assert ssn in ssns, ssn
+ ssns.remove(ssn)
+
class LinkTest(Test):
@@ -621,6 +677,72 @@ class LinkTest(Test):
conn.close()
self.pump()
+ def test_multiple_iterator(self):
+ snd1 = self.snd
+ sess1 = self.snd.session
+ snd2 = sess1.sender('sender2')
+ snd3 = sess1.sender('sender3')
+
+ # Check that the iterator gets all senders for no mask, and all senders
+ # are uninitialized local and remote
+ snds = [snd1, snd2, snd3]
+ for snd in sess1.connection.links(0):
+ assert snd.state == Endpoint.LOCAL_UNINIT |
Endpoint.REMOTE_UNINIT, snd.state
+ assert snd in snds, snd
+ snds.remove(snd)
+ assert not snds, snds
+
+ for snd in sess1.connection.links(0):
+ snd.open()
+
+ self.pump()
+
+ # Check that every sender starts uninitialized local and remote
+ snds = [snd1, snd2, snd3]
+ for snd in sess1.connection.links(Endpoint.LOCAL_ACTIVE |
Endpoint.REMOTE_UNINIT):
+ assert snd in snds, f"{snd}, {snd.state} not in {snds}"
+ snds.remove(snd)
+ assert not snds, snds
+
+ rcvs = [rcv for rcv in self.rcv.connection.links(Endpoint.LOCAL_UNINIT
| Endpoint.REMOTE_ACTIVE)]
+ assert len(rcvs) == 3, rcvs
+
+ for rcv in self.rcv.connection.links(0):
+ rcv.open()
+
+ self.pump()
+
+ # Check that every session is now active local and remote
+ snds = [snd1, snd2, snd3]
+ for snd in sess1.connection.links(Endpoint.LOCAL_ACTIVE |
Endpoint.REMOTE_ACTIVE):
+ assert snd in snds, f"{snd}, {snd.state} not in {snds}"
+ snds.remove(snd)
+ assert not snds, snds
+
+ for snd in sess1.connection.links(0):
+ snd.close()
+
+ self.pump()
+
+ # Check that every session is now closed local and active remote
+ snds = [snd1, snd2, snd3]
+ for snd in sess1.connection.links(Endpoint.LOCAL_CLOSED |
Endpoint.REMOTE_ACTIVE):
+ assert snd in snds, f"{snd}, {snd.state} not in {snds}"
+ snds.remove(snd)
+ assert not snds, snds
+
+ for rcv in self.rcv.connection.links(0):
+ rcv.close()
+
+ self.pump()
+
+ # Check that every session is now closed local and remote
+ snds = [snd1, snd2, snd3]
+ for snd in sess1.connection.links(Endpoint.LOCAL_CLOSED |
Endpoint.REMOTE_CLOSED):
+ assert snd in snds, f"{snd}, {snd.state} not in {snds}"
+ snds.remove(snd)
+ assert not snds, snds
+
def test_closing_session(self):
self.snd.open()
self.rcv.open()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]