This is an automated email from the ASF dual-hosted git repository.
skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git
The following commit(s) were added to refs/heads/main by this push:
new e6f03974 fix: support `async with` on async persister factory methods
(#681)
e6f03974 is described below
commit e6f03974a5db50f1d790610b5ee5020b913c03df
Author: André Ahlert <[email protected]>
AuthorDate: Sat Mar 28 16:13:24 2026 -0300
fix: support `async with` on async persister factory methods (#681)
* fix: support `async with` on async persister factory methods
`AsyncSQLitePersister.from_values()` and
`AsyncPostgreSQLPersister.from_values()`
were async classmethods returning coroutines, which cannot be used directly
with `async with`. This wraps them in `_AsyncPersisterContextManager` that
supports both `await` (backwards compatible) and `async with` protocols.
Closes #546
* refactor: move _AsyncPersisterContextManager to
burr/common/async_utils.py and add tests
Address review feedback:
- Move _AsyncPersisterContextManager from b_aiosqlite.py to
burr/common/async_utils.py to avoid cross-dependency between
unrelated integrations
- Add type annotation to coro parameter
- Add tests for async with pattern on from_values and from_config
* fix: guard against __aexit__ crash and double consumption in context
manager wrapper
- __aexit__ now returns False when __aenter__ failed (persister is None),
preventing AttributeError that would mask the original exception
- Add _consumed flag to prevent silent coroutine reuse, raising
RuntimeError with clear message on second await/async with
- Add tests for both edge cases
* docs: fix async persister examples and remove redundant test helpers
- Add missing `await` to from_values() calls in parallelism.rst docs
- Remove AsyncSQLiteContextManager helper class from both test files,
now that from_values() natively supports async with
- Replace deprecated .close() calls with .cleanup() in test fixtures
* style: fix black formatting and isort import order
* style: reformat with black 23.11.0 (project version)
* ci: trigger workflow rerun
---
burr/common/async_utils.py | 42 +++++++++++++-
burr/integrations/persisters/b_aiosqlite.py | 35 +++++++++---
burr/integrations/persisters/b_asyncpg.py | 61 ++++++++++++++-------
docs/concepts/parallelism.rst | 8 +--
tests/core/test_persistence.py | 28 ++++------
tests/integrations/persisters/test_b_aiosqlite.py | 67 ++++++++++++++++-------
6 files changed, 171 insertions(+), 70 deletions(-)
diff --git a/burr/common/async_utils.py b/burr/common/async_utils.py
index b1f60881..56ce75ff 100644
--- a/burr/common/async_utils.py
+++ b/burr/common/async_utils.py
@@ -16,7 +16,7 @@
# under the License.
import inspect
-from typing import AsyncGenerator, AsyncIterable, Generator, List, TypeVar,
Union
+from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Generator,
List, TypeVar, Union
T = TypeVar("T")
@@ -27,6 +27,46 @@ SyncOrAsyncGenerator = Union[Generator[GenType, None, None],
AsyncGenerator[GenT
SyncOrAsyncGeneratorOrItemOrList = Union[SyncOrAsyncGenerator[GenType],
List[GenType], GenType]
+class _AsyncPersisterContextManager:
+ """Wraps an async coroutine that returns a persister so it can be used
+ directly with ``async with``::
+
+ async with AsyncSQLitePersister.from_values(...) as persister:
+ ...
+
+ The wrapper awaits the coroutine on ``__aenter__`` and delegates
+ ``__aexit__`` to the persister's own ``__aexit__``.
+
+ .. note::
+ Each instance wraps a single coroutine and can only be consumed once,
+ either via ``await`` or ``async with``. A second use will raise
+ ``RuntimeError``.
+ """
+
+ def __init__(self, coro: Coroutine[Any, Any, Any]):
+ self._coro = coro
+ self._persister = None
+ self._consumed = False
+
+ def __await__(self):
+ if self._consumed:
+ raise RuntimeError("This factory result has already been consumed")
+ self._consumed = True
+ return self._coro.__await__()
+
+ async def __aenter__(self):
+ if self._consumed:
+ raise RuntimeError("This factory result has already been consumed")
+ self._consumed = True
+ self._persister = await self._coro
+ return await self._persister.__aenter__()
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ if self._persister is None:
+ return False
+ return await self._persister.__aexit__(exc_type, exc_value, traceback)
+
+
async def asyncify_generator(
generator: SyncOrAsyncGenerator[GenType],
) -> AsyncGenerator[GenType, None]:
diff --git a/burr/integrations/persisters/b_aiosqlite.py
b/burr/integrations/persisters/b_aiosqlite.py
index 9ce3c4a5..a75eb682 100644
--- a/burr/integrations/persisters/b_aiosqlite.py
+++ b/burr/integrations/persisters/b_aiosqlite.py
@@ -21,6 +21,7 @@ from typing import Literal, Optional
import aiosqlite
+from burr.common.async_utils import _AsyncPersisterContextManager
from burr.common.types import BaseCopyable
from burr.core import State
from burr.core.persistence import AsyncBaseStatePersister, PersistedStateData
@@ -60,27 +61,41 @@ class AsyncSQLitePersister(AsyncBaseStatePersister,
BaseCopyable):
PARTITION_KEY_DEFAULT = ""
@classmethod
- async def from_config(cls, config: dict) -> "AsyncSQLitePersister":
+ def from_config(cls, config: dict) -> "_AsyncPersisterContextManager":
"""Creates a new instance of the AsyncSQLitePersister from a
configuration dictionary.
+ Can be used with ``await`` or as an async context manager::
+
+ persister = await AsyncSQLitePersister.from_config(config)
+ # or
+ async with AsyncSQLitePersister.from_config(config) as persister:
+ ...
+
The config key:value pair needed are:
db_path: str,
table_name: str,
serde_kwargs: dict,
connect_kwargs: dict,
"""
- return await cls.from_values(**config)
+ return cls.from_values(**config)
@classmethod
- async def from_values(
+ def from_values(
cls,
db_path: str,
table_name: str = "burr_state",
serde_kwargs: dict = None,
connect_kwargs: dict = None,
- ) -> "AsyncSQLitePersister":
+ ) -> "_AsyncPersisterContextManager":
"""Creates a new instance of the AsyncSQLitePersister from passed in
values.
+ Can be used with ``await`` or as an async context manager::
+
+ persister = await
AsyncSQLitePersister.from_values(db_path="test.db")
+ # or
+ async with AsyncSQLitePersister.from_values(db_path="test.db") as
persister:
+ ...
+
:param db_path: the path the DB will be stored.
:param table_name: the table name to store things under.
:param serde_kwargs: kwargs for state serialization/deserialization.
@@ -88,10 +103,14 @@ class AsyncSQLitePersister(AsyncBaseStatePersister,
BaseCopyable):
:return: async sqlite persister instance with an open connection. You
are responsible
for closing the connection yourself.
"""
- connection = await aiosqlite.connect(
- db_path, **connect_kwargs if connect_kwargs is not None else {}
- )
- return cls(connection, table_name, serde_kwargs)
+
+ async def _create():
+ connection = await aiosqlite.connect(
+ db_path, **connect_kwargs if connect_kwargs is not None else {}
+ )
+ return cls(connection, table_name, serde_kwargs)
+
+ return _AsyncPersisterContextManager(_create())
def __init__(
self,
diff --git a/burr/integrations/persisters/b_asyncpg.py
b/burr/integrations/persisters/b_asyncpg.py
index 66f91f20..c694350d 100644
--- a/burr/integrations/persisters/b_asyncpg.py
+++ b/burr/integrations/persisters/b_asyncpg.py
@@ -19,6 +19,7 @@ import json
import logging
from typing import Any, ClassVar, Literal, Optional
+from burr.common.async_utils import _AsyncPersisterContextManager
from burr.common.types import BaseCopyable
from burr.core import persistence, state
from burr.integrations import base
@@ -106,12 +107,20 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister, BaseCopyable
return cls._pool
@classmethod
- async def from_config(cls, config: dict) -> "AsyncPostgreSQLPersister":
- """Creates a new instance of the PostgreSQLPersister from a
configuration dictionary."""
- return await cls.from_values(**config)
+ def from_config(cls, config: dict) -> "_AsyncPersisterContextManager":
+ """Creates a new instance of the PostgreSQLPersister from a
configuration dictionary.
+
+ Can be used with ``await`` or as an async context manager::
+
+ persister = await AsyncPostgreSQLPersister.from_config(config)
+ # or
+ async with AsyncPostgreSQLPersister.from_config(config) as
persister:
+ ...
+ """
+ return cls.from_values(**config)
@classmethod
- async def from_values(
+ def from_values(
cls,
db_name: str,
user: str,
@@ -121,9 +130,16 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister, BaseCopyable
table_name: str = "burr_state",
use_pool: bool = False,
**pool_kwargs,
- ) -> "AsyncPostgreSQLPersister":
+ ) -> "_AsyncPersisterContextManager":
"""Builds a new instance of the PostgreSQLPersister from the provided
values.
+ Can be used with ``await`` or as an async context manager::
+
+ persister = await AsyncPostgreSQLPersister.from_values(...)
+ # or
+ async with AsyncPostgreSQLPersister.from_values(...) as persister:
+ ...
+
:param db_name: the name of the PostgreSQL database.
:param user: the username to connect to the PostgreSQL database.
:param password: the password to connect to the PostgreSQL database.
@@ -133,22 +149,25 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister, BaseCopyable
:param use_pool: whether to use a connection pool (True) or a direct
connection (False)
:param pool_kwargs: additional kwargs to pass to the pool creation
"""
- if use_pool:
- pool = await cls.create_pool(
- user=user,
- password=password,
- database=db_name,
- host=host,
- port=port,
- **pool_kwargs,
- )
- return cls(connection=None, pool=pool, table_name=table_name)
- else:
- # Original behavior - direct connection
- connection = await asyncpg.connect(
- user=user, password=password, database=db_name, host=host,
port=port
- )
- return cls(connection=connection, table_name=table_name)
+
+ async def _create():
+ if use_pool:
+ pool = await cls.create_pool(
+ user=user,
+ password=password,
+ database=db_name,
+ host=host,
+ port=port,
+ **pool_kwargs,
+ )
+ return cls(connection=None, pool=pool, table_name=table_name)
+ else:
+ connection = await asyncpg.connect(
+ user=user, password=password, database=db_name, host=host,
port=port
+ )
+ return cls(connection=connection, table_name=table_name)
+
+ return _AsyncPersisterContextManager(_create())
def __init__(
self,
diff --git a/docs/concepts/parallelism.rst b/docs/concepts/parallelism.rst
index 0ced6bcc..c875c5f8 100644
--- a/docs/concepts/parallelism.rst
+++ b/docs/concepts/parallelism.rst
@@ -698,7 +698,7 @@ When using state persistence with async parallelism, make
sure to use the async
from burr.integrations.persisters.b_asyncpg import AsyncPGPersister
# Create an async persister with a connection pool
- persister = AsyncPGPersister.from_values(
+ persister = await AsyncPGPersister.from_values(
host="localhost",
port=5432,
user="postgres",
@@ -707,7 +707,7 @@ When using state persistence with async parallelism, make
sure to use the async
use_pool=True # Important for parallelism!
)
- app = (
+ app = await (
ApplicationBuilder()
.with_state_persister(persister)
.with_action(
@@ -722,12 +722,12 @@ Remember to properly clean up your async persisters when
you're done with them:
.. code-block:: python
- # Using as a context manager
+ # Using as a context manager (recommended)
async with AsyncPGPersister.from_values(..., use_pool=True) as persister:
# Use persister here
# Or manual cleanup
- persister = AsyncPGPersister.from_values(..., use_pool=True)
+ persister = await AsyncPGPersister.from_values(..., use_pool=True)
try:
# Use persister here
finally:
diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py
index b362cd96..18c65a0d 100644
--- a/tests/core/test_persistence.py
+++ b/tests/core/test_persistence.py
@@ -110,7 +110,12 @@ def
test_sqlite_persister_save_without_initialize_raises_runtime_error():
try:
with pytest.raises(RuntimeError, match="Uninitialized persister"):
persister.save(
- "partition_key", "app_id", 1, "position", State({"key":
"value"}), "completed"
+ "partition_key",
+ "app_id",
+ 1,
+ "position",
+ State({"key": "value"}),
+ "completed",
)
finally:
persister.cleanup()
@@ -168,17 +173,6 @@ from burr.integrations.persisters.b_aiosqlite import
AsyncSQLitePersister
"""Asyncio integration for sqlite persister + """
-class AsyncSQLiteContextManager:
- def __init__(self, sqlite_object):
- self.client = sqlite_object
-
- async def __aenter__(self):
- return self.client
-
- async def __aexit__(self, exc_type, exc, tb):
- await self.client.close()
-
-
@pytest.fixture()
async def async_persistence(request):
yield AsyncInMemoryPersister()
@@ -276,15 +270,15 @@ async def test_AsyncSQLitePersister_connection_shutdown():
@pytest.fixture()
async def initializing_async_persistence():
- sqlite_persister = await AsyncSQLitePersister.from_values(
+ async with AsyncSQLitePersister.from_values(
db_path=":memory:", table_name="test_table"
- )
- async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
- async with async_context_manager as client:
+ ) as client:
yield client
-async def
test_async_persistence_initialization_creates_table(initializing_async_persistence):
+async def test_async_persistence_initialization_creates_table(
+ initializing_async_persistence,
+):
await asyncio.sleep(0.00001)
await initializing_async_persistence.initialize()
assert await initializing_async_persistence.list_app_ids("partition_key")
== []
diff --git a/tests/integrations/persisters/test_b_aiosqlite.py
b/tests/integrations/persisters/test_b_aiosqlite.py
index 00c98677..adb97532 100644
--- a/tests/integrations/persisters/test_b_aiosqlite.py
+++ b/tests/integrations/persisters/test_b_aiosqlite.py
@@ -25,17 +25,6 @@ from burr.core import ApplicationBuilder, State, action
from burr.integrations.persisters.b_aiosqlite import AsyncSQLitePersister
-class AsyncSQLiteContextManager:
- def __init__(self, sqlite_object):
- self.client = sqlite_object
-
- async def __aenter__(self):
- return self.client
-
- async def __aexit__(self, exc_type, exc, tb):
- await self.client.cleanup()
-
-
async def test_copy_persister(async_persistence: AsyncSQLitePersister):
copy = async_persistence.copy()
assert copy.table_name == async_persistence.table_name
@@ -45,11 +34,9 @@ async def test_copy_persister(async_persistence:
AsyncSQLitePersister):
@pytest.fixture()
async def async_persistence(request):
- sqlite_persister = await AsyncSQLitePersister.from_values(
+ async with AsyncSQLitePersister.from_values(
db_path=":memory:", table_name="test_table"
- )
- async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
- async with async_context_manager as client:
+ ) as client:
yield client
@@ -118,6 +105,50 @@ async def test_async_persister_methods_none_partition_key(
# these operations are stateful (i.e., read/write to a db)
+async def test_async_sqlite_from_values_as_context_manager(tmp_path):
+ """Test that from_values works directly with async with (issue #546)."""
+ db_path = str(tmp_path / "test.db")
+ async with AsyncSQLitePersister.from_values(db_path=db_path) as persister:
+ await persister.initialize()
+ await persister.save("pk", "app1", 1, "pos", State({"k": "v"}),
"completed")
+ loaded = await persister.load("pk", "app1")
+ assert loaded is not None
+ assert loaded["state"] == State({"k": "v"})
+
+
+async def test_async_sqlite_from_config_as_context_manager(tmp_path):
+ """Test that from_config works directly with async with (issue #546)."""
+ db_path = str(tmp_path / "test.db")
+ config = {"db_path": db_path, "table_name": "burr_state"}
+ async with AsyncSQLitePersister.from_config(config) as persister:
+ await persister.initialize()
+ await persister.save("pk", "app1", 1, "pos", State({"k": "v"}),
"completed")
+ loaded = await persister.load("pk", "app1")
+ assert loaded is not None
+
+
+async def test_async_sqlite_from_values_cannot_be_consumed_twice():
+ """Test that the factory wrapper raises on double consumption."""
+ wrapper = AsyncSQLitePersister.from_values(db_path=":memory:")
+ persister = await wrapper
+ with pytest.raises(RuntimeError, match="already been consumed"):
+ await wrapper
+ await persister.cleanup()
+
+
+async def
test_async_sqlite_context_manager_aexit_safe_on_failed_aenter(tmp_path):
+ """Test that __aexit__ doesn't crash if __aenter__ never completed."""
+ from burr.common.async_utils import _AsyncPersisterContextManager
+
+ async def _failing_create():
+ raise ConnectionError("simulated connection failure")
+
+ mgr = _AsyncPersisterContextManager(_failing_create())
+ with pytest.raises(ConnectionError, match="simulated connection failure"):
+ async with mgr:
+ pass # should never reach here
+
+
async def test_AsyncSQLitePersister_from_values():
await asyncio.sleep(0.00001)
connection = await aiosqlite.connect(":memory:")
@@ -145,11 +176,9 @@ async def test_AsyncSQLitePersister_connection_shutdown():
@pytest.fixture()
async def initializing_async_persistence():
- sqlite_persister = await AsyncSQLitePersister.from_values(
+ async with AsyncSQLitePersister.from_values(
db_path=":memory:", table_name="test_table"
- )
- async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
- async with async_context_manager as client:
+ ) as client:
yield client