This is an automated email from the ASF dual-hosted git repository.
ebenizzy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/burr.git
The following commit(s) were added to refs/heads/main by this push:
new ac782937 Pool based async persister (#535)
ac782937 is described below
commit ac7829370172f681f9a648bbf6fe196b1e52bf8e
Author: gamarin2 <[email protected]>
AuthorDate: Mon Jul 21 06:09:13 2025 +0200
Pool based async persister (#535)
* pool based asyncpg persister
* docs
* remove ruff formatting to avoid diff
* fix docstring
* hook fixes
---------
Co-authored-by: Gautier MARIN <[email protected]>
---
burr/core/parallelism.py | 2 +-
burr/integrations/persisters/b_asyncpg.py | 322 +++++++++++++++++++++---------
docs/concepts/actions.rst | 16 ++
docs/concepts/parallelism.rst | 102 +++++++++-
docs/concepts/sync-vs-async.rst | 25 +++
5 files changed, 367 insertions(+), 100 deletions(-)
diff --git a/burr/core/parallelism.py b/burr/core/parallelism.py
index 91c8f30a..25c9e6f2 100644
--- a/burr/core/parallelism.py
+++ b/burr/core/parallelism.py
@@ -875,4 +875,4 @@ def map_reduce_action(
"""Experimental API for creating a map-reduce action easily. We'll be
improving this."""
return PassThroughMapActionsAndStates(
action=action, state=state, reducer=reducer, reads=reads,
writes=writes, inputs=inputs
- )
+ )
\ No newline at end of file
diff --git a/burr/integrations/persisters/b_asyncpg.py
b/burr/integrations/persisters/b_asyncpg.py
index d6f788f4..592d6c29 100644
--- a/burr/integrations/persisters/b_asyncpg.py
+++ b/burr/integrations/persisters/b_asyncpg.py
@@ -1,20 +1,27 @@
+import json
+import logging
+from typing import Literal, Optional, ClassVar
+from typing import Any
+from burr.common.types import BaseCopyable
+from burr.core import persistence, state
from burr.integrations import base
+
try:
import asyncpg
except ImportError as e:
base.require_plugin(e, "asyncpg")
-import json
-import logging
-from typing import Literal, Optional
+try:
+ from typing import Self
+except ImportError:
+ Self = Any
-from burr.core import persistence, state
logger = logging.getLogger(__name__)
-class AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
+class AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister,
BaseCopyable):
"""Class for async PostgreSQL persistence of state.
.. warning::
@@ -24,6 +31,11 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
We suggest to use the persister either as a context manager through
the ``async with`` clause or
using the method ``.cleanup()``.
+ .. warning::
+ If you intend to use parallelism features or need to share this
persister across multiple tasks,
+ you should initialize it with a connection pool (set ``use_pool=True``
in ``from_values``).
+ Direct connections cannot be shared across different tasks and may
cause errors in concurrent scenarios.
+
.. note::
The implementation relies on the popular asyncpg library:
https://github.com/MagicStack/asyncpg
@@ -35,10 +47,10 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
.. code:: bash
- docker run --name local-psql \ # container name
- -v local_psql_data:/SOME/FILE_PATH/ \ # mounting a volume
for data persistence
- -p 54320:5432 \ # port mapping
- -e POSTGRES_PASSWORD=my_password \ # superuser password
+ docker run --name local-psql \\\\ # container name
+ -v local_psql_data:/SOME/FILE_PATH/ \\\\ # mounting a
volume for data persistence
+ -p 54320:5432 \\\\ # port mapping
+ -e POSTGRES_PASSWORD=my_password \\\\ # superuser password
-d postgres # database name
Then you should be able to create the class like this:
@@ -48,11 +60,35 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
p = await AsyncPostgreSQLPersister.from_values("postgres", "postgres",
"my_password",
"localhost", 54320,
table_name="burr_state")
-
"""
PARTITION_KEY_DEFAULT = ""
+ # Class variable to hold the connection pool
+ _pool: ClassVar[Optional[asyncpg.Pool]] = None
+
+ @classmethod
+ async def create_pool(
+ cls,
+ user: str,
+ password: str,
+ database: str,
+ host: str,
+ port: int,
+ **pool_kwargs,
+ ) -> asyncpg.Pool:
+ """Creates a connection pool that can be shared across persisters."""
+ if cls._pool is None:
+ cls._pool = await asyncpg.create_pool(
+ user=user,
+ password=password,
+ database=database,
+ host=host,
+ port=port,
+ **pool_kwargs,
+ )
+ return cls._pool
+
@classmethod
async def from_config(cls, config: dict) -> "AsyncPostgreSQLPersister":
"""Creates a new instance of the PostgreSQLPersister from a
configuration dictionary."""
@@ -67,6 +103,8 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
host: str,
port: int,
table_name: str = "burr_state",
+ use_pool: bool = False,
+ **pool_kwargs,
) -> "AsyncPostgreSQLPersister":
"""Builds a new instance of the PostgreSQLPersister from the provided
values.
@@ -76,55 +114,119 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
:param host: the host of the PostgreSQL database.
:param port: the port of the PostgreSQL database.
:param table_name: the table name to store things under.
+ :param use_pool: whether to use a connection pool (True) or a direct
connection (False)
+ :param pool_kwargs: additional kwargs to pass to the pool creation
"""
- connection = await asyncpg.connect(
- user=user, password=password, database=db_name, host=host,
port=port
- )
- return cls(connection, table_name)
+ if use_pool:
+ pool = await cls.create_pool(
+ user=user,
+ password=password,
+ database=db_name,
+ host=host,
+ port=port,
+ **pool_kwargs,
+ )
+ return cls(connection=None, pool=pool, table_name=table_name)
+ else:
+ # Original behavior - direct connection
+ connection = await asyncpg.connect(
+ user=user, password=password, database=db_name, host=host,
port=port
+ )
+ return cls(connection=connection, table_name=table_name)
- def __init__(self, connection, table_name: str = "burr_state",
serde_kwargs: dict = None):
+ def __init__(
+ self,
+ connection=None,
+ pool=None,
+ table_name: str = "burr_state",
+ serde_kwargs: dict = None,
+ ):
"""Constructor
- :param connection: the connection to the PostgreSQL database.
+ :param connection: the connection to the PostgreSQL database (optional
if pool is provided)
+ :param pool: a connection pool to use instead of a direct connection
(optional if connection is provided)
:param table_name: the table name to store things under.
+ :param serde_kwargs: kwargs for state serialization/deserialization
"""
+ if connection is None and pool is None:
+ raise ValueError("Either connection or pool must be provided")
+
self.table_name = table_name
self.connection = connection
+ self.pool = pool
self.serde_kwargs = serde_kwargs or {}
self._initialized = False
+ def copy(self) -> "Self":
+ """Creates a copy of this persister.
+
+ If using a pool, returns a new persister that will acquire its own
connection from the pool.
+ If using a direct connection, just returns a new persister with the
same connection (won't work for async parallelism)
+ """
+ if self.pool is not None:
+ return AsyncPostgreSQLPersister(
+ connection=None,
+ pool=self.pool,
+ table_name=self.table_name,
+ serde_kwargs=self.serde_kwargs,
+ )
+ else:
+ return AsyncPostgreSQLPersister(
+ connection=self.connection,
+ table_name=self.table_name,
+ serde_kwargs=self.serde_kwargs,
+ )
+
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
- await self.connection.close()
+ await self.cleanup()
return False
+ async def _get_connection(self):
+ """Gets a connection - either the dedicated one or one from the
pool."""
+ if self.pool is not None:
+ return await self.pool.acquire(), True
+ elif self.connection is not None:
+ return self.connection, False
+ else:
+ raise ValueError("No connection or pool available")
+
+ async def _release_connection(self, connection, acquired):
+ """Releases a connection back to the pool if it was acquired."""
+ if acquired and self.pool is not None:
+ await self.pool.release(connection)
+
def set_serde_kwargs(self, serde_kwargs: dict):
"""Sets the serde_kwargs for the persister."""
self.serde_kwargs = serde_kwargs
async def create_table(self, table_name: str):
"""Helper function to create the table where things are stored."""
- async with self.connection.transaction():
- await self.connection.execute(
- f"""
- CREATE TABLE IF NOT EXISTS {table_name} (
- partition_key TEXT DEFAULT '{self.PARTITION_KEY_DEFAULT}',
- app_id TEXT NOT NULL,
- sequence_id INTEGER NOT NULL,
- position TEXT NOT NULL,
- status TEXT NOT NULL,
- state JSONB NOT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- PRIMARY KEY (partition_key, app_id, sequence_id, position)
- )"""
- )
- await self.connection.execute(
- f"""
- CREATE INDEX IF NOT EXISTS {table_name}_created_at_index ON
{table_name} (created_at);
- """
- )
+ conn, acquired = await self._get_connection()
+ try:
+ async with conn.transaction():
+ await conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {table_name} (
+ partition_key TEXT DEFAULT
'{self.PARTITION_KEY_DEFAULT}',
+ app_id TEXT NOT NULL,
+ sequence_id INTEGER NOT NULL,
+ position TEXT NOT NULL,
+ status TEXT NOT NULL,
+ state JSONB NOT NULL,
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ PRIMARY KEY (partition_key, app_id, sequence_id,
position)
+ )"""
+ )
+ await conn.execute(
+ f"""
+ CREATE INDEX IF NOT EXISTS {table_name}_created_at_index
ON {table_name} (created_at);
+ """
+ )
+ finally:
+ await self._release_connection(conn, acquired)
async def initialize(self):
"""Creates the table"""
@@ -139,23 +241,35 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
if self._initialized:
return True
- query = "SELECT EXISTS (SELECT FROM information_schema.tables WHERE
table_name = $1)"
- self._initialized = await self.connection.fetchval(query,
self.table_name, column=0)
- return self._initialized
+ conn, acquired = await self._get_connection()
+ try:
+ query = "SELECT EXISTS (SELECT FROM information_schema.tables
WHERE table_name = $1)"
+ self._initialized = await conn.fetchval(query, self.table_name,
column=0)
+ return self._initialized
+ finally:
+ await self._release_connection(conn, acquired)
async def list_app_ids(self, partition_key: str, **kwargs) -> list[str]:
"""Lists the app_ids for a given partition_key."""
- query = (
- f"SELECT DISTINCT app_id, created_at FROM {self.table_name} "
- "WHERE partition_key = $1 "
- "ORDER BY created_at DESC"
- )
- fetched_data = await self.connection.fetch(query, partition_key)
- app_ids = [row[0] for row in fetched_data]
- return app_ids
+ conn, acquired = await self._get_connection()
+ try:
+ query = (
+ f"SELECT DISTINCT app_id, created_at FROM {self.table_name} "
+ "WHERE partition_key = $1 "
+ "ORDER BY created_at DESC"
+ )
+ fetched_data = await conn.fetch(query, partition_key)
+ app_ids = [row[0] for row in fetched_data]
+ return app_ids
+ finally:
+ await self._release_connection(conn, acquired)
async def load(
- self, partition_key: Optional[str], app_id: str, sequence_id: int =
None, **kwargs
+ self,
+ partition_key: Optional[str],
+ app_id: str,
+ sequence_id: int = None,
+ **kwargs,
) -> Optional[persistence.PersistedStateData]:
"""Loads state for a given partition id.
@@ -171,47 +285,53 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
partition_key = self.PARTITION_KEY_DEFAULT
logger.debug("Loading %s, %s, %s", partition_key, app_id, sequence_id)
- if app_id is None:
- # get latest for all app_ids
- query = (
- f"SELECT position, state, sequence_id, app_id, created_at,
status FROM {self.table_name} "
- "WHERE partition_key = $1 "
- f"ORDER BY CREATED_AT DESC LIMIT 1"
- )
- row = await self.connection.fetchrow(query, partition_key)
-
- elif sequence_id is None:
- query = (
- f"SELECT position, state, sequence_id, app_id, created_at,
status FROM {self.table_name} "
- "WHERE partition_key = $1 AND app_id = $2 "
- f"ORDER BY sequence_id DESC LIMIT 1"
- )
- row = await self.connection.fetchrow(query, partition_key, app_id)
- else:
- query = (
- f"SELECT position, state, sequence_id, app_id, created_at,
status FROM {self.table_name} "
- "WHERE partition_key = $1 AND app_id = $2 AND sequence_id = $3
"
- )
- row = await self.connection.fetchrow(
- query,
- partition_key,
- app_id,
- sequence_id,
- )
- if row is None:
- return None
- # converts from asyncpg str to dict
- json_row = json.loads(row[1])
- _state = state.State.deserialize(json_row, **self.serde_kwargs)
- return {
- "partition_key": partition_key,
- "app_id": row[3],
- "sequence_id": row[2],
- "position": row[0],
- "state": _state,
- "created_at": row[4],
- "status": row[5],
- }
+ conn, acquired = await self._get_connection()
+ try:
+ row = None
+ if app_id is None:
+ # get latest for all app_ids
+ query = (
+ f"SELECT position, state, sequence_id, app_id, created_at,
status FROM {self.table_name} "
+ "WHERE partition_key = $1 "
+ f"ORDER BY CREATED_AT DESC LIMIT 1"
+ )
+ row = await conn.fetchrow(query, partition_key)
+ elif sequence_id is None:
+ query = (
+ f"SELECT position, state, sequence_id, app_id, created_at,
status FROM {self.table_name} "
+ "WHERE partition_key = $1 AND app_id = $2 "
+ f"ORDER BY sequence_id DESC LIMIT 1"
+ )
+ row = await conn.fetchrow(query, partition_key, app_id)
+ else:
+ query = (
+ f"SELECT position, state, sequence_id, app_id, created_at,
status FROM {self.table_name} "
+ "WHERE partition_key = $1 AND app_id = $2 AND sequence_id
= $3 "
+ )
+ row = await conn.fetchrow(
+ query,
+ partition_key,
+ app_id,
+ sequence_id,
+ )
+
+ if row is None:
+ return None
+
+ # converts from asyncpg str to dict
+ json_row = json.loads(row[1])
+ _state = state.State.deserialize(json_row, **self.serde_kwargs)
+ return {
+ "partition_key": partition_key,
+ "app_id": row[3],
+ "sequence_id": row[2],
+ "position": row[0],
+ "state": _state,
+ "created_at": row[4],
+ "status": row[5],
+ }
+ finally:
+ await self._release_connection(conn, acquired)
async def save(
self,
@@ -250,15 +370,21 @@ class
AsyncPostgreSQLPersister(persistence.AsyncBaseStatePersister):
status,
)
- json_state = json.dumps(state.serialize(**self.serde_kwargs))
- query = (
- f"INSERT INTO {self.table_name} (partition_key, app_id,
sequence_id, position, state, status) "
- "VALUES ($1, $2, $3, $4, $5, $6)"
- )
- await self.connection.execute(
- query, partition_key, app_id, sequence_id, position, json_state,
status
- )
+ conn, acquired = await self._get_connection()
+ try:
+ json_state = json.dumps(state.serialize(**self.serde_kwargs))
+ query = (
+ f"INSERT INTO {self.table_name} (partition_key, app_id,
sequence_id, position, state, status) "
+ "VALUES ($1, $2, $3, $4, $5, $6)"
+ )
+ await conn.execute(
+ query, partition_key, app_id, sequence_id, position,
json_state, status
+ )
+ finally:
+ await self._release_connection(conn, acquired)
async def cleanup(self):
"""Closes the connection to the database."""
- await self.connection.close()
+ if self.connection is not None:
+ await self.connection.close()
+ self.connection = None
diff --git a/docs/concepts/actions.rst b/docs/concepts/actions.rst
index 541cea83..2316fae3 100644
--- a/docs/concepts/actions.rst
+++ b/docs/concepts/actions.rst
@@ -15,6 +15,22 @@ Actions do the heavy-lifting in a workflow. They should
contain all complex comp
either through a class-based or function-based API. If actions implement
``async def run`` then will be run in an
asynchronous context (and thus require one of the async application functions).
+.. note::
+ When implementing asynchronous actions with ``async def run``, you must
also override the ``is_async`` method
+ to return ``True``. This tells the framework to execute the action in an
asynchronous context:
+
+ .. code-block:: python
+
+ class AsyncAction(Action):
+ @property
+ def is_async(self) -> bool:
+ return True
+
+ async def run(self, state: State) -> dict:
+ # Async implementation
+ ...
+
+
Actions have two primary responsibilities:
1. ``run`` -- compute a result
diff --git a/docs/concepts/parallelism.rst b/docs/concepts/parallelism.rst
index ba7fa455..b878921f 100644
--- a/docs/concepts/parallelism.rst
+++ b/docs/concepts/parallelism.rst
@@ -617,6 +617,106 @@ To do this, you would:
3. Join them in parallel, waiting for any user-input if provided
4. Decide after every step of the first graph whether you want to cancel the
second graph or not -- E.G. is the user satisfied.
+
+Async Parallelism
+================
+
+Burr also supports asynchronous parallelism. When working in an async context,
you need to make a few adjustments to your parallel actions:
+
+1. Make your methods async
+--------------------------
+
+The `action`, `states`, `reduce`, and other methods should be defined as async:
+
+.. code-block:: python
+
+ class AsyncMapActionsAndStatesExample(MapActionsAndStates):
+
+ async def action(self, state: State, inputs: Dict[str, Any]) ->
AsyncGenerator[Action, None]:
+ # Yield multiple model components to run in parallel
+ for i, model_config in enumerate(self._model_configs):
+ yield
ModelResponse(config=model_config).with_name(f"model_{i}")
+
+ async def states(self, state: State, inputs: Dict[str, Any]) ->
AsyncGenerator[State, None]:
+ # Prepare the state with the user query
+ for prompt in [
+ "What is the meaning of life?",
+ "What is the airspeed velocity of an unladen swallow?",
+ "What is the best way to cook a steak?",
+ ]:
+ yield state.update(prompt=prompt)
+
+ async def reduce(self, state: State, states: AsyncGenerator[State,
None]) -> State:
+ # Collect all model responses
+ all_responses = []
+ async for sub_state in states:
+ model_key = sub_state.get("model_key")
+ response = sub_state.get(model_key, [])[-1].get("content", "")
+ all_responses.append(response)
+
+ return state.update(ensemble_responses=all_responses)
+
+2. Implement the is_async method
+-------------------------------
+
+You must override the `is_async` method to return `True`:
+
+.. code-block:: python
+
+ class AsyncMapActionsAndStatesExample(MapActionsAndStates):
+
+ @property
+ def is_async(self) -> bool:
+ return True
+
+ # ... other methods ...
+
+3. Use async persisters with connection pools
+--------------------------------------------
+
+When using state persistence with async parallelism, make sure to use the
async version of persisters and initialize them with a connection pool:
+
+.. code-block:: python
+
+ from burr.integrations.persisters.b_asyncpg import AsyncPGPersister
+
+ # Create an async persister with a connection pool
+ persister = AsyncPGPersister.from_values(
+ host="localhost",
+ port=5432,
+ user="postgres",
+ password="postgres",
+ database="burr",
+ use_pool=True # Important for parallelism!
+ )
+
+ app = (
+ ApplicationBuilder()
+ .with_state_persister(persister)
+ .with_action(
+ async_parallel_action=AsyncMapActionsAndStatesExample(),
+ )
+ .abuild()
+ )
+
+Connection pools are crucial for handling concurrent operations. Direct
connections cannot be shared across different tasks and may cause errors in
concurrent scenarios.
+
+Remember to properly clean up your async persisters when you're done with them:
+
+.. code-block:: python
+
+ # Using as a context manager
+ async with AsyncPGPersister.from_values(..., use_pool=True) as persister:
+ # Use persister here
+
+ # Or manual cleanup
+ persister = AsyncPGPersister.from_values(..., use_pool=True)
+ try:
+ # Use persister here
+ finally:
+ await persister.cleanup()
+
+
Notes
=====
@@ -631,4 +731,4 @@ Things that may change:
1. We will likely alter the executor API to be more flexible, although we will
probably allow for use of the current executor API
2. We will be adding guard-rails for generator-types (sync versus async)
3. The UI is a WIP -- we have more sophisticated capabilities but are still
polishing them
-4. Support for action-level executors
+4. Support for action-level executors
\ No newline at end of file
diff --git a/docs/concepts/sync-vs-async.rst b/docs/concepts/sync-vs-async.rst
index c24e37c5..4058eb4f 100644
--- a/docs/concepts/sync-vs-async.rst
+++ b/docs/concepts/sync-vs-async.rst
@@ -21,6 +21,31 @@ Burr gives you the ability to write synchronous (standard
python) and asynchrono
* :py:meth:`.run() <.Application.run()>`
* :py:meth:`.stream_result() <.Application.stream_result()>`
+Checklist for Async Applications
+-------------------------------
+
+When building asynchronous applications with Burr, ensure you:
+
+1. **Use async action implementations**:
+ * Implement ``async def run`` methods in your actions
+ * Override the ``is_async`` property to return ``True`` in all async
class-based actions
+ * Use ``await`` for all I/O operations inside your actions
+
+2. **Use async builder and application methods**:
+ * Use ``.abuild()`` instead of ``.build()``
+ * Use ``.arun()``, ``.aiterate()``, and ``.astream_result()`` instead of
their sync counterparts
+
+3. **Use async hooks and persisters**:
+ * Implement async hooks with ``async def`` methods
+ * Use async persisters (e.g., ``AsyncPGPersister`` instead of
``PGPersister``)
+ * Properly clean up async resources using context managers or explicit
cleanup calls
+
+4. **For parallel actions**:
+ * Make ``actions``, ``states``, and ``reduce`` methods async
+ * Override ``is_async`` to return ``True``
+ * Use ``AsyncGenerator`` return types
+ * Use async persisters with connection pools
+
Comparison
----------