This is an automated email from the ASF dual-hosted git repository.

skrawcz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git


The following commit(s) were added to refs/heads/main by this push:
     new e6f03974 fix: support `async with` on async persister factory methods 
(#681)
e6f03974 is described below

commit e6f03974a5db50f1d790610b5ee5020b913c03df
Author: AndrĂ© Ahlert <[email protected]>
AuthorDate: Sat Mar 28 16:13:24 2026 -0300

    fix: support `async with` on async persister factory methods (#681)
    
    * fix: support `async with` on async persister factory methods
    
    `AsyncSQLitePersister.from_values()` and 
`AsyncPostgreSQLPersister.from_values()`
    were async classmethods returning coroutines, which cannot be used directly
    with `async with`. This wraps them in `_AsyncPersisterContextManager` that
    supports both `await` (backwards compatible) and `async with` protocols.
    
    Closes #546
    
    * refactor: move _AsyncPersisterContextManager to 
burr/common/async_utils.py and add tests
    
    Address review feedback:
    - Move _AsyncPersisterContextManager from b_aiosqlite.py to
      burr/common/async_utils.py to avoid cross-dependency between
      unrelated integrations
    - Add type annotation to coro parameter
    - Add tests for async with pattern on from_values and from_config
    
    * fix: guard against __aexit__ crash and double consumption in context 
manager wrapper
    
    - __aexit__ now returns False when __aenter__ failed (persister is None),
      preventing AttributeError that would mask the original exception
    - Add _consumed flag to prevent silent coroutine reuse, raising
      RuntimeError with clear message on second await/async with
    - Add tests for both edge cases
    
    * docs: fix async persister examples and remove redundant test helpers
    
    - Add missing `await` to from_values() calls in parallelism.rst docs
    - Remove AsyncSQLiteContextManager helper class from both test files,
      now that from_values() natively supports async with
    - Replace deprecated .close() calls with .cleanup() in test fixtures
    
    * style: fix black formatting and isort import order
    
    * style: reformat with black 23.11.0 (project version)
    
    * ci: trigger workflow rerun
---
 burr/common/async_utils.py                        | 42 +++++++++++++-
 burr/integrations/persisters/b_aiosqlite.py       | 35 +++++++++---
 burr/integrations/persisters/b_asyncpg.py         | 61 ++++++++++++++-------
 docs/concepts/parallelism.rst                     |  8 +--
 tests/core/test_persistence.py                    | 28 ++++------
 tests/integrations/persisters/test_b_aiosqlite.py | 67 ++++++++++++++++-------
 6 files changed, 171 insertions(+), 70 deletions(-)

diff --git a/burr/common/async_utils.py b/burr/common/async_utils.py
index b1f60881..56ce75ff 100644
--- a/burr/common/async_utils.py
+++ b/burr/common/async_utils.py
@@ -16,7 +16,7 @@
 # under the License.
 
 import inspect
-from typing import AsyncGenerator, AsyncIterable, Generator, List, TypeVar, 
Union
+from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Generator, 
List, TypeVar, Union
 
 T = TypeVar("T")
 
@@ -27,6 +27,46 @@ SyncOrAsyncGenerator = Union[Generator[GenType, None, None], 
AsyncGenerator[GenT
 SyncOrAsyncGeneratorOrItemOrList = Union[SyncOrAsyncGenerator[GenType], 
List[GenType], GenType]
 
 
+class _AsyncPersisterContextManager:
+    """Wraps an async coroutine that returns a persister so it can be used
+    directly with ``async with``::
+
+        async with AsyncSQLitePersister.from_values(...) as persister:
+            ...
+
+    The wrapper awaits the coroutine on ``__aenter__`` and delegates
+    ``__aexit__`` to the persister's own ``__aexit__``.
+
+    .. note::
+        Each instance wraps a single coroutine and can only be consumed once,
+        either via ``await`` or ``async with``. A second use will raise
+        ``RuntimeError``.
+    """
+
+    def __init__(self, coro: Coroutine[Any, Any, Any]):
+        self._coro = coro
+        self._persister = None
+        self._consumed = False
+
+    def __await__(self):
+        if self._consumed:
+            raise RuntimeError("This factory result has already been consumed")
+        self._consumed = True
+        return self._coro.__await__()
+
+    async def __aenter__(self):
+        if self._consumed:
+            raise RuntimeError("This factory result has already been consumed")
+        self._consumed = True
+        self._persister = await self._coro
+        return await self._persister.__aenter__()
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        if self._persister is None:
+            return False
+        return await self._persister.__aexit__(exc_type, exc_value, traceback)
+
+
 async def asyncify_generator(
     generator: SyncOrAsyncGenerator[GenType],
 ) -> AsyncGenerator[GenType, None]:
diff --git a/burr/integrations/persisters/b_aiosqlite.py 
b/burr/integrations/persisters/b_aiosqlite.py
index 9ce3c4a5..a75eb682 100644
--- a/burr/integrations/persisters/b_aiosqlite.py
+++ b/burr/integrations/persisters/b_aiosqlite.py
@@ -21,6 +21,7 @@ from typing import Literal, Optional
 
 import aiosqlite
 
+from burr.common.async_utils import _AsyncPersisterContextManager
 from burr.common.types import BaseCopyable
 from burr.core import State
 from burr.core.persistence import AsyncBaseStatePersister, PersistedStateData
@@ -60,27 +61,41 @@ class AsyncSQLitePersister(AsyncBaseStatePersister, 
BaseCopyable):
     PARTITION_KEY_DEFAULT = ""
 
     @classmethod
-    async def from_config(cls, config: dict) -> "AsyncSQLitePersister":
+    def from_config(cls, config: dict) -> "_AsyncPersisterContextManager":
         """Creates a new instance of the AsyncSQLitePersister from a 
configuration dictionary.
 
+        Can be used with ``await`` or as an async context manager::
+
+            persister = await AsyncSQLitePersister.from_config(config)
+            # or
+            async with AsyncSQLitePersister.from_config(config) as persister:
+                ...
+
         The config key:value pair needed are:
         db_path: str,
         table_name: str,
         serde_kwargs: dict,
         connect_kwargs: dict,
         """
-        return await cls.from_values(**config)
+        return cls.from_values(**config)
 
     @classmethod
-    async def from_values(
+    def from_values(
         cls,
         db_path: str,
         table_name: str = "burr_state",
         serde_kwargs: dict = None,
         connect_kwargs: dict = None,
-    ) -> "AsyncSQLitePersister":
+    ) -> "_AsyncPersisterContextManager":
         """Creates a new instance of the AsyncSQLitePersister from passed in 
values.
 
+        Can be used with ``await`` or as an async context manager::
+
+            persister = await 
AsyncSQLitePersister.from_values(db_path="test.db")
+            # or
+            async with AsyncSQLitePersister.from_values(db_path="test.db") as 
persister:
+                ...
+
         :param db_path: the path the DB will be stored.
         :param table_name: the table name to store things under.
         :param serde_kwargs: kwargs for state serialization/deserialization.
@@ -88,10 +103,14 @@ class AsyncSQLitePersister(AsyncBaseStatePersister, 
BaseCopyable):
         :return: async sqlite persister instance with an open connection. You 
are responsible
             for closing the connection yourself.
         """
-        connection = await aiosqlite.connect(
-            db_path, **connect_kwargs if connect_kwargs is not None else {}
-        )
-        return cls(connection, table_name, serde_kwargs)
+
+        async def _create():
+            connection = await aiosqlite.connect(
+                db_path, **connect_kwargs if connect_kwargs is not None else {}
+            )
+            return cls(connection, table_name, serde_kwargs)
+
+        return _AsyncPersisterContextManager(_create())
 
     def __init__(
         self,
diff --git a/burr/integrations/persisters/b_asyncpg.py 
b/burr/integrations/persisters/b_asyncpg.py
index 66f91f20..c694350d 100644
--- a/burr/integrations/persisters/b_asyncpg.py
+++ b/burr/integrations/persisters/b_asyncpg.py
@@ -19,6 +19,7 @@ import json
 import logging
 from typing import Any, ClassVar, Literal, Optional
 
+from burr.common.async_utils import _AsyncPersisterContextManager
 from burr.common.types import BaseCopyable
 from burr.core import persistence, state
 from burr.integrations import base
@@ -106,12 +107,20 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister, BaseCopyable
         return cls._pool
 
     @classmethod
-    async def from_config(cls, config: dict) -> "AsyncPostgreSQLPersister":
-        """Creates a new instance of the PostgreSQLPersister from a 
configuration dictionary."""
-        return await cls.from_values(**config)
+    def from_config(cls, config: dict) -> "_AsyncPersisterContextManager":
+        """Creates a new instance of the PostgreSQLPersister from a 
configuration dictionary.
+
+        Can be used with ``await`` or as an async context manager::
+
+            persister = await AsyncPostgreSQLPersister.from_config(config)
+            # or
+            async with AsyncPostgreSQLPersister.from_config(config) as 
persister:
+                ...
+        """
+        return cls.from_values(**config)
 
     @classmethod
-    async def from_values(
+    def from_values(
         cls,
         db_name: str,
         user: str,
@@ -121,9 +130,16 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister, BaseCopyable
         table_name: str = "burr_state",
         use_pool: bool = False,
         **pool_kwargs,
-    ) -> "AsyncPostgreSQLPersister":
+    ) -> "_AsyncPersisterContextManager":
         """Builds a new instance of the PostgreSQLPersister from the provided 
values.
 
+        Can be used with ``await`` or as an async context manager::
+
+            persister = await AsyncPostgreSQLPersister.from_values(...)
+            # or
+            async with AsyncPostgreSQLPersister.from_values(...) as persister:
+                ...
+
         :param db_name: the name of the PostgreSQL database.
         :param user: the username to connect to the PostgreSQL database.
         :param password: the password to connect to the PostgreSQL database.
@@ -133,22 +149,25 @@ class 
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister, BaseCopyable
         :param use_pool: whether to use a connection pool (True) or a direct 
connection (False)
         :param pool_kwargs: additional kwargs to pass to the pool creation
         """
-        if use_pool:
-            pool = await cls.create_pool(
-                user=user,
-                password=password,
-                database=db_name,
-                host=host,
-                port=port,
-                **pool_kwargs,
-            )
-            return cls(connection=None, pool=pool, table_name=table_name)
-        else:
-            # Original behavior - direct connection
-            connection = await asyncpg.connect(
-                user=user, password=password, database=db_name, host=host, 
port=port
-            )
-            return cls(connection=connection, table_name=table_name)
+
+        async def _create():
+            if use_pool:
+                pool = await cls.create_pool(
+                    user=user,
+                    password=password,
+                    database=db_name,
+                    host=host,
+                    port=port,
+                    **pool_kwargs,
+                )
+                return cls(connection=None, pool=pool, table_name=table_name)
+            else:
+                connection = await asyncpg.connect(
+                    user=user, password=password, database=db_name, host=host, 
port=port
+                )
+                return cls(connection=connection, table_name=table_name)
+
+        return _AsyncPersisterContextManager(_create())
 
     def __init__(
         self,
diff --git a/docs/concepts/parallelism.rst b/docs/concepts/parallelism.rst
index 0ced6bcc..c875c5f8 100644
--- a/docs/concepts/parallelism.rst
+++ b/docs/concepts/parallelism.rst
@@ -698,7 +698,7 @@ When using state persistence with async parallelism, make 
sure to use the async
     from burr.integrations.persisters.b_asyncpg import AsyncPGPersister
 
     # Create an async persister with a connection pool
-    persister = AsyncPGPersister.from_values(
+    persister = await AsyncPGPersister.from_values(
         host="localhost",
         port=5432,
         user="postgres",
@@ -707,7 +707,7 @@ When using state persistence with async parallelism, make 
sure to use the async
         use_pool=True  # Important for parallelism!
     )
 
-    app = (
+    app = await (
         ApplicationBuilder()
         .with_state_persister(persister)
         .with_action(
@@ -722,12 +722,12 @@ Remember to properly clean up your async persisters when 
you're done with them:
 
 .. code-block:: python
 
-    # Using as a context manager
+    # Using as a context manager (recommended)
     async with AsyncPGPersister.from_values(..., use_pool=True) as persister:
         # Use persister here
 
     # Or manual cleanup
-    persister = AsyncPGPersister.from_values(..., use_pool=True)
+    persister = await AsyncPGPersister.from_values(..., use_pool=True)
     try:
         # Use persister here
     finally:
diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py
index b362cd96..18c65a0d 100644
--- a/tests/core/test_persistence.py
+++ b/tests/core/test_persistence.py
@@ -110,7 +110,12 @@ def 
test_sqlite_persister_save_without_initialize_raises_runtime_error():
     try:
         with pytest.raises(RuntimeError, match="Uninitialized persister"):
             persister.save(
-                "partition_key", "app_id", 1, "position", State({"key": 
"value"}), "completed"
+                "partition_key",
+                "app_id",
+                1,
+                "position",
+                State({"key": "value"}),
+                "completed",
             )
     finally:
         persister.cleanup()
@@ -168,17 +173,6 @@ from burr.integrations.persisters.b_aiosqlite import 
AsyncSQLitePersister
 """Asyncio integration for sqlite persister + """
 
 
-class AsyncSQLiteContextManager:
-    def __init__(self, sqlite_object):
-        self.client = sqlite_object
-
-    async def __aenter__(self):
-        return self.client
-
-    async def __aexit__(self, exc_type, exc, tb):
-        await self.client.close()
-
-
 @pytest.fixture()
 async def async_persistence(request):
     yield AsyncInMemoryPersister()
@@ -276,15 +270,15 @@ async def test_AsyncSQLitePersister_connection_shutdown():
 
 @pytest.fixture()
 async def initializing_async_persistence():
-    sqlite_persister = await AsyncSQLitePersister.from_values(
+    async with AsyncSQLitePersister.from_values(
         db_path=":memory:", table_name="test_table"
-    )
-    async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
-    async with async_context_manager as client:
+    ) as client:
         yield client
 
 
-async def 
test_async_persistence_initialization_creates_table(initializing_async_persistence):
+async def test_async_persistence_initialization_creates_table(
+    initializing_async_persistence,
+):
     await asyncio.sleep(0.00001)
     await initializing_async_persistence.initialize()
     assert await initializing_async_persistence.list_app_ids("partition_key") 
== []
diff --git a/tests/integrations/persisters/test_b_aiosqlite.py 
b/tests/integrations/persisters/test_b_aiosqlite.py
index 00c98677..adb97532 100644
--- a/tests/integrations/persisters/test_b_aiosqlite.py
+++ b/tests/integrations/persisters/test_b_aiosqlite.py
@@ -25,17 +25,6 @@ from burr.core import ApplicationBuilder, State, action
 from burr.integrations.persisters.b_aiosqlite import AsyncSQLitePersister
 
 
-class AsyncSQLiteContextManager:
-    def __init__(self, sqlite_object):
-        self.client = sqlite_object
-
-    async def __aenter__(self):
-        return self.client
-
-    async def __aexit__(self, exc_type, exc, tb):
-        await self.client.cleanup()
-
-
 async def test_copy_persister(async_persistence: AsyncSQLitePersister):
     copy = async_persistence.copy()
     assert copy.table_name == async_persistence.table_name
@@ -45,11 +34,9 @@ async def test_copy_persister(async_persistence: 
AsyncSQLitePersister):
 
 @pytest.fixture()
 async def async_persistence(request):
-    sqlite_persister = await AsyncSQLitePersister.from_values(
+    async with AsyncSQLitePersister.from_values(
         db_path=":memory:", table_name="test_table"
-    )
-    async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
-    async with async_context_manager as client:
+    ) as client:
         yield client
 
 
@@ -118,6 +105,50 @@ async def test_async_persister_methods_none_partition_key(
     # these operations are stateful (i.e., read/write to a db)
 
 
+async def test_async_sqlite_from_values_as_context_manager(tmp_path):
+    """Test that from_values works directly with async with (issue #546)."""
+    db_path = str(tmp_path / "test.db")
+    async with AsyncSQLitePersister.from_values(db_path=db_path) as persister:
+        await persister.initialize()
+        await persister.save("pk", "app1", 1, "pos", State({"k": "v"}), 
"completed")
+        loaded = await persister.load("pk", "app1")
+        assert loaded is not None
+        assert loaded["state"] == State({"k": "v"})
+
+
+async def test_async_sqlite_from_config_as_context_manager(tmp_path):
+    """Test that from_config works directly with async with (issue #546)."""
+    db_path = str(tmp_path / "test.db")
+    config = {"db_path": db_path, "table_name": "burr_state"}
+    async with AsyncSQLitePersister.from_config(config) as persister:
+        await persister.initialize()
+        await persister.save("pk", "app1", 1, "pos", State({"k": "v"}), 
"completed")
+        loaded = await persister.load("pk", "app1")
+        assert loaded is not None
+
+
+async def test_async_sqlite_from_values_cannot_be_consumed_twice():
+    """Test that the factory wrapper raises on double consumption."""
+    wrapper = AsyncSQLitePersister.from_values(db_path=":memory:")
+    persister = await wrapper
+    with pytest.raises(RuntimeError, match="already been consumed"):
+        await wrapper
+    await persister.cleanup()
+
+
+async def 
test_async_sqlite_context_manager_aexit_safe_on_failed_aenter(tmp_path):
+    """Test that __aexit__ doesn't crash if __aenter__ never completed."""
+    from burr.common.async_utils import _AsyncPersisterContextManager
+
+    async def _failing_create():
+        raise ConnectionError("simulated connection failure")
+
+    mgr = _AsyncPersisterContextManager(_failing_create())
+    with pytest.raises(ConnectionError, match="simulated connection failure"):
+        async with mgr:
+            pass  # should never reach here
+
+
 async def test_AsyncSQLitePersister_from_values():
     await asyncio.sleep(0.00001)
     connection = await aiosqlite.connect(":memory:")
@@ -145,11 +176,9 @@ async def test_AsyncSQLitePersister_connection_shutdown():
 
 @pytest.fixture()
 async def initializing_async_persistence():
-    sqlite_persister = await AsyncSQLitePersister.from_values(
+    async with AsyncSQLitePersister.from_values(
         db_path=":memory:", table_name="test_table"
-    )
-    async_context_manager = AsyncSQLiteContextManager(sqlite_persister)
-    async with async_context_manager as client:
+    ) as client:
         yield client
 
 

Reply via email to