diff --git a/burr/common/async_utils.py b/burr/common/async_utils.py index b1f60881f..56ce75ff2 100644 --- a/burr/common/async_utils.py +++ b/burr/common/async_utils.py @@ -16,7 +16,7 @@ # under the License. import inspect -from typing import AsyncGenerator, AsyncIterable, Generator, List, TypeVar, Union +from typing import Any, AsyncGenerator, AsyncIterable, Coroutine, Generator, List, TypeVar, Union T = TypeVar("T") @@ -27,6 +27,46 @@ SyncOrAsyncGeneratorOrItemOrList = Union[SyncOrAsyncGenerator[GenType], List[GenType], GenType] +class _AsyncPersisterContextManager: + """Wraps an async coroutine that returns a persister so it can be used + directly with ``async with``:: + + async with AsyncSQLitePersister.from_values(...) as persister: + ... + + The wrapper awaits the coroutine on ``__aenter__`` and delegates + ``__aexit__`` to the persister's own ``__aexit__``. + + .. note:: + Each instance wraps a single coroutine and can only be consumed once, + either via ``await`` or ``async with``. A second use will raise + ``RuntimeError``. + """ + + def __init__(self, coro: Coroutine[Any, Any, Any]): + self._coro = coro + self._persister = None + self._consumed = False + + def __await__(self): + if self._consumed: + raise RuntimeError("This factory result has already been consumed") + self._consumed = True + return self._coro.__await__() + + async def __aenter__(self): + if self._consumed: + raise RuntimeError("This factory result has already been consumed") + self._consumed = True + self._persister = await self._coro + return await self._persister.__aenter__() + + async def __aexit__(self, exc_type, exc_value, traceback): + if self._persister is None: + return False + return await self._persister.__aexit__(exc_type, exc_value, traceback) + + async def asyncify_generator( generator: SyncOrAsyncGenerator[GenType], ) -> AsyncGenerator[GenType, None]: diff --git a/burr/integrations/persisters/b_aiosqlite.py b/burr/integrations/persisters/b_aiosqlite.py index 9ce3c4a5d..a75eb682c 100644 --- a/burr/integrations/persisters/b_aiosqlite.py +++ b/burr/integrations/persisters/b_aiosqlite.py @@ -21,6 +21,7 @@ import aiosqlite +from burr.common.async_utils import _AsyncPersisterContextManager from burr.common.types import BaseCopyable from burr.core import State from burr.core.persistence import AsyncBaseStatePersister, PersistedStateData @@ -60,27 +61,41 @@ def copy(self) -> "Self": PARTITION_KEY_DEFAULT = "" @classmethod - async def from_config(cls, config: dict) -> "AsyncSQLitePersister": + def from_config(cls, config: dict) -> "_AsyncPersisterContextManager": """Creates a new instance of the AsyncSQLitePersister from a configuration dictionary. + Can be used with ``await`` or as an async context manager:: + + persister = await AsyncSQLitePersister.from_config(config) + # or + async with AsyncSQLitePersister.from_config(config) as persister: + ... + The config key:value pair needed are: db_path: str, table_name: str, serde_kwargs: dict, connect_kwargs: dict, """ - return await cls.from_values(**config) + return cls.from_values(**config) @classmethod - async def from_values( + def from_values( cls, db_path: str, table_name: str = "burr_state", serde_kwargs: dict = None, connect_kwargs: dict = None, - ) -> "AsyncSQLitePersister": + ) -> "_AsyncPersisterContextManager": """Creates a new instance of the AsyncSQLitePersister from passed in values. + Can be used with ``await`` or as an async context manager:: + + persister = await AsyncSQLitePersister.from_values(db_path="test.db") + # or + async with AsyncSQLitePersister.from_values(db_path="test.db") as persister: + ... + :param db_path: the path the DB will be stored. :param table_name: the table name to store things under. :param serde_kwargs: kwargs for state serialization/deserialization. @@ -88,10 +103,14 @@ async def from_values( :return: async sqlite persister instance with an open connection. You are responsible for closing the connection yourself. """ - connection = await aiosqlite.connect( - db_path, **connect_kwargs if connect_kwargs is not None else {} - ) - return cls(connection, table_name, serde_kwargs) + + async def _create(): + connection = await aiosqlite.connect( + db_path, **connect_kwargs if connect_kwargs is not None else {} + ) + return cls(connection, table_name, serde_kwargs) + + return _AsyncPersisterContextManager(_create()) def __init__( self, diff --git a/burr/integrations/persisters/b_asyncpg.py b/burr/integrations/persisters/b_asyncpg.py index 66f91f206..48cab6435 100644 --- a/burr/integrations/persisters/b_asyncpg.py +++ b/burr/integrations/persisters/b_asyncpg.py @@ -22,6 +22,7 @@ from burr.common.types import BaseCopyable from burr.core import persistence, state from burr.integrations import base +from burr.common.async_utils import _AsyncPersisterContextManager try: import asyncpg @@ -106,12 +107,20 @@ async def create_pool( return cls._pool @classmethod - async def from_config(cls, config: dict) -> "AsyncPostgreSQLPersister": - """Creates a new instance of the PostgreSQLPersister from a configuration dictionary.""" - return await cls.from_values(**config) + def from_config(cls, config: dict) -> "_AsyncPersisterContextManager": + """Creates a new instance of the PostgreSQLPersister from a configuration dictionary. + + Can be used with ``await`` or as an async context manager:: + + persister = await AsyncPostgreSQLPersister.from_config(config) + # or + async with AsyncPostgreSQLPersister.from_config(config) as persister: + ... + """ + return cls.from_values(**config) @classmethod - async def from_values( + def from_values( cls, db_name: str, user: str, @@ -121,9 +130,16 @@ async def from_values( table_name: str = "burr_state", use_pool: bool = False, **pool_kwargs, - ) -> "AsyncPostgreSQLPersister": + ) -> "_AsyncPersisterContextManager": """Builds a new instance of the PostgreSQLPersister from the provided values. + Can be used with ``await`` or as an async context manager:: + + persister = await AsyncPostgreSQLPersister.from_values(...) + # or + async with AsyncPostgreSQLPersister.from_values(...) as persister: + ... + :param db_name: the name of the PostgreSQL database. :param user: the username to connect to the PostgreSQL database. :param password: the password to connect to the PostgreSQL database. @@ -133,22 +149,25 @@ async def from_values( :param use_pool: whether to use a connection pool (True) or a direct connection (False) :param pool_kwargs: additional kwargs to pass to the pool creation """ - if use_pool: - pool = await cls.create_pool( - user=user, - password=password, - database=db_name, - host=host, - port=port, - **pool_kwargs, - ) - return cls(connection=None, pool=pool, table_name=table_name) - else: - # Original behavior - direct connection - connection = await asyncpg.connect( - user=user, password=password, database=db_name, host=host, port=port - ) - return cls(connection=connection, table_name=table_name) + + async def _create(): + if use_pool: + pool = await cls.create_pool( + user=user, + password=password, + database=db_name, + host=host, + port=port, + **pool_kwargs, + ) + return cls(connection=None, pool=pool, table_name=table_name) + else: + connection = await asyncpg.connect( + user=user, password=password, database=db_name, host=host, port=port + ) + return cls(connection=connection, table_name=table_name) + + return _AsyncPersisterContextManager(_create()) def __init__( self, diff --git a/docs/concepts/parallelism.rst b/docs/concepts/parallelism.rst index 0ced6bcc7..c875c5f83 100644 --- a/docs/concepts/parallelism.rst +++ b/docs/concepts/parallelism.rst @@ -698,7 +698,7 @@ When using state persistence with async parallelism, make sure to use the async from burr.integrations.persisters.b_asyncpg import AsyncPGPersister # Create an async persister with a connection pool - persister = AsyncPGPersister.from_values( + persister = await AsyncPGPersister.from_values( host="localhost", port=5432, user="postgres", @@ -707,7 +707,7 @@ When using state persistence with async parallelism, make sure to use the async use_pool=True # Important for parallelism! ) - app = ( + app = await ( ApplicationBuilder() .with_state_persister(persister) .with_action( @@ -722,12 +722,12 @@ Remember to properly clean up your async persisters when you're done with them: .. code-block:: python - # Using as a context manager + # Using as a context manager (recommended) async with AsyncPGPersister.from_values(..., use_pool=True) as persister: # Use persister here # Or manual cleanup - persister = AsyncPGPersister.from_values(..., use_pool=True) + persister = await AsyncPGPersister.from_values(..., use_pool=True) try: # Use persister here finally: diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py index b362cd96d..deadb42a8 100644 --- a/tests/core/test_persistence.py +++ b/tests/core/test_persistence.py @@ -168,15 +168,6 @@ def test_persister_methods_none_partition_key(persistence, method_name: str, kwa """Asyncio integration for sqlite persister + """ -class AsyncSQLiteContextManager: - def __init__(self, sqlite_object): - self.client = sqlite_object - - async def __aenter__(self): - return self.client - - async def __aexit__(self, exc_type, exc, tb): - await self.client.close() @pytest.fixture() @@ -276,11 +267,9 @@ async def test_AsyncSQLitePersister_connection_shutdown(): @pytest.fixture() async def initializing_async_persistence(): - sqlite_persister = await AsyncSQLitePersister.from_values( + async with AsyncSQLitePersister.from_values( db_path=":memory:", table_name="test_table" - ) - async_context_manager = AsyncSQLiteContextManager(sqlite_persister) - async with async_context_manager as client: + ) as client: yield client diff --git a/tests/integrations/persisters/test_b_aiosqlite.py b/tests/integrations/persisters/test_b_aiosqlite.py index 00c98677a..adb97532c 100644 --- a/tests/integrations/persisters/test_b_aiosqlite.py +++ b/tests/integrations/persisters/test_b_aiosqlite.py @@ -25,17 +25,6 @@ from burr.integrations.persisters.b_aiosqlite import AsyncSQLitePersister -class AsyncSQLiteContextManager: - def __init__(self, sqlite_object): - self.client = sqlite_object - - async def __aenter__(self): - return self.client - - async def __aexit__(self, exc_type, exc, tb): - await self.client.cleanup() - - async def test_copy_persister(async_persistence: AsyncSQLitePersister): copy = async_persistence.copy() assert copy.table_name == async_persistence.table_name @@ -45,11 +34,9 @@ async def test_copy_persister(async_persistence: AsyncSQLitePersister): @pytest.fixture() async def async_persistence(request): - sqlite_persister = await AsyncSQLitePersister.from_values( + async with AsyncSQLitePersister.from_values( db_path=":memory:", table_name="test_table" - ) - async_context_manager = AsyncSQLiteContextManager(sqlite_persister) - async with async_context_manager as client: + ) as client: yield client @@ -118,6 +105,50 @@ async def test_async_persister_methods_none_partition_key( # these operations are stateful (i.e., read/write to a db) +async def test_async_sqlite_from_values_as_context_manager(tmp_path): + """Test that from_values works directly with async with (issue #546).""" + db_path = str(tmp_path / "test.db") + async with AsyncSQLitePersister.from_values(db_path=db_path) as persister: + await persister.initialize() + await persister.save("pk", "app1", 1, "pos", State({"k": "v"}), "completed") + loaded = await persister.load("pk", "app1") + assert loaded is not None + assert loaded["state"] == State({"k": "v"}) + + +async def test_async_sqlite_from_config_as_context_manager(tmp_path): + """Test that from_config works directly with async with (issue #546).""" + db_path = str(tmp_path / "test.db") + config = {"db_path": db_path, "table_name": "burr_state"} + async with AsyncSQLitePersister.from_config(config) as persister: + await persister.initialize() + await persister.save("pk", "app1", 1, "pos", State({"k": "v"}), "completed") + loaded = await persister.load("pk", "app1") + assert loaded is not None + + +async def test_async_sqlite_from_values_cannot_be_consumed_twice(): + """Test that the factory wrapper raises on double consumption.""" + wrapper = AsyncSQLitePersister.from_values(db_path=":memory:") + persister = await wrapper + with pytest.raises(RuntimeError, match="already been consumed"): + await wrapper + await persister.cleanup() + + +async def test_async_sqlite_context_manager_aexit_safe_on_failed_aenter(tmp_path): + """Test that __aexit__ doesn't crash if __aenter__ never completed.""" + from burr.common.async_utils import _AsyncPersisterContextManager + + async def _failing_create(): + raise ConnectionError("simulated connection failure") + + mgr = _AsyncPersisterContextManager(_failing_create()) + with pytest.raises(ConnectionError, match="simulated connection failure"): + async with mgr: + pass # should never reach here + + async def test_AsyncSQLitePersister_from_values(): await asyncio.sleep(0.00001) connection = await aiosqlite.connect(":memory:") @@ -145,11 +176,9 @@ async def test_AsyncSQLitePersister_connection_shutdown(): @pytest.fixture() async def initializing_async_persistence(): - sqlite_persister = await AsyncSQLitePersister.from_values( + async with AsyncSQLitePersister.from_values( db_path=":memory:", table_name="test_table" - ) - async_context_manager = AsyncSQLiteContextManager(sqlite_persister) - async with async_context_manager as client: + ) as client: yield client