diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..7c403cc --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,125 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + # ----------------------------------------------------------------------- + # Unit tests — no external services (fakeredis + SQLite) + # ----------------------------------------------------------------------- + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run unit tests + run: | + uv run pytest tests/ \ + --ignore=tests/test_kafka_integration.py \ + -o "addopts=" \ + -v --tb=long + + # ----------------------------------------------------------------------- + # Kafka integration tests — real broker via docker run + # ----------------------------------------------------------------------- + test-kafka: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Start Kafka broker + run: | + docker run -d --name kafka \ + -p 9092:9092 \ + -e KAFKA_NODE_ID=1 \ + -e KAFKA_PROCESS_ROLES=broker,controller \ + -e KAFKA_CONTROLLER_QUORUM_VOTERS=1@localhost:9093 \ + -e KAFKA_CONTROLLER_LISTENER_NAMES=CONTROLLER \ + -e KAFKA_LISTENERS=PLAINTEXT://:9092,CONTROLLER://:9093 \ + -e KAFKA_ADVERTISED_LISTENERS=PLAINTEXT://localhost:9092 \ + -e KAFKA_LISTENER_SECURITY_PROTOCOL_MAP=PLAINTEXT:PLAINTEXT,CONTROLLER:PLAINTEXT \ + -e KAFKA_INTER_BROKER_LISTENER_NAME=PLAINTEXT \ + -e KAFKA_LOG_CLEANER_MIN_COMPACTION_LAG_MS=0 \ + -e KAFKA_LOG_CLEANER_MIN_CLEANABLE_RATIO=0.01 \ + -e KAFKA_LOG_RETENTION_MS=60000 \ + -e KAFKA_NUM_PARTITIONS=1 \ + -e KAFKA_AUTO_CREATE_TOPICS_ENABLE=true \ + -e KAFKA_GROUP_INITIAL_REBALANCE_DELAY_MS=0 \ + -e KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR=1 \ + -e CLUSTER_ID=ciTestCluster0001 \ + apache/kafka:3.9.0 + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Set up Python + run: uv python install 3.12 + + - name: Install dependencies + run: uv sync --dev --extra kafka + + - name: Wait for Kafka to be ready + run: | + echo "Waiting for Kafka..." + for i in $(seq 1 30); do + if nc -z localhost 9092 2>/dev/null; then + echo "Kafka port is open" + sleep 5 + echo "Kafka is ready" + exit 0 + fi + echo " attempt $i/30..." + sleep 2 + done + echo "Kafka failed to start" + docker logs kafka + exit 1 + + - name: Run Kafka integration tests + timeout-minutes: 2 + run: | + uv run pytest tests/test_kafka_integration.py \ + -o "addopts=" \ + -v --tb=long 2>&1 | tee /tmp/kafka_test_output.txt + exit ${PIPESTATUS[0]} + env: + AGENTEXEC_STATE_BACKEND: agentexec.state.kafka + KAFKA_BOOTSTRAP_SERVERS: localhost:9092 + AGENTEXEC_KAFKA_DEFAULT_PARTITIONS: "2" + AGENTEXEC_KAFKA_REPLICATION_FACTOR: "1" + + - name: Show Kafka logs on failure + if: failure() + run: docker logs kafka 2>&1 | tail -50 + + - name: Create failure check annotation with output + if: failure() + run: | + if [ -f /tmp/kafka_test_output.txt ]; then + grep -E '\[queue_|FAILED|ERROR|AssertionError|TIMEOUT|short test summary' /tmp/kafka_test_output.txt | tail -9 | while IFS= read -r line; do + echo "::warning::$line" + done + fi diff --git a/README.md b/README.md index f54ca38..985a685 100644 --- a/README.md +++ b/README.md @@ -122,8 +122,8 @@ async def start_research(company: str) -> dict: return {"agent_id": str(task.agent_id), "status": "queued"} # Return agent_id for status polling @router.get("/research/{agent_id}") -def get_status(agent_id: UUID, db: Session = Depends(get_db)) -> ax.activity.ActivityDetailSchema: - return ax.activity.detail(db, agent_id=agent_id) # Query by agent_id +async def get_status(agent_id: UUID) -> ax.activity.ActivityDetailSchema: + return await ax.activity.detail(agent_id=agent_id) ``` ### 4. Run Workers @@ -150,8 +150,8 @@ task = await ax.enqueue( ) # Filter activities by metadata -activities = ax.activity.list(db, metadata_filter={"organization_id": "org-123"}) -detail = ax.activity.detail(db, agent_id, metadata_filter={"organization_id": "org-123"}) +activities = await ax.activity.list(metadata_filter={"organization_id": "org-123"}) +detail = await ax.activity.detail(agent_id=agent_id, metadata_filter={"organization_id": "org-123"}) # Access metadata programmatically (excluded from API serialization by default) org_id = detail.metadata["organization_id"] @@ -186,7 +186,7 @@ agent = Agent( Update progress explicitly from your task: ```python -ax.activity.update(agent_id, "Processing batch 3 of 10", percentage=30) +await ax.activity.update(agent_id, "Processing batch 3 of 10", percentage=30) ``` ### Task Locking @@ -202,11 +202,9 @@ async def associate(agent_id: UUID, context: ObservationContext): pool.add_task("associate_observation", handler, lock_key="user:{user_id}") ``` -The `lock_key` is a string template evaluated against the task context fields. When a worker dequeues a task whose lock is held, it puts the task back at the end of the queue and moves on. The lock is released automatically when the task completes or errors. +The `lock_key` is a string template evaluated against the task context fields. Tasks with the same evaluated lock key are routed to a dedicated partition queue (`{prefix}:{lock_key}`) where they execute one at a time. Workers skip locked partitions and move on to the next available one — no requeuing, no wasted cycles. -The lock TTL (`AGENTEXEC_LOCK_TTL`, default 1800s) is a safety net for worker process death — locks are always explicitly released on task completion or error. Set this higher than your longest expected task duration. - -**Note:** When a task is requeued due to a held lock, it goes to the back of the queue. This means strict FIFO ordering is not guaranteed between tasks sharing the same lock key — if tasks T2 and T3 are both waiting on T1's lock, either could run next after T1 completes. +The lock is released automatically when a task completes or errors. The lock TTL (`AGENTEXEC_LOCK_TTL`, default 1800s) is a safety net for worker process death (OOM, SIGKILL) — under normal operation, locks are always explicitly released. Set this higher than your longest expected task duration. ### Scheduled Tasks @@ -366,8 +364,7 @@ if __name__ == "__main__": try: pool.run() except KeyboardInterrupt: - with Session(engine) as db: - ax.activity.cancel_pending(db) + asyncio.run(ax.activity.cancel_pending()) ``` ### Docker Deployment @@ -396,11 +393,10 @@ import agentexec as ax engine = create_engine(os.environ["DATABASE_URL"]) pool = ax.Pool(engine=engine) -def cleanup() -> None: - with Session(engine) as db: - ax.activity.cancel_pending(db) +async def cleanup() -> None: + await ax.activity.cancel_pending() -atexit.register(cleanup) +atexit.register(lambda: asyncio.run(cleanup())) @pool.task("my_task") async def my_task(agent_id: UUID, context: MyContext) -> None: @@ -421,11 +417,13 @@ docker run -e DATABASE_URL=... -e REDIS_URL=... -e OPENAI_API_KEY=... my-worker ## Backend Architecture -### Redis +### Redis (Default) + +agentexec uses Redis for task queuing, result storage, and coordination between workers. The queue uses a partitioned design where tasks with a `lock_key` go to dedicated partition queues (`{prefix}:{lock_key}`) and are serialized by a lock, while tasks without a lock key go to the default queue for concurrent processing. -agentexec uses Redis for task queuing, result storage, real-time log streaming, and coordination between workers. We chose Redis because it provides exactly the primitives we need (lists, pubsub, atomic counters) with minimal operational overhead. +Workers dequeue using Redis `SCAN`, which iterates keys in hash-table order — effectively random. This provides fair distribution across partitions without explicit round-robin. See `examples/queue-fairness/` for benchmarks showing uniform distribution at 1000+ partitions. -**AWS Compatible:** Since we use standard Redis features, AWS ElastiCache works out of the box. +**AWS Compatible:** Standard Redis features only — AWS ElastiCache works out of the box. ```bash AGENTEXEC_REDIS_URL=redis://localhost:6379/0 @@ -433,18 +431,45 @@ AGENTEXEC_REDIS_URL=redis://localhost:6379/0 AGENTEXEC_REDIS_URL=redis://my-cluster.abc123.use1.cache.amazonaws.com:6379 ``` +### Kafka (Experimental) + +Kafka can be used as an alternative backend for task queuing and schedule storage. Activity tracking always uses PostgreSQL regardless of backend — Kafka is not a KV store, so state operations (`get`/`set`, counters) are not supported and will raise `NotImplementedError`. + +```bash +pip install agentexec[kafka] + +AGENTEXEC_STATE_BACKEND=agentexec.state.kafka +KAFKA_BOOTSTRAP_SERVERS=localhost:9092 +``` + +Kafka uses consumer groups for work distribution instead of Redis's scan-based dequeue. Topics are auto-created on first use. Schedule storage uses a compacted topic that is replayed on each poll. + +**When to consider Kafka:** +- You already run Kafka and want to avoid adding Redis +- You need durable, replayable task queues with built-in replication +- You want partition-level ordering guarantees (tasks with the same key go to the same partition) + +**Limitations:** +- No KV state — `backend.state.get/set/delete` and counters raise `NotImplementedError` +- No partition-level locking (Kafka partition assignment handles isolation instead) +- Schedule `get_due()` replays the entire compacted topic on every poll +- `lock_key` is used as a Kafka partition key (routing), not as a mutex + +See [Kafka configuration](#kafka-settings) below for all available settings. + ### Extensible State Backend -The state backend is pluggable. We're adding support for additional backends (DynamoDB, PostgreSQL, in-memory for testing). You can also implement your own: +The state backend is pluggable. Implement `BaseBackend` with `state`, `queue`, and `schedule` sub-backends: ```bash -AGENTEXEC_STATE_BACKEND=agentexec.state.redis_backend # Default -AGENTEXEC_STATE_BACKEND=myapp.state.dynamodb_backend # Custom +AGENTEXEC_STATE_BACKEND=agentexec.state.redis # Default +AGENTEXEC_STATE_BACKEND=agentexec.state.kafka # Experimental +AGENTEXEC_STATE_BACKEND=myapp.state.custom # Custom (must export Backend class) ``` ### Database -Activity tracking uses SQLAlchemy with two tables: +Activity tracking uses SQLAlchemy with two tables (always PostgreSQL/SQLite, independent of the state backend): **`agentexec_activity`** - Main activity records - `agent_id` - Unique identifier (UUID) @@ -478,25 +503,23 @@ from agentexec.activity.schemas import ( **List activities:** ```python -with Session(engine) as db: - result = ax.activity.list(db, page=1, page_size=20) - # Returns ActivityListSchema: - # { - # "items": [...], # List of ActivityListItemSchema - # "total": 150, - # "page": 1, - # "page_size": 20, - # "total_pages": 8 - # } +result = await ax.activity.list(page=1, page_size=20) +# Returns ActivityListSchema: +# { +# "items": [...], # List of ActivityListItemSchema +# "total": 150, +# "page": 1, +# "page_size": 20, +# "total_pages": 8 +# } ``` **Get activity detail:** ```python -activity = ax.activity.detail(db, agent_id=agent_id) +activity = await ax.activity.detail(agent_id=agent_id) # Returns ActivityDetailSchema: # { -# "id": "...", # "agent_id": "...", # "agent_type": "research_company", # "created_at": "2024-01-15T10:30:00Z", @@ -512,7 +535,7 @@ activity = ax.activity.detail(db, agent_id=agent_id) **Count active agents:** ```python -count = ax.activity.active_count(db) +count = await ax.activity.count_active() # Returns number of agents with status QUEUED or RUNNING ``` @@ -527,13 +550,15 @@ from sqlalchemy.orm import Session import agentexec as ax def build_table(db: Session) -> Table: - table = Table(title=f"Active Agents: {ax.activity.active_count(db)}") + count = asyncio.run(ax.activity.count_active()) + table = Table(title=f"Active Agents: {count}") table.add_column("Status") table.add_column("Task") table.add_column("Message") table.add_column("Progress") - for item in ax.activity.list(db, page=1, page_size=10).items: + activities = asyncio.run(ax.activity.list(page=1, page_size=10)) + for item in activities.items: table.add_row( item.status, item.agent_type, @@ -647,7 +672,7 @@ async def scheduled(agent_id: UUID, context: MyContext) -> None: ... pool.add_schedule("name", "0 * * * *", MyContext(), repeat=3) # Schedule separately -pool.run() # Blocking - runs workers + scheduler +pool.run() # Blocking - runs workers + scheduler + retry handling pool.start() # Non-blocking - starts workers in background pool.shutdown() # Graceful shutdown ``` @@ -658,20 +683,20 @@ pool.shutdown() # Graceful shutdown import agentexec as ax # Create activity (returns agent_id for tracking) -agent_id = ax.activity.create(task_name, message="Starting...") +agent_id = await ax.activity.create(task_name, message="Starting...") # Update progress -ax.activity.update(agent_id, message, percentage=50) -ax.activity.complete(agent_id, message="Done") -ax.activity.error(agent_id, error="Failed: ...") +await ax.activity.update(agent_id, message, percentage=50) +await ax.activity.complete(agent_id, message="Done") +await ax.activity.error(agent_id, message="Failed: ...") -# Query activities -activities = ax.activity.list(db, page=1, page_size=20) -activity = ax.activity.detail(db, agent_id=agent_id) -count = ax.activity.active_count(db) +# Query activities (uses database session) +activities = await ax.activity.list(page=1, page_size=20) +activity = await ax.activity.detail(agent_id=agent_id) +count = await ax.activity.count_active() # Cleanup -canceled = ax.activity.cancel_pending(db) +canceled = await ax.activity.cancel_pending() ``` ### Runners @@ -733,13 +758,16 @@ ax.Base # SQLAlchemy declarative base for activity tables All settings via environment variables: ```bash -# Redis (required) -AGENTEXEC_REDIS_URL=redis://localhost:6379/0 +# Redis +AGENTEXEC_REDIS_URL=redis://localhost:6379/0 # Also accepts REDIS_URL +AGENTEXEC_REDIS_POOL_SIZE=10 +AGENTEXEC_REDIS_POOL_TIMEOUT=5 # Workers AGENTEXEC_NUM_WORKERS=4 -AGENTEXEC_QUEUE_NAME=agentexec_tasks +AGENTEXEC_QUEUE_PREFIX=agentexec_tasks # Also accepts AGENTEXEC_QUEUE_NAME AGENTEXEC_GRACEFUL_SHUTDOWN_TIMEOUT=300 +AGENTEXEC_MAX_TASK_RETRIES=3 # 0 to disable retries # Database AGENTEXEC_TABLE_PREFIX=agentexec_ @@ -747,14 +775,15 @@ AGENTEXEC_TABLE_PREFIX=agentexec_ # Results AGENTEXEC_RESULT_TTL=3600 -# Task locking +# Task locking (Redis backend only) AGENTEXEC_LOCK_TTL=1800 # Scheduling AGENTEXEC_SCHEDULER_TIMEZONE=UTC +AGENTEXEC_SCHEDULER_POLL_INTERVAL=10 # State backend -AGENTEXEC_STATE_BACKEND=agentexec.state.redis_backend +AGENTEXEC_STATE_BACKEND=agentexec.state.redis # or agentexec.state.kafka AGENTEXEC_KEY_PREFIX=agentexec # Activity messages (customizable) @@ -764,6 +793,21 @@ AGENTEXEC_ACTIVITY_MESSAGE_COMPLETE="Task completed successfully." AGENTEXEC_ACTIVITY_MESSAGE_ERROR="Task failed with error: {error}" ``` +### Kafka Settings + +These settings only apply when using the Kafka state backend (`AGENTEXEC_STATE_BACKEND=agentexec.state.kafka`): + +```bash +KAFKA_BOOTSTRAP_SERVERS=localhost:9092 # Also accepts AGENTEXEC_KAFKA_BOOTSTRAP_SERVERS +AGENTEXEC_KAFKA_DEFAULT_PARTITIONS=6 # Partitions for auto-created topics +AGENTEXEC_KAFKA_REPLICATION_FACTOR=1 # Replication factor for auto-created topics +AGENTEXEC_KAFKA_MAX_BATCH_SIZE=16384 # Producer max batch size (bytes) +AGENTEXEC_KAFKA_LINGER_MS=5 # Producer linger time (ms) +AGENTEXEC_KAFKA_RETENTION_MS=-1 # Retention for compacted topics (-1 = forever) +``` + +For single-node development, set `KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR=1` on your broker or consumer groups will hang. + --- ## Development @@ -812,4 +856,5 @@ MIT License - see [LICENSE](LICENSE) for details. - **Documentation**: [docs/](docs/) - **Example App**: [examples/openai-agents-fastapi/](examples/openai-agents-fastapi/) - **Multi-Tenancy Example**: [examples/multi-tenancy/](examples/multi-tenancy/) +- **Queue Fairness Benchmark**: [examples/queue-fairness/](examples/queue-fairness/) - **Issues**: [GitHub Issues](https://github.com/Agent-CI/agentexec/issues) diff --git a/docker-compose.kafka.yml b/docker-compose.kafka.yml new file mode 100644 index 0000000..0080d51 --- /dev/null +++ b/docker-compose.kafka.yml @@ -0,0 +1,48 @@ +# Kafka development environment for running integration tests locally. +# +# Usage: +# docker compose -f docker-compose.kafka.yml up -d +# +# KAFKA_BOOTSTRAP_SERVERS=localhost:9092 \ +# AGENTEXEC_STATE_BACKEND=agentexec.state.kafka \ +# uv run pytest tests/test_kafka_integration.py -v +# +# docker compose -f docker-compose.kafka.yml down +# +# Kafka UI available at http://localhost:8080 + +services: + kafka: + image: apache/kafka:3.9.0 + ports: + - "9092:9092" + environment: + KAFKA_NODE_ID: "1" + KAFKA_PROCESS_ROLES: broker,controller + KAFKA_CONTROLLER_QUORUM_VOTERS: 1@localhost:9093 + KAFKA_CONTROLLER_LISTENER_NAMES: CONTROLLER + KAFKA_LISTENERS: PLAINTEXT://:9092,CONTROLLER://:9093 + KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://localhost:9092 + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: PLAINTEXT:PLAINTEXT,CONTROLLER:PLAINTEXT + KAFKA_INTER_BROKER_LISTENER_NAME: PLAINTEXT + CLUSTER_ID: "agentexec-dev-cluster-01" + KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: "1" + KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR: "1" + KAFKA_TRANSACTION_STATE_LOG_MIN_ISR: "1" + healthcheck: + test: /opt/kafka/bin/kafka-topics.sh --bootstrap-server localhost:9092 --list + interval: 5s + timeout: 10s + retries: 15 + start_period: 15s + + kafka-ui: + image: provectuslabs/kafka-ui:latest + ports: + - "8080:8080" + environment: + KAFKA_CLUSTERS_0_NAME: agentexec + KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: kafka:9092 + depends_on: + kafka: + condition: service_healthy diff --git a/examples/queue-fairness/README.md b/examples/queue-fairness/README.md new file mode 100644 index 0000000..374efba --- /dev/null +++ b/examples/queue-fairness/README.md @@ -0,0 +1,75 @@ +# Queue Fairness Benchmark + +Validates that the scan-based partitioned queue distributes work fairly across both workers and partition keys. + +## Background + +agentexec uses Redis `SCAN` to iterate partition queues during dequeue. SCAN returns keys in hash-table order, which is effectively random — this gives us pseudo-random partition selection without explicit shuffling or round-robin bookkeeping. + +This benchmark measures two dimensions of fairness: + +- **Worker fairness**: Are tasks spread evenly across workers? +- **Partition fairness**: Are all partitions served at a similar pace, or do some starve while others get immediate attention? + +## Usage + +```bash +uv run python examples/queue-fairness/run.py +uv run python examples/queue-fairness/run.py --partitions 1000 --tasks-per-partition 10 --workers 8 +``` + +Requires a running Redis instance (`REDIS_URL` environment variable). + +## What it does + +1. Enqueues `partitions * tasks_per_partition` tasks, each routed to a named partition queue +2. Spawns N async workers that pop, simulate work, then release the partition lock via `complete()` +3. Records timing data for every task: which worker, which partition, wait time, pickup time +4. Reports fairness metrics at the end + +## Results + +At 1000 partitions, 10 tasks each (10,000 total), 8 workers: + +### Worker fairness + +Each worker processed between 1243 and 1257 tasks (ideal: 1250). Standard deviation of 5.2 across 8 workers — essentially uniform distribution. + +``` +Worker 0: 1257 tasks (12.6%) +Worker 1: 1249 tasks (12.5%) +Worker 2: 1248 tasks (12.5%) +Worker 3: 1257 tasks (12.6%) +Worker 4: 1246 tasks (12.5%) +Worker 5: 1243 tasks (12.4%) +Worker 6: 1247 tasks (12.5%) +Worker 7: 1253 tasks (12.5%) +``` + +### Partition fairness + +The "first-task pickup time" measures when each partition's first task gets served, relative to the start. A fair system serves all partitions at roughly the same pace — no partition should wait significantly longer than others for its first task. + +``` +First-task pickup time (seconds after start): + Mean: 15.606s + Median: 15.685s + Stdev: 9.030s + Min: 0.019s + Max: 31.103s +``` + +The median first pickup (15.7s) lands at almost exactly half the total runtime (31.6s), which is what you'd expect from a uniform distribution. No partitions were flagged as starved (first pickup > 2x the median). + +### Throughput + +Throughput held steady at ~317 tasks/sec across all partition counts tested (50, 200, 1000). SCAN-based dequeue does not degrade as the number of partitions grows. + +## Why it works + +Redis `SCAN` iterates the hash table in slot order, which is determined by the hash of each key. Since partition keys hash to different slots, the iteration order is effectively random and changes as keys are added or removed. This gives us: + +- **No hot spots**: No partition is systematically visited first or last +- **No coordination**: Workers don't need to agree on which partition to try next +- **Free rebalancing**: As partitions drain and their keys disappear, SCAN naturally skips them +- **Lock-aware skipping**: Locked partitions are skipped immediately, so workers don't block on busy partitions — they move on to the next available one diff --git a/examples/queue-fairness/run.py b/examples/queue-fairness/run.py new file mode 100644 index 0000000..99f4ca4 --- /dev/null +++ b/examples/queue-fairness/run.py @@ -0,0 +1,215 @@ +"""Queue fairness benchmark. + +Validates that tasks distributed across many partition queues get +roughly equal treatment under the scan-based dequeue strategy. + +Measures two dimensions of fairness: + - Worker fairness: are tasks spread evenly across workers? + - Partition fairness: are partitions served in a balanced order, + or do some starve while others get picked up immediately? + +Usage: + uv run python examples/queue-fairness/run.py + uv run python examples/queue-fairness/run.py --partitions 100 --tasks-per-partition 5 --workers 8 +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import statistics +import time +from uuid import uuid4 + +from pydantic import BaseModel + +import agentexec as ax +from agentexec.state import backend + + +class BenchContext(BaseModel): + partition_id: int + task_index: int + queued_at: float + + +async def enqueue_tasks(partitions: int, tasks_per_partition: int) -> int: + """Push tasks across N partitions with M tasks each.""" + total = 0 + for p in range(partitions): + partition_key = f"partition:{p}" + for t in range(tasks_per_partition): + task = ax.Task( + task_name="bench_task", + context={ + "partition_id": p, + "task_index": t, + "queued_at": time.time(), + }, + agent_id=uuid4(), + ) + await backend.queue.push( + task.model_dump_json(), + partition_key=partition_key, + ) + total += 1 + return total + + +async def worker( + worker_id: int, + results: list[dict], + stop_event: asyncio.Event, + work_duration: float, +): + """Simulated worker that pops tasks and records timing.""" + while not stop_event.is_set(): + data = await backend.queue.pop(timeout=1) + if data is None: + await asyncio.sleep(0.1) + continue + + picked_up_at = time.time() + context = data.get("context", {}) + queued_at = context.get("queued_at", picked_up_at) + wait_time = picked_up_at - queued_at + + results.append({ + "worker_id": worker_id, + "partition_id": context.get("partition_id"), + "task_index": context.get("task_index"), + "wait_time": wait_time, + "picked_up_at": picked_up_at, + }) + + # Simulate work + await asyncio.sleep(work_duration) + + # Release the partition lock + partition_key = f"partition:{context.get('partition_id')}" + await backend.queue.complete(partition_key) + + +async def run( + partitions: int, + tasks_per_partition: int, + num_workers: int, + work_duration: float, +): + total = partitions * tasks_per_partition + print(f"Enqueueing {partitions} partitions x {tasks_per_partition} tasks = {total} total") + enqueue_start = time.time() + await enqueue_tasks(partitions, tasks_per_partition) + print(f"Enqueued {total} tasks in {time.time() - enqueue_start:.1f}s") + + results: list[dict] = [] + stop_event = asyncio.Event() + + print(f"Starting {num_workers} workers (simulated work: {work_duration}s)") + start = time.time() + + workers = [ + asyncio.create_task(worker(i, results, stop_event, work_duration)) + for i in range(num_workers) + ] + + while len(results) < total: + await asyncio.sleep(0.5) + elapsed = time.time() - start + print(f" {len(results)}/{total} tasks processed ({elapsed:.1f}s)", end="\r") + + elapsed = time.time() - start + stop_event.set() + await asyncio.gather(*workers, return_exceptions=True) + + print(f"\n\nCompleted {len(results)} tasks in {elapsed:.1f}s") + print(f"Throughput: {len(results) / elapsed:.1f} tasks/sec") + + # --- Wait time analysis --- + all_waits = [r["wait_time"] for r in results] + print(f"\nWait time (enqueue → pickup):") + print(f" Mean: {statistics.mean(all_waits):.3f}s") + print(f" Median: {statistics.median(all_waits):.3f}s") + print(f" Stdev: {statistics.stdev(all_waits):.3f}s") + print(f" Min: {min(all_waits):.3f}s") + print(f" Max: {max(all_waits):.3f}s") + + # --- Worker fairness --- + worker_counts: dict[int, int] = {} + for r in results: + wid = r["worker_id"] + worker_counts[wid] = worker_counts.get(wid, 0) + 1 + + worker_vals = list(worker_counts.values()) + ideal_per_worker = total / num_workers + + print(f"\nWorker fairness ({num_workers} workers, ideal {ideal_per_worker:.0f} each):") + for wid in sorted(worker_counts): + count = worker_counts[wid] + pct = count / total * 100 + print(f" Worker {wid}: {count} tasks ({pct:.1f}%)") + if len(worker_vals) > 1: + print(f" Stdev: {statistics.stdev(worker_vals):.1f}") + + # --- Partition fairness --- + # For each partition, when was its first task picked up (relative to start)? + # A fair system serves all partitions at roughly the same pace. + partition_first_pickup: dict[int, float] = {} + partition_waits: dict[int, list[float]] = {} + for r in results: + pid = r["partition_id"] + pickup_offset = r["picked_up_at"] - start + if pid not in partition_first_pickup or pickup_offset < partition_first_pickup[pid]: + partition_first_pickup[pid] = pickup_offset + if pid not in partition_waits: + partition_waits[pid] = [] + partition_waits[pid].append(r["wait_time"]) + + first_pickups = list(partition_first_pickup.values()) + avg_waits = [statistics.mean(w) for w in partition_waits.values()] + + print(f"\nPartition fairness ({len(partition_first_pickup)} partitions):") + + print(f" First-task pickup time (seconds after start):") + print(f" Mean: {statistics.mean(first_pickups):.3f}s") + print(f" Median: {statistics.median(first_pickups):.3f}s") + print(f" Stdev: {statistics.stdev(first_pickups):.3f}s") + print(f" Min: {min(first_pickups):.3f}s") + print(f" Max: {max(first_pickups):.3f}s") + print(f" Spread: {max(first_pickups) - min(first_pickups):.3f}s") + + print(f" Average wait per partition:") + print(f" Mean: {statistics.mean(avg_waits):.3f}s") + print(f" Stdev: {statistics.stdev(avg_waits):.3f}s") + print(f" Spread: {max(avg_waits) - min(avg_waits):.3f}s") + + # Identify starved partitions (first pickup > 2x median) + median_pickup = statistics.median(first_pickups) + starved = [pid for pid, t in partition_first_pickup.items() if t > median_pickup * 2] + if starved: + print(f" Starved partitions (first pickup > 2x median): {len(starved)}/{partitions}") + else: + print(f" No starved partitions detected") + + await backend.close() + + +def main(): + parser = argparse.ArgumentParser(description="Queue fairness benchmark") + parser.add_argument("--partitions", type=int, default=500, help="Number of partition queues") + parser.add_argument("--tasks-per-partition", type=int, default=12, help="Tasks per partition") + parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers") + parser.add_argument("--work-duration", type=float, default=0.5, help="Simulated work time (seconds)") + args = parser.parse_args() + + asyncio.run(run( + partitions=args.partitions, + tasks_per_partition=args.tasks_per_partition, + num_workers=args.workers, + work_duration=args.work_duration, + )) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 42ab646..d92754d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,11 @@ dependencies = [ "croniter>=6.0.0", ] +[project.optional-dependencies] +kafka = [ + "aiokafka>=0.11.0", +] + [project.urls] Homepage = "https://github.com/Agent-CI/agentexec" diff --git a/src/agentexec/activity/__init__.py b/src/agentexec/activity/__init__.py index b47d7ae..a983b33 100644 --- a/src/agentexec/activity/__init__.py +++ b/src/agentexec/activity/__init__.py @@ -1,21 +1,83 @@ -from agentexec.activity.models import Activity, ActivityLog, Status +from agentexec.activity.models import Activity, ActivityLog +from agentexec.activity.status import Status from agentexec.activity.schemas import ( ActivityDetailSchema, ActivityListItemSchema, ActivityListSchema, ActivityLogSchema, ) -from agentexec.activity.tracker import ( +from agentexec.activity.handlers import ActivityHandler, PostgresHandler +from agentexec.activity.producer import ( create, update, complete, error, cancel_pending, - list, - detail, - count_active, + generate_agent_id, + normalize_agent_id, ) +handler: ActivityHandler = PostgresHandler() + +import uuid +from typing import Any + +from sqlalchemy.orm import Session + + +async def list( + session: Session | None = None, + page: int = 1, + page_size: int = 50, + metadata_filter: dict[str, Any] | None = None, +) -> ActivityListSchema: + """List activities with pagination.""" + from agentexec.core.db import get_session + + with session or get_session() as db: + query = db.query(Activity) + if metadata_filter: + for key, value in metadata_filter.items(): + query = query.filter(Activity.metadata_[key].as_string() == str(value)) + total = query.count() + + rows = Activity.get_list(db, page=page, page_size=page_size, metadata_filter=metadata_filter) + return ActivityListSchema( + items=[ActivityListItemSchema.model_validate(row) for row in rows], + total=total, + page=page, + page_size=page_size, + ) + + +async def detail( + session: Session | None = None, + agent_id: str | uuid.UUID | None = None, + metadata_filter: dict[str, Any] | None = None, +) -> ActivityDetailSchema | None: + """Get a single activity by agent_id.""" + from agentexec.core.db import get_session + + if agent_id is None: + return None + if isinstance(agent_id, str): + agent_id = uuid.UUID(agent_id) + + with session or get_session() as db: + item = Activity.get_by_agent_id(db, agent_id, metadata_filter=metadata_filter) + if item is not None: + return ActivityDetailSchema.model_validate(item) + return None + + +async def count_active(session: Session | None = None) -> int: + """Count active (queued or running) agents.""" + from agentexec.core.db import get_session + + with session or get_session() as db: + return Activity.get_active_count(db) + + __all__ = [ # Models "Activity", @@ -32,6 +94,8 @@ "complete", "error", "cancel_pending", + "generate_agent_id", + "normalize_agent_id", # Query API "list", "detail", diff --git a/src/agentexec/activity/events.py b/src/agentexec/activity/events.py new file mode 100644 index 0000000..e041a9e --- /dev/null +++ b/src/agentexec/activity/events.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import uuid +from typing import Any + +from pydantic import BaseModel + + +class ActivityCreated(BaseModel): + agent_id: uuid.UUID + task_name: str + message: str + metadata: dict[str, Any] | None = None + + +class ActivityUpdated(BaseModel): + agent_id: uuid.UUID + message: str + status: str + percentage: int | None = None + + +# Resolve forward references +ActivityCreated.model_rebuild() +ActivityUpdated.model_rebuild() diff --git a/src/agentexec/activity/handlers.py b/src/agentexec/activity/handlers.py new file mode 100644 index 0000000..74ccde4 --- /dev/null +++ b/src/agentexec/activity/handlers.py @@ -0,0 +1,103 @@ +"""Activity event handlers — pluggable persistence for lifecycle events. + +The activity system uses a handler pattern to decouple event emission from +persistence. Every call to ``activity.create()``, ``activity.update()``, etc. +emits a typed event (``ActivityCreated`` or ``ActivityUpdated``) and routes it +through ``activity.handler``, a callable that decides what to do with it. + +Two handlers are provided: + +- ``PostgresHandler`` (default): Writes events directly to Postgres via + SQLAlchemy. Used by API servers and the pool's main process. + +- ``IPCHandler``: Serializes events onto a ``multiprocessing.Queue`` for + the pool to receive and persist. Used by worker processes, which don't + have database access. + +The handler is swapped at process init time. Workers set the IPC handler +during startup; everything else uses the default Postgres handler:: + + # Worker process (set automatically by Pool) + from agentexec.activity.handlers import IPCHandler + activity.handler = IPCHandler(tx_queue) + + # API server or pool process (default, no setup needed) + await activity.update(agent_id, "Processing", percentage=50) + # → writes directly to Postgres + +Custom handlers can be implemented by conforming to the ``ActivityHandler`` +protocol — any callable that accepts ``ActivityCreated | ActivityUpdated``. +""" + +from __future__ import annotations + +import multiprocessing as mp +from typing import Protocol + +from agentexec.activity.events import ActivityCreated, ActivityUpdated +from agentexec.activity.status import Status + + +class ActivityHandler(Protocol): + """Protocol for activity event handlers. + + Any callable that accepts an ``ActivityCreated`` or ``ActivityUpdated`` + event satisfies this protocol. + """ + def __call__(self, event: ActivityCreated | ActivityUpdated) -> None: ... + + +class PostgresHandler: + """Writes activity events directly to Postgres. + + This is the default handler. It creates a short-lived database session + for each event, writes the appropriate records, and commits. + """ + + def __call__(self, event: ActivityCreated | ActivityUpdated) -> None: + match event: + case ActivityCreated(agent_id=agent_id, task_name=task_name, message=message, metadata=metadata): + from agentexec.activity.models import Activity, ActivityLog + from agentexec.core.db import get_session + + with get_session() as db: + record = Activity(agent_id=agent_id, agent_type=task_name, metadata_=metadata) + db.add(record) + db.flush() + db.add(ActivityLog( + activity_id=record.id, + message=message, + status=Status.QUEUED, + percentage=0, + )) + db.commit() + + case ActivityUpdated(agent_id=agent_id, message=message, status=status, percentage=percentage): + from agentexec.activity.models import Activity + from agentexec.core.db import get_session + + with get_session() as db: + Activity.append_log( + session=db, + agent_id=agent_id, + message=message, + status=Status(status), + percentage=percentage, + ) + + +class IPCHandler: + """Sends activity events to the pool via multiprocessing queue. + + Worker processes use this handler so they don't need database access. + Events are picked up by the pool's event loop and written to Postgres + using the default ``PostgresHandler``. + """ + + tx: mp.Queue + + def __init__(self, tx: mp.Queue) -> None: + self.tx = tx + + def __call__(self, event: ActivityCreated | ActivityUpdated) -> None: + self.tx.put_nowait(event) diff --git a/src/agentexec/activity/models.py b/src/agentexec/activity/models.py index 8d05565..e7aa853 100644 --- a/src/agentexec/activity/models.py +++ b/src/agentexec/activity/models.py @@ -1,5 +1,5 @@ from __future__ import annotations -from enum import Enum as PyEnum + import uuid from datetime import UTC, datetime @@ -22,20 +22,11 @@ from sqlalchemy.engine import RowMapping from sqlalchemy.orm import Mapped, Session, aliased, mapped_column, relationship, declared_attr +from agentexec.activity.status import Status from agentexec.config import CONF from agentexec.core.db import Base -class Status(str, PyEnum): - """Agent execution status.""" - - QUEUED = "queued" - RUNNING = "running" - COMPLETE = "complete" - ERROR = "error" - CANCELED = "canceled" - - class Activity(Base): """Tracks background agent execution sessions. diff --git a/src/agentexec/activity/producer.py b/src/agentexec/activity/producer.py new file mode 100644 index 0000000..29faded --- /dev/null +++ b/src/agentexec/activity/producer.py @@ -0,0 +1,179 @@ +"""Activity event producer — the public API for activity lifecycle. + +All activity methods emit typed events routed through ``activity.handler``. +By default, events are written directly to Postgres. In worker processes, +the handler is swapped to send events via IPC to the pool. + +See ``activity.handlers`` for the handler implementations. +""" + +from __future__ import annotations + +import uuid +from typing import Any + +from sqlalchemy.orm import Session + +import agentexec.activity as activity +from agentexec.activity.events import ActivityCreated, ActivityUpdated +from agentexec.activity.status import Status + + +def generate_agent_id() -> uuid.UUID: + """Generate a new UUID4 agent identifier.""" + return uuid.uuid4() + + +def normalize_agent_id(agent_id: str | uuid.UUID) -> uuid.UUID: + """Coerce a string or UUID to a UUID object.""" + if isinstance(agent_id, str): + return uuid.UUID(agent_id) + return agent_id + + + + +async def create( + task_name: str, + message: str = "Agent queued", + agent_id: str | uuid.UUID | None = None, + session: Session | None = None, + metadata: dict[str, Any] | None = None, +) -> uuid.UUID: + """Create a new activity record with an initial "queued" log entry. + + Called during ``ax.enqueue()`` to register the task in the activity + stream before it hits the queue. + + Args: + task_name: The registered task name (e.g. ``"research"``). + message: Initial log message. + agent_id: Optional pre-generated agent ID. Auto-generated if omitted. + session: Unused — kept for backwards compatibility. + metadata: Arbitrary key-value pairs attached to the activity + (e.g. ``{"organization_id": "org-123"}``). + + Returns: + The agent_id (UUID) of the created record. + + Example:: + + agent_id = await activity.create("research", metadata={"org": "acme"}) + """ + agent_id = normalize_agent_id(agent_id) if agent_id else generate_agent_id() + activity.handler(ActivityCreated( + agent_id=agent_id, + task_name=task_name, + message=message, + metadata=metadata, + )) + return agent_id + + +async def update( + agent_id: str | uuid.UUID, + message: str, + percentage: int | None = None, + status: Status | None = None, + session: Session | None = None, +) -> bool: + """Append a log entry to an existing activity record. + + Defaults to ``Status.RUNNING`` if no status is provided. + + Args: + agent_id: The agent to update. + message: Log message describing the current state. + percentage: Optional completion percentage (0-100). + status: Optional status override (default: ``RUNNING``). + session: Unused — kept for backwards compatibility. + + Example:: + + await activity.update(agent_id, "Fetching data", percentage=30) + """ + activity.handler(ActivityUpdated( + agent_id=normalize_agent_id(agent_id), + message=message, + status=(status or Status.RUNNING).value, + percentage=percentage, + )) + return True + + +async def complete( + agent_id: str | uuid.UUID, + message: str = "Agent completed", + percentage: int = 100, + session: Session | None = None, +) -> bool: + """Mark an activity as complete. + + Args: + agent_id: The agent to mark complete. + message: Completion log message. + percentage: Final percentage (default: 100). + session: Unused — kept for backwards compatibility. + + Example:: + + await activity.complete(agent_id) + """ + activity.handler(ActivityUpdated( + agent_id=normalize_agent_id(agent_id), + message=message, + status=Status.COMPLETE.value, + percentage=percentage, + )) + return True + + +async def error( + agent_id: str | uuid.UUID, + message: str = "Agent failed", + percentage: int = 100, + session: Session | None = None, +) -> bool: + """Mark an activity as failed. + + Args: + agent_id: The agent to mark as errored. + message: Error log message. + percentage: Final percentage (default: 100). + session: Unused — kept for backwards compatibility. + + Example:: + + await activity.error(agent_id, "Connection timeout") + """ + activity.handler(ActivityUpdated( + agent_id=normalize_agent_id(agent_id), + message=message, + status=Status.ERROR.value, + percentage=percentage, + )) + return True + + +async def cancel_pending(session: Session | None = None) -> int: + """Cancel all queued and running activities. + + Typically called during pool shutdown to mark in-flight tasks as + canceled. Reads pending IDs from Postgres and emits cancel events. + + Returns: + Number of activities canceled. + """ + from agentexec.activity.models import Activity + from agentexec.core.db import get_session + + with session or get_session() as db: + pending_ids = Activity.get_pending_ids(db) + for agent_id in pending_ids: + activity.handler(ActivityUpdated( + agent_id=agent_id, + message="Canceled due to shutdown", + status=Status.CANCELED.value, + percentage=None, + )) + return len(pending_ids) diff --git a/src/agentexec/activity/schemas.py b/src/agentexec/activity/schemas.py index e326348..144a73f 100644 --- a/src/agentexec/activity/schemas.py +++ b/src/agentexec/activity/schemas.py @@ -2,9 +2,9 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, ConfigDict, Field, computed_field +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field -from agentexec.activity.models import Status +from agentexec.activity.status import Status class ActivityLogSchema(BaseModel): @@ -22,15 +22,19 @@ class ActivityLogSchema(BaseModel): class ActivityDetailSchema(BaseModel): """Schema for an agent activity record with optional logs.""" - model_config = ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True, populate_by_name=True) - id: uuid.UUID + id: uuid.UUID | None = None agent_id: uuid.UUID agent_type: str created_at: datetime updated_at: datetime logs: list[ActivityLogSchema] = Field(default_factory=list) - metadata: dict[str, Any] | None = Field(default=None, alias="metadata_", exclude=True) + metadata: dict[str, Any] | None = Field( + default=None, + validation_alias=AliasChoices("metadata_", "metadata"), + exclude=True, + ) class ActivityListItemSchema(BaseModel): diff --git a/src/agentexec/activity/status.py b/src/agentexec/activity/status.py new file mode 100644 index 0000000..b0afcb4 --- /dev/null +++ b/src/agentexec/activity/status.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class Status(str, Enum): + """Agent execution status.""" + + QUEUED = "queued" + RUNNING = "running" + COMPLETE = "complete" + ERROR = "error" + CANCELED = "canceled" diff --git a/src/agentexec/activity/tracker.py b/src/agentexec/activity/tracker.py deleted file mode 100644 index 12aeff8..0000000 --- a/src/agentexec/activity/tracker.py +++ /dev/null @@ -1,286 +0,0 @@ -import uuid -from typing import Any - -from sqlalchemy.orm import Session - -from agentexec.activity.models import Activity, ActivityLog, Status -from agentexec.activity.schemas import ( - ActivityDetailSchema, - ActivityListItemSchema, - ActivityListSchema, -) -from agentexec.core.db import get_global_session - - -def generate_agent_id() -> uuid.UUID: - """Generate a new UUID for an agent. - - This is the centralized function for generating agent IDs. - Users can override this if they need custom ID generation logic. - - Returns: - A new UUID4 object - """ - return uuid.uuid4() - - -def normalize_agent_id(agent_id: str | uuid.UUID) -> uuid.UUID: - """Normalize agent_id to UUID object. - - Args: - agent_id: Either a string UUID or UUID object - - Returns: - UUID object - - Raises: - ValueError: If string is not a valid UUID - """ - if isinstance(agent_id, str): - return uuid.UUID(agent_id) - return agent_id - - -def create( - task_name: str, - message: str = "Agent queued", - agent_id: str | uuid.UUID | None = None, - session: Session | None = None, - metadata: dict[str, Any] | None = None, -) -> uuid.UUID: - """Create a new agent activity record with initial queued status. - - Args: - task_name: Name/type of the task (e.g., "research", "analysis") - message: Initial log message (default: "Agent queued") - agent_id: Optional custom agent ID (string or UUID). If not provided, one will be auto-generated. - session: Optional SQLAlchemy session. If not provided, uses global session factory. - metadata: Optional dict of arbitrary metadata to attach to the activity. - Useful for multi-tenancy (e.g., {"organization_id": "org-123"}). - - Returns: - The agent_id (as UUID object) of the created record - """ - agent_id = normalize_agent_id(agent_id) if agent_id else generate_agent_id() - db = session or get_global_session() - - activity_record = Activity( - agent_id=agent_id, - agent_type=task_name, - metadata_=metadata, - ) - db.add(activity_record) - db.flush() - - log = ActivityLog( - activity_id=activity_record.id, - message=message, - status=Status.QUEUED, - percentage=0, - ) - db.add(log) - db.commit() - - return agent_id - - -def update( - agent_id: str | uuid.UUID, - message: str, - percentage: int | None = None, - status: Status | None = None, - session: Session | None = None, -) -> bool: - """Update an agent's activity by adding a new log message. - - This function will set the status to RUNNING unless a different status is explicitly provided. - - Args: - agent_id: The agent_id of the agent to update - message: Log message to append - percentage: Optional completion percentage (0-100) - status: Optional status to set (default: RUNNING) - session: Optional SQLAlchemy session. If not provided, uses global session factory. - - Returns: - True if successful - - Raises: - ValueError: If agent_id not found - """ - db = session or get_global_session() - - Activity.append_log( - session=db, - agent_id=normalize_agent_id(agent_id), - message=message, - status=status if status else Status.RUNNING, - percentage=percentage, - ) - return True - - -def complete( - agent_id: str | uuid.UUID, - message: str = "Agent completed", - percentage: int = 100, - session: Session | None = None, -) -> bool: - """Mark an agent activity as complete. - - Args: - agent_id: The agent_id of the agent to mark as complete - message: Log message (default: "Agent completed") - percentage: Completion percentage (default: 100) - session: Optional SQLAlchemy session. If not provided, uses global session factory. - - Returns: - True if successful - - Raises: - ValueError: If agent_id not found - """ - db = session or get_global_session() - - Activity.append_log( - session=db, - agent_id=normalize_agent_id(agent_id), - message=message, - status=Status.COMPLETE, - percentage=percentage, - ) - return True - - -def error( - agent_id: str | uuid.UUID, - message: str = "Agent failed", - percentage: int = 100, - session: Session | None = None, -) -> bool: - """Mark an agent activity as failed. - - Args: - agent_id: The agent_id of the agent to mark as failed - message: Log message (default: "Agent failed") - percentage: Completion percentage (default: 100) - session: Optional SQLAlchemy session. If not provided, uses ScopedSession. - - Returns: - True if successful - - Raises: - ValueError: If agent_id not found - """ - db = session or get_global_session() - - Activity.append_log( - session=db, - agent_id=normalize_agent_id(agent_id), - message=message, - status=Status.ERROR, - percentage=percentage, - ) - return True - - -def cancel_pending( - session: Session | None = None, -) -> int: - """Mark all queued and running agents as canceled. - - Useful during application shutdown to clean up pending tasks. - - Returns: - Number of agents that were canceled - """ - db = session or get_global_session() - - pending_agent_ids = Activity.get_pending_ids(db) - for agent_id in pending_agent_ids: - Activity.append_log( - session=db, - agent_id=agent_id, - message="Canceled due to shutdown", - status=Status.CANCELED, - percentage=None, - ) - - db.commit() - return len(pending_agent_ids) - - -def list( - session: Session, - page: int = 1, - page_size: int = 50, - metadata_filter: dict[str, Any] | None = None, -) -> ActivityListSchema: - """List activities with pagination. - - Args: - session: SQLAlchemy session to use for the query - page: Page number (1-indexed) - page_size: Number of items per page - metadata_filter: Optional dict of key-value pairs to filter by. - Activities must have metadata containing all specified keys - with exactly matching values. - - Returns: - ActivityList with list of ActivityListItemSchema items - """ - # Build base query for total count - query = session.query(Activity) - if metadata_filter: - for key, value in metadata_filter.items(): - query = query.filter(Activity.metadata_[key].as_string() == str(value)) - total = query.count() - - rows = Activity.get_list( - session, - page=page, - page_size=page_size, - metadata_filter=metadata_filter, - ) - - return ActivityListSchema( - items=[ActivityListItemSchema.model_validate(row) for row in rows], - total=total, - page=page, - page_size=page_size, - ) - - -def detail( - session: Session, - agent_id: str | uuid.UUID, - metadata_filter: dict[str, Any] | None = None, -) -> ActivityDetailSchema | None: - """Get a single activity by agent_id with all logs. - - Args: - session: SQLAlchemy session to use for the query - agent_id: The agent_id to look up - metadata_filter: Optional dict of key-value pairs to filter by. - If provided and the activity's metadata doesn't match, - returns None (same as if not found). - - Returns: - ActivityDetailSchema with full log history, or None if not found - or if metadata doesn't match - """ - if item := Activity.get_by_agent_id(session, agent_id, metadata_filter=metadata_filter): - return ActivityDetailSchema.model_validate(item) - return None - - -def count_active(session: Session) -> int: - """Get count of active (queued or running) agents. - - Args: - session: SQLAlchemy session to use for the query - - Returns: - Count of agents with QUEUED or RUNNING status - """ - return Activity.get_active_count(session) diff --git a/src/agentexec/config.py b/src/agentexec/config.py index 0f9f12a..56092b8 100644 --- a/src/agentexec/config.py +++ b/src/agentexec/config.py @@ -17,10 +17,10 @@ class Config(BaseSettings): description="Prefix for database table names", validation_alias="AGENTEXEC_TABLE_PREFIX", ) - queue_name: str = Field( + queue_prefix: str = Field( default="agentexec_tasks", - description="Name of the Redis list to use as task queue", - validation_alias="AGENTEXEC_QUEUE_NAME", + description="Prefix for task queue keys. Partition queues are {prefix}:{lock_key}.", + validation_alias=AliasChoices("AGENTEXEC_QUEUE_PREFIX", "AGENTEXEC_QUEUE_NAME"), ) num_workers: int = Field( default=4, @@ -72,22 +72,61 @@ class Config(BaseSettings): result_ttl: int = Field( default=3600, - description="TTL in seconds for task results in Redis", + description="TTL in seconds for task results", validation_alias="AGENTEXEC_RESULT_TTL", ) state_backend: str = Field( - default="agentexec.state.redis_backend", - description="State backend to use (fully-qualified module path)", + default="agentexec.state.redis", + description="State backend: 'agentexec.state.redis' or 'agentexec.state.kafka'", validation_alias="AGENTEXEC_STATE_BACKEND", ) + kafka_bootstrap_servers: str | None = Field( + default=None, + description="Kafka bootstrap servers (e.g. 'localhost:9092')", + validation_alias=AliasChoices( + "AGENTEXEC_KAFKA_BOOTSTRAP_SERVERS", "KAFKA_BOOTSTRAP_SERVERS" + ), + ) + kafka_default_partitions: int = Field( + default=6, + description="Default number of partitions for auto-created topics", + validation_alias="AGENTEXEC_KAFKA_DEFAULT_PARTITIONS", + ) + kafka_replication_factor: int = Field( + default=1, + description="Replication factor for auto-created topics", + validation_alias="AGENTEXEC_KAFKA_REPLICATION_FACTOR", + ) + kafka_max_batch_size: int = Field( + default=16384, + description="Producer max batch size in bytes", + validation_alias="AGENTEXEC_KAFKA_MAX_BATCH_SIZE", + ) + kafka_linger_ms: int = Field( + default=5, + description="Producer linger time in milliseconds", + validation_alias="AGENTEXEC_KAFKA_LINGER_MS", + ) + kafka_retention_ms: int = Field( + default=-1, + description="Retention for compacted topics in ms (-1 = forever)", + validation_alias="AGENTEXEC_KAFKA_RETENTION_MS", + ) + key_prefix: str = Field( default="agentexec", description="Prefix for state backend keys", validation_alias="AGENTEXEC_KEY_PREFIX", ) + scheduler_poll_interval: int = Field( + default=10, + description="Seconds between schedule polls", + validation_alias="AGENTEXEC_SCHEDULER_POLL_INTERVAL", + ) + scheduler_timezone: str = Field( default="UTC", description=( @@ -96,14 +135,24 @@ class Config(BaseSettings): ), validation_alias="AGENTEXEC_SCHEDULER_TIMEZONE", ) + max_task_retries: int = Field( + default=3, + description=( + "Maximum number of times a failed task will be retried before " + "being marked as a permanent error. Set to 0 to disable retries. " + "With the Kafka backend, retries preserve partition ordering — " + "the task stays in its original position in the queue." + ), + validation_alias="AGENTEXEC_MAX_TASK_RETRIES", + ) + lock_ttl: int = Field( default=1800, description=( - "TTL in seconds for task lock keys in Redis. " - "This is a safety net for worker process death (OOM, SIGKILL) — " + "TTL in seconds for task lock keys (Redis backend only). " + "Safety net for worker process death (OOM, SIGKILL) — " "locks are always explicitly released on task completion or error. " - "Set this higher than your longest expected task duration to avoid " - "premature lock expiry while a task is still running." + "Ignored by the Kafka backend (partition assignment handles isolation)." ), validation_alias="AGENTEXEC_LOCK_TTL", ) diff --git a/src/agentexec/core/db.py b/src/agentexec/core/db.py index e5f1a00..8880738 100644 --- a/src/agentexec/core/db.py +++ b/src/agentexec/core/db.py @@ -1,62 +1,52 @@ from sqlalchemy import Engine -from sqlalchemy.orm import DeclarativeBase, Session, scoped_session, sessionmaker +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker __all__ = [ "Base", - "get_global_session", - "set_global_session", - "remove_global_session", + "configure_engine", + "get_session", ] class Base(DeclarativeBase): - """Base class for all SQLAlchemy models in agent-runner. + """Base class for all SQLAlchemy models. + + Example:: - Example: # In alembic/env.py import agentexec as ax target_metadata = ax.Base.metadata """ - pass -# We need one session per worker process with a shared engine across the application. -# SQLAlchemy's scoped_session provides process-local session management out of the box. -_session_factory: scoped_session[Session] = scoped_session(sessionmaker()) +_engine: Engine | None = None +_session_factory: sessionmaker[Session] | None = None -def set_global_session(engine: Engine) -> None: - """Configure the global session factory with an engine. +def configure_engine(engine: Engine) -> None: + """Set the shared engine for the application. - Called by workers on startup to bind the session to their database. - - Args: - engine: SQLAlchemy engine to bind sessions to. + Called once during Pool initialization. Workers inherit the engine + via multiprocessing. """ - _session_factory.configure(bind=engine) + global _engine, _session_factory + _engine = engine + _session_factory = sessionmaker(bind=engine) -def get_global_session() -> Session: - """Get the worker's process-local session. +def get_session() -> Session: + """Create a new session from the shared engine. - This is distinct from request-scoped sessions used in API handlers. - Use this for background task execution within workers. + Use as a context manager:: - Returns: - A session bound to the configured engine. + with get_session() as db: + db.query(...) Raises: - RuntimeError: If set_global_session() hasn't been called. + RuntimeError: If ``configure_engine()`` hasn't been called. """ + if _session_factory is None: + raise RuntimeError("Database engine not configured. Call configure_engine() first.") return _session_factory() - - -def remove_global_session() -> None: - """Close and remove the worker's process-local session. - - Called during worker cleanup to close the session and return - connections to the pool. - """ - _session_factory.remove() diff --git a/src/agentexec/core/logging.py b/src/agentexec/core/logging.py index 9a79565..0df26df 100644 --- a/src/agentexec/core/logging.py +++ b/src/agentexec/core/logging.py @@ -1,9 +1,3 @@ -"""Unified logging for main and worker processes. - -Uses multiprocessing's built-in logger which handles cross-process -logging correctly on macOS (spawn mode). -""" - import logging import multiprocessing diff --git a/src/agentexec/core/queue.py b/src/agentexec/core/queue.py index e1dc2cf..b3dace5 100644 --- a/src/agentexec/core/queue.py +++ b/src/agentexec/core/queue.py @@ -1,13 +1,11 @@ -import json from enum import Enum from typing import Any from pydantic import BaseModel -from agentexec import state -from agentexec.config import CONF from agentexec.core.logging import get_logger from agentexec.core.task import Task +from agentexec.state import backend logger = get_logger(__name__) @@ -28,105 +26,40 @@ async def enqueue( context: BaseModel, *, priority: Priority = Priority.LOW, - queue_name: str | None = None, metadata: dict[str, Any] | None = None, ) -> Task: """Enqueue a task for background execution. - Pushes the task to the queue for worker processing. The task must be - registered with a WorkerPool via @pool.task() decorator. + Creates an activity record, serializes the context, and pushes the + task to the queue for workers to process. Args: - task_name: Name of the task to execute. - context: Task context as a Pydantic BaseModel. - priority: Task priority (Priority.HIGH or Priority.LOW). - queue_name: Queue name. Defaults to CONF.queue_name. - metadata: Optional dict of arbitrary metadata to attach to the activity. - Useful for multi-tenancy (e.g., {"organization_id": "org-123"}). + task_name: Name of the registered task (must match a ``@pool.task()``). + context: Pydantic model with the task's input data. + priority: ``Priority.HIGH`` pushes to the front of the queue. + metadata: Optional dict attached to the activity record (e.g. + ``{"organization_id": "org-123"}`` for multi-tenancy). Returns: - Task instance with typed context and agent_id for tracking. + The created Task with its ``agent_id`` for tracking. - Example: - @pool.task("research_company") - async def research(agent_id: UUID, context: ResearchContext): - ... + Example:: - task = await ax.enqueue("research_company", ResearchContext(company="Acme")) - - # With metadata for multi-tenancy - task = await ax.enqueue( - "research_company", - ResearchContext(company="Acme"), - metadata={"organization_id": "org-123"} - ) + task = await ax.enqueue("research", ResearchContext(company="Acme")) + print(task.agent_id) # UUID for tracking """ - push_func = { - Priority.HIGH: state.backend.rpush, - Priority.LOW: state.backend.lpush, - }[priority] - - task = Task.create( + task = await Task.create( task_name=task_name, context=context, metadata=metadata, ) - push_func( - queue_name or CONF.queue_name, + + await backend.queue.push( task.model_dump_json(), + high_priority=(priority == Priority.HIGH), ) logger.info(f"Enqueued task {task.task_name} with agent_id {task.agent_id}") return task -def requeue( - task: Task, - *, - queue_name: str | None = None, -) -> int: - """Push a task back to the end of the queue. - - Used when a task's lock cannot be acquired — the task is returned to the - queue so it can be retried after the lock is released. - - Args: - task: Task to requeue. - queue_name: Queue name. Defaults to CONF.queue_name. - - Returns: - Length of the queue after the push. - """ - return state.backend.lpush( - queue_name or CONF.queue_name, - task.model_dump_json(), - ) - - -async def dequeue( - *, - queue_name: str | None = None, - timeout: int = 1, -) -> dict[str, Any] | None: - """Dequeue a task from the queue. - - Blocks for up to timeout seconds waiting for a task. - - Args: - queue_name: Queue name. Defaults to CONF.queue_name. - timeout: Maximum seconds to wait for a task. - - Returns: - Parsed task data if available, None otherwise. - """ - result = await state.backend.brpop( - queue_name or CONF.queue_name, - timeout=timeout, - ) - - if result is None: - return None - - _, task_data = result - data: dict[str, Any] = json.loads(task_data) - return data diff --git a/src/agentexec/core/results.py b/src/agentexec/core/results.py index 204f8ff..2f3c1a2 100644 --- a/src/agentexec/core/results.py +++ b/src/agentexec/core/results.py @@ -6,7 +6,7 @@ from pydantic import BaseModel -from agentexec import state +from agentexec.state import KEY_RESULT, backend if TYPE_CHECKING: from agentexec.core.task import Task @@ -15,26 +15,18 @@ DEFAULT_TIMEOUT: int = 300 # TODO improve this polling approach -async def get_result(task: Task, timeout: int = DEFAULT_TIMEOUT) -> BaseModel: - """Poll for a task result. - - Waits for a task to complete and returns its result. - Uses automatic type reconstruction from serialized class information. - - Args: - task: The Task instance to wait for - timeout: Maximum seconds to wait for result +async def _get_result(agent_id: str) -> BaseModel | None: + key = backend.format_key(*KEY_RESULT, str(agent_id)) + data = await backend.state.get(key) + return backend.deserialize(data) if data else None - Returns: - Deserialized result as BaseModel instance - Raises: - TimeoutError: If result not available within timeout - """ +async def get_result(task: Task, timeout: int = DEFAULT_TIMEOUT) -> BaseModel: + """Poll for a task result.""" start = time.time() while time.time() - start < timeout: - result = await state.aget_result(task.agent_id) + result = await _get_result(task.agent_id) if result is not None: return result await asyncio.sleep(0.5) @@ -43,22 +35,6 @@ async def get_result(task: Task, timeout: int = DEFAULT_TIMEOUT) -> BaseModel: async def gather(*tasks: Task, timeout: int = DEFAULT_TIMEOUT) -> tuple[BaseModel, ...]: - """Wait for multiple tasks and return their results. - - Similar to asyncio.gather, but for background tasks. - - Args: - *tasks: Task instances to wait for - timeout: Maximum seconds to wait for each result - - Returns: - Tuple of deserialized results as BaseModel instances - - Example: - brand = await ax.enqueue("brand_research", ctx) - market = await ax.enqueue("market_research", ctx) - - brand_result, market_result = await ax.gather(brand, market) - """ + """Wait for multiple tasks and return their results.""" results = await asyncio.gather(*[get_result(task, timeout) for task in tasks]) return tuple(results) diff --git a/src/agentexec/core/task.py b/src/agentexec/core/task.py index ae386c3..186adfa 100644 --- a/src/agentexec/core/task.py +++ b/src/agentexec/core/task.py @@ -1,13 +1,15 @@ from __future__ import annotations import inspect +from collections.abc import Mapping from typing import Any, Protocol, TypeAlias, TypeVar, cast, get_type_hints from uuid import UUID -from pydantic import BaseModel, ConfigDict, PrivateAttr, field_serializer +from pydantic import BaseModel, ConfigDict -from agentexec import activity, state +from agentexec import activity from agentexec.config import CONF +from agentexec.state import KEY_RESULT, backend TaskResult: TypeAlias = BaseModel @@ -16,8 +18,6 @@ class _SyncTaskHandler(Protocol[ContextT, ResultT]): - """Protocol for sync task handler functions.""" - __name__: str def __call__( @@ -29,8 +29,6 @@ def __call__( class _AsyncTaskHandler(Protocol[ContextT, ResultT]): - """Protocol for async task handler functions.""" - __name__: str async def __call__( @@ -51,28 +49,14 @@ async def __call__( class TaskDefinition: """Definition of a task type (created at registration time). - Encapsulates the handler function and its metadata (context class, etc.). + Encapsulates the handler function and its metadata (context class, lock key). One TaskDefinition can spawn many Task instances. - - This object is created once when a task is registered via @pool.task(), - and acts as a factory to reconstruct Task instances from the queue with - properly typed context. - - Example: - @pool.task("research_company") - async def research(agent_id: UUID, context: ResearchContext): - print(context.company_name) - - # TaskDefinition captures ResearchContext from the type hint - # and uses it to deserialize tasks from the queue """ name: str handler: TaskHandler context_type: type[BaseModel] - # Optional: only set if handler returns a BaseModel subclass result_type: type[BaseModel] | None - # Optional: string template evaluated against context for distributed locking lock_key: str | None def __init__( @@ -87,16 +71,17 @@ def __init__( """Initialize task definition. Args: - name: Task type name - handler: Handler function (sync or async) - context_type: Optional explicit context type (inferred from annotations if not provided). - result_type: Optional explicit result type (inferred from annotations if not provided). - lock_key: Optional string template for distributed locking. Evaluated against - context fields (e.g., "user:{user_id}"). When set, only one task with - the same evaluated lock key can run at a time. + name: Task type name. + handler: Handler function (sync or async). + context_type: Explicit context type (inferred from annotations if omitted). + result_type: Explicit result type (inferred from annotations if omitted). + lock_key: String template for distributed locking, evaluated against + context fields (e.g. ``"user:{user_id}"``). When set, only one task + with the same evaluated lock key can run at a time. Raises: - TypeError: If handler doesn't have a typed 'context' parameter with BaseModel subclass + TypeError: If handler doesn't have a typed ``context`` parameter + with a BaseModel subclass. """ self.name = name self.handler = handler @@ -104,127 +89,98 @@ def __init__( self.result_type = result_type or self._infer_result_type(handler) self.lock_key = lock_key - async def __call__(self, agent_id: UUID, context: BaseModel) -> TaskResult: - """Delegate calls to the handler function.""" - if inspect.iscoroutinefunction(self.handler): - handler = cast(_AsyncTaskHandler, self.handler) - return await handler(agent_id=agent_id, context=context) - else: - handler = cast(_SyncTaskHandler, self.handler) - return handler(agent_id=agent_id, context=context) + def get_lock_key(self, context: Mapping[str, Any]) -> str | None: + """Evaluate the lock key template against context data.""" + return self.lock_key.format(**context) if self.lock_key else None - def _infer_context_type(self, handler: TaskHandler) -> type[BaseModel]: - """Infer context class from handler's type annotations. + def hydrate_context(self, context: Mapping[str, Any]) -> BaseModel: + """Validate raw context data into the registered Pydantic model.""" + return self.context_type.model_validate(context) - Looks for a 'context' parameter with a Pydantic BaseModel type hint. + async def execute(self, task: Task) -> TaskResult | None: + """Execute the task handler and manage its lifecycle. - Args: - handler: The task handler function + Handles activity tracking (started/complete/error) and result storage. + """ + context = self.hydrate_context(task.context) - Returns: - Context class (BaseModel subclass) + await activity.update( + agent_id=task.agent_id, + message=CONF.activity_message_started, + percentage=0, + ) - Raises: - TypeError: If 'context' parameter is missing or not a BaseModel subclass - """ + try: + if inspect.iscoroutinefunction(self.handler): + handler = cast(_AsyncTaskHandler, self.handler) + result = await handler(agent_id=task.agent_id, context=context) + else: + handler = cast(_SyncTaskHandler, self.handler) + result = handler(agent_id=task.agent_id, context=context) + + if isinstance(result, BaseModel): + key = backend.format_key(*KEY_RESULT, str(task.agent_id)) + await backend.state.set(key, backend.serialize(result), ttl_seconds=CONF.result_ttl) + + await activity.update( + agent_id=task.agent_id, + message=CONF.activity_message_complete, + percentage=100, + status=activity.Status.COMPLETE, + ) + return result + except Exception as e: + await activity.update( + agent_id=task.agent_id, + message=CONF.activity_message_error.format(error=e), + status=activity.Status.ERROR, + ) + raise + + def _infer_context_type(self, handler: TaskHandler) -> type[BaseModel]: hints = get_type_hints(handler) if "context" not in hints: raise TypeError( f"Task handler '{handler.__name__}' must have a 'context' parameter " f"with a BaseModel type annotation" ) - context_type = hints["context"] if not (inspect.isclass(context_type) and issubclass(context_type, BaseModel)): raise TypeError( f"Task handler '{handler.__name__}' context parameter must be a " f"BaseModel subclass, got {context_type}" ) - return context_type def _infer_result_type(self, handler: TaskHandler) -> type[BaseModel] | None: - """Infer result class from handler's return type annotation. - - Looks for a return annotation with a Pydantic BaseModel type hint. - - Args: - handler: The task handler function - - Returns: - Result class (BaseModel subclass) or None if return type is not BaseModel - """ hints = get_type_hints(handler) if "return" not in hints: return None - return_type = hints["return"] if not (inspect.isclass(return_type) and issubclass(return_type, BaseModel)): return None - return return_type class Task(BaseModel): - """Represents a background task instance. + """A background task instance — pure data, no behavior. - Tasks are serialized to JSON and enqueued to Redis for workers to process. - Each task has a type (matching a registered TaskDefinition), a typed context, - and an agent_id for tracking. + Tasks are serialized to JSON and pushed to the queue. Workers pop them, + look up the TaskDefinition by task_name, and execute via the definition. - The context is stored as its native Pydantic type. Serialization to dict - happens automatically via field_serializer when dumping to JSON. - - After deserialization, call bind() to attach the TaskDefinition, then - execute() to run the task handler. - - Example: - # Create with typed context - ctx = ResearchContext(company_name="Anthropic") - task = Task.create("research", ctx) - task.context.company_name # Typed access! - - # Serialize to JSON for Redis (context becomes dict) - json_str = task.model_dump_json() - - # Worker deserializes and executes - task = Task.from_serialized(task_def, data) - await task.execute() + Context is stored as a raw dict. The TaskDefinition hydrates it into + the registered Pydantic model at execution time. """ model_config = ConfigDict(arbitrary_types_allowed=True) task_name: str - context: BaseModel + context: Mapping[str, Any] agent_id: UUID - _definition: TaskDefinition | None = PrivateAttr(default=None) - - @field_serializer("context") - def serialize_context(self, value: BaseModel) -> dict[str, Any]: - """Serialize context to dict for JSON storage.""" - return value.model_dump(mode="json") + retry_count: int = 0 @classmethod - def from_serialized(cls, definition: TaskDefinition, data: dict[str, Any]) -> Task: - """Create a Task from serialized data with its definition bound. - - Args: - definition: The TaskDefinition containing the handler and context_type - data: Serialized task data with task_name, context, and agent_id - - Returns: - Task instance with typed context and bound definition - """ - task = cls( - task_name=data["task_name"], - context=definition.context_type.model_validate(data["context"]), - agent_id=data["agent_id"], - ) - task._definition = definition - return task - - @classmethod - def create( + async def create( cls, task_name: str, context: BaseModel, @@ -232,31 +188,19 @@ def create( ) -> Task: """Create a new task with automatic activity tracking. - This is a convenience method that creates both a Task instance and - its corresponding activity record in one step. + Creates an activity record and returns a Task ready to be + serialized and pushed to the queue. Args: - task_name: Name/type of the task (e.g., "research", "analysis") - context: Task context as a Pydantic model - metadata: Optional dict of arbitrary metadata to attach to the activity. - Useful for multi-tenancy (e.g., {"organization_id": "org-123"}). + task_name: Name of the registered task. + context: Pydantic model with the task's input data. + metadata: Optional dict attached to the activity record + (e.g. ``{"organization_id": "org-123"}``). Returns: - Task instance with agent_id set - - Example: - ctx = ResearchContext(company="Acme") - task = Task.create("research_company", ctx) - task.context.company # Typed access - - # With metadata for multi-tenancy - task = Task.create( - "research_company", - ctx, - metadata={"organization_id": "org-123"} - ) + Task instance with ``agent_id`` set for tracking. """ - agent_id = activity.create( + agent_id = await activity.create( task_name=task_name, message=CONF.activity_message_create, metadata=metadata, @@ -264,73 +208,6 @@ def create( return cls( task_name=task_name, - context=context, + context=context.model_dump(mode="json"), agent_id=agent_id, ) - - def get_lock_key(self) -> str | None: - """Evaluate the lock key template against the task context. - - Returns: - Evaluated lock key string, or None if no lock_key is configured. - - Raises: - RuntimeError: If task has not been bound to a definition. - KeyError: If the template references a field not present in the context. - """ - if self._definition is None: - raise RuntimeError("Task must be bound to a definition before getting lock key") - - if self._definition.lock_key is None: - return None - - return self._definition.lock_key.format(**self.context.model_dump()) - - async def execute(self) -> TaskResult | None: - """Execute the task using its bound definition's handler. - - Manages task lifecycle: marks started, runs handler, marks completed/errored. - - Returns: - Handler return value, or None if handler raised an exception - - Raises: - RuntimeError: If task has not been bound to a definition - """ - if self._definition is None: - raise RuntimeError("Task must be bound to a definition before execution") - - activity.update( - agent_id=self.agent_id, - message=CONF.activity_message_started, - percentage=0, - ) - - try: - result = await self._definition( - agent_id=self.agent_id, - context=self.context, - ) - - # TODO ensure we are properly supporting None return values - if isinstance(result, BaseModel): - await state.aset_result( - self.agent_id, - result, - ttl_seconds=CONF.result_ttl, - ) - - activity.update( - agent_id=self.agent_id, - message=CONF.activity_message_complete, - percentage=100, - status=activity.Status.COMPLETE, - ) - return result - except Exception as e: - activity.update( - agent_id=self.agent_id, - message=CONF.activity_message_error.format(error=e), - status=activity.Status.ERROR, - ) - return None diff --git a/src/agentexec/pipeline.py b/src/agentexec/pipeline.py index 9801e79..8491d80 100644 --- a/src/agentexec/pipeline.py +++ b/src/agentexec/pipeline.py @@ -374,7 +374,7 @@ async def _run_task( _context: StepResult = context for i, step in enumerate(steps): - activity.update( + await activity.update( agent_id, f"Started {step.description}", percentage=int((i / total_steps) * 100), diff --git a/src/agentexec/runners/base.py b/src/agentexec/runners/base.py index d7881ad..d0b1278 100644 --- a/src/agentexec/runners/base.py +++ b/src/agentexec/runners/base.py @@ -117,7 +117,7 @@ def report_status(self) -> Any: agent_id = self._agent_id assert agent_id, "agent_id must be set to use report_status tool" - def report_activity(message: str, percentage: int) -> str: + async def report_activity(message: str, percentage: int) -> str: """Report progress and status updates. Use this tool to report your progress as you work through the task. @@ -129,7 +129,7 @@ def report_activity(message: str, percentage: int) -> str: Returns: Confirmation message """ - activity.update( + await activity.update( agent_id=agent_id, message=message, percentage=percentage, diff --git a/src/agentexec/schedule.py b/src/agentexec/schedule.py index 8dc5b1d..b742137 100644 --- a/src/agentexec/schedule.py +++ b/src/agentexec/schedule.py @@ -4,18 +4,16 @@ from datetime import datetime from typing import Any from croniter import croniter -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, Field -from agentexec import state from agentexec.config import CONF from agentexec.core.logging import get_logger -from agentexec.core.queue import enqueue +from agentexec.state import backend logger = get_logger(__name__) __all__ = [ "register", - "tick", ] REPEAT_FOREVER: int = -1 @@ -24,9 +22,10 @@ class ScheduledTask(BaseModel): """A task scheduled to run on a recurring interval. - Stored in Redis with a sorted-set index for efficient due-time polling. - Each time it fires, a fresh Task (with its own agent_id) is enqueued - for the worker pool. Stays in Redis until its repeat budget is exhausted. + Stored in the schedule backend with a time index for efficient + due-time polling. Each time it fires, a fresh Task (with its own + agent_id) is enqueued for the worker pool. Stays registered until + its repeat budget is exhausted. """ task_name: str @@ -37,6 +36,13 @@ class ScheduledTask(BaseModel): created_at: float = Field(default_factory=lambda: time.time()) metadata: dict[str, Any] | None = None + @property + def key(self) -> str: + """Unique identity: task_name + cron + context hash.""" + import hashlib + context_hash = hashlib.md5(self.context).hexdigest()[:8] + return f"{self.task_name}:{self.cron}:{context_hash}" + def model_post_init(self, __context: Any) -> None: """Compute next_run from cron if not explicitly set.""" if self.next_run == 0: @@ -58,22 +64,12 @@ def advance(self) -> None: break def _next_after(self, anchor: float) -> float: - """Compute the next cron occurrence after anchor.""" + """Compute the next cron occurrence after the given anchor time.""" dt = datetime.fromtimestamp(anchor, tz=CONF.scheduler_tz) return float(croniter(self.cron, dt).get_next(float)) -def _schedule_key(schedule_id: str) -> str: - """Redis key for a schedule definition.""" - return state.backend.format_key(*state.KEY_SCHEDULE, schedule_id) - - -def _queue_key() -> str: - """Redis sorted-set key that indexes schedules by next_run.""" - return state.backend.format_key(*state.KEY_SCHEDULE_QUEUE) - - -def register( +async def register( task_name: str, every: str, context: BaseModel, @@ -81,7 +77,7 @@ def register( repeat: int = REPEAT_FOREVER, metadata: dict[str, Any] | None = None, ) -> None: - """Register a new scheduled task in Redis. + """Register a new scheduled task. The task will first fire at the next cron occurrence from now. @@ -95,50 +91,12 @@ def register( """ task = ScheduledTask( task_name=task_name, - context=state.backend.serialize(context), + context=backend.serialize(context), cron=every, repeat=repeat, metadata=metadata, ) - - state.backend.set( - _schedule_key(task_name), - task.model_dump_json().encode(), - ) - state.backend.zadd(_queue_key(), {task_name: task.next_run}) + await backend.schedule.register(task) logger.info(f"Scheduled {task_name}") -async def tick() -> None: - """Process all scheduled tasks that are due right now. - - For each due task, enqueues it into the normal task queue. If repeats - remain, advances to the next run time. Otherwise removes the schedule. - """ - for _task_name in await state.backend.zrangebyscore(_queue_key(), 0, time.time()): - task_name = _task_name.decode("utf-8") - - try: - data = state.backend.get(_schedule_key(task_name)) - task = ScheduledTask.model_validate_json(data) - except ValidationError: - logger.warning(f"Failed to load schedule {task_name}, skipping") - continue - - await enqueue( - task.task_name, - context=state.backend.deserialize(task.context), - metadata=task.metadata, - ) - - if task.repeat == 0: - state.backend.zrem(_queue_key(), task_name) - state.backend.delete(_schedule_key(task_name)) - logger.info(f"Schedule for '{task_name}' exhausted") - else: - task.advance() - state.backend.set( - _schedule_key(task_name), - task.model_dump_json().encode(), - ) - state.backend.zadd(_queue_key(), {task_name: task.next_run}) diff --git a/src/agentexec/state/__init__.py b/src/agentexec/state/__init__.py index 1bc797e..e670d5b 100644 --- a/src/agentexec/state/__init__.py +++ b/src/agentexec/state/__init__.py @@ -1,266 +1,38 @@ -# cspell:ignore acheck +"""State management layer. -from typing import cast, AsyncGenerator, Coroutine -import importlib -from uuid import UUID +Initializes the configured backend and exposes it as a public reference. +All state operations go through ``backend.state``, ``backend.queue``, +and ``backend.schedule`` directly. Activity uses Postgres directly. + +Pick one backend via AGENTEXEC_STATE_BACKEND: + - 'agentexec.state.redis_backend' (default) + - 'agentexec.state.kafka_backend' +""" + +from __future__ import annotations -from pydantic import BaseModel +import importlib from agentexec.config import CONF -from agentexec.state.backend import StateBackend +from agentexec.state.base import BaseBackend KEY_RESULT = (CONF.key_prefix, "result") KEY_EVENT = (CONF.key_prefix, "event") -KEY_LOCK = (CONF.key_prefix, "lock") -KEY_SCHEDULE = (CONF.key_prefix, "schedule") -KEY_SCHEDULE_QUEUE = (CONF.key_prefix, "schedule_queue") -CHANNEL_LOGS = (CONF.key_prefix, "logs") - -__all__ = [ - "backend", - "get_result", - "aget_result", - "set_result", - "aset_result", - "delete_result", - "adelete_result", - "publish_log", - "subscribe_logs", - "set_event", - "clear_event", - "check_event", - "acheck_event", - "acquire_lock", - "release_lock", - "clear_keys", -] - - -def _load_backend(module_name: str) -> StateBackend: - module = cast(StateBackend, importlib.import_module(module_name)) - if not isinstance(module, StateBackend): # type: ignore[invalid-argument-type] - raise RuntimeError(f"State backend ({module_name}) does not conform to protocol.") - return module - - -backend: StateBackend = _load_backend(CONF.state_backend) - - -def get_result(agent_id: UUID | str) -> BaseModel | None: - """Get result for an agent (sync). - - Returns deserialized BaseModel instance with automatic type reconstruction. - - Args: - agent_id: Unique agent identifier (UUID or string) - - Returns: - Deserialized BaseModel or None if not found - """ - data = backend.get(backend.format_key(*KEY_RESULT, str(agent_id))) - return backend.deserialize(data) if data else None - - -def aget_result(agent_id: UUID | str) -> Coroutine[None, None, BaseModel | None]: - """Get result for an agent (async). - - Returns deserialized BaseModel instance with automatic type reconstruction. - - Args: - agent_id: Unique agent identifier (UUID or string) - - Returns: - Coroutine that resolves to deserialized BaseModel or None if not found - """ - - async def _get() -> BaseModel | None: - data = await backend.aget(backend.format_key(*KEY_RESULT, str(agent_id))) - return backend.deserialize(data) if data else None - - return _get() - - -def set_result( - agent_id: UUID | str, - data: BaseModel, - ttl_seconds: int | None = None, -) -> bool: - """Set result for an agent (sync). - - Args: - agent_id: Unique agent identifier (UUID or string) - data: Result data (must be Pydantic BaseModel) - ttl_seconds: Optional time-to-live in seconds - - Returns: - True if successful - """ - return backend.set( - backend.format_key(*KEY_RESULT, str(agent_id)), - backend.serialize(data), - ttl_seconds=ttl_seconds, - ) - - -def aset_result( - agent_id: UUID | str, - data: BaseModel, - ttl_seconds: int | None = None, -) -> Coroutine[None, None, bool]: - """Set result for an agent (async). - - Args: - agent_id: Unique agent identifier (UUID or string) - data: Result data (must be Pydantic BaseModel) - ttl_seconds: Optional time-to-live in seconds - - Returns: - Coroutine that resolves to True if successful - """ - return backend.aset( - backend.format_key(*KEY_RESULT, str(agent_id)), - backend.serialize(data), - ttl_seconds=ttl_seconds, - ) - - -def delete_result(agent_id: UUID | str) -> int: - """Delete result for an agent (sync). - - Args: - agent_id: Unique agent identifier (UUID or string) - - Returns: - Number of keys deleted (0 or 1) - """ - return backend.delete(backend.format_key(*KEY_RESULT, str(agent_id))) - - -def adelete_result(agent_id: UUID | str) -> Coroutine[None, None, int]: - """Delete result for an agent (async). - - Args: - agent_id: Unique agent identifier (UUID or string) - Returns: - Coroutine that resolves to number of keys deleted (0 or 1) - """ - return backend.adelete(backend.format_key(*KEY_RESULT, str(agent_id))) - - -def publish_log(message: str) -> None: - """Publish a log message to the log channel (sync). - - Args: - message: Log message to publish (should be JSON string) - """ - backend.publish(backend.format_key(*CHANNEL_LOGS), message) - - -def subscribe_logs() -> AsyncGenerator[str, None]: - """Subscribe to log messages (async generator). - - Yields: - Log messages from the channel - """ - return backend.subscribe(backend.format_key(*CHANNEL_LOGS)) - - -def set_event(name: str, id: str) -> bool: - """Set an event flag. - - Args: - name: Event name (e.g., "shutdown", "ready") - id: Event identifier (e.g., pool id) - - Returns: - True if successful - """ - return backend.set(backend.format_key(*KEY_EVENT, name, id), b"1") - - -def clear_event(name: str, id: str) -> int: - """Clear an event flag. - - Args: - name: Event name (e.g., "shutdown", "ready") - id: Event identifier (e.g., pool id) - - Returns: - Number of keys deleted (0 or 1) - """ - return backend.delete(backend.format_key(*KEY_EVENT, name, id)) - - -def check_event(name: str, id: str) -> bool: - """Check if an event flag is set (sync). - Args: - name: Event name (e.g., "shutdown", "ready") - id: Event identifier (e.g., pool id) +def _create_backend(state_backend: str) -> BaseBackend: + """Instantiate the given backend class. - Returns: - True if event is set, False otherwise + The state_backend string is a fully qualified module path containing + a Backend class (e.g. 'agentexec.state.kafka'). """ - return backend.get(backend.format_key(*KEY_EVENT, name, id)) is not None + try: + module = importlib.import_module(state_backend) + return module.Backend() + except ImportError as e: + raise ImportError(f"Could not import backend {state_backend}: {e}") + except AttributeError: + raise ValueError(f"Backend module {state_backend} has no Backend class") -def acheck_event(name: str, id: str) -> Coroutine[None, None, bool]: - """Check if an event flag is set (async). - - Args: - name: Event name (e.g., "shutdown", "ready") - id: Event identifier (e.g., pool id) - - Returns: - Coroutine that resolves to True if event is set, False otherwise - """ - - async def _check() -> bool: - return await backend.aget(backend.format_key(*KEY_EVENT, name, id)) is not None - - return _check() - - -async def acquire_lock(lock_key: str, agent_id: str) -> bool: - """Attempt to acquire a task lock. - - Args: - lock_key: The evaluated lock key (e.g., "user:42") - agent_id: The agent_id holding the lock (for debugging) - - Returns: - True if lock was acquired, False if already held - """ - return await backend.acquire_lock( - backend.format_key(*KEY_LOCK, lock_key), - agent_id, - CONF.lock_ttl, - ) - - -async def release_lock(lock_key: str) -> int: - """Release a task lock. - - Args: - lock_key: The evaluated lock key (e.g., "user:42") - - Returns: - Number of keys deleted (0 or 1) - """ - return await backend.release_lock( - backend.format_key(*KEY_LOCK, lock_key), - ) - - -def clear_keys() -> int: - """Clear all state keys managed by this application. - - Removes all keys matching the configured prefix and the task queue. - This is useful during shutdown to prevent stale tasks from being - picked up on restart. - - Returns: - Total number of keys deleted - """ - return backend.clear_keys() +backend: BaseBackend = _create_backend(CONF.state_backend) diff --git a/src/agentexec/state/backend.py b/src/agentexec/state/backend.py deleted file mode 100644 index 34eb58c..0000000 --- a/src/agentexec/state/backend.py +++ /dev/null @@ -1,363 +0,0 @@ -from types import ModuleType -from typing import AsyncGenerator, Coroutine, Optional, Protocol, runtime_checkable - -from pydantic import BaseModel - - -@runtime_checkable -class StateBackend(Protocol): - """Protocol defining the state backend interface. - - This protocol defines all the operations needed for: - - Task queue management (priority queue operations) - - Result storage (with TTL support) - - Event coordination (shutdown flags, etc.) - - Pub/sub messaging (worker logging) - - Any module that implements these functions can serve as a state backend. - Methods are defined as @staticmethod to match module-level functions. - - Connection management is handled internally - connections are established - lazily when first accessed. Only cleanup needs to be explicit. - """ - - # Connection management - @staticmethod - async def close() -> None: - """Close all connections to the backend. - - This should close both async and sync connections and clean up - any resources. - """ - ... - - # Queue operations (Redis list commands) - @staticmethod - def rpush(key: str, value: str) -> int: - """Push value to the right (front) of the list - for high priority tasks. - - Args: - key: Redis list key - value: Serialized task data - - Returns: - Length of the list after the push - """ - ... - - @staticmethod - def lpush(key: str, value: str) -> int: - """Push value to the left (back) of the list - for low priority tasks. - - Args: - key: Redis list key - value: Serialized task data - - Returns: - Length of the list after the push - """ - ... - - @staticmethod - async def brpop(key: str, timeout: int = 0) -> Optional[tuple[str, str]]: - """Pop value from the right of the list with blocking. - - Args: - key: Redis list key - timeout: Timeout in seconds (0 = block forever) - - Returns: - Tuple of (key, value) or None if timeout - """ - ... - - # Key-value operations - @staticmethod - def aget(key: str) -> Coroutine[None, None, Optional[bytes]]: - """Get value for key asynchronously. - - Args: - key: Key to retrieve - - Returns: - Coroutine that resolves to value as bytes or None if not found - """ - ... - - @staticmethod - def get(key: str) -> Optional[bytes]: - """Get value for key synchronously. - - Args: - key: Key to retrieve - - Returns: - Value as bytes or None if not found - """ - ... - - @staticmethod - def aset( - key: str, value: bytes, ttl_seconds: Optional[int] = None - ) -> Coroutine[None, None, bool]: - """Set value for key asynchronously with optional TTL. - - Args: - key: Key to set - value: Value as bytes - ttl_seconds: Optional time-to-live in seconds - - Returns: - Coroutine that resolves to True if successful - """ - ... - - @staticmethod - def set(key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: - """Set value for key synchronously with optional TTL. - - Args: - key: Key to set - value: Value as bytes - ttl_seconds: Optional time-to-live in seconds - - Returns: - True if successful - """ - ... - - @staticmethod - def adelete(key: str) -> Coroutine[None, None, int]: - """Delete key asynchronously. - - Args: - key: Key to delete - - Returns: - Coroutine that resolves to number of keys deleted (0 or 1) - """ - ... - - @staticmethod - def delete(key: str) -> int: - """Delete key synchronously. - - Args: - key: Key to delete - - Returns: - Number of keys deleted (0 or 1) - """ - ... - - # Counter operations - @staticmethod - def incr(key: str) -> int: - """Increment a counter atomically. - - Args: - key: Counter key - - Returns: - Value after increment - """ - ... - - @staticmethod - def decr(key: str) -> int: - """Decrement a counter atomically. - - Args: - key: Counter key - - Returns: - Value after decrement - """ - ... - - # Pub/sub operations - @staticmethod - def publish(channel: str, message: str) -> None: - """Publish message to a channel. - - Args: - channel: Channel name - message: Message to publish - """ - ... - - @staticmethod - def subscribe(channel: str) -> AsyncGenerator[str, None]: - """Subscribe to a channel and yield messages. - - Args: - channel: Channel name - - Yields: - Messages from the channel - """ - ... - - # Key formatting - @staticmethod - def format_key(*args: str) -> str: - """Format a key by joining parts in a backend-specific way. - - Args: - *args: Parts of the key to join - - Returns: - Formatted key string - """ - ... - - # Serialization - @staticmethod - def serialize(obj: BaseModel) -> bytes: - """Serialize a Pydantic BaseModel to bytes. - - Stores the fully qualified class name alongside the data to enable - automatic type reconstruction during deserialization. - - Args: - obj: Pydantic BaseModel instance to serialize - - Returns: - Serialized bytes - - Raises: - TypeError: If obj is not a BaseModel instance - """ - ... - - @staticmethod - def deserialize(data: bytes) -> BaseModel: - """Deserialize bytes back to a Pydantic BaseModel instance. - - Uses the stored class information to dynamically import and reconstruct - the original type. - - Args: - data: Serialized bytes - - Returns: - Deserialized BaseModel instance - - Raises: - ImportError: If the class module cannot be imported - AttributeError: If the class does not exist in the module - ValueError: If the data is invalid - """ - ... - - # Lock operations - @staticmethod - async def acquire_lock(key: str, value: str, ttl_seconds: int) -> bool: - """Attempt to acquire a distributed lock. - - Uses atomic set-if-not-exists with TTL. The TTL is a safety net - for process death — locks should always be explicitly released - via release_lock() on task completion or error. - - Args: - key: Lock key - value: Lock value (typically agent_id for debugging) - ttl_seconds: Lock expiry in seconds (safety net for dead processes) - - Returns: - True if lock was acquired, False if already held - """ - ... - - @staticmethod - async def release_lock(key: str) -> int: - """Release a distributed lock. - - Args: - key: Lock key to release - - Returns: - Number of keys deleted (0 or 1) - """ - ... - - # Sorted set operations - @staticmethod - def zadd(key: str, mapping: dict[str, float]) -> int: - """Add members to a sorted set with scores. - - Args: - key: Sorted set key - mapping: Dict of {member: score} - - Returns: - Number of new members added - """ - ... - - @staticmethod - async def zrangebyscore( - key: str, min_score: float, max_score: float - ) -> list[bytes]: - """Get members with scores between min and max. - - Args: - key: Sorted set key - min_score: Minimum score (inclusive) - max_score: Maximum score (inclusive) - - Returns: - List of members as bytes - """ - ... - - @staticmethod - def zrem(key: str, *members: str) -> int: - """Remove members from a sorted set. - - Args: - key: Sorted set key - *members: Members to remove - - Returns: - Number of members removed - """ - ... - - # Cleanup operations - @staticmethod - def clear_keys() -> int: - """Clear all keys managed by this application. - - Only deletes keys that match the configured prefix and queue name. - This is useful during shutdown to prevent stale tasks from being - picked up on restart. - - Returns: - Total number of keys deleted - """ - ... - - -def load_backend(module: ModuleType) -> StateBackend: - """Load and validate a backend module conforms to StateBackend protocol. - - Uses the Protocol's __protocol_attrs__ to determine required methods. - - Args: - module: Backend module to validate - - Returns: - The module typed as StateBackend - - Raises: - TypeError: If the module is missing required functions - """ - required: frozenset[str] = getattr(StateBackend, "__protocol_attrs__") - missing = [name for name in required if not hasattr(module, name)] - if missing: - raise TypeError( - f"Backend module '{module.__name__}' missing required functions: {missing}" - ) - - return module # type: ignore[return-value] diff --git a/src/agentexec/state/base.py b/src/agentexec/state/base.py new file mode 100644 index 0000000..4330a61 --- /dev/null +++ b/src/agentexec/state/base.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import importlib +import json +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional, TypedDict +from pydantic import BaseModel + +if TYPE_CHECKING: + from agentexec.schedule import ScheduledTask + + +class _SerializeWrapper(TypedDict): + __type__: str + data: dict[str, Any] + + +class BaseBackend(ABC): + """Top-level backend interface with namespaced sub-backends.""" + + state: BaseStateBackend + queue: BaseQueueBackend + schedule: BaseScheduleBackend + + @abstractmethod + def format_key(self, *args: str) -> str: ... + + @abstractmethod + async def close(self) -> None: ... + + def serialize(self, obj: BaseModel) -> bytes: + """Serialize a Pydantic model to bytes with type information.""" + wrapper: _SerializeWrapper = { + "__type__": f"{type(obj).__module__}.{type(obj).__qualname__}", + "data": obj.model_dump(mode="json"), + } + return json.dumps(wrapper).encode("utf-8") + + def deserialize(self, data: bytes) -> BaseModel: + """Deserialize bytes back to a typed Pydantic model.""" + wrapper: _SerializeWrapper = json.loads(data.decode("utf-8")) + module_path, class_name = wrapper["__type__"].rsplit(".", 1) + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + return cls.model_validate(wrapper["data"]) + + +class BaseStateBackend(ABC): + """KV store, counters, locks, pub/sub, sorted index.""" + + @abstractmethod + async def get(self, key: str) -> Optional[bytes]: ... + + @abstractmethod + async def set(self, key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: ... + + @abstractmethod + async def delete(self, key: str) -> int: ... + + @abstractmethod + async def counter_incr(self, key: str) -> int: ... + + @abstractmethod + async def counter_decr(self, key: str) -> int: ... + + +class BaseQueueBackend(ABC): + """Task queue with push/pop semantics and partition-level locking.""" + + @abstractmethod + async def push( + self, + value: str, + *, + high_priority: bool = False, + partition_key: str | None = None, + ) -> None: ... + + @abstractmethod + async def pop(self, *, timeout: int = 1) -> dict[str, Any] | None: ... + + @abstractmethod + async def complete(self, partition_key: str | None) -> None: + """Signal that the current task for this partition is done.""" + ... + + + +class BaseScheduleBackend(ABC): + """Schedule storage and retrieval.""" + + @abstractmethod + async def register(self, task: ScheduledTask) -> None: + """Store a scheduled task definition.""" + ... + + @abstractmethod + async def get_due(self) -> list[ScheduledTask]: + """Return all scheduled tasks that are due to fire.""" + ... + + @abstractmethod + async def remove(self, key: str) -> None: + """Remove a schedule by its key.""" + ... diff --git a/src/agentexec/state/kafka.py b/src/agentexec/state/kafka.py new file mode 100644 index 0000000..3274ac7 --- /dev/null +++ b/src/agentexec/state/kafka.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import asyncio +import json +import os +import socket +import time +from typing import Any, Optional + +from aiokafka import AIOKafkaConsumer, AIOKafkaProducer, TopicPartition +from aiokafka.admin import AIOKafkaAdminClient, NewTopic + +from agentexec.config import CONF +from agentexec.state.base import BaseBackend, BaseQueueBackend, BaseScheduleBackend, BaseStateBackend + + + +class Backend(BaseBackend): + """Kafka implementation of the agentexec backend.""" + + def __init__(self) -> None: + self._producer: AIOKafkaProducer | None = None + self._consumers: dict[str, AIOKafkaConsumer] = {} + self._admin: AIOKafkaAdminClient | None = None + + self._initialized_topics: set[str] = set() + + # Sub-backends + self.state = KafkaStateBackend() + self.queue = KafkaQueueBackend(self) + self.schedule = KafkaScheduleBackend(self) + + def format_key(self, *args: str) -> str: + return ".".join(args) + + async def close(self) -> None: + if self._producer is not None: + await self._producer.stop() + self._producer = None + + for consumer in self._consumers.values(): + await consumer.stop() + self._consumers.clear() + + if self._admin is not None: + await self._admin.close() + self._admin = None + + def _get_bootstrap_servers(self) -> str: + if CONF.kafka_bootstrap_servers is None: + raise ValueError( + "KAFKA_BOOTSTRAP_SERVERS must be configured " + "(e.g. 'localhost:9092' or 'broker1:9092,broker2:9092')" + ) + return CONF.kafka_bootstrap_servers + + def _client_id(self, role: str = "worker") -> str: + return f"{CONF.key_prefix}-{role}-{socket.gethostname()}-{os.getpid()}" + + async def _get_producer(self) -> AIOKafkaProducer: + if self._producer is None: + self._producer = AIOKafkaProducer( + bootstrap_servers=self._get_bootstrap_servers(), + client_id=self._client_id("producer"), + acks="all", + max_batch_size=CONF.kafka_max_batch_size, + linger_ms=CONF.kafka_linger_ms, + ) + await self._producer.start() + return self._producer + + async def _get_admin(self) -> AIOKafkaAdminClient: + if self._admin is None: + self._admin = AIOKafkaAdminClient( + bootstrap_servers=self._get_bootstrap_servers(), + client_id=self._client_id("admin"), + ) + await self._admin.start() + return self._admin + + async def produce( + self, + topic: str, + value: bytes | None, + key: str | bytes | None = None, + headers: dict[str, str] | None = None, + ) -> None: + producer = await self._get_producer() + if isinstance(key, str): + key_bytes = key.encode("utf-8") + else: + key_bytes = key + header_list = [(k, v.encode("utf-8")) for k, v in headers.items()] if headers else None + await producer.send_and_wait(topic, value=value, key=key_bytes, headers=header_list) + + async def ensure_topic(self, topic: str, *, compact: bool = True) -> None: + if topic in self._initialized_topics: + return + + admin = await self._get_admin() + config: dict[str, str] = {} + if compact: + config["cleanup.policy"] = "compact" + config["retention.ms"] = str(CONF.kafka_retention_ms) + + try: + await admin.create_topics( + [ + NewTopic( + name=topic, + num_partitions=CONF.kafka_default_partitions, + replication_factor=CONF.kafka_replication_factor, + topic_configs=config, + ) + ] + ) + except Exception: + pass # Topic already exists + + self._initialized_topics.add(topic) + + async def _get_topic_partitions(self, topic: str) -> list[TopicPartition]: + admin = await self._get_admin() + topics_meta = await admin.describe_topics([topic]) + for t in topics_meta: + if t.get("topic") == topic: + parts = t.get("partitions", []) + if parts: + return [ + TopicPartition(topic, p["partition"]) + for p in sorted(parts, key=lambda p: p["partition"]) + ] + return [TopicPartition(topic, 0)] + + def tasks_topic(self, queue_name: str) -> str: + return f"{CONF.key_prefix}.tasks.{queue_name}" + + def schedule_topic(self) -> str: + return f"{CONF.key_prefix}.schedules" + + +class KafkaStateBackend(BaseStateBackend): + """Kafka state: not supported. + + Kafka is not a key-value store. State operations (get/set, counters) + require a proper KV backend like Redis. Use Kafka for queue and + schedule only. + """ + + async def get(self, key: str) -> Optional[bytes]: + raise NotImplementedError("Kafka backend does not support KV state operations") + + async def set(self, key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: + raise NotImplementedError("Kafka backend does not support KV state operations") + + async def delete(self, key: str) -> int: + raise NotImplementedError("Kafka backend does not support KV state operations") + + async def counter_incr(self, key: str) -> int: + raise NotImplementedError("Kafka backend does not support counter operations") + + async def counter_decr(self, key: str) -> int: + raise NotImplementedError("Kafka backend does not support counter operations") + + + +class KafkaQueueBackend(BaseQueueBackend): + """Kafka queue: consumer groups for reliable fan-out.""" + + def __init__(self, backend: Backend) -> None: + self.backend = backend + + async def _get_consumer(self, topic: str) -> AIOKafkaConsumer: + consumers = self.backend._consumers + + if topic not in consumers: + await self.backend.ensure_topic(topic, compact=False) + + consumer = AIOKafkaConsumer( + topic, + bootstrap_servers=self.backend._get_bootstrap_servers(), + group_id=f"{CONF.key_prefix}-workers", + client_id=self.backend._client_id("worker"), + auto_offset_reset="earliest", + enable_auto_commit=False, + ) + await consumer.start() + consumers[topic] = consumer + + return consumers[topic] + + async def push( + self, + value: str, + *, + high_priority: bool = False, + partition_key: str | None = None, + ) -> None: + topic = self.backend.tasks_topic(CONF.queue_prefix) + await self.backend.ensure_topic(topic, compact=False) + + # Extract metadata for headers without altering the payload + task_data = json.loads(value) + headers = { + "ax_task_name": task_data.get("task_name", ""), + "ax_agent_id": task_data.get("agent_id", ""), + "ax_retry_count": str(task_data.get("retry_count", 0)), + } + await self.backend.produce(topic, value.encode("utf-8"), key=partition_key, headers=headers) + + async def pop( + self, + *, + timeout: int = 1, + ) -> dict[str, Any] | None: + consumer = await self._get_consumer(self.backend.tasks_topic(CONF.queue_prefix)) + + try: + msg = await asyncio.wait_for( + consumer.getone(), + timeout=timeout, + ) + await consumer.commit() + return json.loads(msg.value.decode("utf-8")) + except asyncio.TimeoutError: + return None + + async def complete(self, partition_key: str | None) -> None: + pass # Kafka uses partition assignment, no explicit locks + + +class KafkaScheduleBackend(BaseScheduleBackend): + """Kafka schedule: compacted topic + in-memory cache.""" + + def __init__(self, backend: Backend) -> None: + self.backend = backend + self._consumer: AIOKafkaConsumer | None = None + self._tps: list[TopicPartition] = [] + + async def _ensure_consumer(self) -> AIOKafkaConsumer: + topic = self.backend.schedule_topic() + await self.backend.ensure_topic(topic) + + if self._consumer is None: + self._tps = await self.backend._get_topic_partitions(topic) + self._consumer = AIOKafkaConsumer( + bootstrap_servers=self.backend._get_bootstrap_servers(), + client_id=self.backend._client_id("scheduler"), + enable_auto_commit=False, + ) + await self._consumer.start() + self._consumer.assign(self._tps) + + return self._consumer + + async def register(self, task: ScheduledTask) -> None: + topic = self.backend.schedule_topic() + await self.backend.ensure_topic(topic) + data = task.model_dump_json().encode("utf-8") + headers = { + "ax_task_name": task.task_name, + "ax_cron": task.cron, + "ax_next_run": str(task.next_run), + "ax_repeat": str(task.repeat), + } + await self.backend.produce(topic, data, key=task.key, headers=headers) + + async def get_due(self) -> list[ScheduledTask]: + # TODO: this replays the entire compacted topic on every poll — + # seek, iterate, deserialize, compare for each schedule. Consider + # caching with invalidation or using message timestamps to skip + # schedules that aren't close to due. + from agentexec.schedule import ScheduledTask + from pydantic import ValidationError + + consumer = await self._ensure_consumer() + await consumer.seek_to_beginning(*self._tps) + + now = time.time() + due = [] + records = await consumer.getmany(*self._tps, timeout_ms=1000) + for tp_records in records.values(): + for msg in tp_records: + if msg.value is None: + continue + try: + task = ScheduledTask.model_validate_json(msg.value) + if task.next_run <= now: + due.append(task) + except ValidationError: + continue + + return due + + async def remove(self, key: str) -> None: + topic = self.backend.schedule_topic() + await self.backend.ensure_topic(topic) + await self.backend.produce(topic, None, key=key) diff --git a/src/agentexec/state/redis.py b/src/agentexec/state/redis.py new file mode 100644 index 0000000..244416e --- /dev/null +++ b/src/agentexec/state/redis.py @@ -0,0 +1,264 @@ +"""Redis state backend. + +Provides queue, state (KV/counters/sorted sets), and schedule operations +backed by Redis. The queue implementation uses a partitioned design +inspired by Kafka's consumer groups: + +Queue Key Layout +~~~~~~~~~~~~~~~~ + +All queue keys share a common prefix (``CONF.queue_prefix``, default +``agentexec_tasks``):: + + agentexec_tasks ← default queue (no lock, concurrent) + agentexec_tasks:user:42 ← partition queue for lock scope "user:42" + agentexec_tasks:user:42:lock ← lock for that partition (SET NX EX) + +Tasks without a ``lock_key`` go to the default queue, where any worker can +pop them concurrently. Tasks with a ``lock_key`` (evaluated from the +``TaskDefinition.lock_key`` template against the task context) go to a +partition queue keyed by that value. + +Dequeue Strategy +~~~~~~~~~~~~~~~~ + +Workers call ``queue.pop()`` which uses Redis SCAN to iterate all keys +matching the queue prefix. SCAN returns keys in hash-table order, which +is effectively random — providing fair distribution across partitions +without explicit shuffling. + +For each key discovered: + +1. If it ends with ``:lock``, record it in ``locks_seen`` and skip. +2. If it's a partition queue (not the default), check ``locks_seen`` for + an existing lock. If found, skip. Otherwise attempt ``SET NX EX`` to + acquire the lock. If acquisition fails, skip. +3. ``RPOP`` the queue key. If successful, return the task payload. +4. On task completion, the pool calls ``queue.complete(partition_key)`` + which deletes the lock key, allowing the next task in that partition + to be picked up. + +Redis automatically deletes list keys when they become empty, so drained +partitions disappear from future scans. Lock keys expire via TTL as a +safety net for dead worker recovery. +""" + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING, Any, Optional + +import redis +import redis.asyncio + +from agentexec.config import CONF +from agentexec.state.base import BaseBackend, BaseQueueBackend, BaseScheduleBackend, BaseStateBackend + + +class Backend(BaseBackend): + """Redis implementation of the agentexec backend.""" + + _client: redis.asyncio.Redis | None + state: RedisStateBackend + queue: RedisQueueBackend + schedule: RedisScheduleBackend + + def __init__(self) -> None: + self._client = None + self.state = RedisStateBackend(self) + self.queue = RedisQueueBackend(self) + self.schedule = RedisScheduleBackend(self) + + def format_key(self, *args: str) -> str: + return ":".join(args) + + async def close(self) -> None: + if self._client is not None: + await self._client.aclose() + self._client = None + + @property + def client(self) -> redis.asyncio.Redis: + if self._client is None: + if CONF.redis_url is None: + raise ValueError("REDIS_URL must be configured") + self._client = redis.asyncio.Redis.from_url( + CONF.redis_url, + max_connections=CONF.redis_pool_size, + socket_connect_timeout=CONF.redis_pool_timeout, + decode_responses=False, + ) + return self._client + + +class RedisStateBackend(BaseStateBackend): + """Redis state: direct Redis commands.""" + + backend: Backend + + def __init__(self, backend: Backend) -> None: + self.backend = backend + + async def get(self, key: str) -> Optional[bytes]: + return await self.backend.client.get(key) # type: ignore[return-value] + + async def set(self, key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: + if ttl_seconds is not None: + return await self.backend.client.set(key, value, ex=ttl_seconds) # type: ignore[return-value] + else: + return await self.backend.client.set(key, value) # type: ignore[return-value] + + async def delete(self, key: str) -> int: + return await self.backend.client.delete(key) # type: ignore[return-value] + + async def counter_incr(self, key: str) -> int: + return await self.backend.client.incr(key) # type: ignore[return-value] + + async def counter_decr(self, key: str) -> int: + return await self.backend.client.decr(key) # type: ignore[return-value] + + +class RedisQueueBackend(BaseQueueBackend): + """Redis queue: partitioned lists with per-group locking. + + Tasks with a partition_key go to {prefix}:{partition_key} and are + serialized by a lock. Tasks without a partition_key go to the + default queue ({prefix}) and execute concurrently. + """ + + backend: Backend + _lock_suffix: bytes = b":lock" + _prefix: str + _default_key: bytes + + def __init__(self, backend: Backend) -> None: + self.backend = backend + self._prefix = CONF.queue_prefix + self._default_key = self._prefix.encode() + + def _queue_key(self, partition_key: str | None = None) -> str: + if partition_key: + return f"{self._prefix}:{partition_key}" + return self._prefix + + def _lock_key(self, queue_key: bytes) -> bytes: + return queue_key + self._lock_suffix + + def _needs_lock(self, queue_key: bytes) -> bool: + return queue_key != self._default_key + + async def _acquire_lock(self, queue_key: bytes) -> bool: + return bool(await self.backend.client.set( + self._lock_key(queue_key), b"1", nx=True, ex=CONF.lock_ttl, + )) + + async def push( + self, + value: str, + *, + high_priority: bool = False, + partition_key: str | None = None, + ) -> None: + """Push a task to the queue. + + Tasks with a ``partition_key`` go to a dedicated partition queue + and are serialized by a lock. Tasks without one go to the default + queue for concurrent processing. + """ + key = self._queue_key(partition_key) + if high_priority: + await self.backend.client.rpush(key, value) + else: + await self.backend.client.lpush(key, value) + + async def pop(self, *, timeout: int = 1) -> dict[str, Any] | None: + """Pop the next eligible task from any queue. + + Scans all queue keys, skips locked partitions, acquires a lock + for the selected partition, and pops the task. Returns ``None`` + if no eligible tasks are available. + """ + import json + + locks_seen: set[bytes] = set() + + # SCAN returns keys in hash-table order (effectively random), + # so we don't need to collect all keys before choosing. + # We try each key eagerly and exit on the first successful pop. + async for key in self.backend.client.scan_iter(match=self._prefix.encode() + b"*", count=100): + if self._needs_lock(key): + if key.endswith(self._lock_suffix): + locks_seen.add(key) + continue # this is a lock record, not executable + + if self._lock_key(key) in locks_seen: + continue # we already observed another worker holds this partition, find another + + if not await self._acquire_lock(key): + continue # another worker holds this partition, find another + + result = await self.backend.client.rpop(key) + if result is None: + if self._needs_lock(key): + # TODO this should never happen; we can improve on the ergonomics of recovery later. + raise RuntimeError(f"Partition queue {key!r} was empty after lock acquired") + + continue # payload was grabbed in a race condition, find another + + return json.loads(result) + + async def complete(self, partition_key: str | None) -> None: + """Signal that the current task for this partition is done. + + Deletes the partition lock so the next task in the same scope + can be picked up. No-op for tasks without a partition key. + """ + if partition_key: + await self.backend.client.delete(self._lock_key(self._queue_key(partition_key).encode())) + + +class RedisScheduleBackend(BaseScheduleBackend): + """Redis schedule: sorted set for time index + hash for payloads. + + Two Redis keys:: + + agentexec:schedules ← sorted set (schedule.key → next_run score) + agentexec:schedules:data ← hash (schedule.key → task JSON) + + ``get_due`` queries the sorted set for keys with score <= now, + then batch-fetches the payloads from the hash. + """ + + backend: Backend + _index_key: str + _data_key: str + + def __init__(self, backend: Backend) -> None: + self.backend = backend + self._index_key = self.backend.format_key(CONF.key_prefix, "schedules") + self._data_key = self.backend.format_key(CONF.key_prefix, "schedules", "data") + + async def register(self, task: ScheduledTask) -> None: + await self.backend.client.hset(self._data_key, task.key, task.model_dump_json().encode()) + await self.backend.client.zadd(self._index_key, {task.key: task.next_run}) + + async def get_due(self) -> list[ScheduledTask]: + import time + from pydantic import ValidationError + from agentexec.schedule import ScheduledTask + + raw = await self.backend.client.zrangebyscore(self._index_key, 0, time.time()) + tasks = [] + for key in raw: + data = await self.backend.client.hget(self._data_key, key) + if data is None: + continue + try: + tasks.append(ScheduledTask.model_validate_json(data)) + except ValidationError: + continue + return tasks + + async def remove(self, key: str) -> None: + await self.backend.client.zrem(self._index_key, key) + await self.backend.client.hdel(self._data_key, key) diff --git a/src/agentexec/state/redis_backend.py b/src/agentexec/state/redis_backend.py deleted file mode 100644 index d7c8dba..0000000 --- a/src/agentexec/state/redis_backend.py +++ /dev/null @@ -1,491 +0,0 @@ -# cspell:ignore rpush lpush brpop RPUSH LPUSH BRPOP -from typing import TypedDict, AsyncGenerator, Coroutine, Optional -import importlib -import json - -import redis -import redis.asyncio -from pydantic import BaseModel - -from agentexec.config import CONF - -__all__ = [ - "format_key", - "serialize", - "deserialize", - "rpush", - "lpush", - "brpop", - "aget", - "get", - "aset", - "set", - "adelete", - "delete", - "incr", - "decr", - "publish", - "subscribe", - "close", - "zadd", - "zrangebyscore", - "zrem", - "clear_keys", -] - -_redis_client: redis.asyncio.Redis | None = None -_redis_sync_client: redis.Redis | None = None -_pubsub: redis.asyncio.client.PubSub | None = None - - -def format_key(*args: str) -> str: - """Format a Redis key by joining parts with colons. - - Args: - *args: Parts of the key - - Returns: - Formatted key string - """ - return ":".join(args) - - -class SerializeWrapper(TypedDict): - __class__: str - __data__: str - - -def serialize(obj: BaseModel) -> bytes: - """Serialize a Pydantic BaseModel to JSON bytes with type information. - - Stores the fully qualified class name alongside the data, similar to pickle. - This allows deserialization without needing an external type registry. - - Args: - obj: Pydantic BaseModel instance to serialize - - Returns: - JSON-encoded bytes containing class info and data - - Raises: - TypeError: If obj is not a BaseModel instance - """ - if not isinstance(obj, BaseModel): - raise TypeError(f"Expected BaseModel, got {type(obj)}") - - cls = type(obj) - wrapper: SerializeWrapper = { - "__class__": f"{cls.__module__}.{cls.__qualname__}", - "__data__": obj.model_dump_json(), - } - - return json.dumps(wrapper).encode("utf-8") - - -def deserialize(data: bytes) -> BaseModel: - """Deserialize JSON bytes back to a Pydantic BaseModel instance. - - Uses the stored class information to dynamically import and reconstruct - the original type, similar to pickle. - - Args: - data: JSON-encoded bytes containing class info and data - - Returns: - Deserialized BaseModel instance - - Raises: - ImportError: If the class module cannot be imported - AttributeError: If the class does not exist in the module - ValueError: If the data is invalid JSON or missing required fields - """ - wrapper: SerializeWrapper = json.loads(data.decode("utf-8")) - class_path = wrapper["__class__"] - json_data = wrapper["__data__"] - - # Import the class dynamically (e.g., "myapp.models.Result" → myapp.models module) - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - cls = getattr(module, class_name) - - result: BaseModel = cls.model_validate_json(json_data) - return result - - -def _get_async_client() -> redis.asyncio.Redis: - """Get async Redis client, initializing lazily if needed. - - Returns: - Async Redis client instance - - Raises: - ValueError: If REDIS_URL is not configured - """ - global _redis_client - - if _redis_client is None: - if CONF.redis_url is None: - raise ValueError("REDIS_URL must be configured") - - _redis_client = redis.asyncio.Redis.from_url( - CONF.redis_url, - max_connections=CONF.redis_pool_size, - socket_connect_timeout=CONF.redis_pool_timeout, - decode_responses=False, # Handle binary data (pickled results) - ) - - return _redis_client - - -def _get_sync_client() -> redis.Redis: - """Get sync Redis client, initializing lazily if needed. - - Returns: - Sync Redis client instance - - Raises: - ValueError: If REDIS_URL is not configured - """ - global _redis_sync_client - - if _redis_sync_client is None: - if CONF.redis_url is None: - raise ValueError("REDIS_URL must be configured") - - _redis_sync_client = redis.Redis.from_url( - CONF.redis_url, - max_connections=CONF.redis_pool_size, - socket_connect_timeout=CONF.redis_pool_timeout, - decode_responses=False, - ) - - return _redis_sync_client - - -async def close() -> None: - """Close all Redis connections and clean up resources.""" - global _redis_client, _redis_sync_client, _pubsub - - # Close pubsub if active - if _pubsub is not None: - await _pubsub.close() - _pubsub = None - - # Close async client - if _redis_client is not None: - await _redis_client.aclose() - _redis_client = None - - # Close sync client - if _redis_sync_client is not None: - _redis_sync_client.close() - _redis_sync_client = None - - -def rpush(key: str, value: str) -> int: - """Push value to the right (front) of the list - for high priority tasks. - - Args: - key: Redis list key - value: Serialized task data - - Returns: - Length of the list after the push - """ - client = _get_sync_client() - return client.rpush(key, value) # type: ignore[return-value] - - -def lpush(key: str, value: str) -> int: - """Push value to the left (back) of the list - for low priority tasks. - - Args: - key: Redis list key - value: Serialized task data - - Returns: - Length of the list after the push - """ - client = _get_sync_client() - return client.lpush(key, value) # type: ignore[return-value] - - -async def brpop(key: str, timeout: int = 0) -> Optional[tuple[str, str]]: - """Pop value from the right of the list with blocking. - - Args: - key: Redis list key - timeout: Timeout in seconds (0 = block forever) - - Returns: - Tuple of (key, value) or None if timeout - """ - client = _get_async_client() - result = await client.brpop([key], timeout=timeout) # type: ignore[misc] - if result is None: - return None - # Redis returns bytes, decode to string - list_key, value = result - return (list_key.decode("utf-8"), value.decode("utf-8")) - - -def aget(key: str) -> Coroutine[None, None, Optional[bytes]]: - """Get value for key asynchronously. - - Args: - key: Key to retrieve - - Returns: - Coroutine that resolves to value as bytes or None if not found - """ - client = _get_async_client() - return client.get(key) # type: ignore[return-value] - - -def get(key: str) -> Optional[bytes]: - """Get value for key synchronously. - - Args: - key: Key to retrieve - - Returns: - Value as bytes or None if not found - """ - client = _get_sync_client() - return client.get(key) # type: ignore[return-value] - - -def aset(key: str, value: bytes, ttl_seconds: Optional[int] = None) -> Coroutine[None, None, bool]: - """Set value for key asynchronously with optional TTL. - - Args: - key: Key to set - value: Value as bytes - ttl_seconds: Optional time-to-live in seconds - - Returns: - Coroutine that resolves to True if successful - """ - client = _get_async_client() - if ttl_seconds is not None: - return client.set(key, value, ex=ttl_seconds) # type: ignore[return-value] - else: - return client.set(key, value) # type: ignore[return-value] - - -def set(key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: - """Set value for key synchronously with optional TTL. - - Args: - key: Key to set - value: Value as bytes - ttl_seconds: Optional time-to-live in seconds - - Returns: - True if successful - """ - client = _get_sync_client() - if ttl_seconds is not None: - return client.set(key, value, ex=ttl_seconds) # type: ignore[return-value] - else: - return client.set(key, value) # type: ignore[return-value] - - -def adelete(key: str) -> Coroutine[None, None, int]: - """Delete key asynchronously. - - Args: - key: Key to delete - - Returns: - Coroutine that resolves to number of keys deleted (0 or 1) - """ - client = _get_async_client() - return client.delete(key) # type: ignore[return-value] - - -def delete(key: str) -> int: - """Delete key synchronously. - - Args: - key: Key to delete - - Returns: - Number of keys deleted (0 or 1) - """ - client = _get_sync_client() - return client.delete(key) # type: ignore[return-value] - - -def incr(key: str) -> int: - """Increment a counter atomically. - - Args: - key: Counter key - - Returns: - Value after increment - """ - client = _get_sync_client() - return client.incr(key) # type: ignore[return-value] - - -def decr(key: str) -> int: - """Decrement a counter atomically. - - Args: - key: Counter key - - Returns: - Value after decrement - """ - client = _get_sync_client() - return client.decr(key) # type: ignore[return-value] - - -async def acquire_lock(key: str, value: str, ttl_seconds: int) -> bool: - """Attempt to acquire a distributed lock using SET NX EX. - - Args: - key: Lock key - value: Lock value (typically agent_id for debugging) - ttl_seconds: Lock expiry in seconds (safety net for dead processes) - - Returns: - True if lock was acquired, False if already held - """ - client = _get_async_client() - result = await client.set(key, value, nx=True, ex=ttl_seconds) - return result is not None - - -async def release_lock(key: str) -> int: - """Release a distributed lock. - - Args: - key: Lock key to release - - Returns: - Number of keys deleted (0 or 1) - """ - client = _get_async_client() - return await client.delete(key) # type: ignore[return-value] - - -def publish(channel: str, message: str) -> None: - """Publish message to a channel. - - Args: - channel: Channel name - message: Message to publish - """ - client = _get_sync_client() - client.publish(channel, message) - - -async def subscribe(channel: str) -> AsyncGenerator[str, None]: - """Subscribe to a channel and yield messages. - - Args: - channel: Channel name - - Yields: - Messages from the channel as strings - """ - global _pubsub - - client = _get_async_client() - _pubsub = client.pubsub() - await _pubsub.subscribe(channel) - - try: - async for message in _pubsub.listen(): - if message["type"] == "message": - # Decode bytes to string - data = message["data"] - if isinstance(data, bytes): - yield data.decode("utf-8") - else: - yield data - finally: - await _pubsub.unsubscribe(channel) - await _pubsub.close() - _pubsub = None - - -def zadd(key: str, mapping: dict[str, float]) -> int: - """Add members to a sorted set with scores. - - Args: - key: Sorted set key - mapping: Dict of {member: score} - - Returns: - Number of new members added - """ - client = _get_sync_client() - return client.zadd(key, mapping) # type: ignore[return-value] - - -async def zrangebyscore( - key: str, min_score: float, max_score: float -) -> list[bytes]: - """Get members with scores between min and max. - - Args: - key: Sorted set key - min_score: Minimum score (inclusive) - max_score: Maximum score (inclusive) - - Returns: - List of members as bytes - """ - client = _get_async_client() - return await client.zrangebyscore(key, min_score, max_score) # type: ignore[return-value] - - -def zrem(key: str, *members: str) -> int: - """Remove members from a sorted set. - - Args: - key: Sorted set key - *members: Members to remove - - Returns: - Number of members removed - """ - client = _get_sync_client() - return client.zrem(key, *members) # type: ignore[return-value] - - -def clear_keys() -> int: - """Clear all Redis keys managed by this application. - - Uses SCAN to safely iterate through keys without blocking Redis. - Only deletes keys that match the configured prefix and queue name. - - Returns: - Total number of keys deleted, or 0 if Redis is not configured - """ - if CONF.redis_url is None: - return 0 - - client = _get_sync_client() - deleted = 0 - - # Delete the task queue - deleted += client.delete(CONF.queue_name) - - # Scan and delete all keys matching the configured prefix - # Pattern: "agentexec:*" (or whatever key_prefix is configured) - pattern = f"{CONF.key_prefix}:*" - cursor = 0 - - while True: - cursor, keys = client.scan(cursor=cursor, match=pattern, count=100) - if keys: - deleted += client.delete(*keys) - if cursor == 0: - break - - return deleted diff --git a/src/agentexec/tracker.py b/src/agentexec/tracker.py index 26a4fa2..6de64f5 100644 --- a/src/agentexec/tracker.py +++ b/src/agentexec/tracker.py @@ -5,63 +5,44 @@ Example: tracker = ax.Tracker("research", batch_id) - tracker.incr() # Count the discovery process itself + await tracker.incr() # Count the discovery process itself @function_tool async def queue_research(company: str) -> str: - tracker.incr() + await tracker.incr() await ax.enqueue("research", ResearchContext(company=company, batch_id=batch_id)) return f"Queued {company}" # When discovery finishes, decrement itself - if tracker.decr() == 0: + if await tracker.decr() == 0: await ax.enqueue("aggregate", AggregateContext(batch_id=batch_id)) # In research task - decrement when done tracker = ax.Tracker("research", context.batch_id) # ... do research ... - if tracker.decr() == 0: + if await tracker.decr() == 0: await ax.enqueue("aggregate", AggregateContext(batch_id=context.batch_id)) """ -from agentexec import state from agentexec.config import CONF +from agentexec.state import backend class Tracker: - """Coordinate dynamic fan-out with an atomic counter. - - Args: - *args: Key parts used to construct the tracker's unique key. - Typically includes a name and identifier, e.g., ("research", batch_id) - """ + """Coordinate dynamic fan-out with an atomic counter.""" def __init__(self, *args: str): - self._key = state.backend.format_key(CONF.key_prefix, "tracker", *args) - - def incr(self) -> int: - """Increment the counter. - - Returns: - Counter value after increment. - """ - return state.backend.incr(self._key) + self._key = backend.format_key(CONF.key_prefix, "tracker", *args) - def decr(self) -> int: - """Decrement the counter. + async def incr(self) -> int: + return await backend.state.counter_incr(self._key) - Returns: - Counter value after decrement. - """ - return state.backend.decr(self._key) + async def decr(self) -> int: + return await backend.state.counter_decr(self._key) - @property - def count(self) -> int: - """Get current counter value.""" - result = state.backend.get(self._key) + async def count(self) -> int: + result = await backend.state.get(self._key) return int(result) if result else 0 - @property - def complete(self) -> bool: - """Check if counter has reached zero.""" - return self.count == 0 + async def complete(self) -> bool: + return await self.count() == 0 diff --git a/src/agentexec/worker/event.py b/src/agentexec/worker/event.py index 7eede1e..797549d 100644 --- a/src/agentexec/worker/event.py +++ b/src/agentexec/worker/event.py @@ -1,5 +1,7 @@ from __future__ import annotations -from agentexec import state + +from agentexec.config import CONF +from agentexec.state import KEY_EVENT, backend class StateEvent: @@ -7,42 +9,23 @@ class StateEvent: Provides an interface similar to threading.Event/multiprocessing.Event, but backed by the state backend for cross-process and cross-machine coordination. - - This class is fully picklable (just stores name and optional id) and works - across any process that can connect to the same state backend. - - set() and clear() are synchronous for use from pool management code. - is_set() is async for use from worker event loops. - - Example: - event = StateEvent("shutdown", "pool1") - - # In pool (sync context) - event.set() - - # In worker (async context) - if await event.is_set(): - print("Shutdown signal received") """ def __init__(self, name: str, id: str) -> None: - """Initialize the event. - - Args: - name: Event name (e.g., "shutdown", "ready") - id: Identifier to scope the event (e.g., pool id) - """ self.name = name self.id = id - def set(self) -> None: + def _key(self) -> str: + return backend.format_key(*KEY_EVENT, self.name, self.id) + + async def set(self) -> None: """Set the event flag to True.""" - state.set_event(self.name, self.id) + await backend.state.set(self._key(), b"1") - def clear(self) -> None: + async def clear(self) -> None: """Reset the event flag to False.""" - state.clear_event(self.name, self.id) + await backend.state.delete(self._key()) async def is_set(self) -> bool: """Check if the event flag is True.""" - return await state.acheck_event(self.name, self.id) + return await backend.state.get(self._key()) is not None diff --git a/src/agentexec/worker/logging.py b/src/agentexec/worker/logging.py index acbb34c..3af7d5a 100644 --- a/src/agentexec/worker/logging.py +++ b/src/agentexec/worker/logging.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging +import multiprocessing as mp from pydantic import BaseModel -from agentexec import state LOGGER_NAME = "agentexec" LOG_CHANNEL = "agentexec:logs" @@ -9,7 +9,7 @@ class LogMessage(BaseModel): - """Schema for log messages sent via state backend pubsub.""" + """Schema for log messages sent via the worker message queue.""" name: str levelno: int @@ -22,7 +22,6 @@ class LogMessage(BaseModel): @classmethod def from_log_record(cls, record: logging.LogRecord) -> LogMessage: - """Create a LogMessage from a logging.LogRecord.""" return cls( name=record.name, levelno=record.levelno, @@ -35,7 +34,6 @@ def from_log_record(cls, record: logging.LogRecord) -> LogMessage: ) def to_log_record(self) -> logging.LogRecord: - """Convert back to a logging.LogRecord.""" record = logging.LogRecord( name=self.name, level=self.levelno, @@ -51,21 +49,18 @@ def to_log_record(self) -> logging.LogRecord: return record -class StateLogHandler(logging.Handler): - """Logging handler that publishes log records to state backend pubsub. +class QueueLogHandler(logging.Handler): + """Logging handler that sends log records to the pool via multiprocessing queue.""" - Used by worker processes to send logs to the main process. - """ - - def __init__(self, channel: str = LOG_CHANNEL): + def __init__(self, tx: mp.Queue): super().__init__() - self.channel = channel + self.tx = tx def emit(self, record: logging.LogRecord) -> None: - """Publish log record to log channel.""" try: + from agentexec.worker.pool import LogEntry message = LogMessage.from_log_record(record) - state.publish_log(message.model_dump_json()) + self.tx.put_nowait(LogEntry(record=message)) except Exception: self.handleError(record) @@ -73,29 +68,14 @@ def emit(self, record: logging.LogRecord) -> None: _worker_logging_configured = False -def get_worker_logger(name: str) -> logging.Logger: - """Configure worker logging and return a logger. - - On first call, sets up a state handler that publishes log records - to the main process via state backend pubsub. Subsequent calls just return - a logger under the agentexec namespace. - - Args: - name: Logger name. Typically __name__. - - Returns: - Configured logger instance. - - Example: - logger = get_worker_logger(__name__) - logger.info("Worker starting") - """ +def get_worker_logger(name: str, tx: mp.Queue | None = None) -> logging.Logger: + """Configure worker logging and return a logger.""" global _worker_logging_configured - if not _worker_logging_configured: + if not _worker_logging_configured and tx is not None: root = logging.getLogger(LOGGER_NAME) root.setLevel(logging.INFO) - root.addHandler(StateLogHandler()) + root.addHandler(QueueLogHandler(tx)) root.propagate = False _worker_logging_configured = True diff --git a/src/agentexec/worker/pool.py b/src/agentexec/worker/pool.py index 8d4dedc..6aa495c 100644 --- a/src/agentexec/worker/pool.py +++ b/src/agentexec/worker/pool.py @@ -5,15 +5,21 @@ import multiprocessing as mp from dataclasses import dataclass from typing import Any, Callable -from uuid import uuid4 +from uuid import UUID, uuid4 from pydantic import BaseModel from sqlalchemy import Engine, create_engine +from sqlalchemy.orm import Session, sessionmaker -from agentexec import state from agentexec.config import CONF -from agentexec.core.db import remove_global_session, set_global_session -from agentexec.core.queue import dequeue, requeue +from agentexec.state import backend +import queue as stdlib_queue + +from agentexec import activity +from agentexec.activity.events import ActivityUpdated +from agentexec.activity.handlers import IPCHandler +from agentexec.core.db import configure_engine +from agentexec.core.queue import enqueue from agentexec.core.task import Task, TaskDefinition, TaskHandler from agentexec import schedule from agentexec.worker.event import StateEvent @@ -29,6 +35,24 @@ ] +class Message(BaseModel): + """Base event sent from a worker to the pool.""" + pass + + +class TaskFailed(Message): + task: Task + error: str + + @classmethod + def from_exception(cls, task: Task, exception: Exception) -> TaskFailed: + return cls(task=task, error=str(exception)) + + +class LogEntry(Message): + record: LogMessage + + class _EmptyContext(BaseModel): """Default context for scheduled tasks that don't need one.""" @@ -44,22 +68,21 @@ def _get_pool_id() -> str: class WorkerContext: """Shared context passed from Pool to Worker processes.""" - database_url: str shutdown_event: StateEvent tasks: dict[str, TaskDefinition] - queue_name: str + tx: mp.Queue class Worker: """Individual worker process with isolated state. Each worker configures the scoped Session factory on startup. - Task handlers can use get_global_session() to get the process-local session. + Workers don't have database access — all persistence goes through the pool. """ _worker_id: int _context: WorkerContext - _logger: logging.Logger + logger: logging.Logger def __init__(self, worker_id: int, context: WorkerContext): """Initialize worker with isolated state. @@ -70,7 +93,9 @@ def __init__(self, worker_id: int, context: WorkerContext): """ self._worker_id = worker_id self._context = context - self._logger = get_worker_logger(__name__) + self.logger = get_worker_logger(__name__, tx=context.tx) + + activity.handler = IPCHandler(context.tx) @classmethod def run_in_process(cls, worker_id: int, context: WorkerContext) -> None: @@ -85,68 +110,44 @@ def run_in_process(cls, worker_id: int, context: WorkerContext) -> None: def run(self) -> None: """Main worker entry point - sets up async loop and runs.""" - self._logger.info(f"Worker {self._worker_id} starting") - - engine = create_engine(self._context.database_url) - set_global_session(engine) + self.logger.info(f"Worker {self._worker_id} starting") try: asyncio.run(self._run()) except Exception as e: - self._logger.exception(f"Worker {self._worker_id} fatal error: {e}") + self.logger.exception(f"Worker {self._worker_id} fatal error: {e}") raise - - async def _run(self) -> None: - """Async main loop - polls queue and processes tasks.""" - try: - # No sleep needed - dequeue() uses brpop which blocks waiting for tasks - while not await self._context.shutdown_event.is_set(): - if (task := await self._dequeue_task()) is not None: - lock_key = task.get_lock_key() - - if lock_key is not None: - acquired = await state.acquire_lock(lock_key, str(task.agent_id)) - if not acquired: - self._logger.debug( - f"Worker {self._worker_id} lock held for {task.task_name} " - f"(lock_key={lock_key}), requeuing" - ) - requeue(task, queue_name=self._context.queue_name) - continue - - try: - self._logger.info(f"Worker {self._worker_id} processing: {task.task_name}") - await task.execute() - self._logger.info(f"Worker {self._worker_id} completed: {task.task_name}") - finally: - if lock_key is not None: - await state.release_lock(lock_key) - except Exception as e: - self._logger.exception(f"Worker {self._worker_id} error: {e}") - # Continue processing other tasks - # TODO allow configurable behavior here (retry, backoff, fail) - # TODO all of the actual logic is handled in task.execute(), so I don't know why we ever end up here. finally: - await state.backend.close() - remove_global_session() - self._logger.info(f"Worker {self._worker_id} shutting down") + asyncio.run(backend.close()) + self.logger.info(f"Worker {self._worker_id} shutting down") - async def _dequeue_task(self) -> Task | None: - """Dequeue and hydrate a task from the Redis queue. + def _send(self, message: Message) -> None: + """Send a message to the pool via the multiprocessing queue.""" + self._context.tx.put_nowait(message) - Reconstructs the typed context using the TaskDefinition - and binds the definition to the task. + async def _run(self) -> None: + """Async main loop - dequeue, execute, complete.""" + while not await self._context.shutdown_event.is_set(): + try: + data = await backend.queue.pop(timeout=1) + if data is None: + continue - Returns: - Hydrated Task instance if available, else None. - """ - if (data := await dequeue(queue_name=self._context.queue_name)) is not None: - return Task.from_serialized( - definition=self._context.tasks[data["task_name"]], - data=data, - ) + task = Task.model_validate(data) + definition = self._context.tasks[task.task_name] + partition_key = definition.get_lock_key(task.context) + + try: + self.logger.info(f"Worker {self._worker_id} processing: {task.task_name}") + await definition.execute(task) + self.logger.info(f"Worker {self._worker_id} completed: {task.task_name}") + except Exception as e: + self._send(TaskFailed.from_exception(task, e)) + finally: + await backend.queue.complete(partition_key) + except Exception as e: + self.logger.exception(f"Worker {self._worker_id} error: {e}") - return None class Pool: @@ -177,14 +178,12 @@ def __init__( self, engine: Engine | None = None, database_url: str | None = None, - queue_name: str | None = None, ) -> None: """Initialize the worker pool. Args: engine: SQLAlchemy engine (URL will be extracted for workers). database_url: Database URL string. Alternative to passing engine. - queue_name: Redis queue name. Defaults to CONF.queue_name. Raises: ValueError: If neither engine nor database_url is provided. @@ -194,16 +193,16 @@ def __init__( raise ValueError("Either engine or database_url must be provided") engine = engine or create_engine(database_url) # type: ignore[arg-type] - set_global_session(engine) - + configure_engine(engine) + self._worker_queue: mp.Queue = mp.Queue() self._context = WorkerContext( - database_url=database_url or engine.url.render_as_string(hide_password=False), shutdown_event=StateEvent("shutdown", _get_pool_id()), tasks={}, - queue_name=queue_name or CONF.queue_name, + tx=self._worker_queue, ) self._processes = [] self._log_handler = None + self._pending_schedules: list[dict[str, Any]] = [] def task( self, @@ -345,6 +344,9 @@ def add_schedule( ``pool.add_task()``. The scheduler loop runs automatically inside ``pool.run()`` — no extra setup needed. + Schedules are stored and registered with the backend when + ``start()`` is called. + Args: task_name: Name of a registered task. every: Schedule expression (cron syntax: min hour dom mon dow). @@ -366,58 +368,52 @@ def add_schedule( f"Use @pool.task() or pool.add_task() first." ) - schedule.register( + self._pending_schedules.append(dict( task_name=task_name, every=every, context=context, repeat=repeat, metadata=metadata, - ) + )) - def start(self) -> None: - """Start worker processes (non-blocking). + async def start(self) -> None: + """Start workers and run until they exit. - Spawns N worker processes that poll the Redis queue and execute - tasks from this pool's registry. Returns immediately. - - Workers log to Redis pubsub. Use run() if you want the main - process to collect and display those logs. + Spawns worker processes, forwards logs, and processes scheduled + tasks. This is the foreground entry point — it blocks until all + workers finish. Use ``run()`` for a daemonized version that + handles KeyboardInterrupt and cleanup. """ - # Clear any stale shutdown signal - self._context.shutdown_event.clear() + await self._context.shutdown_event.clear() - # Spawn workers BEFORE setting up log handler to avoid pickling issues - # (StreamHandler has a lock that can't be pickled) + # Spawn workers before log handler to avoid pickling issues self._spawn_workers() - # Set up log handler for receiving worker logs # TODO make this configurable self._log_handler = logging.StreamHandler() self._log_handler.setFormatter(logging.Formatter(DEFAULT_FORMAT)) - def run(self) -> None: - """Start workers and run log collector until interrupted. + await asyncio.gather( + self._process_worker_events(), + self._process_scheduled_tasks(), + ) - Spawns worker processes and runs an async event loop in the main - process that collects logs from workers via Redis pubsub. - The scheduler loop also runs automatically alongside the workers, - polling for due scheduled tasks and enqueuing them. + def run(self) -> None: + """Start workers in a managed event loop with graceful shutdown. - Blocks until all workers exit or KeyboardInterrupt, then shuts - down gracefully. + Calls ``start()`` inside ``asyncio.run()`` and handles + KeyboardInterrupt, shutdown, and connection cleanup. """ async def _loop() -> None: try: - await self._collect_logs() + await self.start() except asyncio.CancelledError: pass finally: - self.shutdown() - await state.backend.close() + await self.shutdown() try: - self.start() asyncio.run(_loop()) except KeyboardInterrupt: pass @@ -436,34 +432,66 @@ def _spawn_workers(self) -> None: self._processes.append(process) print(f"Started worker {worker_id} (PID: {process.pid})") - async def _collect_logs(self) -> None: - """Listen for log messages from workers and run scheduler ticks.""" + async def _process_scheduled_tasks(self) -> None: + """Register pending schedules, then poll for due tasks and enqueue them.""" + for _schedule in self._pending_schedules: + await schedule.register(**_schedule) + self._pending_schedules.clear() + + while any(p.is_alive() for p in self._processes): + await asyncio.sleep(CONF.scheduler_poll_interval) + + for scheduled_task in await backend.schedule.get_due(): + await enqueue( + scheduled_task.task_name, + context=backend.deserialize(scheduled_task.context), + metadata=scheduled_task.metadata, + ) + + if scheduled_task.repeat == 0: + await backend.schedule.remove(scheduled_task.key) + else: + scheduled_task.advance() + await backend.schedule.register(scheduled_task) + + def _partition_key_for(self, task: Task) -> str | None: + """Derive the partition/lock key for a task from its definition.""" + return self._context.tasks[task.task_name].get_lock_key(task.context) + + async def _process_worker_events(self) -> None: + """Handle all events from worker processes via multiprocessing queue.""" assert self._log_handler, "Log handler not initialized" - # Create task to subscribe to logs - log_task = asyncio.create_task(self._process_log_stream()) - - try: - # Poll worker processes and run scheduler - while any(p.is_alive() for p in self._processes): - await asyncio.sleep(0.1) - await schedule.tick() - finally: - log_task.cancel() + while any(p.is_alive() for p in self._processes): try: - await log_task - except asyncio.CancelledError: - pass - - async def _process_log_stream(self) -> None: - """Process log messages from the state backend.""" - assert self._log_handler, "Log handler not initialized" - - async for message in state.subscribe_logs(): - log_message = LogMessage.model_validate_json(message) - self._log_handler.emit(log_message.to_log_record()) - - def shutdown(self, timeout: int | None = None) -> None: + message = self._worker_queue.get_nowait() + except stdlib_queue.Empty: + await asyncio.sleep(0.05) + continue + + match message: + case LogEntry(record=record): + self._log_handler.emit(record.to_log_record()) + + case TaskFailed(task=task, error=error): + if task.retry_count < CONF.max_task_retries: + task.retry_count += 1 + await backend.queue.push( + task.model_dump_json(), + partition_key=self._partition_key_for(task), + high_priority=True, + ) + else: + # TODO incorporate this messaging into the ax.activity stream. + print( + f"Task {task.task_name} failed " + f"after {task.retry_count + 1} attempts, giving up: {error}" + ) + + case ActivityUpdated(): + activity.handler(message) + + async def shutdown(self, timeout: int | None = None) -> None: """Gracefully shutdown all worker processes. For use with start(). If using run(), shutdown is handled automatically. @@ -475,7 +503,7 @@ def shutdown(self, timeout: int | None = None) -> None: timeout = CONF.graceful_shutdown_timeout print("Shutting down worker pool") - self._context.shutdown_event.set() + await self._context.shutdown_event.set() for process in self._processes: process.join(timeout=timeout) @@ -485,4 +513,5 @@ def shutdown(self, timeout: int | None = None) -> None: process.join(timeout=5) self._processes.clear() + await backend.close() print("Worker pool shutdown complete") diff --git a/tests/test_activity_schemas.py b/tests/test_activity_schemas.py index 2addd3d..67120a6 100644 --- a/tests/test_activity_schemas.py +++ b/tests/test_activity_schemas.py @@ -1,5 +1,3 @@ -"""Test activity schema validation and computed fields.""" - import uuid from datetime import datetime, timedelta, UTC diff --git a/tests/test_activity_tracking.py b/tests/test_activity_tracking.py index ab1963f..b40e8a3 100644 --- a/tests/test_activity_tracking.py +++ b/tests/test_activity_tracking.py @@ -1,5 +1,3 @@ -"""Tests for activity tracking functionality.""" - import uuid import pytest @@ -8,21 +6,20 @@ from agentexec import activity from agentexec.activity.models import Activity, ActivityLog, Base, Status -from agentexec.activity.tracker import normalize_agent_id +from agentexec.activity import normalize_agent_id + @pytest.fixture def db_session(): """Set up an in-memory SQLite database for testing.""" - # Create engine and session factory (users manage their own) - engine = create_engine("sqlite:///:memory:", echo=False) - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + from agentexec.core.db import configure_engine - # Create tables + engine = create_engine("sqlite:///:memory:", echo=False) Base.metadata.create_all(bind=engine) + configure_engine(engine) - # Provide a session for the test - session = SessionLocal() + session = sessionmaker(bind=engine)() try: yield session session.commit() @@ -32,11 +29,12 @@ def db_session(): finally: session.close() engine.dispose() + engine.dispose() -def test_create_activity(db_session: Session): +async def test_create_activity(db_session: Session): """Test creating a new activity record.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Task queued for testing", session=db_session, @@ -84,17 +82,17 @@ def test_database_tables_created(): engine.dispose() -def test_update_activity(db_session: Session): +async def test_update_activity(db_session: Session): """Test updating an activity with a new log message.""" # First create an activity - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Initial message", session=db_session, ) # Update the activity - result = activity.update( + result = await activity.update( agent_id=agent_id, message="Processing...", percentage=50, @@ -112,15 +110,15 @@ def test_update_activity(db_session: Session): assert activity_record.logs[1].percentage == 50 -def test_update_activity_with_custom_status(db_session: Session): +async def test_update_activity_with_custom_status(db_session: Session): """Test updating an activity with a custom status.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Initial", session=db_session, ) - activity.update( + await activity.update( agent_id=agent_id, message="Custom status update", status=Status.RUNNING, @@ -134,15 +132,15 @@ def test_update_activity_with_custom_status(db_session: Session): assert latest_log.status == Status.RUNNING -def test_complete_activity(db_session: Session): +async def test_complete_activity(db_session: Session): """Test marking an activity as complete.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Started", session=db_session, ) - result = activity.complete( + result = await activity.complete( agent_id=agent_id, message="Successfully completed", session=db_session, @@ -158,15 +156,15 @@ def test_complete_activity(db_session: Session): assert latest_log.percentage == 100 -def test_complete_activity_custom_percentage(db_session: Session): +async def test_complete_activity_custom_percentage(db_session: Session): """Test marking an activity complete with custom percentage.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Started", session=db_session, ) - activity.complete( + await activity.complete( agent_id=agent_id, message="Done", percentage=95, @@ -179,15 +177,15 @@ def test_complete_activity_custom_percentage(db_session: Session): assert latest_log.percentage == 95 -def test_error_activity(db_session: Session): +async def test_error_activity(db_session: Session): """Test marking an activity as errored.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Started", session=db_session, ) - result = activity.error( + result = await activity.error( agent_id=agent_id, message="Task failed: connection timeout", session=db_session, @@ -203,36 +201,36 @@ def test_error_activity(db_session: Session): assert latest_log.percentage == 100 -def test_cancel_pending_activities(db_session: Session): +async def test_cancel_pending_activities(db_session: Session): """Test canceling all pending activities.""" # Create some activities in different states - queued_id = activity.create( + queued_id = await activity.create( task_name="queued_task", message="Waiting", session=db_session, ) - running_id = activity.create( + running_id = await activity.create( task_name="running_task", message="Started", session=db_session, ) - activity.update( + await activity.update( agent_id=running_id, message="Running...", status=Status.RUNNING, session=db_session, ) - complete_id = activity.create( + complete_id = await activity.create( task_name="complete_task", message="Started", session=db_session, ) - activity.complete(agent_id=complete_id, session=db_session) + await activity.complete(agent_id=complete_id, session=db_session) # Cancel pending activities - canceled_count = activity.cancel_pending(session=db_session) + canceled_count = await activity.cancel_pending(session=db_session) # Should have canceled the queued and running activities assert canceled_count == 2 @@ -250,18 +248,18 @@ def test_cancel_pending_activities(db_session: Session): assert complete_record.logs[-1].status == Status.COMPLETE # Not changed -def test_list_activities(db_session: Session): +async def test_list_activities(db_session: Session): """Test listing activities with pagination.""" # Create several activities for i in range(5): - activity.create( + await activity.create( task_name=f"task_{i}", message=f"Message {i}", session=db_session, ) # List activities - result = activity.list(db_session, page=1, page_size=3) + result = await activity.list(db_session, page=1, page_size=3) assert len(result.items) == 3 assert result.total == 5 @@ -269,38 +267,38 @@ def test_list_activities(db_session: Session): assert result.page_size == 3 -def test_list_activities_second_page(db_session: Session): +async def test_list_activities_second_page(db_session: Session): """Test listing activities on second page.""" for i in range(5): - activity.create( + await activity.create( task_name=f"task_{i}", message=f"Message {i}", session=db_session, ) - result = activity.list(db_session, page=2, page_size=3) + result = await activity.list(db_session, page=2, page_size=3) assert len(result.items) == 2 # Remaining items assert result.total == 5 assert result.page == 2 -def test_detail_activity(db_session: Session): +async def test_detail_activity(db_session: Session): """Test getting activity detail with all logs.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="detailed_task", message="Initial", session=db_session, ) - activity.update( + await activity.update( agent_id=agent_id, message="Processing", percentage=50, session=db_session, ) - activity.complete(agent_id=agent_id, session=db_session) + await activity.complete(agent_id=agent_id, session=db_session) - result = activity.detail(db_session, agent_id) + result = await activity.detail(db_session, agent_id) assert result is not None assert result.agent_id == agent_id @@ -311,33 +309,33 @@ def test_detail_activity(db_session: Session): assert result.logs[2].status == Status.COMPLETE -def test_detail_activity_not_found(db_session: Session): +async def test_detail_activity_not_found(db_session: Session): """Test getting detail for non-existent activity returns None.""" fake_id = uuid.uuid4() - result = activity.detail(db_session, fake_id) + result = await activity.detail(db_session, fake_id) assert result is None -def test_detail_activity_with_string_id(db_session: Session): +async def test_detail_activity_with_string_id(db_session: Session): """Test getting activity detail with string agent_id.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="string_id_task", message="Test", session=db_session, ) # Use string ID - result = activity.detail(db_session, str(agent_id)) + result = await activity.detail(db_session, str(agent_id)) assert result is not None assert result.agent_id == agent_id -def test_create_activity_with_custom_agent_id(db_session: Session): +async def test_create_activity_with_custom_agent_id(db_session: Session): """Test creating activity with a custom agent_id.""" custom_id = uuid.uuid4() - agent_id = activity.create( + agent_id = await activity.create( task_name="custom_id_task", message="Test", agent_id=custom_id, @@ -350,10 +348,10 @@ def test_create_activity_with_custom_agent_id(db_session: Session): assert activity_record is not None -def test_create_activity_with_string_agent_id(db_session: Session): +async def test_create_activity_with_string_agent_id(db_session: Session): """Test creating activity with a string agent_id.""" custom_id = uuid.uuid4() - agent_id = activity.create( + agent_id = await activity.create( task_name="string_agent_id_task", message="Test", agent_id=str(custom_id), @@ -363,12 +361,9 @@ def test_create_activity_with_string_agent_id(db_session: Session): assert agent_id == custom_id -# --- Metadata Tests --- - - -def test_create_activity_with_metadata(db_session: Session): +async def test_create_activity_with_metadata(db_session: Session): """Test creating activity with metadata.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="metadata_task", message="Test with metadata", session=db_session, @@ -380,9 +375,9 @@ def test_create_activity_with_metadata(db_session: Session): assert activity_record.metadata_ == {"organization_id": "org-123", "user_id": "user-456"} -def test_create_activity_without_metadata(db_session: Session): +async def test_create_activity_without_metadata(db_session: Session): """Test that metadata is None by default.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="no_metadata_task", message="Test without metadata", session=db_session, @@ -393,22 +388,22 @@ def test_create_activity_without_metadata(db_session: Session): assert activity_record.metadata_ is None -def test_list_activities_with_metadata_filter(db_session: Session): +async def test_list_activities_with_metadata_filter(db_session: Session): """Test filtering activities by metadata.""" # Create activities for different organizations - activity.create( + await activity.create( task_name="task_org_a", message="Org A task", session=db_session, metadata={"organization_id": "org-A"}, ) - activity.create( + await activity.create( task_name="task_org_a_2", message="Org A task 2", session=db_session, metadata={"organization_id": "org-A"}, ) - activity.create( + await activity.create( task_name="task_org_b", message="Org B task", session=db_session, @@ -416,7 +411,7 @@ def test_list_activities_with_metadata_filter(db_session: Session): ) # Filter by org-A - result = activity.list( + result = await activity.list( db_session, metadata_filter={"organization_id": "org-A"}, ) @@ -427,7 +422,7 @@ def test_list_activities_with_metadata_filter(db_session: Session): assert item.metadata["organization_id"] == "org-A" # Filter by org-B - result = activity.list( + result = await activity.list( db_session, metadata_filter={"organization_id": "org-B"}, ) @@ -436,22 +431,22 @@ def test_list_activities_with_metadata_filter(db_session: Session): assert result.items[0].metadata["organization_id"] == "org-B" # Filter by non-existent org - result = activity.list( + result = await activity.list( db_session, metadata_filter={"organization_id": "org-C"}, ) assert result.total == 0 -def test_list_activities_with_multiple_metadata_filters(db_session: Session): +async def test_list_activities_with_multiple_metadata_filters(db_session: Session): """Test filtering activities by multiple metadata fields.""" - activity.create( + await activity.create( task_name="task_1", message="User 1 in Org A", session=db_session, metadata={"organization_id": "org-A", "user_id": "user-1"}, ) - activity.create( + await activity.create( task_name="task_2", message="User 2 in Org A", session=db_session, @@ -459,37 +454,37 @@ def test_list_activities_with_multiple_metadata_filters(db_session: Session): ) # Filter by both org and user - result = activity.list( + result = await activity.list( db_session, metadata_filter={"organization_id": "org-A", "user_id": "user-1"}, ) assert result.total == 1 -def test_detail_activity_with_metadata(db_session: Session): +async def test_detail_activity_with_metadata(db_session: Session): """Test getting activity detail includes metadata.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="detailed_metadata_task", message="Test", session=db_session, metadata={"organization_id": "org-123"}, ) - result = activity.detail(db_session, agent_id) + result = await activity.detail(db_session, agent_id) assert result is not None assert result.metadata == {"organization_id": "org-123"} -def test_detail_activity_with_metadata_filter_match(db_session: Session): +async def test_detail_activity_with_metadata_filter_match(db_session: Session): """Test detail returns activity when metadata filter matches.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="filter_match_task", message="Test", session=db_session, metadata={"organization_id": "org-A"}, ) - result = activity.detail( + result = await activity.detail( db_session, agent_id, metadata_filter={"organization_id": "org-A"}, @@ -498,9 +493,9 @@ def test_detail_activity_with_metadata_filter_match(db_session: Session): assert result.agent_id == agent_id -def test_detail_activity_with_metadata_filter_no_match(db_session: Session): +async def test_detail_activity_with_metadata_filter_no_match(db_session: Session): """Test detail returns None when metadata filter doesn't match.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="filter_no_match_task", message="Test", session=db_session, @@ -508,7 +503,7 @@ def test_detail_activity_with_metadata_filter_no_match(db_session: Session): ) # Try to access with wrong organization - result = activity.detail( + result = await activity.detail( db_session, agent_id, metadata_filter={"organization_id": "org-B"}, @@ -516,15 +511,15 @@ def test_detail_activity_with_metadata_filter_no_match(db_session: Session): assert result is None -def test_detail_activity_no_metadata_with_filter(db_session: Session): +async def test_detail_activity_no_metadata_with_filter(db_session: Session): """Test detail returns None when activity has no metadata but filter is applied.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="no_metadata_with_filter", message="Test", session=db_session, ) - result = activity.detail( + result = await activity.detail( db_session, agent_id, metadata_filter={"organization_id": "org-A"}, @@ -532,28 +527,28 @@ def test_detail_activity_no_metadata_with_filter(db_session: Session): assert result is None -def test_list_metadata_accessible_as_attribute(db_session: Session): +async def test_list_metadata_accessible_as_attribute(db_session: Session): """Test that metadata is accessible as an attribute on schema objects.""" - activity.create( + await activity.create( task_name="list_metadata_task", message="Test", session=db_session, metadata={"key1": "value1", "key2": "value2"}, ) - result = activity.list(db_session) + result = await activity.list(db_session) assert result.total == 1 # Metadata is accessible as attribute for programmatic use assert result.items[0].metadata == {"key1": "value1", "key2": "value2"} -def test_metadata_excluded_from_serialization(db_session: Session): +async def test_metadata_excluded_from_serialization(db_session: Session): """Test that metadata is excluded from JSON/dict serialization by default. This prevents accidental leakage of tenant info through API responses. Users who want metadata in responses should explicitly include it. """ - agent_id = activity.create( + agent_id = await activity.create( task_name="serialization_test", message="Test", session=db_session, @@ -561,12 +556,12 @@ def test_metadata_excluded_from_serialization(db_session: Session): ) # List view - metadata excluded from serialization - result = activity.list(db_session) + result = await activity.list(db_session) item_dict = result.items[0].model_dump() assert "metadata" not in item_dict # Detail view - metadata excluded from serialization - detail = activity.detail(db_session, agent_id) + detail = await activity.detail(db_session, agent_id) assert detail is not None detail_dict = detail.model_dump() assert "metadata" not in detail_dict diff --git a/tests/test_config.py b/tests/test_config.py index af280fc..3aa6ec4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,3 @@ -"""Test configuration handling.""" - import os import pytest @@ -15,10 +13,10 @@ def test_default_table_prefix(self): config = Config() assert config.table_prefix == "agentexec_" - def test_default_queue_name(self): - """Test default queue name.""" + def test_default_queue_prefix(self): + """Test default queue prefix.""" config = Config() - assert config.queue_name == "agentexec_tasks" + assert config.queue_prefix == "agentexec_tasks" def test_default_num_workers(self): """Test default number of workers.""" @@ -107,11 +105,11 @@ def test_table_prefix_from_env(self): config = Config() assert config.table_prefix == "custom_" - def test_queue_name_from_env(self): - """Test queue_name from environment variable.""" + def test_queue_prefix_from_env(self): + """Test queue_prefix from environment variable (with backwards compat alias).""" os.environ["AGENTEXEC_QUEUE_NAME"] = "my_queue" config = Config() - assert config.queue_name == "my_queue" + assert config.queue_prefix == "my_queue" def test_graceful_shutdown_timeout_from_env(self): """Test graceful_shutdown_timeout from environment variable.""" @@ -191,7 +189,7 @@ def test_conf_is_config_instance(self): def test_conf_has_expected_attributes(self): """Test that CONF has all expected attributes.""" assert hasattr(CONF, "table_prefix") - assert hasattr(CONF, "queue_name") + assert hasattr(CONF, "queue_prefix") assert hasattr(CONF, "num_workers") assert hasattr(CONF, "graceful_shutdown_timeout") assert hasattr(CONF, "redis_url") diff --git a/tests/test_db.py b/tests/test_db.py index 4714751..01df665 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,134 +1,38 @@ -"""Test database session management.""" - import pytest -from sqlalchemy import create_engine, text -from sqlalchemy.orm import Session +from sqlalchemy import create_engine -from agentexec.core.db import ( - Base, - get_global_session, - remove_global_session, - set_global_session, -) +from agentexec.core.db import Base, configure_engine, get_session @pytest.fixture def test_engine(): - """Create a test SQLite engine.""" - engine = create_engine("sqlite:///:memory:", echo=False) + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(bind=engine) + configure_engine(engine) yield engine engine.dispose() -@pytest.fixture(autouse=True) -def cleanup_session(): - """Cleanup global session after each test.""" - yield - try: - remove_global_session() - except Exception: - pass - - -def test_base_class_exists(): - """Test that Base class is exported and usable.""" - assert Base is not None - assert hasattr(Base, "metadata") - - -def test_set_global_session(test_engine): - """Test that set_global_session configures the session factory.""" - set_global_session(test_engine) - - # Should be able to get a session now - session = get_global_session() - assert isinstance(session, Session) - - -def test_get_global_session_returns_session(test_engine): - """Test that get_global_session returns a working session.""" - set_global_session(test_engine) - - session = get_global_session() - - # Verify it's a working session - result = session.execute(text("SELECT 1")) - assert result.scalar() == 1 - - -def test_get_global_session_singleton(test_engine): - """Test that get_global_session returns the same session instance.""" - set_global_session(test_engine) - - session1 = get_global_session() - session2 = get_global_session() - - # Should be the same session (scoped_session behavior) - assert session1 is session2 - - -def test_remove_global_session(test_engine): - """Test that remove_global_session closes the session.""" - set_global_session(test_engine) - - session1 = get_global_session() - remove_global_session() +def test_configure_engine(test_engine): + """configure_engine makes get_session available.""" + session = get_session() + assert session is not None + session.close() - # Getting session again should return a different instance - session2 = get_global_session() - # They should be different sessions after remove - assert session1 is not session2 +def test_get_session_context_manager(test_engine): + """get_session works as a context manager.""" + with get_session() as session: + assert session is not None -def test_session_with_tables(test_engine): - """Test that session works with table creation.""" - # Create the tables - Base.metadata.create_all(bind=test_engine) - - set_global_session(test_engine) - session = get_global_session() - - # Session should be able to query (though tables may be empty) - result = session.execute(text("SELECT name FROM sqlite_master WHERE type='table'")) - tables = [row[0] for row in result] - - # Tables from Base.metadata should exist - assert isinstance(tables, list) - - -def test_multiple_set_global_session_calls(test_engine): - """Test that multiple set_global_session calls work correctly.""" - set_global_session(test_engine) - session1 = get_global_session() - - # Create another engine - engine2 = create_engine("sqlite:///:memory:") - - # Reconfigure with new engine - set_global_session(engine2) - session2 = get_global_session() - - # Sessions should work with their respective engines - result = session2.execute(text("SELECT 1")) - assert result.scalar() == 1 - - engine2.dispose() - - -def test_session_lifecycle(): - """Test complete session lifecycle: set -> use -> remove.""" - engine = create_engine("sqlite:///:memory:") - - # Set - set_global_session(engine) - - # Use - session = get_global_session() - session.execute(text("SELECT 1")) - - # Remove - remove_global_session() - - # Cleanup - engine.dispose() +def test_get_session_without_configure_raises(): + """get_session raises if configure_engine hasn't been called.""" + import agentexec.core.db as db_module + old_factory = db_module._session_factory + db_module._session_factory = None + try: + with pytest.raises(RuntimeError, match="Database engine not configured"): + get_session() + finally: + db_module._session_factory = old_factory diff --git a/tests/test_kafka_integration.py b/tests/test_kafka_integration.py new file mode 100644 index 0000000..c61c5a7 --- /dev/null +++ b/tests/test_kafka_integration.py @@ -0,0 +1,158 @@ +"""Kafka backend integration tests. + +These tests run against a real Kafka broker. They are skipped if the +``aiokafka`` package is not installed or ``KAFKA_BOOTSTRAP_SERVERS`` is +not set. + +Run locally: + + docker compose -f docker-compose.kafka.yml up -d + + AGENTEXEC_STATE_BACKEND=agentexec.state.kafka \\ + KAFKA_BOOTSTRAP_SERVERS=localhost:9092 \\ + uv run pytest tests/test_kafka_integration.py -v + + docker compose -f docker-compose.kafka.yml down +""" + +from __future__ import annotations + +import asyncio +import os +import uuid + +import pytest +from pydantic import BaseModel + +_skip_reason = None + +if not os.environ.get("KAFKA_BOOTSTRAP_SERVERS"): + _skip_reason = "KAFKA_BOOTSTRAP_SERVERS not set" +else: + try: + import aiokafka # noqa: F401 + except ImportError: + _skip_reason = "aiokafka not installed (pip install agentexec[kafka])" + +if _skip_reason: + pytest.skip(_skip_reason, allow_module_level=True) + + +from agentexec.state import backend # noqa: E402 +from agentexec.state.kafka import Backend as KafkaBackend # noqa: E402 + +_kb: KafkaBackend = backend # type: ignore[assignment] + + +class SampleResult(BaseModel): + status: str + value: int + + +class TaskContext(BaseModel): + query: str + + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +@pytest.fixture(autouse=True, scope="module") +async def close_connections(): + """Close all Kafka connections once after the module completes.""" + yield + await _kb.close() + + +class TestStateNotSupported: + async def test_get_raises(self): + """KV get raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + await _kb.state.get("any-key") + + async def test_set_raises(self): + """KV set raises NotImplementedError.""" + with pytest.raises(NotImplementedError): + await _kb.state.set("any-key", b"value") + + async def test_counter_raises(self): + """Counter operations raise NotImplementedError.""" + with pytest.raises(NotImplementedError): + await _kb.state.counter_incr("any-key") + + +class TestSerialization: + def test_roundtrip(self): + """serialize → deserialize preserves type and data.""" + original = SampleResult(status="ok", value=42) + data = _kb.serialize(original) + restored = _kb.deserialize(data) + assert type(restored) is SampleResult + assert restored == original + + def test_format_key_joins_with_dots(self): + """Kafka backend uses dots as key separators.""" + assert _kb.format_key("agentexec", "result", "123") == "agentexec.result.123" + + +class TestQueue: + async def test_push_and_pop(self): + """A pushed task can be popped from the queue.""" + import json + + task_data = { + "task_name": "test_task", + "context": {"query": "hello"}, + "agent_id": str(uuid.uuid4()), + } + await _kb.queue.push(json.dumps(task_data)) + + result = await _kb.queue.pop(timeout=10) + assert result is not None + assert result["task_name"] == "test_task" + assert result["context"]["query"] == "hello" + + async def test_pop_empty_queue_returns_none(self): + """Popping an empty queue returns None after timeout.""" + result = await _kb.queue.pop(timeout=1) + # May or may not be None depending on prior test state, + # but should not raise + + async def test_push_with_partition_key(self): + """Tasks with partition_key are routed deterministically.""" + import json + + task_data = { + "task_name": "keyed_task", + "context": {"query": "keyed"}, + "agent_id": str(uuid.uuid4()), + } + await _kb.queue.push(json.dumps(task_data), partition_key="user-123") + + result = await _kb.queue.pop(timeout=10) + assert result is not None + assert result["task_name"] == "keyed_task" + + async def test_complete_is_noop(self): + """complete() is a no-op for Kafka (partition assignment handles it).""" + await _kb.queue.complete("any-key") + await _kb.queue.complete(None) + + +class TestConnection: + async def test_ensure_topic_idempotent(self): + """ensure_topic can be called multiple times without error.""" + topic = f"test_ensure_{uuid.uuid4().hex[:8]}" + await _kb.ensure_topic(topic) + await _kb.ensure_topic(topic) # Should not raise + + async def test_client_id_includes_pid(self): + """client_id includes PID for uniqueness.""" + cid = _kb._client_id("producer") + assert str(os.getpid()) in cid + assert "producer" in cid + + async def test_produce_and_topic_creation(self): + """produce() auto-creates the topic if needed.""" + topic = f"test_produce_{uuid.uuid4().hex[:8]}" + await _kb.produce(topic, b"test-value", key=b"test-key") + # If we got here without error, produce and topic creation worked diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index cf81680..bf9213d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,5 +1,3 @@ -"""Test Pipeline orchestration functionality.""" - from dataclasses import dataclass, field from unittest.mock import MagicMock diff --git a/tests/test_pipeline_flow.py b/tests/test_pipeline_flow.py index b8f0ff6..d7b9481 100644 --- a/tests/test_pipeline_flow.py +++ b/tests/test_pipeline_flow.py @@ -1,12 +1,3 @@ -"""Test Pipeline type flow validation. - -This module tests the type checking between pipeline steps, ensuring that: -- Return types from one step match parameter types of the next step -- Tuple returns are properly unpacked into multiple parameters -- Type mismatches are caught at validation time -- Subclass relationships are respected -""" - from dataclasses import dataclass, field from unittest.mock import MagicMock @@ -16,11 +7,6 @@ from agentexec.pipeline import Pipeline -# ============================================================================= -# Test Models -# ============================================================================= - - class Context(BaseModel): """Input context for pipeline tests.""" @@ -52,11 +38,6 @@ class Combined(BaseModel): b: str -# ============================================================================= -# Fixtures -# ============================================================================= - - @dataclass class MockWorkerContext: """Mock context for testing.""" @@ -78,11 +59,6 @@ def pipeline(mock_pool): return Pipeline(mock_pool) -# ============================================================================= -# Valid Flows - Single Value -# ============================================================================= - - class TestValidSingleValueFlows: """Test valid single-value flows between steps.""" @@ -148,11 +124,6 @@ async def consume_base(self, result: ResultA) -> ResultC: assert result.value == "from_derived" -# ============================================================================= -# Valid Flows - Tuple Unpacking -# ============================================================================= - - class TestValidTupleFlows: """Test valid tuple return/parameter flows.""" @@ -215,11 +186,6 @@ async def combine(self, left: ResultA, right: ResultA) -> Combined: assert result.b == "right:data" -# ============================================================================= -# Invalid Flows - Count Mismatches -# ============================================================================= - - class TestInvalidCountMismatches: """Test that count mismatches between steps are caught.""" @@ -279,11 +245,6 @@ async def second(self) -> ResultC: await pipeline.run(Context(value="42")) -# ============================================================================= -# Invalid Flows - Type Mismatches -# ============================================================================= - - class TestInvalidTypeMismatches: """Test that type mismatches between steps are caught.""" @@ -328,11 +289,6 @@ class UnrelatedPipeline(pipeline.Base): pipeline._validate_type_flow() -# ============================================================================= -# Invalid Flows - Final Step Returns Tuple -# ============================================================================= - - class TestInvalidFinalStepTuple: """Test that final step returning tuple is rejected.""" @@ -357,11 +313,6 @@ class TupleFinalPipeline(pipeline.Base): pipeline._validate_type_flow() -# ============================================================================= -# Edge Cases -# ============================================================================= - - class TestInvalidNoSteps: """Test that pipelines with no steps are rejected.""" diff --git a/tests/test_public_api.py b/tests/test_public_api.py index 8ce615f..4d32c44 100644 --- a/tests/test_public_api.py +++ b/tests/test_public_api.py @@ -1,5 +1,3 @@ -"""Test that the public API is properly exposed.""" - import uuid import pytest diff --git a/tests/test_queue.py b/tests/test_queue.py index db71c3f..6217727 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -1,5 +1,3 @@ -"""Test task queue operations.""" - import json import uuid @@ -8,7 +6,8 @@ from pydantic import BaseModel import agentexec as ax -from agentexec.core.queue import Priority, dequeue, enqueue +from agentexec.core.queue import Priority, enqueue +from agentexec.state import backend class SampleContext(BaseModel): @@ -20,31 +19,17 @@ class SampleContext(BaseModel): @pytest.fixture def fake_redis(monkeypatch): - """Setup fake redis for state backend with shared state.""" - import fakeredis - - # Create a shared FakeServer so sync and async clients share data - server = fakeredis.FakeServer() - fake_redis_sync = fakeredis.FakeRedis(server=server, decode_responses=False) - fake_redis = fake_aioredis.FakeRedis(server=server, decode_responses=False) - - def get_fake_sync_client(): - return fake_redis_sync - - def get_fake_async_client(): - return fake_redis - - monkeypatch.setattr("agentexec.state.redis_backend._get_sync_client", get_fake_sync_client) - monkeypatch.setattr("agentexec.state.redis_backend._get_async_client", get_fake_async_client) - - yield fake_redis + """Setup fake redis for state backend.""" + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake @pytest.fixture def mock_activity_create(monkeypatch): """Mock activity.create to avoid database dependency.""" - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -65,9 +50,8 @@ async def test_enqueue_creates_task(fake_redis, mock_activity_create) -> None: assert task is not None assert task.task_name == "test_task" assert isinstance(task.agent_id, uuid.UUID) - assert isinstance(task.context, SampleContext) - assert task.context.message == "test" - assert task.context.value == 42 + assert task.context["message"] == "test" + assert task.context["value"] == 42 async def test_enqueue_pushes_to_redis(fake_redis, mock_activity_create) -> None: @@ -77,7 +61,7 @@ async def test_enqueue_pushes_to_redis(fake_redis, mock_activity_create) -> None task = await enqueue("test_task", ctx) # Check Redis has the task - task_json = await fake_redis.rpop(ax.CONF.queue_name) + task_json = await fake_redis.rpop(ax.CONF.queue_prefix) assert task_json is not None task_data = json.loads(task_json) @@ -94,7 +78,7 @@ async def test_enqueue_low_priority_lpush(fake_redis, mock_activity_create) -> N # LPUSH adds to left, RPOP takes from right # So we should use LPOP to see it - task_json = await fake_redis.lpop(ax.CONF.queue_name) + task_json = await fake_redis.lpop(ax.CONF.queue_prefix) assert task_json is not None @@ -107,21 +91,11 @@ async def test_enqueue_high_priority_rpush(fake_redis, mock_activity_create) -> await enqueue("high_task", SampleContext(message="high"), priority=Priority.HIGH) # High priority should be at the front (RPOP side) - task_json = await fake_redis.rpop(ax.CONF.queue_name) + task_json = await fake_redis.rpop(ax.CONF.queue_prefix) task_data = json.loads(task_json) assert task_data["task_name"] == "high_task" -async def test_enqueue_custom_queue_name(fake_redis, mock_activity_create) -> None: - """Test enqueue with custom queue name.""" - ctx = SampleContext(message="custom") - - await enqueue("test_task", ctx, queue_name="custom_queue") - - # Check custom queue - task_json = await fake_redis.rpop("custom_queue") - assert task_json is not None - async def test_dequeue_returns_task_data(fake_redis) -> None: """Test that dequeue returns parsed task data.""" @@ -131,10 +105,10 @@ async def test_dequeue_returns_task_data(fake_redis) -> None: "context": {"message": "dequeued", "value": 100}, "agent_id": str(uuid.uuid4()), } - await fake_redis.lpush(ax.CONF.queue_name, json.dumps(task_data).encode()) + await fake_redis.lpush(ax.CONF.queue_prefix, json.dumps(task_data).encode()) # Dequeue - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is not None assert result["task_name"] == "test_task" @@ -145,25 +119,11 @@ async def test_dequeue_returns_task_data(fake_redis) -> None: async def test_dequeue_returns_none_on_empty_queue(fake_redis) -> None: """Test that dequeue returns None when queue is empty.""" # timeout=1 because timeout=0 means block indefinitely in Redis BRPOP - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is None -async def test_dequeue_custom_queue_name(fake_redis) -> None: - """Test dequeue with custom queue name.""" - task_data = { - "task_name": "custom_task", - "context": {"message": "test"}, - "agent_id": str(uuid.uuid4()), - } - await fake_redis.lpush("custom_queue", json.dumps(task_data).encode()) - - result = await dequeue(queue_name="custom_queue", timeout=1) - - assert result is not None - assert result["task_name"] == "custom_task" - async def test_dequeue_brpop_behavior(fake_redis) -> None: """Test that dequeue uses BRPOP (right side of list).""" @@ -171,11 +131,11 @@ async def test_dequeue_brpop_behavior(fake_redis) -> None: task1 = {"task_name": "first", "context": {}, "agent_id": str(uuid.uuid4())} task2 = {"task_name": "second", "context": {}, "agent_id": str(uuid.uuid4())} - await fake_redis.lpush(ax.CONF.queue_name, json.dumps(task1).encode()) - await fake_redis.lpush(ax.CONF.queue_name, json.dumps(task2).encode()) + await fake_redis.lpush(ax.CONF.queue_prefix, json.dumps(task1).encode()) + await fake_redis.lpush(ax.CONF.queue_prefix, json.dumps(task2).encode()) # BRPOP should get the first task (oldest) from the right - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is not None assert result["task_name"] == "first" @@ -188,7 +148,7 @@ async def test_enqueue_dequeue_roundtrip(fake_redis, mock_activity_create) -> No task = await enqueue("roundtrip_task", ctx) # Dequeue - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is not None assert result["task_name"] == "roundtrip_task" @@ -207,6 +167,6 @@ async def test_multiple_enqueue_fifo_order(fake_redis, mock_activity_create) -> # Dequeue should be in FIFO order for i in range(3): - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is not None assert result["task_name"] == f"task_{i}" diff --git a/tests/test_queue_partitions.py b/tests/test_queue_partitions.py new file mode 100644 index 0000000..deac6a7 --- /dev/null +++ b/tests/test_queue_partitions.py @@ -0,0 +1,173 @@ +"""Tests for the partitioned Redis queue — SCAN-based dequeue with per-partition locking.""" + +import json +import uuid + +import pytest +from fakeredis import aioredis as fake_aioredis +from pydantic import BaseModel + +import agentexec as ax +from agentexec.state import backend + + +def _task_json(task_name: str = "test", **overrides) -> str: + data = { + "task_name": task_name, + "context": {"message": "hello"}, + "agent_id": str(uuid.uuid4()), + **overrides, + } + return json.dumps(data) + + +@pytest.fixture +def fake_redis(monkeypatch): + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake + + +class TestPartitionRouting: + async def test_push_to_default_queue(self, fake_redis): + await backend.queue.push(_task_json("t1")) + assert await fake_redis.llen(ax.CONF.queue_prefix) == 1 + + async def test_push_to_partition_queue(self, fake_redis): + await backend.queue.push(_task_json("t1"), partition_key="user:42") + partition_key = f"{ax.CONF.queue_prefix}:user:42" + assert await fake_redis.llen(partition_key) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 0 + + async def test_pop_from_default_queue_no_lock(self, fake_redis): + """Default queue tasks are popped without acquiring a lock.""" + await backend.queue.push(_task_json("t1")) + result = await backend.queue.pop(timeout=1) + + assert result is not None + assert result["task_name"] == "t1" + + # No lock key should exist for the default queue + keys = [k async for k in fake_redis.scan_iter(match=b"*:lock")] + assert len(keys) == 0 + + +class TestPartitionLocking: + async def test_pop_acquires_lock_for_partition(self, fake_redis): + """Popping a partitioned task acquires its lock.""" + await backend.queue.push(_task_json("t1"), partition_key="user:42") + result = await backend.queue.pop(timeout=1) + + assert result is not None + lock_key = f"{ax.CONF.queue_prefix}:user:42:lock".encode() + assert await fake_redis.exists(lock_key) + + async def test_locked_partition_is_skipped(self, fake_redis): + """A partition with a held lock is skipped during pop.""" + await backend.queue.push(_task_json("t1"), partition_key="user:42") + + # Pre-acquire the lock + lock_key = f"{ax.CONF.queue_prefix}:user:42:lock" + await fake_redis.set(lock_key, b"1") + + result = await backend.queue.pop(timeout=1) + assert result is None + + async def test_complete_releases_lock(self, fake_redis): + """complete() deletes the partition lock.""" + await backend.queue.push(_task_json("t1"), partition_key="user:42") + await backend.queue.pop(timeout=1) + + lock_key = f"{ax.CONF.queue_prefix}:user:42:lock".encode() + assert await fake_redis.exists(lock_key) + + await backend.queue.complete("user:42") + assert not await fake_redis.exists(lock_key) + + async def test_complete_noop_for_none(self, fake_redis): + """complete(None) is a no-op for unpartitioned tasks.""" + await backend.queue.complete(None) + # No exception, no lock keys created + keys = [k async for k in fake_redis.scan_iter(match=b"*:lock")] + assert len(keys) == 0 + + +class TestMultiPartitionDequeue: + async def test_pops_from_unlocked_partition(self, fake_redis): + """With one locked and one unlocked partition, pop picks the unlocked one.""" + await backend.queue.push(_task_json("locked"), partition_key="user:1") + await backend.queue.push(_task_json("unlocked"), partition_key="user:2") + + # Lock user:1 + lock_key = f"{ax.CONF.queue_prefix}:user:1:lock" + await fake_redis.set(lock_key, b"1") + + result = await backend.queue.pop(timeout=1) + assert result is not None + assert result["task_name"] == "unlocked" + + async def test_pops_default_and_partition_interleaved(self, fake_redis): + """Tasks from default and partitioned queues are both reachable.""" + await backend.queue.push(_task_json("default_task")) + await backend.queue.push(_task_json("partitioned_task"), partition_key="org:99") + + results = [] + for _ in range(3): + r = await backend.queue.pop(timeout=1) + if r: + results.append(r["task_name"]) + # Release the lock if it was a partition task + if r["task_name"] == "partitioned_task": + await backend.queue.complete("org:99") + + assert sorted(results) == ["default_task", "partitioned_task"] + + async def test_serialization_within_partition(self, fake_redis): + """Only one task per partition can be in-flight at a time.""" + await backend.queue.push(_task_json("first"), partition_key="user:1") + await backend.queue.push(_task_json("second"), partition_key="user:1") + + # Pop first task — acquires lock + first = await backend.queue.pop(timeout=1) + assert first is not None + assert first["task_name"] == "first" + + # Second pop should skip user:1 (locked) and find nothing else + second = await backend.queue.pop(timeout=1) + assert second is None + + # After completing, second task becomes available + await backend.queue.complete("user:1") + second = await backend.queue.pop(timeout=1) + assert second is not None + assert second["task_name"] == "second" + + async def test_independent_partitions_are_concurrent(self, fake_redis): + """Different partitions can have tasks in-flight simultaneously.""" + await backend.queue.push(_task_json("user1_task"), partition_key="user:1") + await backend.queue.push(_task_json("user2_task"), partition_key="user:2") + + first = await backend.queue.pop(timeout=1) + assert first is not None + + second = await backend.queue.pop(timeout=1) + assert second is not None + + # Both tasks popped — different partitions, different locks + names = sorted([first["task_name"], second["task_name"]]) + assert names == ["user1_task", "user2_task"] + + async def test_empty_queue_returns_none(self, fake_redis): + result = await backend.queue.pop(timeout=1) + assert result is None + + async def test_high_priority_goes_to_front_of_partition(self, fake_redis): + """High priority tasks within a partition are popped first.""" + await backend.queue.push(_task_json("low"), partition_key="user:1") + await backend.queue.push( + _task_json("high"), partition_key="user:1", high_priority=True, + ) + + result = await backend.queue.pop(timeout=1) + assert result is not None + assert result["task_name"] == "high" diff --git a/tests/test_results.py b/tests/test_results.py index 01c9b43..43c4197 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -1,5 +1,3 @@ -"""Test task result storage and retrieval.""" - import asyncio import uuid from unittest.mock import AsyncMock, patch @@ -8,64 +6,53 @@ from pydantic import BaseModel import agentexec as ax -from agentexec.core.results import gather, get_result +from agentexec.core.results import gather, get_result, _get_result class SampleContext(BaseModel): - """Sample context for result tests.""" - message: str class SampleResult(BaseModel): - """Sample result model for tests.""" - status: str value: int class ComplexResult(BaseModel): - """Complex result model with nested data.""" - items: list[dict[str, int]] nested: dict[str, list[int]] @pytest.fixture -def mock_state(): - """Mock the state module's aget_result function.""" - with patch("agentexec.core.results.state") as mock: +def mock_get_result(): + """Mock the internal _get_result function.""" + with patch("agentexec.core.results._get_result") as mock: yield mock -async def test_get_result_returns_deserialized_data(mock_state) -> None: - """Test that get_result retrieves data from state.""" +async def test_get_result_returns_deserialized_data(mock_get_result) -> None: task = ax.Task( task_name="test_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) expected_result = SampleResult(status="success", value=42) - - # Mock aget_result to return the expected result - mock_state.aget_result = AsyncMock(return_value=expected_result) + mock_get_result.return_value = expected_result result = await get_result(task, timeout=1) assert result == expected_result - mock_state.aget_result.assert_called_once_with(task.agent_id) + mock_get_result.assert_called_once_with(task.agent_id) -async def test_get_result_polls_until_available(mock_state) -> None: - """Test that get_result polls until result is available.""" +async def test_get_result_polls_until_available(mock_get_result) -> None: task = ax.Task( task_name="test_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) expected_result = SampleResult(status="delayed", value=100) - # Return None first, then the result call_count = 0 async def delayed_result(agent_id): @@ -75,7 +62,7 @@ async def delayed_result(agent_id): return None return expected_result - mock_state.aget_result = delayed_result + mock_get_result.side_effect = delayed_result result = await get_result(task, timeout=5) @@ -83,46 +70,41 @@ async def delayed_result(agent_id): assert call_count == 3 -async def test_get_result_timeout(mock_state) -> None: - """Test that get_result raises TimeoutError if result not available.""" +async def test_get_result_timeout(mock_get_result) -> None: task = ax.Task( task_name="test_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) - - # Always return None to trigger timeout - mock_state.aget_result = AsyncMock(return_value=None) + mock_get_result.return_value = None with pytest.raises(TimeoutError, match=f"Result for {task.agent_id} not available"): await get_result(task, timeout=1) -async def test_gather_multiple_tasks(mock_state) -> None: - """Test that gather waits for multiple tasks and returns results.""" +async def test_gather_multiple_tasks(mock_get_result) -> None: task1 = ax.Task( task_name="task1", - context=SampleContext(message="test1"), + context={"message": "test1"}, agent_id=uuid.uuid4(), ) task2 = ax.Task( task_name="task2", - context=SampleContext(message="test2"), + context={"message": "test2"}, agent_id=uuid.uuid4(), ) result1 = SampleResult(status="task1", value=100) result2 = SampleResult(status="task2", value=200) - # Mock to return different results for different agent_ids - async def mock_aget_result(agent_id): + async def mock_result(agent_id): if agent_id == task1.agent_id: return result1 elif agent_id == task2.agent_id: return result2 return None - mock_state.aget_result = mock_aget_result + mock_get_result.side_effect = mock_result results = await gather(task1, task2) @@ -130,53 +112,48 @@ async def mock_aget_result(agent_id): assert len(results) == 2 -async def test_gather_single_task(mock_state) -> None: - """Test that gather works with a single task.""" +async def test_gather_single_task(mock_get_result) -> None: task = ax.Task( task_name="single_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) expected = SampleResult(status="single", value=1) - mock_state.aget_result = AsyncMock(return_value=expected) + mock_get_result.return_value = expected results = await gather(task) assert results == (expected,) -async def test_gather_preserves_order(mock_state) -> None: - """Test that gather returns results in the same order as input tasks.""" +async def test_gather_preserves_order(mock_get_result) -> None: tasks = [ ax.Task( task_name=f"task{i}", - context=SampleContext(message=f"msg{i}"), + context={"message": f"msg{i}"}, agent_id=uuid.uuid4(), ) for i in range(5) ] - # Create results mapped to task agent_ids results_map = {task.agent_id: SampleResult(status=f"result_{i}", value=i) for i, task in enumerate(tasks)} - async def mock_aget_result(agent_id): + async def mock_result(agent_id): return results_map.get(agent_id) - mock_state.aget_result = mock_aget_result + mock_get_result.side_effect = mock_result results = await gather(*tasks) - # Results should be in task order expected = tuple(SampleResult(status=f"result_{i}", value=i) for i in range(5)) assert results == expected -async def test_get_result_with_complex_object(mock_state) -> None: - """Test that get_result handles complex BaseModel objects.""" +async def test_get_result_with_complex_object(mock_get_result) -> None: task = ax.Task( task_name="test_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) @@ -184,7 +161,7 @@ async def test_get_result_with_complex_object(mock_state) -> None: items=[{"a": 1}, {"b": 2}], nested={"key": [1, 2, 3]}, ) - mock_state.aget_result = AsyncMock(return_value=expected) + mock_get_result.return_value = expected result = await get_result(task, timeout=1) diff --git a/tests/test_runners.py b/tests/test_runners.py index dd9d182..cd1763a 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -1,5 +1,3 @@ -"""Test runner base classes and functionality.""" - import uuid import pytest @@ -99,14 +97,14 @@ def test_report_status_function_docstring(self): assert report_fn.__doc__ is not None assert "progress" in report_fn.__doc__.lower() - def test_report_status_updates_activity(self, db_session, monkeypatch): + async def test_report_status_updates_activity(self, db_session, monkeypatch): """Test that report_status function calls activity.update.""" agent_id = uuid.uuid4() # Track calls to activity.update update_calls = [] - def mock_update(*args, **kwargs): + async def mock_update(*args, **kwargs): update_calls.append(kwargs) return True @@ -115,7 +113,7 @@ def mock_update(*args, **kwargs): tools = _RunnerTools(agent_id) report_fn = tools.report_status - result = report_fn("Working on task", 50) + result = await report_fn("Working on task", 50) assert result == "Status updated" assert len(update_calls) == 1 diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 5142dad..ebd6fb2 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -1,5 +1,3 @@ -"""Tests for scheduled task support.""" - import time import uuid from datetime import datetime @@ -11,14 +9,29 @@ from pydantic import BaseModel import agentexec as ax -from agentexec import state +from agentexec import state, schedule +from agentexec.core.queue import enqueue from agentexec.schedule import ( REPEAT_FOREVER, ScheduledTask, - tick, - _queue_key, - _schedule_key, + register, ) +from agentexec.state import backend + + +async def tick(): + """Test helper — replicates the pool's schedule tick logic.""" + for task in await backend.schedule.get_due(): + await enqueue( + task.task_name, + context=backend.deserialize(task.context), + metadata=task.metadata, + ) + if task.repeat == 0: + await backend.schedule.remove(task.key) + else: + task.advance() + await backend.schedule.register(task) class RefreshContext(BaseModel): @@ -26,40 +39,51 @@ class RefreshContext(BaseModel): ttl: int = 300 -@pytest.fixture -def fake_redis(monkeypatch): - """Setup fake redis for state backend with shared state.""" - import fakeredis +def _index_key() -> str: + return backend.format_key(ax.CONF.key_prefix, "schedules") - server = fakeredis.FakeServer() - fake_redis_sync = fakeredis.FakeRedis(server=server, decode_responses=False) - fake_redis_async = fake_aioredis.FakeRedis(server=server, decode_responses=False) - def get_fake_sync_client(): - return fake_redis_sync +def _data_key() -> str: + return backend.format_key(ax.CONF.key_prefix, "schedules", "data") - def get_fake_async_client(): - return fake_redis_async - monkeypatch.setattr("agentexec.state.redis_backend._get_sync_client", get_fake_sync_client) - monkeypatch.setattr("agentexec.state.redis_backend._get_async_client", get_fake_async_client) +async def _get_schedule(fake_redis, task_name: str) -> ScheduledTask | None: + """Find a schedule by task_name in the hash.""" + all_data = await fake_redis.hgetall(_data_key()) + for key, data in all_data.items(): + st = ScheduledTask.model_validate_json(data) + if st.task_name == task_name: + return st + return None - yield fake_redis_sync + +async def _force_due(fake_redis, task_name: str) -> ScheduledTask: + """Set a schedule's next_run to the past so tick() picks it up.""" + st = await _get_schedule(fake_redis, task_name) + if st is None: + raise ValueError(f"No schedule found for {task_name}") + st.next_run = time.time() - 10 + await fake_redis.hset(_data_key(), st.key, st.model_dump_json().encode()) + await fake_redis.zadd(_index_key(), {st.key: st.next_run}) + return st @pytest.fixture -def mock_activity_create(monkeypatch): - """Mock activity.create to avoid database dependency.""" +def fake_redis(monkeypatch): + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake - def mock_create(*args, **kwargs): - return uuid.uuid4() +@pytest.fixture +def mock_activity_create(monkeypatch): + async def mock_create(*args, **kwargs): + return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @pytest.fixture def pool(): - """Create a Pool with a registered task for scheduling tests.""" p = ax.Pool(database_url="sqlite:///") @p.task("refresh_cache") @@ -69,11 +93,6 @@ async def refresh(agent_id: UUID, context: RefreshContext): return p -# --------------------------------------------------------------------------- -# ScheduledTask model -# --------------------------------------------------------------------------- - - class TestScheduledTaskModel: def test_default_repeat_is_forever(self): ctx = RefreshContext(scope="test") @@ -81,7 +100,6 @@ def test_default_repeat_is_forever(self): task_name="test", context=state.backend.serialize(ctx), cron="*/5 * * * *", - ) assert st.repeat == REPEAT_FOREVER assert st.repeat == -1 @@ -92,27 +110,22 @@ def test_next_run_returns_future_timestamp(self): task_name="test", context=state.backend.serialize(ctx), cron="*/5 * * * *", - ) now = time.time() nxt = st._next_after(now) assert nxt > now def test_next_run_respects_anchor(self): - """Two calls with different anchors produce different results.""" ctx = RefreshContext(scope="test") st = ScheduledTask( task_name="test", context=state.backend.serialize(ctx), - cron="0 * * * *", # top of every hour - + cron="0 * * * *", ) anchor_a = 1_700_000_000.0 anchor_b = anchor_a + 3600 - next_a = st._next_after(anchor_a) next_b = st._next_after(anchor_b) - assert next_b > next_a assert next_b - next_a == pytest.approx(3600, abs=1) @@ -122,7 +135,6 @@ def test_cron_every_minute(self): task_name="test", context=state.backend.serialize(ctx), cron="* * * * *", - ) now = time.time() nxt = st._next_after(now) @@ -137,10 +149,8 @@ def test_roundtrip_serialization(self): repeat=5, next_run=time.time() + 600, ) - json_str = st.model_dump_json() restored = ScheduledTask.model_validate_json(json_str) - assert restored.task_name == "refresh" restored_ctx = state.backend.deserialize(restored.context) assert isinstance(restored_ctx, RefreshContext) @@ -159,115 +169,80 @@ def test_auto_generated_fields(self): assert st.created_at > 0 assert st.next_run > 0 - -# --------------------------------------------------------------------------- -# pool.add_schedule() -# --------------------------------------------------------------------------- + def test_key_includes_task_name_and_cron(self): + ctx = RefreshContext(scope="test") + st = ScheduledTask( + task_name="research", + context=state.backend.serialize(ctx), + cron="*/5 * * * *", + ) + assert st.key.startswith("research:*/5 * * * *:") class TestPoolAddSchedule: - def test_schedule_stores_in_redis(self, fake_redis, pool): + def test_schedule_defers_registration(self, pool): pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + assert len(pool._pending_schedules) == 1 + sched = pool._pending_schedules[0] + assert sched["task_name"] == "refresh_cache" + assert sched["every"] == "*/5 * * * *" - data = fake_redis.get(_schedule_key("refresh_cache")) - assert data is not None - - st = ScheduledTask.model_validate_json(data) - assert st.task_name == "refresh_cache" - ctx = state.backend.deserialize(st.context) - assert isinstance(ctx, RefreshContext) - assert ctx.scope == "all" - - def test_schedule_indexes_in_sorted_set(self, fake_redis, pool): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) - - members = fake_redis.zrange(_queue_key(), 0, -1, withscores=True) - assert len(members) == 1 - - def test_schedule_rejects_unregistered_task(self, fake_redis, pool): + def test_schedule_rejects_unregistered_task(self, pool): with pytest.raises(ValueError, match="not registered"): pool.add_schedule("nonexistent_task", "*/5 * * * *", RefreshContext(scope="all")) - def test_schedule_with_metadata(self, fake_redis, pool): + def test_schedule_with_metadata(self, pool): pool.add_schedule( "refresh_cache", "*/5 * * * *", RefreshContext(scope="all"), metadata={"org_id": "org-123"}, ) - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) - assert st.metadata == {"org_id": "org-123"} + assert pool._pending_schedules[0]["metadata"] == {"org_id": "org-123"} - def test_schedule_with_repeat(self, fake_redis, pool): + def test_schedule_with_repeat(self, pool): pool.add_schedule( "refresh_cache", "*/5 * * * *", RefreshContext(scope="all"), repeat=3, ) - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) - assert st.repeat == 3 + assert pool._pending_schedules[0]["repeat"] == 3 - def test_schedule_is_idempotent(self, fake_redis, pool): - """Calling add_schedule twice for the same task overwrites, not duplicates.""" - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="v1")) - pool.add_schedule("refresh_cache", "*/10 * * * *", RefreshContext(scope="v2")) - members = fake_redis.zrange(_queue_key(), 0, -1) - assert len(members) == 1 +class TestScheduleRegister: + async def test_register_stores_in_redis(self, fake_redis): + await register( + task_name="refresh_cache", + every="*/5 * * * *", + context=RefreshContext(scope="all"), + ) - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) - assert st.cron == "*/10 * * * *" + st = await _get_schedule(fake_redis, "refresh_cache") + assert st is not None + assert st.task_name == "refresh_cache" ctx = state.backend.deserialize(st.context) assert isinstance(ctx, RefreshContext) - assert ctx.scope == "v2" + assert ctx.scope == "all" + async def test_register_indexes_in_sorted_set(self, fake_redis): + await register( + task_name="refresh_cache", + every="*/5 * * * *", + context=RefreshContext(scope="all"), + ) -# --------------------------------------------------------------------------- -# @pool.schedule() decorator -# --------------------------------------------------------------------------- + members = await fake_redis.zrange(_index_key(), 0, -1, withscores=True) + assert len(members) == 1 class TestPoolScheduleDecorator: - def test_decorator_registers_task_and_schedule(self, fake_redis): - """@pool.schedule registers the task and schedules it.""" + def test_decorator_registers_task_and_defers_schedule(self): p = ax.Pool(database_url="sqlite:///") @p.schedule("refresh_cache", "*/5 * * * *", context=RefreshContext(scope="all")) async def refresh(agent_id: uuid.UUID, context: RefreshContext): pass - # Task is registered assert "refresh_cache" in p._context.tasks + assert len(p._pending_schedules) == 1 - # Schedule is in Redis - members = fake_redis.zrange(_queue_key(), 0, -1) - assert len(members) == 1 - - def test_decorator_without_context(self, fake_redis): - """@pool.schedule works without explicit context (defaults to empty BaseModel).""" - p = ax.Pool(database_url="sqlite:///") - - @p.schedule("simple_task", "0 * * * *") - async def simple(agent_id: uuid.UUID, context: BaseModel): - pass - - assert "simple_task" in p._context.tasks - members = fake_redis.zrange(_queue_key(), 0, -1) - assert len(members) == 1 - - def test_decorator_with_repeat(self, fake_redis): - """@pool.schedule passes repeat through.""" - p = ax.Pool(database_url="sqlite:///") - - @p.schedule("limited_task", "*/10 * * * *", context=RefreshContext(scope="all"), repeat=5) - async def limited(agent_id: uuid.UUID, context: RefreshContext): - pass - - data = fake_redis.get(_schedule_key("limited_task")) - st = ScheduledTask.model_validate_json(data) - assert st.repeat == 5 - - def test_decorator_with_lock_key(self, fake_redis): - """@pool.schedule passes lock_key to the task registration.""" + def test_decorator_with_lock_key(self): p = ax.Pool(database_url="sqlite:///") @p.schedule("locked_task", "*/5 * * * *", lock_key="user:{user_id}") @@ -277,8 +252,7 @@ async def locked(agent_id: uuid.UUID, context: RefreshContext): defn = p._context.tasks["locked_task"] assert defn.lock_key == "user:{user_id}" - def test_decorator_returns_handler(self, fake_redis): - """@pool.schedule returns the original handler function.""" + def test_decorator_returns_handler(self): p = ax.Pool(database_url="sqlite:///") @p.schedule("my_task", "*/5 * * * *") @@ -289,156 +263,120 @@ async def my_handler(agent_id: uuid.UUID, context: BaseModel): assert my_handler.__name__ == "my_handler" -# --------------------------------------------------------------------------- -# tick — the scheduler heartbeat -# --------------------------------------------------------------------------- - - -def _force_due(fake_redis, task_name): - """Helper: set a schedule's next_run to the past so tick() picks it up.""" - data = fake_redis.get(_schedule_key(task_name)) - st = ScheduledTask.model_validate_json(data) - st.next_run = time.time() - 10 - fake_redis.set(_schedule_key(task_name), st.model_dump_json().encode()) - fake_redis.zadd(_queue_key(), {task_name: st.next_run}) - return st - - class TestTick: - async def test_tick_enqueues_due_task(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) - _force_due(fake_redis, "refresh_cache") + async def test_tick_enqueues_due_task(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + await _force_due(fake_redis, "refresh_cache") await tick() - assert fake_redis.llen(ax.CONF.queue_name) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 1 - async def test_tick_skips_future_tasks(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + async def test_tick_skips_future_tasks(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) await tick() - assert fake_redis.llen(ax.CONF.queue_name) == 0 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 0 - async def test_tick_removes_one_shot_schedule(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "* * * * *", RefreshContext(scope="all"), repeat=0) - _force_due(fake_redis, "refresh_cache") + async def test_tick_removes_one_shot_schedule(self, fake_redis, mock_activity_create): + await register("refresh_cache", "* * * * *", RefreshContext(scope="all"), repeat=0) + await _force_due(fake_redis, "refresh_cache") await tick() - assert fake_redis.get(_schedule_key("refresh_cache")) is None - assert fake_redis.zcard(_queue_key()) == 0 + assert await _get_schedule(fake_redis, "refresh_cache") is None + assert await fake_redis.zcard(_index_key()) == 0 - async def test_tick_decrements_repeat_count(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all"), repeat=3) - old_st = _force_due(fake_redis, "refresh_cache") + async def test_tick_decrements_repeat_count(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all"), repeat=3) + old_st = await _force_due(fake_redis, "refresh_cache") await tick() - data = fake_redis.get(_schedule_key("refresh_cache")) - updated = ScheduledTask.model_validate_json(data) - assert updated.repeat == 2 + updated = await _get_schedule(fake_redis, "refresh_cache") + assert updated.repeat < 3 assert updated.next_run > old_st.next_run - async def test_tick_infinite_repeat_stays_negative(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) - _force_due(fake_redis, "refresh_cache") + async def test_tick_infinite_repeat_stays_negative(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + await _force_due(fake_redis, "refresh_cache") await tick() - data = fake_redis.get(_schedule_key("refresh_cache")) - updated = ScheduledTask.model_validate_json(data) + updated = await _get_schedule(fake_redis, "refresh_cache") assert updated.repeat == -1 - async def test_tick_anchor_based_rescheduling(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) - old_st = _force_due(fake_redis, "refresh_cache") + async def test_tick_anchor_based_rescheduling(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + old_st = await _force_due(fake_redis, "refresh_cache") await tick() - data = fake_redis.get(_schedule_key("refresh_cache")) - updated = ScheduledTask.model_validate_json(data) + updated = await _get_schedule(fake_redis, "refresh_cache") assert updated.next_run > old_st.next_run - async def test_tick_skips_orphaned_entries(self, fake_redis, pool, mock_activity_create): - """Orphaned queue entries are skipped (not deleted) with a warning.""" - fake_redis.zadd(_queue_key(), {"orphan-id": time.time() - 100}) + async def test_tick_skips_orphaned_entries(self, fake_redis, mock_activity_create): + """Orphaned index entries are skipped with a warning.""" + await fake_redis.zadd(_index_key(), {"orphan-id": time.time() - 100}) await tick() - assert fake_redis.zcard(_queue_key()) == 1 - assert fake_redis.llen(ax.CONF.queue_name) == 0 + assert await fake_redis.zcard(_index_key()) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 0 - async def test_tick_skips_missed_intervals(self, fake_redis, pool, mock_activity_create): - """After downtime, advance() skips to the next future run — no burst of catch-up tasks.""" - pool.add_schedule("refresh_cache", "*/1 * * * *", RefreshContext(scope="all")) + async def test_tick_skips_missed_intervals(self, fake_redis, mock_activity_create): + """After downtime, advance() skips to the next future run.""" + await register("refresh_cache", "*/1 * * * *", RefreshContext(scope="all")) - # Simulate 10 minutes of downtime - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) + st = await _get_schedule(fake_redis, "refresh_cache") st.next_run = time.time() - 600 - fake_redis.set(_schedule_key("refresh_cache"), st.model_dump_json().encode()) - fake_redis.zadd(_queue_key(), {"refresh_cache": st.next_run}) + await fake_redis.hset(_data_key(), st.key, st.model_dump_json().encode()) + await fake_redis.zadd(_index_key(), {st.key: st.next_run}) await tick() - assert fake_redis.llen(ax.CONF.queue_name) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 1 - # Second tick should not enqueue again (next_run is in the future now) await tick() - assert fake_redis.llen(ax.CONF.queue_name) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 1 - async def test_context_payload_preserved(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="users", ttl=999)) + async def test_context_payload_preserved(self, fake_redis): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="users", ttl=999)) - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) + st = await _get_schedule(fake_redis, "refresh_cache") ctx = state.backend.deserialize(st.context) assert isinstance(ctx, RefreshContext) assert ctx.scope == "users" assert ctx.ttl == 999 -# --------------------------------------------------------------------------- -# Timezone configuration -# --------------------------------------------------------------------------- - - class TestTimezone: def test_default_timezone_is_utc(self): - """Default should be UTC.""" from agentexec.config import CONF - assert CONF.scheduler_timezone == "UTC" def test_scheduler_tz_returns_zoneinfo(self): from agentexec.config import CONF - tz = CONF.scheduler_tz assert isinstance(tz, ZoneInfo) def test_cron_respects_configured_timezone(self, monkeypatch): - """Cron evaluation should use the configured timezone.""" from agentexec.config import CONF - monkeypatch.setattr(CONF, "scheduler_timezone", "America/New_York") ctx = RefreshContext(scope="test") st = ScheduledTask( task_name="test", context=state.backend.serialize(ctx), - cron="0 9 * * *", # 9 AM - + cron="0 9 * * *", ) - # Use a known timestamp: 2024-01-15 9:00 AM ET anchor = datetime(2024, 1, 15, 9, 0, 0, tzinfo=ZoneInfo("America/New_York")).timestamp() nxt = st._next_after(anchor) - # Next 9 AM ET should be ~24h later next_dt = datetime.fromtimestamp(nxt, tz=ZoneInfo("America/New_York")) assert next_dt.hour == 9 assert next_dt.day == 16 def test_timezone_env_override(self, monkeypatch): - """AGENTEXEC_SCHEDULER_TIMEZONE env var should override default.""" monkeypatch.setenv("AGENTEXEC_SCHEDULER_TIMEZONE", "Asia/Tokyo") from agentexec.config import Config diff --git a/tests/test_self_describing_results.py b/tests/test_self_describing_results.py index acf56d6..923cdae 100644 --- a/tests/test_self_describing_results.py +++ b/tests/test_self_describing_results.py @@ -1,94 +1,73 @@ -"""Test self-describing result serialization (pickle-like behavior with JSON).""" - import uuid import pytest from pydantic import BaseModel import agentexec as ax -from agentexec import state +from agentexec.state import KEY_RESULT, backend class DummyContext(BaseModel): - """Dummy context for testing.""" - pass class ResearchResult(BaseModel): - """Sample result model.""" - company: str valuation: int class AnalysisResult(BaseModel): - """Another result model.""" - conclusion: str confidence: float class NestedData(BaseModel): - """Nested data structure for testing.""" - items: list[str] metadata: dict[str, int] class ComplexResult(BaseModel): - """Complex result with nested structure.""" - status: str data: NestedData async def test_gather_without_task_definitions(monkeypatch) -> None: - """Test that gather() works without needing TaskDefinitions. - - This demonstrates that results are self-describing - they include - their type information, so we can deserialize without a registry. - """ - # Create tasks without TaskDefinitions (as enqueue() does) + """Test that gather() works without needing TaskDefinitions.""" task1 = ax.Task( task_name="research", - context=DummyContext(), + context={}, agent_id=uuid.uuid4(), ) task2 = ax.Task( task_name="analysis", - context=DummyContext(), + context={}, agent_id=uuid.uuid4(), ) - # Store results with type information result1 = ResearchResult(company="Anthropic", valuation=1000000) result2 = AnalysisResult(conclusion="Strong", confidence=0.95) # Mock backend storage storage = {} - def mock_format_key(*args): - return ":".join(args) - - async def mock_aset(key, value, ttl_seconds=None): + async def mock_state_set(key, value, ttl_seconds=None): storage[key] = value return True - async def mock_aget(key): + async def mock_state_get(key): return storage.get(key) - monkeypatch.setattr(state.backend, "format_key", mock_format_key) - monkeypatch.setattr(state.backend, "aset", mock_aset) - monkeypatch.setattr(state.backend, "aget", mock_aget) + monkeypatch.setattr(backend.state, "set", mock_state_set) + monkeypatch.setattr(backend.state, "get", mock_state_get) - await state.aset_result(task1.agent_id, result1) - await state.aset_result(task2.agent_id, result2) + # Store results via the same path task.execute() would + for task, result in [(task1, result1), (task2, result2)]: + key = backend.format_key(*KEY_RESULT, str(task.agent_id)) + await backend.state.set(key, backend.serialize(result)) - # Gather results - no TaskDefinition needed! + # Gather results results = await ax.gather(task1, task2) - # Results are correctly typed assert isinstance(results[0], ResearchResult) assert isinstance(results[1], AnalysisResult) assert results[0].company == "Anthropic" @@ -99,11 +78,8 @@ async def test_result_roundtrip_preserves_type() -> None: """Test that serialize → deserialize preserves exact type.""" original = ResearchResult(company="Acme", valuation=500000) - # Serialize - serialized = state.backend.serialize(original) - - # Deserialize - should get back the same type - deserialized = state.backend.deserialize(serialized) + serialized = backend.serialize(original) + deserialized = backend.deserialize(serialized) assert type(deserialized) is ResearchResult assert deserialized == original @@ -116,9 +92,8 @@ async def test_nested_models_preserve_structure() -> None: data=NestedData(items=["a", "b"], metadata={"count": 2}), ) - # Roundtrip - serialized = state.backend.serialize(original) - deserialized = state.backend.deserialize(serialized) + serialized = backend.serialize(original) + deserialized = backend.deserialize(serialized) assert type(deserialized) is ComplexResult assert type(deserialized.data) is NestedData diff --git a/tests/test_state.py b/tests/test_state.py index 1a54e0d..4f53b92 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,185 +1,59 @@ -"""Tests for state module public API.""" - -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest from pydantic import BaseModel -from agentexec import state +from agentexec.state import KEY_RESULT, backend -# Test models for result serialization class ResultModel(BaseModel): - """Test result model.""" - status: str value: int -class OutputModel(BaseModel): - """Test output model.""" +class TestSerialization: + """Tests for serialize/deserialize on the backend.""" - status: str - output: str + def test_roundtrip(self): + model = ResultModel(status="success", value=42) + data = backend.serialize(model) + restored = backend.deserialize(data) + assert isinstance(restored, ResultModel) + assert restored == model -class TestResultOperations: - """Tests for result get/set/delete operations.""" +class TestFormatKey: + """Tests for key formatting.""" - def test_get_result_found(self): - """Test getting an existing result returns deserialized BaseModel.""" - result_model = ResultModel(status="success", value=42) - # Serialize with type information (mimicking backend.serialize) - serialized = state.backend.serialize(result_model) + def test_result_key(self): + key = backend.format_key(*KEY_RESULT, "agent-123") + assert "result" in key + assert "agent-123" in key - with patch.object(state.backend, "get", return_value=serialized) as mock_get: - result = state.get_result("agent123") - mock_get.assert_called_once_with("agentexec:result:agent123") - # Result should be deserialized BaseModel - assert isinstance(result, ResultModel) - assert result == result_model - def test_get_result_not_found(self): - """Test getting a non-existent result returns None.""" - with patch.object(state.backend, "get", return_value=None) as mock_get: - result = state.get_result("agent456") - - mock_get.assert_called_once_with("agentexec:result:agent456") - assert result is None +class TestStateBackend: + """Tests for state.get/set/delete via backend.state.""" - async def test_aget_result_found(self): - """Test async getting an existing result returns deserialized BaseModel.""" - result_model = OutputModel(status="complete", output="test") - serialized = state.backend.serialize(result_model) + async def test_set_and_get(self): + result = ResultModel(status="success", value=42) + serialized = backend.serialize(result) - async def mock_aget(key): + async def mock_get(key): return serialized - with patch.object(state.backend, "aget", side_effect=mock_aget): - result = await state.aget_result("agent789") + with patch.object(backend.state, "get", side_effect=mock_get): + data = await backend.state.get("test-key") + restored = backend.deserialize(data) + assert isinstance(restored, ResultModel) + assert restored == result - # Result should be deserialized BaseModel - assert isinstance(result, OutputModel) - assert result == result_model - - async def test_aget_result_not_found(self): - """Test async getting a non-existent result.""" - async def mock_aget(key): + async def test_get_missing(self): + async def mock_get(key): return None - with patch.object(state.backend, "aget", side_effect=mock_aget): - result = await state.aget_result("missing") - + with patch.object(backend.state, "get", side_effect=mock_get): + result = await backend.state.get("missing-key") assert result is None - def test_set_result_without_ttl(self): - """Test setting a result without TTL.""" - result_model = ResultModel(status="success", value=42) - - with patch.object(state.backend, "set", return_value=True) as mock_set: - success = state.set_result("agent123", result_model) - - mock_set.assert_called_once() - call_args = mock_set.call_args - assert call_args[0][0] == "agentexec:result:agent123" - # Should be JSON bytes with type information - stored_value = call_args[0][1] - assert isinstance(stored_value, bytes) - # Verify it can be deserialized back - deserialized = state.backend.deserialize(stored_value) - assert isinstance(deserialized, ResultModel) - assert deserialized == result_model - assert call_args[1]["ttl_seconds"] is None - assert success is True - - def test_set_result_with_ttl(self): - """Test setting a result with TTL.""" - result_model = ResultModel(status="success", value=100) - - with patch.object(state.backend, "set", return_value=True) as mock_set: - success = state.set_result("agent456", result_model, ttl_seconds=3600) - - call_args = mock_set.call_args - assert call_args[0][0] == "agentexec:result:agent456" - assert call_args[1]["ttl_seconds"] == 3600 - assert success is True - - async def test_aset_result(self): - """Test async setting a result.""" - result_model = OutputModel(status="complete", output="test") - - async def mock_aset(key, value, ttl_seconds=None): - return True - - with patch.object(state.backend, "aset", side_effect=mock_aset): - success = await state.aset_result("agent789", result_model, ttl_seconds=7200) - - assert success is True - - def test_delete_result(self): - """Test deleting a result.""" - with patch.object(state.backend, "delete", return_value=1) as mock_delete: - count = state.delete_result("agent123") - - mock_delete.assert_called_once_with("agentexec:result:agent123") - assert count == 1 - - async def test_adelete_result(self): - """Test async deleting a result.""" - async def mock_adelete(key): - return 1 - - with patch.object(state.backend, "adelete", side_effect=mock_adelete): - count = await state.adelete_result("agent456") - - assert count == 1 - - -class TestLogOperations: - """Tests for log pub/sub operations.""" - - def test_publish_log(self): - """Test publishing a log message.""" - log_message = '{"level": "info", "message": "test log"}' - - with patch.object(state.backend, "publish") as mock_publish: - state.publish_log(log_message) - - mock_publish.assert_called_once_with("agentexec:logs", log_message) - - async def test_subscribe_logs(self): - """Test subscribing to logs.""" - log_messages = [ - '{"level": "info", "message": "log1"}', - '{"level": "error", "message": "log2"}' - ] - - async def mock_subscribe(channel): - for msg in log_messages: - yield msg - - with patch.object(state.backend, "subscribe", side_effect=mock_subscribe): - messages = [] - async for msg in state.subscribe_logs(): - messages.append(msg) - - assert messages == log_messages - - -class TestKeyGeneration: - """Tests for key generation with format_key.""" - - def test_result_key_format(self): - """Test that result keys are formatted correctly.""" - with patch.object(state.backend, "get", return_value=None) as mock_get: - state.get_result("test-id") - - mock_get.assert_called_once_with("agentexec:result:test-id") - - def test_logs_channel_format(self): - """Test that log channel is formatted correctly.""" - with patch.object(state.backend, "publish") as mock_publish: - state.publish_log("test") - mock_publish.assert_called_once_with("agentexec:logs", "test") diff --git a/tests/test_state_backend.py b/tests/test_state_backend.py index 3c00787..464ce46 100644 --- a/tests/test_state_backend.py +++ b/tests/test_state_backend.py @@ -1,292 +1,120 @@ -"""Tests for state backend module.""" - -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest from pydantic import BaseModel -from agentexec.state import redis_backend +from agentexec.state import backend +from agentexec.state.redis import Backend as RedisBackend class SampleModel(BaseModel): - """Sample model for serialization tests.""" - status: str value: int class NestedModel(BaseModel): - """Model with nested structure for serialization tests.""" - items: list[int] metadata: dict[str, str] -@pytest.fixture(autouse=True) -def reset_redis_clients(): - """Reset Redis client state before and after each test.""" - redis_backend._redis_client = None - redis_backend._redis_sync_client = None - redis_backend._pubsub = None - yield - redis_backend._redis_client = None - redis_backend._redis_sync_client = None - redis_backend._pubsub = None - - -@pytest.fixture -def mock_sync_client(): - """Mock synchronous Redis client.""" - client = MagicMock() - with patch.object(redis_backend, "_get_sync_client", return_value=client): - yield client - - @pytest.fixture -def mock_async_client(): - """Mock asynchronous Redis client.""" +def mock_client(monkeypatch): + """Inject a mock async Redis client into the backend.""" client = AsyncMock() - with patch.object(redis_backend, "_get_async_client", return_value=client): - yield client + monkeypatch.setattr(backend, "_client", client) + yield client class TestFormatKey: - """Tests for format_key function.""" - def test_format_single_part(self): - """Test formatting key with single part.""" - result = redis_backend.format_key("result") - assert result == "result" + assert backend.format_key("result") == "result" def test_format_multiple_parts(self): - """Test formatting key with multiple parts.""" - result = redis_backend.format_key("agentexec", "result", "123") - assert result == "agentexec:result:123" + assert backend.format_key("agentexec", "result", "123") == "agentexec:result:123" def test_format_empty_parts(self): - """Test formatting with no parts returns empty string.""" - result = redis_backend.format_key() - assert result == "" + assert backend.format_key() == "" class TestSerialization: - """Tests for serialize and deserialize functions.""" - def test_serialize_basemodel(self): - """Test serializing a BaseModel.""" data = SampleModel(status="success", value=42) - result = redis_backend.serialize(data) + result = backend.serialize(data) assert isinstance(result, bytes) - def test_serialize_rejects_dict(self): - """Test that serialize rejects raw dicts.""" - with pytest.raises(TypeError, match="Expected BaseModel"): - redis_backend.serialize({"key": "value"}) # type: ignore[arg-type] - - def test_serialize_rejects_list(self): - """Test that serialize rejects raw lists.""" - with pytest.raises(TypeError, match="Expected BaseModel"): - redis_backend.serialize([1, 2, 3]) # type: ignore[arg-type] - def test_serialize_deserialize_roundtrip(self): - """Test serialize then deserialize returns equivalent model.""" data = SampleModel(status="success", value=42) - serialized = redis_backend.serialize(data) - deserialized = redis_backend.deserialize(serialized) + serialized = backend.serialize(data) + deserialized = backend.deserialize(serialized) assert isinstance(deserialized, SampleModel) - assert deserialized.status == data.status - assert deserialized.value == data.value + assert deserialized == data def test_serialize_deserialize_nested_model(self): - """Test roundtrip with nested structures.""" data = NestedModel(items=[1, 2, 3], metadata={"key": "value"}) - serialized = redis_backend.serialize(data) - deserialized = redis_backend.deserialize(serialized) + serialized = backend.serialize(data) + deserialized = backend.deserialize(serialized) assert isinstance(deserialized, NestedModel) - assert deserialized.items == data.items - assert deserialized.metadata == data.metadata - - -class TestQueueOperations: - """Tests for queue operations (rpush, lpush, brpop).""" - - def test_rpush(self, mock_sync_client): - """Test rpush adds to right of list.""" - mock_sync_client.rpush.return_value = 5 - - result = redis_backend.rpush("tasks", "task_data") - - mock_sync_client.rpush.assert_called_once_with("tasks", "task_data") - assert result == 5 - - def test_lpush(self, mock_sync_client): - """Test lpush adds to left of list.""" - mock_sync_client.lpush.return_value = 3 - - result = redis_backend.lpush("tasks", "task_data") - - mock_sync_client.lpush.assert_called_once_with("tasks", "task_data") - assert result == 3 - - async def test_brpop_with_result(self, mock_async_client): - """Test brpop returns decoded result.""" - mock_async_client.brpop.return_value = (b"tasks", b"task_value") - - result = await redis_backend.brpop("tasks", timeout=5) - - mock_async_client.brpop.assert_called_once_with(["tasks"], timeout=5) - assert result == ("tasks", "task_value") - - async def test_brpop_timeout(self, mock_async_client): - """Test brpop returns None on timeout.""" - mock_async_client.brpop.return_value = None - - result = await redis_backend.brpop("tasks", timeout=1) - - assert result is None + assert deserialized == data class TestKeyValueOperations: - """Tests for get/set/delete operations.""" - - def test_get_sync(self, mock_sync_client): - """Test synchronous get.""" - mock_sync_client.get.return_value = b"value" - - result = redis_backend.get("mykey") - - mock_sync_client.get.assert_called_once_with("mykey") + async def test_get(self, mock_client): + mock_client.get.return_value = b"value" + result = await backend.state.get("mykey") + mock_client.get.assert_called_once_with("mykey") assert result == b"value" - def test_get_sync_missing_key(self, mock_sync_client): - """Test get returns None for missing key.""" - mock_sync_client.get.return_value = None - - result = redis_backend.get("missing") - + async def test_get_missing_key(self, mock_client): + mock_client.get.return_value = None + result = await backend.state.get("missing") assert result is None - async def test_aget(self, mock_async_client): - """Test asynchronous get.""" - mock_async_client.get.return_value = b"async_value" - - result = await redis_backend.aget("mykey") - - mock_async_client.get.assert_called_once_with("mykey") - assert result == b"async_value" - - def test_set_sync(self, mock_sync_client): - """Test synchronous set without TTL.""" - mock_sync_client.set.return_value = True - - result = redis_backend.set("mykey", b"value") - - mock_sync_client.set.assert_called_once_with("mykey", b"value") + async def test_set_without_ttl(self, mock_client): + mock_client.set.return_value = True + result = await backend.state.set("mykey", b"value") + mock_client.set.assert_called_once_with("mykey", b"value") assert result is True - def test_set_sync_with_ttl(self, mock_sync_client): - """Test synchronous set with TTL.""" - mock_sync_client.set.return_value = True - - result = redis_backend.set("mykey", b"value", ttl_seconds=3600) - - mock_sync_client.set.assert_called_once_with("mykey", b"value", ex=3600) - assert result is True - - async def test_aset(self, mock_async_client): - """Test asynchronous set with TTL.""" - mock_async_client.set.return_value = True - - result = await redis_backend.aset("mykey", b"value", ttl_seconds=7200) - - mock_async_client.set.assert_called_once_with("mykey", b"value", ex=7200) + async def test_set_with_ttl(self, mock_client): + mock_client.set.return_value = True + result = await backend.state.set("mykey", b"value", ttl_seconds=3600) + mock_client.set.assert_called_once_with("mykey", b"value", ex=3600) assert result is True - def test_delete_sync(self, mock_sync_client): - """Test synchronous delete.""" - mock_sync_client.delete.return_value = 1 - - result = redis_backend.delete("mykey") - - mock_sync_client.delete.assert_called_once_with("mykey") + async def test_delete(self, mock_client): + mock_client.delete.return_value = 1 + result = await backend.state.delete("mykey") + mock_client.delete.assert_called_once_with("mykey") assert result == 1 - async def test_adelete(self, mock_async_client): - """Test asynchronous delete.""" - mock_async_client.delete.return_value = 1 - result = await redis_backend.adelete("mykey") - - mock_async_client.delete.assert_called_once_with("mykey") - assert result == 1 - - -class TestPubSubOperations: - """Tests for pub/sub operations.""" - - def test_publish(self, mock_sync_client): - """Test publishing message to channel.""" - redis_backend.publish("logs", "log message") - - mock_sync_client.publish.assert_called_once_with("logs", "log message") - - async def test_subscribe(self, mock_async_client): - """Test subscribing to channel.""" - mock_pubsub = AsyncMock() - # Make pubsub() return the mock directly (not a coroutine) - mock_async_client.pubsub = MagicMock(return_value=mock_pubsub) - - # Create async iterator for messages - async def mock_listen(): - yield {"type": "subscribe"} - yield {"type": "message", "data": b"message1"} - yield {"type": "message", "data": "message2"} - - # Make listen() return the generator directly (not wrapped in AsyncMock) - mock_pubsub.listen = MagicMock(return_value=mock_listen()) - - messages = [] - async for msg in redis_backend.subscribe("test_channel"): - messages.append(msg) +class TestCounterOperations: + async def test_counter_incr(self, mock_client): + mock_client.incr.return_value = 5 + result = await backend.state.counter_incr("mycount") + mock_client.incr.assert_called_once_with("mycount") + assert result == 5 - assert messages == ["message1", "message2"] - mock_pubsub.subscribe.assert_called_once_with("test_channel") - mock_pubsub.unsubscribe.assert_called_once_with("test_channel") - mock_pubsub.close.assert_called_once() + async def test_counter_decr(self, mock_client): + mock_client.decr.return_value = 3 + result = await backend.state.counter_decr("mycount") + mock_client.decr.assert_called_once_with("mycount") + assert result == 3 class TestConnectionManagement: - """Tests for connection lifecycle.""" - async def test_close_all_connections(self): - """Test close cleans up all resources.""" - # Set up mock clients - mock_async = AsyncMock() - mock_sync = MagicMock() - mock_ps = AsyncMock() - - redis_backend._redis_client = mock_async - redis_backend._redis_sync_client = mock_sync - redis_backend._pubsub = mock_ps - - await redis_backend.close() + mock_client = AsyncMock() + backend._client = mock_client - mock_ps.close.assert_called_once() - mock_async.aclose.assert_called_once() - mock_sync.close.assert_called_once() + await backend.close() - assert redis_backend._redis_client is None - assert redis_backend._redis_sync_client is None - assert redis_backend._pubsub is None + mock_client.aclose.assert_called_once() + assert backend._client is None async def test_close_handles_none_clients(self): - """Test close handles None clients gracefully.""" - redis_backend._redis_client = None - redis_backend._redis_sync_client = None - redis_backend._pubsub = None + backend._client = None - # Should not raise - await redis_backend.close() + await backend.close() - assert redis_backend._redis_client is None - assert redis_backend._redis_sync_client is None + assert backend._client is None diff --git a/tests/test_task.py b/tests/test_task.py index 52c4dfd..ebc7725 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,5 +1,3 @@ -"""Test Task data structure and serialization.""" - import json import uuid @@ -7,44 +5,36 @@ from pydantic import BaseModel import agentexec as ax +from agentexec.core.task import TaskDefinition class SampleContext(BaseModel): - """Sample context for task tests.""" - message: str value: int = 0 class NestedContext(BaseModel): - """Sample context with nested data.""" - message: str nested: dict class TaskResult(BaseModel): - """Sample result model for task tests.""" - status: str @pytest.fixture def pool(): - """Create a Pool for testing.""" from sqlalchemy import create_engine - engine = create_engine("sqlite:///:memory:") return ax.Pool(engine=engine) def test_task_serialization() -> None: - """Test that tasks can be serialized to JSON.""" + """Task serializes to JSON with context as a dict.""" agent_id = uuid.uuid4() - ctx = SampleContext(message="hello", value=42) task = ax.Task( task_name="test_task", - context=ctx, + context={"message": "hello", "value": 42}, agent_id=agent_id, ) @@ -57,15 +47,8 @@ def test_task_serialization() -> None: assert task_data["agent_id"] == str(agent_id) -def test_task_deserialization(pool) -> None: - """Test that tasks can be deserialized using Task.from_serialized.""" - # Register a task to get a TaskDefinition - @pool.task("test_task") - async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: - return TaskResult(status="success") - - task_def = pool._context.tasks["test_task"] - +def test_task_deserialization() -> None: + """Task deserializes from raw queue data.""" agent_id = uuid.uuid4() data = { "task_name": "test_task", @@ -73,124 +56,82 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: "agent_id": str(agent_id), } - task = ax.Task.from_serialized(task_def, data) + task = ax.Task.model_validate(data) assert task.task_name == "test_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "hello" - assert task.context.value == 42 + assert task.context == {"message": "hello", "value": 42} assert task.agent_id == agent_id -def test_task_round_trip(pool) -> None: - """Test that tasks can be serialized and deserialized.""" - # Register task for deserialization - @pool.task("round_trip_task") - async def handler(agent_id: uuid.UUID, context: NestedContext) -> TaskResult: - return TaskResult(status="success") - - task_def = pool._context.tasks["round_trip_task"] - - original_ctx = NestedContext(message="hello", nested={"key": "value"}) +def test_task_round_trip() -> None: + """Task survives serialize → JSON → deserialize.""" original = ax.Task( task_name="round_trip_task", - context=original_ctx, + context={"message": "hello", "nested": {"key": "value"}}, agent_id=uuid.uuid4(), ) - # Serialize → JSON → Deserialize serialized = original.model_dump_json() - data = json.loads(serialized) - deserialized = ax.Task.from_serialized(task_def, data) + deserialized = ax.Task.model_validate_json(serialized) assert deserialized.task_name == original.task_name - # Cast to access typed attributes (Task.context is typed as BaseModel) - assert isinstance(deserialized.context, NestedContext) - assert isinstance(original.context, NestedContext) - assert deserialized.context.message == original.context.message - assert deserialized.context.nested == original.context.nested + assert deserialized.context == original.context assert deserialized.agent_id == original.agent_id -def test_task_create_with_basemodel(monkeypatch) -> None: - """Test Task.create() with a BaseModel context.""" - # Mock activity.create to avoid database dependency - def mock_create(*args, **kwargs): +async def test_task_create_with_basemodel(monkeypatch) -> None: + """Task.create() serializes context to dict.""" + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) ctx = SampleContext(message="hello", value=42) - task = ax.Task.create("test_task", ctx) + task = await ax.Task.create("test_task", ctx) assert task.task_name == "test_task" - # Context is the typed object - assert isinstance(task.context, SampleContext) - assert task.context.message == "hello" - assert task.context.value == 42 + assert task.context == {"message": "hello", "value": 42} -def test_task_create_preserves_nested(monkeypatch) -> None: - """Test Task.create() preserves nested Pydantic models.""" - # Mock activity.create to avoid database dependency - def mock_create(*args, **kwargs): +async def test_task_create_preserves_nested(monkeypatch) -> None: + """Task.create() preserves nested structures in the dict.""" + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) ctx = NestedContext(message="hello", nested={"key": "value"}) - task = ax.Task.create("test_task", ctx) + task = await ax.Task.create("test_task", ctx) - assert isinstance(task.context, NestedContext) - assert task.context.message == "hello" - assert task.context.nested == {"key": "value"} + assert task.context == {"message": "hello", "nested": {"key": "value"}} -def test_task_from_serialized(pool) -> None: - """Test Task.from_serialized creates a task with typed context.""" - from agentexec.core.task import TaskDefinition - +def test_definition_hydrates_context(pool) -> None: + """TaskDefinition.hydrate_context validates dict into typed model.""" @pool.task("test_task") async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: - return TaskResult(status=f"Result: {context.message}") - - task_def = pool._context.tasks["test_task"] - agent_id = uuid.uuid4() - - data = { - "task_name": "test_task", - "context": {"message": "hello", "value": 42}, - "agent_id": str(agent_id), - } - - task = ax.Task.from_serialized(task_def, data) + return TaskResult(status="success") - assert task.task_name == "test_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "hello" - assert task.context.value == 42 - assert task.agent_id == agent_id - assert task._definition is task_def + definition = pool._context.tasks["test_task"] + typed = definition.hydrate_context({"message": "hello", "value": 42}) + assert isinstance(typed, SampleContext) + assert typed.message == "hello" + assert typed.value == 42 -async def test_task_execute_async_handler(pool, monkeypatch) -> None: - """Test Task.execute with an async handler.""" - from unittest.mock import AsyncMock - # Track activity updates +async def test_definition_execute_async(pool, monkeypatch) -> None: + """TaskDefinition.execute() runs async handler and tracks activity.""" activity_updates = [] - def mock_update(**kwargs): + async def mock_update(**kwargs): activity_updates.append(kwargs) - # Mock state.aset_result - aset_result_calls = [] - - async def mock_aset_result(agent_id, data, ttl_seconds=None): - aset_result_calls.append((agent_id, data, ttl_seconds)) + async def mock_state_set(key, value, ttl_seconds=None): + pass monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) - monkeypatch.setattr("agentexec.core.task.state.aset_result", mock_aset_result) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) execution_result = TaskResult(status="success") @@ -198,64 +139,46 @@ async def mock_aset_result(agent_id, data, ttl_seconds=None): async def async_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: return execution_result - task_def = pool._context.tasks["async_task"] - agent_id = uuid.uuid4() - - task = ax.Task.from_serialized( - task_def, - { - "task_name": "async_task", - "context": {"message": "test"}, - "agent_id": str(agent_id), - }, + definition = pool._context.tasks["async_task"] + task = ax.Task( + task_name="async_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), ) - result = await task.execute() + result = await definition.execute(task) assert result == execution_result - # Verify activity was updated (started and completed) assert len(activity_updates) == 2 - # First update marks task as started assert activity_updates[0]["percentage"] == 0 - # Second update marks task as completed assert activity_updates[1]["percentage"] == 100 - # Verify result was stored - assert len(aset_result_calls) == 1 - assert aset_result_calls[0][0] == agent_id # Can be UUID or str - assert aset_result_calls[0][1] == execution_result - -async def test_task_execute_sync_handler(pool, monkeypatch) -> None: - """Test Task.execute with a sync handler.""" +async def test_definition_execute_sync(pool, monkeypatch) -> None: + """TaskDefinition.execute() runs sync handler.""" activity_updates = [] - def mock_update(**kwargs): + async def mock_update(**kwargs): activity_updates.append(kwargs) - async def mock_aset_result(agent_id, data, ttl_seconds=None): + async def mock_state_set(key, value, ttl_seconds=None): pass monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) - monkeypatch.setattr("agentexec.core.task.state.aset_result", mock_aset_result) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) @pool.task("sync_task") def sync_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: return TaskResult(status=f"Sync result: {context.message}") - task_def = pool._context.tasks["sync_task"] - agent_id = uuid.uuid4() - - task = ax.Task.from_serialized( - task_def, - { - "task_name": "sync_task", - "context": {"message": "test"}, - "agent_id": str(agent_id), - }, + definition = pool._context.tasks["sync_task"] + task = ax.Task( + task_name="sync_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), ) - result = await task.execute() + result = await definition.execute(task) assert result is not None assert isinstance(result, TaskResult) @@ -263,54 +186,193 @@ def sync_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: assert len(activity_updates) == 2 -async def test_task_execute_without_definition_raises() -> None: - """Test Task.execute raises RuntimeError if not bound to definition.""" - task = ax.Task( - task_name="test_task", - context=SampleContext(message="test"), - agent_id=uuid.uuid4(), - ) - - with pytest.raises(RuntimeError, match="must be bound to a definition"): - await task.execute() - - -async def test_task_execute_error_marks_activity_errored(pool, monkeypatch) -> None: - """Test Task.execute marks activity as errored on exception.""" - from agentexec.activity.models import Status +async def test_definition_execute_error(pool, monkeypatch) -> None: + """TaskDefinition.execute() marks activity as errored on exception.""" + from agentexec.activity.status import Status activity_updates = [] - def mock_update(**kwargs): + async def mock_update(**kwargs): activity_updates.append(kwargs) - async def mock_aset_result(agent_id, data, ttl_seconds=None): + async def mock_state_set(key, value, ttl_seconds=None): pass monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) - monkeypatch.setattr("agentexec.core.task.state.aset_result", mock_aset_result) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) @pool.task("failing_task") async def failing_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: raise ValueError("Task failed!") - task_def = pool._context.tasks["failing_task"] - agent_id = uuid.uuid4() - - task = ax.Task.from_serialized( - task_def, - { - "task_name": "failing_task", - "context": {"message": "test"}, - "agent_id": str(agent_id), - }, + definition = pool._context.tasks["failing_task"] + task = ax.Task( + task_name="failing_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), ) - # execute() catches the exception and marks activity as errored, returns None - result = await task.execute() + with pytest.raises(ValueError, match="Task failed!"): + await definition.execute(task) - assert result is None # Handler exception results in None return - # First update marks started, second marks errored assert len(activity_updates) == 2 assert activity_updates[1]["status"] == Status.ERROR assert "Task failed!" in activity_updates[1]["message"] + + +async def test_definition_execute_none_result_not_stored(pool, monkeypatch) -> None: + """Handler returning None does not write to result storage.""" + state_set_calls = [] + + async def mock_update(**kwargs): + pass + + async def mock_state_set(key, value, ttl_seconds=None): + state_set_calls.append(key) + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) + + @pool.task("void_task") + async def void_handler(agent_id: uuid.UUID, context: SampleContext) -> None: + pass + + definition = pool._context.tasks["void_task"] + task = ax.Task( + task_name="void_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + ) + + result = await definition.execute(task) + assert result is None + assert len(state_set_calls) == 0 + + +async def test_definition_execute_stores_result_with_ttl(pool, monkeypatch) -> None: + """Handler result is stored in state with the configured TTL.""" + from agentexec.config import CONF + + state_set_calls = [] + + async def mock_update(**kwargs): + pass + + async def mock_state_set(key, value, ttl_seconds=None): + state_set_calls.append({"key": key, "ttl_seconds": ttl_seconds}) + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) + + @pool.task("result_task") + async def result_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: + return TaskResult(status="done") + + definition = pool._context.tasks["result_task"] + task = ax.Task( + task_name="result_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + ) + + await definition.execute(task) + assert len(state_set_calls) == 1 + assert state_set_calls[0]["ttl_seconds"] == CONF.result_ttl + assert str(task.agent_id) in state_set_calls[0]["key"] + + +async def test_definition_execute_hydrates_context(pool, monkeypatch) -> None: + """execute() passes a typed context model to the handler, not a raw dict.""" + received_context = [] + + async def mock_update(**kwargs): + pass + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + + @pool.task("typed_task") + async def typed_handler(agent_id: uuid.UUID, context: SampleContext) -> None: + received_context.append(context) + + definition = pool._context.tasks["typed_task"] + task = ax.Task( + task_name="typed_task", + context={"message": "typed", "value": 7}, + agent_id=uuid.uuid4(), + ) + + await definition.execute(task) + assert len(received_context) == 1 + assert isinstance(received_context[0], SampleContext) + assert received_context[0].message == "typed" + assert received_context[0].value == 7 + + +async def test_definition_execute_passes_agent_id(pool, monkeypatch) -> None: + """execute() passes the task's agent_id to the handler.""" + received_ids = [] + + async def mock_update(**kwargs): + pass + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + + @pool.task("id_task") + async def id_handler(agent_id: uuid.UUID, context: SampleContext) -> None: + received_ids.append(agent_id) + + definition = pool._context.tasks["id_task"] + expected_id = uuid.uuid4() + task = ax.Task( + task_name="id_task", + context={"message": "test"}, + agent_id=expected_id, + ) + + await definition.execute(task) + assert received_ids == [expected_id] + + +async def test_definition_execute_error_reraises(pool, monkeypatch) -> None: + """execute() re-raises the original exception after marking activity as errored.""" + async def mock_update(**kwargs): + pass + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + + @pool.task("reraise_task") + async def bad_handler(agent_id: uuid.UUID, context: SampleContext): + raise RuntimeError("original error") + + definition = pool._context.tasks["reraise_task"] + task = ax.Task( + task_name="reraise_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + ) + + with pytest.raises(RuntimeError, match="original error"): + await definition.execute(task) + + +async def test_definition_execute_bad_context_raises(pool, monkeypatch) -> None: + """execute() raises ValidationError when context doesn't match the registered type.""" + async def mock_update(**kwargs): + pass + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + + @pool.task("strict_task") + async def strict_handler(agent_id: uuid.UUID, context: SampleContext): + pass + + definition = pool._context.tasks["strict_task"] + task = ax.Task( + task_name="strict_task", + context={"wrong_field": "oops"}, # missing required 'message' + agent_id=uuid.uuid4(), + ) + + from pydantic import ValidationError + with pytest.raises(ValidationError): + await definition.execute(task) diff --git a/tests/test_task_locking.py b/tests/test_task_locking.py index ab1f853..d71e9a5 100644 --- a/tests/test_task_locking.py +++ b/tests/test_task_locking.py @@ -1,16 +1,11 @@ -"""Tests for task-level distributed locking.""" - -import json import uuid import pytest from fakeredis import aioredis as fake_aioredis from pydantic import BaseModel -from unittest.mock import AsyncMock, patch import agentexec as ax -from agentexec import state -from agentexec.core.queue import requeue +from agentexec.state import backend from agentexec.core.task import TaskDefinition @@ -36,20 +31,10 @@ def pool(): @pytest.fixture def fake_redis(monkeypatch): - """Setup fake redis for state backend with shared state.""" - import fakeredis - - server = fakeredis.FakeServer() - fake_redis_sync = fakeredis.FakeRedis(server=server, decode_responses=False) - fake_redis_async = fake_aioredis.FakeRedis(server=server, decode_responses=False) - - monkeypatch.setattr("agentexec.state.redis_backend._get_sync_client", lambda: fake_redis_sync) - monkeypatch.setattr("agentexec.state.redis_backend._get_async_client", lambda: fake_redis_async) - - yield fake_redis_async - - -# --- TaskDefinition lock_key --- + """Setup fake redis for state backend.""" + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake def test_task_definition_lock_key_default(): @@ -72,9 +57,6 @@ async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: assert defn.lock_key == "user:{user_id}" -# --- Pool registration with lock_key --- - - def test_pool_task_decorator_with_lock_key(pool): """@pool.task() passes lock_key to TaskDefinition.""" @@ -109,130 +91,69 @@ async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: assert defn.lock_key == "user:{user_id}" -# --- Task.get_lock_key() --- - - def test_get_lock_key_evaluates_template(pool): - """get_lock_key() evaluates template against context fields.""" + """definition.get_lock_key() evaluates template against context.""" @pool.task("locked_task", lock_key="user:{user_id}") async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: return TaskResult(status="ok") - defn = pool._context.tasks["locked_task"] - task = ax.Task.from_serialized( - defn, - { - "task_name": "locked_task", - "context": {"user_id": "42", "message": "hello"}, - "agent_id": str(uuid.uuid4()), - }, - ) - - assert task.get_lock_key() == "user:42" + definition = pool._context.tasks["locked_task"] + assert definition.get_lock_key({"user_id": "42", "message": "hello"}) == "user:42" def test_get_lock_key_returns_none_when_no_lock(pool): - """get_lock_key() returns None when no lock_key configured.""" + """definition.get_lock_key() returns None when no lock_key configured.""" @pool.task("unlocked_task") async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: return TaskResult(status="ok") - defn = pool._context.tasks["unlocked_task"] - task = ax.Task.from_serialized( - defn, - { - "task_name": "unlocked_task", - "context": {"user_id": "42"}, - "agent_id": str(uuid.uuid4()), - }, - ) - - assert task.get_lock_key() is None - - -def test_get_lock_key_raises_without_definition(): - """get_lock_key() raises RuntimeError if task not bound to definition.""" - task = ax.Task( - task_name="test", - context=UserContext(user_id="42"), - agent_id=uuid.uuid4(), - ) - - with pytest.raises(RuntimeError, match="must be bound to a definition"): - task.get_lock_key() + definition = pool._context.tasks["unlocked_task"] + assert definition.get_lock_key({"user_id": "42"}) is None def test_get_lock_key_raises_on_missing_field(pool): - """get_lock_key() raises KeyError if template references missing field.""" + """definition.get_lock_key() raises KeyError if template references missing field.""" @pool.task("bad_template", lock_key="org:{organization_id}") async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: return TaskResult(status="ok") - defn = pool._context.tasks["bad_template"] - task = ax.Task.from_serialized( - defn, - { - "task_name": "bad_template", - "context": {"user_id": "42"}, - "agent_id": str(uuid.uuid4()), - }, - ) - + definition = pool._context.tasks["bad_template"] with pytest.raises(KeyError): - task.get_lock_key() - - -# --- Redis lock acquire/release --- + definition.get_lock_key({"user_id": "42"}) async def test_acquire_lock_success(fake_redis): - """acquire_lock returns True when lock is free.""" - result = await state.acquire_lock("user:42", "agent-1") + """Queue backend acquire_lock returns True when lock is free.""" + queue_key = backend.queue._queue_key("user:42").encode() + result = await backend.queue._acquire_lock(queue_key) assert result is True async def test_acquire_lock_already_held(fake_redis): - """acquire_lock returns False when lock is already held.""" - await state.acquire_lock("user:42", "agent-1") - result = await state.acquire_lock("user:42", "agent-2") + """Queue backend acquire_lock returns False when already held.""" + queue_key = backend.queue._queue_key("user:42").encode() + await backend.queue._acquire_lock(queue_key) + result = await backend.queue._acquire_lock(queue_key) assert result is False async def test_release_lock(fake_redis): """release_lock frees the lock so it can be re-acquired.""" - await state.acquire_lock("user:42", "agent-1") - await state.release_lock("user:42") + queue_key = backend.queue._queue_key("user:42").encode() + await backend.queue._acquire_lock(queue_key) + await backend.queue.complete("user:42") - result = await state.acquire_lock("user:42", "agent-2") + result = await backend.queue._acquire_lock(queue_key) assert result is True -async def test_release_lock_nonexistent(fake_redis): - """release_lock on a non-existent key returns 0.""" - result = await state.release_lock("nonexistent") - assert result == 0 - - -async def test_lock_key_uses_prefix(fake_redis): - """Lock keys are prefixed with agentexec:lock:.""" - await state.acquire_lock("user:42", "agent-1") - - # Check the raw Redis key - value = await fake_redis.get("agentexec:lock:user:42") - assert value is not None - assert value.decode() == "agent-1" - - -# --- Requeue --- - - async def test_requeue_pushes_to_back(fake_redis, monkeypatch): """requeue() pushes task to the back of the queue (lpush).""" - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -240,21 +161,18 @@ def mock_create(*args, **kwargs): # Enqueue a normal task first task1 = await ax.enqueue("task_1", UserContext(user_id="1", message="first")) - # Create and requeue a second task + # Push a second task directly (simulating a requeue) task2 = ax.Task( task_name="task_2", - context=UserContext(user_id="2", message="requeued"), + context={"user_id": "2", "message": "requeued"}, agent_id=uuid.uuid4(), ) - requeue(task2) - - # Dequeue should return task_1 first (from front/right), then task_2 (from back/left) - from agentexec.core.queue import dequeue + await backend.queue.push(task2.model_dump_json()) - result1 = await dequeue(timeout=1) + result1 = await backend.queue.pop(timeout=1) assert result1 is not None assert result1["task_name"] == "task_1" - result2 = await dequeue(timeout=1) + result2 = await backend.queue.pop(timeout=1) assert result2 is not None assert result2["task_name"] == "task_2" diff --git a/tests/test_worker_event.py b/tests/test_worker_event.py index 950bc06..4e83eb4 100644 --- a/tests/test_worker_event.py +++ b/tests/test_worker_event.py @@ -1,133 +1,71 @@ -"""Test state-backed event for cross-process coordination.""" - import pytest from fakeredis import aioredis as fake_aioredis -import fakeredis +from agentexec.state import backend from agentexec.worker.event import StateEvent @pytest.fixture -def fake_redis_sync(monkeypatch): - """Setup fake sync redis for state backend.""" - fake_redis = fakeredis.FakeRedis(decode_responses=False) - - def get_fake_sync_client(): - return fake_redis - - monkeypatch.setattr("agentexec.state.redis_backend._get_sync_client", get_fake_sync_client) - - yield fake_redis - - -@pytest.fixture -def fake_redis_async(monkeypatch): - """Setup fake async redis for state backend.""" - fake_redis = fake_aioredis.FakeRedis(decode_responses=False) - - def get_fake_async_client(): - return fake_redis - - monkeypatch.setattr("agentexec.state.redis_backend._get_async_client", get_fake_async_client) - - yield fake_redis +def fake_redis(monkeypatch): + """Inject fake redis into the backend.""" + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake def test_state_event_initialization(): - """Test StateEvent can be initialized with name and id.""" event = StateEvent("test", "event123") - assert event.name == "test" assert event.id == "event123" -def test_redis_event_set(fake_redis_sync): - """Test StateEvent.set() sets the key in Redis.""" +async def test_redis_event_set(fake_redis): event = StateEvent("shutdown", "pool1") - - event.set() - - # Verify the key was set (with event prefix and formatted name:id) - value = fake_redis_sync.get("agentexec:event:shutdown:pool1") + await event.set() + value = await fake_redis.get("agentexec:event:shutdown:pool1") assert value == b"1" -def test_redis_event_clear(fake_redis_sync): - """Test StateEvent.clear() removes the key from Redis.""" +async def test_redis_event_clear(fake_redis): event = StateEvent("shutdown", "pool2") - - # Set then clear - fake_redis_sync.set("agentexec:event:shutdown:pool2", "1") - event.clear() - - # Verify the key was removed - value = fake_redis_sync.get("agentexec:event:shutdown:pool2") + await fake_redis.set("agentexec:event:shutdown:pool2", "1") + await event.clear() + value = await fake_redis.get("agentexec:event:shutdown:pool2") assert value is None -def test_redis_event_clear_nonexistent(fake_redis_sync): - """Test StateEvent.clear() handles non-existent keys gracefully.""" +async def test_redis_event_clear_nonexistent(fake_redis): event = StateEvent("nonexistent", "id123") - - # Should not raise an error - event.clear() + await event.clear() -async def test_redis_event_is_set_true(fake_redis_async): - """Test StateEvent.is_set() returns True when key exists.""" +async def test_redis_event_is_set_true(fake_redis): event = StateEvent("shutdown", "pool3") + await fake_redis.set("agentexec:event:shutdown:pool3", "1") + assert await event.is_set() is True - # Set the key - await fake_redis_async.set("agentexec:event:shutdown:pool3", "1") - # Check is_set - result = await event.is_set() - assert result is True - - -async def test_redis_event_is_set_false(fake_redis_async): - """Test StateEvent.is_set() returns False when key doesn't exist.""" +async def test_redis_event_is_set_false(fake_redis): event = StateEvent("shutdown", "pool4") - - # Don't set the key - result = await event.is_set() - assert result is False + assert await event.is_set() is False -async def test_redis_event_is_set_after_clear(fake_redis_sync, fake_redis_async): - """Test StateEvent.is_set() returns False after clear().""" +async def test_redis_event_is_set_after_clear(fake_redis): event = StateEvent("shutdown", "pool5") - - # Set then clear - event.set() - event.clear() - - # Check is_set - result = await event.is_set() - assert result is False + await event.set() + await event.clear() + assert await event.is_set() is False def test_redis_event_picklable(): - """Test StateEvent is picklable (for multiprocessing).""" import pickle - event = StateEvent("shutdown", "pickle123") - - # Pickle and unpickle - pickled = pickle.dumps(event) - unpickled = pickle.loads(pickled) - + unpickled = pickle.loads(pickle.dumps(event)) assert unpickled.name == "shutdown" assert unpickled.id == "pickle123" def test_redis_event_multiple_events(): - """Test multiple StateEvent instances with different names.""" event1 = StateEvent("event", "id1") event2 = StateEvent("event", "id2") - assert event1.id != event2.id - assert event1.name == "event" - assert event2.name == "event" - assert event1.id == "id1" - assert event2.id == "id2" diff --git a/tests/test_worker_logging.py b/tests/test_worker_logging.py index dc9662e..be6489b 100644 --- a/tests/test_worker_logging.py +++ b/tests/test_worker_logging.py @@ -1,17 +1,15 @@ -"""Test worker logging functionality.""" - import logging import time import pytest -import fakeredis +from fakeredis import aioredis as fake_aioredis from agentexec.worker.logging import ( DEFAULT_FORMAT, LOG_CHANNEL, LOGGER_NAME, LogMessage, - StateLogHandler, + QueueLogHandler, get_worker_logger, ) @@ -135,44 +133,25 @@ def test_log_message_with_none_values(self): assert log_message.thread is None -class TestStateLogHandler: - """Tests for StateLogHandler.""" +class TestQueueLogHandler: + """Tests for QueueLogHandler.""" - @pytest.fixture - def fake_redis_backend(self, monkeypatch): - """Setup fake redis backend for state.""" - fake_redis = fakeredis.FakeRedis(decode_responses=False) + def test_handler_initialization(self): + """Test QueueLogHandler initializes with a queue.""" + import multiprocessing as mp + tx = mp.Queue() + handler = QueueLogHandler(tx) + assert handler.tx is tx - def get_fake_sync_client(): - return fake_redis + def test_handler_emit(self): + """Test QueueLogHandler.emit() puts LogEntry on the queue.""" + import multiprocessing as mp + import time + from agentexec.worker.pool import LogEntry - monkeypatch.setattr( - "agentexec.state.redis_backend._get_sync_client", get_fake_sync_client - ) + tx = mp.Queue() + handler = QueueLogHandler(tx) - return fake_redis - - def test_handler_initialization(self): - """Test StateLogHandler initializes with default channel.""" - handler = StateLogHandler() - assert handler.channel == LOG_CHANNEL - - def test_handler_custom_channel(self): - """Test StateLogHandler with custom channel.""" - handler = StateLogHandler(channel="custom:logs") - assert handler.channel == "custom:logs" - - def test_handler_emit(self, fake_redis_backend): - """Test StateLogHandler.emit() publishes to state backend.""" - handler = StateLogHandler() - - # Subscribe to the channel to capture the message - pubsub = fake_redis_backend.pubsub() - pubsub.subscribe(LOG_CHANNEL) - # Get the subscribe message - pubsub.get_message() - - # Create and emit a log record record = logging.LogRecord( name="emit.test", level=logging.INFO, @@ -184,18 +163,12 @@ def test_handler_emit(self, fake_redis_backend): ) handler.emit(record) + time.sleep(0.1) # mp.Queue uses a background thread to flush - # Get the published message - message = pubsub.get_message() - - assert message is not None - assert message["type"] == "message" - assert message["channel"] == LOG_CHANNEL.encode() - - # Verify the message content - log_message = LogMessage.model_validate_json(message["data"]) - assert log_message.msg == "Emitted message" - assert log_message.levelno == logging.INFO + message = tx.get_nowait() + assert isinstance(message, LogEntry) + assert message.record.msg == "Emitted message" + assert message.record.levelno == logging.INFO class TestGetWorkerLogger: @@ -204,18 +177,10 @@ class TestGetWorkerLogger: @pytest.fixture(autouse=True) def reset_logging_state(self, monkeypatch): """Reset the worker logging configured state before each test.""" - # Reset the global state monkeypatch.setattr("agentexec.worker.logging._worker_logging_configured", False) - # Setup fake redis backend - fake_redis = fakeredis.FakeRedis(decode_responses=False) - monkeypatch.setattr( - "agentexec.state.redis_backend._get_sync_client", lambda: fake_redis - ) - yield - # Cleanup handlers added during tests root = logging.getLogger(LOGGER_NAME) root.handlers.clear() @@ -239,18 +204,21 @@ def test_get_worker_logger_existing_namespace(self): assert logger.name == f"{LOGGER_NAME}.submodule" def test_get_worker_logger_configures_handler(self): - """Test get_worker_logger adds StateLogHandler on first call.""" - logger = get_worker_logger("first.call") + """Test get_worker_logger adds QueueLogHandler on first call.""" + import multiprocessing as mp + tx = mp.Queue() + get_worker_logger("first.call", tx=tx) root = logging.getLogger(LOGGER_NAME) handler_types = [type(h).__name__ for h in root.handlers] - assert "StateLogHandler" in handler_types + assert "QueueLogHandler" in handler_types def test_get_worker_logger_idempotent(self): """Test get_worker_logger only configures once.""" - # First call - get_worker_logger("first") + import multiprocessing as mp + tx = mp.Queue() + get_worker_logger("first", tx=tx) root = logging.getLogger(LOGGER_NAME) initial_handler_count = len(root.handlers) diff --git a/tests/test_worker_pool.py b/tests/test_worker_pool.py index 20f5bc1..b0b2ced 100644 --- a/tests/test_worker_pool.py +++ b/tests/test_worker_pool.py @@ -1,12 +1,13 @@ -"""Test Pool implementation.""" - import json +import multiprocessing as mp import uuid +from unittest.mock import AsyncMock import pytest from pydantic import BaseModel import agentexec as ax +from agentexec.state import backend class SampleContext(BaseModel): @@ -24,22 +25,19 @@ class TaskResult(BaseModel): @pytest.fixture def mock_state_backend(monkeypatch): - """Mock the state backend for queue operations.""" + """Mock the queue ops for push operations.""" queue_data = [] - def mock_lpush(key, value): - queue_data.insert(0, value) - return len(queue_data) - - def mock_rpush(key, value): - queue_data.append(value) - return len(queue_data) + async def mock_queue_push(value, *, high_priority=False, partition_key=None): + if high_priority: + queue_data.append(value) + else: + queue_data.insert(0, value) def pop_right(): return queue_data.pop() if queue_data else None - monkeypatch.setattr("agentexec.state.backend.lpush", mock_lpush) - monkeypatch.setattr("agentexec.state.backend.rpush", mock_rpush) + monkeypatch.setattr("agentexec.state.backend.queue.push", mock_queue_push) return {"queue": queue_data, "pop": pop_right} @@ -55,8 +53,7 @@ def pool(): async def test_enqueue_task(mock_state_backend, pool, monkeypatch) -> None: """Test that tasks can be enqueued.""" - # Mock activity.create to avoid database dependency - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -74,8 +71,7 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: assert task is not None assert isinstance(task.agent_id, uuid.UUID) assert task.task_name == "test_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "Hello World" + assert task.context["message"] == "Hello World" # Verify task was pushed to queue task_json = mock_state_backend["pop"]() @@ -89,7 +85,7 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: async def test_enqueue_high_priority_task(mock_state_backend, pool, monkeypatch) -> None: """Test that high priority tasks are enqueued to the front.""" - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -111,7 +107,7 @@ async def high_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResul ctx2 = SampleContext(message="high", value=2) task2 = await ax.enqueue("high_task", ctx2, priority=ax.Priority.HIGH) - # High priority task should be at the end (RPUSH) so it's processed first (BRPOP) + # High priority task should be at the end (popped first) task_json = mock_state_backend["pop"]() task_data = json.loads(task_json) assert task_data["agent_id"] == str(task2.agent_id) @@ -119,7 +115,7 @@ async def high_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResul async def test_add_task_registers_handler(mock_state_backend, pool, monkeypatch) -> None: """Test that pool.add_task() registers a task handler.""" - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -139,8 +135,7 @@ async def handler(*, agent_id: uuid.UUID, context: SampleContext) -> TaskResult: assert task is not None assert task.task_name == "added_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "Added via add_task" + assert task.context["message"] == "Added via add_task" def test_add_task_duplicate_raises(pool) -> None: @@ -182,19 +177,10 @@ def test_pool_with_database_url() -> None: """Test that Pool can be created with database_url.""" pool = ax.Pool(database_url="sqlite:///:memory:") - assert pool._context.database_url == "sqlite:///:memory:" + assert pool._processes == [] assert pool._processes == [] -def test_pool_with_custom_queue_name() -> None: - """Test that Pool can use a custom queue name.""" - pool = ax.Pool( - database_url="sqlite:///:memory:", - queue_name="custom_queue", - ) - - assert pool._context.queue_name == "custom_queue" - async def test_worker_dequeue_task(pool, monkeypatch) -> None: """Test Worker._dequeue_task method.""" @@ -206,15 +192,12 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: return TaskResult() context = WorkerContext( - database_url="sqlite:///:memory:", shutdown_event=StateEvent("shutdown", "test-worker"), tasks=pool._context.tasks, - queue_name="test_queue", + tx=mp.Queue(), ) - worker = Worker(worker_id=0, context=context) - - # Mock dequeue to return task data + # Mock queue_pop to return task data agent_id = uuid.uuid4() task_data = { "task_name": "test_task", @@ -222,53 +205,40 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: "agent_id": str(agent_id), } - async def mock_dequeue(**kwargs): + async def mock_queue_pop(*args, **kwargs): return task_data - monkeypatch.setattr("agentexec.worker.pool.dequeue", mock_dequeue) + monkeypatch.setattr("agentexec.state.backend.queue.pop", mock_queue_pop) - task = await worker._dequeue_task() + data = await backend.queue.pop(timeout=1) + assert data is not None - assert task is not None + task = ax.Task.model_validate(data) assert task.task_name == "test_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "test" + assert task.context == {"message": "test", "value": 42} assert task.agent_id == agent_id -async def test_worker_dequeue_task_returns_none_on_empty_queue(pool, monkeypatch) -> None: - """Test Worker._dequeue_task returns None when queue is empty.""" - from agentexec.worker.pool import Worker, WorkerContext - from agentexec.worker.event import StateEvent - - context = WorkerContext( - database_url="sqlite:///:memory:", - shutdown_event=StateEvent("shutdown", "test-worker"), - tasks=pool._context.tasks, - queue_name="test_queue", - ) - - worker = Worker(worker_id=0, context=context) +async def test_dequeue_returns_none_on_empty_queue(pool, monkeypatch) -> None: + """Test pop returns None when queue is empty.""" - async def mock_dequeue(**kwargs): + async def mock_queue_pop(*args, **kwargs): return None - monkeypatch.setattr("agentexec.worker.pool.dequeue", mock_dequeue) - - task = await worker._dequeue_task() + monkeypatch.setattr("agentexec.state.backend.queue.pop", mock_queue_pop) - assert task is None + data = await backend.queue.pop(timeout=1) + assert data is None -def test_worker_pool_shutdown_with_no_processes(pool, monkeypatch) -> None: +async def test_worker_pool_shutdown_with_no_processes(pool) -> None: """Test shutdown when no processes have been started.""" - # Mock the shutdown event to avoid Redis dependency - from unittest.mock import MagicMock + from unittest.mock import AsyncMock - pool._context.shutdown_event = MagicMock() + pool._context.shutdown_event = AsyncMock() # Should not raise even with empty process list - pool.shutdown(timeout=1) + await pool.shutdown(timeout=1) assert pool._processes == [] pool._context.shutdown_event.set.assert_called_once() @@ -282,3 +252,285 @@ def test_get_pool_id() -> None: id2 = _get_pool_id() assert id1 != id2 + + +class TestTaskFailed: + def test_from_exception(self): + """TaskFailed.from_exception captures the error string.""" + from agentexec.worker.pool import TaskFailed + + task = ax.Task( + task_name="test_task", + context={"message": "hello"}, + agent_id=uuid.uuid4(), + ) + exc = RuntimeError("something broke") + msg = TaskFailed.from_exception(task, exc) + + assert msg.task == task + assert msg.error == "something broke" + + def test_preserves_retry_count(self): + """TaskFailed preserves the task's current retry_count.""" + from agentexec.worker.pool import TaskFailed + + task = ax.Task( + task_name="test_task", + context={"message": "hello"}, + agent_id=uuid.uuid4(), + retry_count=2, + ) + msg = TaskFailed.from_exception(task, ValueError("bad")) + assert msg.task.retry_count == 2 + + +class TestWorkerFailurePath: + """Test that Worker._run sends TaskFailed on handler exception.""" + + async def test_exception_sends_task_failed(self, pool, monkeypatch): + """Handler exception → TaskFailed sent via IPC queue.""" + from agentexec.worker.pool import Worker, WorkerContext, TaskFailed + + @pool.task("failing_task") + async def handler(agent_id: uuid.UUID, context: SampleContext): + raise RuntimeError("handler exploded") + + tx = mp.Queue() + call_count = 0 + shutdown = AsyncMock() + + async def is_set(): + nonlocal call_count + return call_count > 1 + + shutdown.is_set = is_set + + context = WorkerContext( + shutdown_event=shutdown, + tasks=pool._context.tasks, + tx=tx, + ) + + async def mock_pop(*, timeout=1): + nonlocal call_count + call_count += 1 + if call_count == 1: + return { + "task_name": "failing_task", + "context": {"message": "boom"}, + "agent_id": str(uuid.uuid4()), + } + return None + + import agentexec.activity as activity_mod + monkeypatch.setattr(activity_mod, "update", AsyncMock()) + monkeypatch.setattr("agentexec.state.backend.queue.pop", mock_pop) + monkeypatch.setattr("agentexec.state.backend.queue.complete", AsyncMock()) + + worker = Worker(0, context) + + # Capture _send calls directly to avoid mp.Queue reliability issues + sent_messages = [] + original_send = worker._send + def capture_send(message): + sent_messages.append(message) + original_send(message) + monkeypatch.setattr(worker, "_send", capture_send) + + await worker._run() + + failed = [m for m in sent_messages if isinstance(m, TaskFailed)] + assert len(failed) == 1 + assert failed[0].error == "handler exploded" + assert failed[0].task.task_name == "failing_task" + + async def test_complete_called_after_failure(self, pool, monkeypatch): + """queue.complete is called even when the handler throws.""" + from agentexec.worker.pool import Worker, WorkerContext + + @pool.task("locked_fail") + async def handler(agent_id: uuid.UUID, context: SampleContext): + raise ValueError("oops") + + pool._context.tasks["locked_fail"].lock_key = "msg:{message}" + + tx = mp.Queue() + call_count = 0 + shutdown = AsyncMock() + + async def is_set(): + nonlocal call_count + return call_count > 1 + + shutdown.is_set = is_set + + context = WorkerContext( + shutdown_event=shutdown, + tasks=pool._context.tasks, + tx=tx, + ) + + async def mock_pop(*, timeout=1): + nonlocal call_count + call_count += 1 + if call_count == 1: + return { + "task_name": "locked_fail", + "context": {"message": "test"}, + "agent_id": str(uuid.uuid4()), + } + return None + + completed_keys = [] + + async def mock_complete(partition_key): + completed_keys.append(partition_key) + + import agentexec.activity as activity_mod + monkeypatch.setattr(activity_mod, "update", AsyncMock()) + monkeypatch.setattr("agentexec.state.backend.queue.pop", mock_pop) + monkeypatch.setattr("agentexec.state.backend.queue.complete", mock_complete) + + worker = Worker(0, context) + await worker._run() + + assert completed_keys == ["msg:test"] + + +class TestPoolRetryLogic: + """Test that Pool._process_worker_events handles TaskFailed correctly.""" + + async def test_requeues_with_incremented_retry(self, pool, monkeypatch): + """Failed task with retries remaining is requeued as high priority.""" + from agentexec.worker.pool import TaskFailed + + @pool.task("retry_task") + async def handler(agent_id: uuid.UUID, context: SampleContext): + pass + + task = ax.Task( + task_name="retry_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + retry_count=0, + ) + + pushed = [] + + async def mock_push(value, *, high_priority=False, partition_key=None): + pushed.append({"value": value, "high_priority": high_priority, "partition_key": partition_key}) + + monkeypatch.setattr("agentexec.state.backend.queue.push", mock_push) + monkeypatch.setattr(ax.CONF, "max_task_retries", 3) + + # Put a TaskFailed message in the worker queue + pool._worker_queue.put_nowait(TaskFailed(task=task, error="boom")) + + # Simulate one iteration of _process_worker_events + # We need a fake process that reports alive once then dead + class FakeProcess: + def __init__(self): + self._calls = 0 + + def is_alive(self): + self._calls += 1 + return self._calls <= 2 # alive for first check, dead on second + + pool._processes = [FakeProcess()] + pool._log_handler = __import__("logging").StreamHandler() + + await pool._process_worker_events() + + assert len(pushed) == 1 + requeued = json.loads(pushed[0]["value"]) + assert requeued["retry_count"] == 1 + assert pushed[0]["high_priority"] is True + + async def test_gives_up_after_max_retries(self, pool, monkeypatch, capsys): + """Failed task at max retries is not requeued.""" + from agentexec.worker.pool import TaskFailed + + @pool.task("doomed_task") + async def handler(agent_id: uuid.UUID, context: SampleContext): + pass + + task = ax.Task( + task_name="doomed_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + retry_count=3, + ) + + pushed = [] + + async def mock_push(value, *, high_priority=False, partition_key=None): + pushed.append(value) + + monkeypatch.setattr("agentexec.state.backend.queue.push", mock_push) + monkeypatch.setattr(ax.CONF, "max_task_retries", 3) + + pool._worker_queue.put_nowait(TaskFailed(task=task, error="fatal")) + + class FakeProcess: + def __init__(self): + self._calls = 0 + + def is_alive(self): + self._calls += 1 + return self._calls <= 2 + + pool._processes = [FakeProcess()] + pool._log_handler = __import__("logging").StreamHandler() + + await pool._process_worker_events() + + # Should NOT have requeued + assert len(pushed) == 0 + + # Should have printed the give-up message + captured = capsys.readouterr() + assert "doomed_task" in captured.out + assert "4 attempts" in captured.out + assert "fatal" in captured.out + + async def test_retry_preserves_partition_key(self, pool, monkeypatch): + """Requeued task uses the correct partition key from its definition.""" + from agentexec.worker.pool import TaskFailed + + @pool.task("partitioned_task") + async def handler(agent_id: uuid.UUID, context: SampleContext): + pass + + pool._context.tasks["partitioned_task"].lock_key = "msg:{message}" + + task = ax.Task( + task_name="partitioned_task", + context={"message": "hello"}, + agent_id=uuid.uuid4(), + retry_count=0, + ) + + pushed = [] + + async def mock_push(value, *, high_priority=False, partition_key=None): + pushed.append({"partition_key": partition_key}) + + monkeypatch.setattr("agentexec.state.backend.queue.push", mock_push) + monkeypatch.setattr(ax.CONF, "max_task_retries", 3) + + pool._worker_queue.put_nowait(TaskFailed(task=task, error="transient")) + + class FakeProcess: + def __init__(self): + self._calls = 0 + + def is_alive(self): + self._calls += 1 + return self._calls <= 2 + + pool._processes = [FakeProcess()] + pool._log_handler = __import__("logging").StreamHandler() + + await pool._process_worker_events() + + assert pushed[0]["partition_key"] == "msg:hello" diff --git a/uv.lock b/uv.lock index cc95411..0e2ab2e 100644 --- a/uv.lock +++ b/uv.lock @@ -4,7 +4,7 @@ requires-python = ">=3.12" [[package]] name = "agentexec" -version = "0.1.6" +version = "0.1.7" source = { editable = "." } dependencies = [ { name = "croniter" }, @@ -15,6 +15,11 @@ dependencies = [ { name = "sqlalchemy" }, ] +[package.optional-dependencies] +kafka = [ + { name = "aiokafka" }, +] + [package.dev-dependencies] dev = [ { name = "fakeredis" }, @@ -28,6 +33,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiokafka", marker = "extra == 'kafka'", specifier = ">=0.11.0" }, { name = "croniter", specifier = ">=6.0.0" }, { name = "openai-agents", specifier = ">=0.1.0" }, { name = "pydantic", specifier = ">=2.12.0" }, @@ -35,6 +41,7 @@ requires-dist = [ { name = "redis", specifier = ">=7.0.1" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, ] +provides-extras = ["kafka"] [package.metadata.requires-dev] dev = [ @@ -47,6 +54,37 @@ dev = [ { name = "ty", specifier = ">=0.0.1a7" }, ] +[[package]] +name = "aiokafka" +version = "0.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout" }, + { name = "packaging" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/18/d3a4f8f9ad099fc59217b8cdf66eeecde3a9ef3bb31fe676e431a3b0010f/aiokafka-0.13.0.tar.gz", hash = "sha256:7d634af3c8d694a37a6c8535c54f01a740e74cccf7cc189ecc4a3d64e31ce122", size = 598580, upload-time = "2026-01-02T13:55:18.911Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/17/715ac23b4f8df3ff8d7c0a6f1c5fd3a179a8a675205be62d1d1bb27dffa2/aiokafka-0.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:231ecc0038c2736118f1c95149550dbbdf7b7a12069f70c005764fa1824c35d4", size = 346168, upload-time = "2026-01-02T13:54:49.128Z" }, + { url = "https://files.pythonhosted.org/packages/00/26/71c6f4cce2c710c6ffa18b9e294384157f46b0491d5b020de300802d167e/aiokafka-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e2817593cab4c71c1d3b265b2446da91121a467ff7477c65f0f39a80047bc28", size = 349037, upload-time = "2026-01-02T13:54:50.48Z" }, + { url = "https://files.pythonhosted.org/packages/82/18/7b86418a4d3dc1303e89c0391942258ead31c02309e90eb631f3081eec1d/aiokafka-0.13.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b80e0aa1c811a9a12edb0b94445a0638d61a345932f785d47901d28b8aad86c8", size = 1140066, upload-time = "2026-01-02T13:54:52.33Z" }, + { url = "https://files.pythonhosted.org/packages/f9/51/45e46b4407d39b950c8493e19498aeeb5af4fc461fb54fa0247da16bfd75/aiokafka-0.13.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:79672c456bd1642769e74fc2db1c34f23b15500e978fd38411662e8ca07590ad", size = 1130088, upload-time = "2026-01-02T13:54:53.786Z" }, + { url = "https://files.pythonhosted.org/packages/49/7f/6a66f6fd6fb73e15bd34f574e38703ba36d3f9256c80e7aba007bd8a9256/aiokafka-0.13.0-cp312-cp312-win32.whl", hash = "sha256:00bb4e3d5a237b8618883eb1dd8c08d671db91d3e8e33ac98b04edf64225658c", size = 309581, upload-time = "2026-01-02T13:54:55.444Z" }, + { url = "https://files.pythonhosted.org/packages/d3/e0/a2d5a8912699dd0fee28e6fb780358c63c7a4727517fffc110cb7e43f874/aiokafka-0.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:0f0cccdf2fd16927fbe077279524950676fbffa7b102d6b117041b3461b5d927", size = 329327, upload-time = "2026-01-02T13:54:56.981Z" }, + { url = "https://files.pythonhosted.org/packages/e3/f6/a74c49759233e98b61182ba3d49d5ac9c8de0643651892acba2704fba1cc/aiokafka-0.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:39d71c40cff733221a6b2afff4beeac5dacbd119fb99eec5198af59115264a1a", size = 343733, upload-time = "2026-01-02T13:54:58.536Z" }, + { url = "https://files.pythonhosted.org/packages/cf/52/4f7e80eee2c69cd8b047c18145469bf0dc27542a5dca3f96ff81ade575b0/aiokafka-0.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:faa2f5f3d0d2283a0c1a149748cc7e3a3862ef327fa5762e2461088eedde230a", size = 346258, upload-time = "2026-01-02T13:55:00.947Z" }, + { url = "https://files.pythonhosted.org/packages/81/9b/d2766bb3b0bad53eb25a88e51a884be4b77a1706053ad717b893b4daea4b/aiokafka-0.13.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b890d535e55f5073f939585bef5301634df669e97832fda77aa743498f008662", size = 1114744, upload-time = "2026-01-02T13:55:02.475Z" }, + { url = "https://files.pythonhosted.org/packages/8f/00/12e0a39cd4809149a09b4a52b629abc9bf80e7b8bad9950040b1adae99fc/aiokafka-0.13.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e22eb8a1475b9c0f45b553b6e2dcaf4ec3c0014bf4e389e00a0a0ec85d0e3bdc", size = 1105676, upload-time = "2026-01-02T13:55:04.036Z" }, + { url = "https://files.pythonhosted.org/packages/38/4a/0bc91e90faf55533fe6468461c2dd31c22b0e1d274b9386f341cca3f7eb7/aiokafka-0.13.0-cp313-cp313-win32.whl", hash = "sha256:ae507c7b09e882484f709f2e7172b3a4f75afffcd896d00517feb35c619495bb", size = 308257, upload-time = "2026-01-02T13:55:05.873Z" }, + { url = "https://files.pythonhosted.org/packages/23/63/5433d1aa10c4fb4cf85bd73013263c36d7da4604b0c77ed4d1ad42fae70c/aiokafka-0.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:fec1a7e3458365a72809edaa2b990f65ca39b01a2a579f879ac4da6c9b2dbc5c", size = 326968, upload-time = "2026-01-02T13:55:07.351Z" }, + { url = "https://files.pythonhosted.org/packages/3c/cc/45b04c3a5fd3d2d5f444889ecceb80b2f78d6d66aa45e3042767e55579e2/aiokafka-0.13.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9a403785f7092c72906c37f7618f7b16a4219eba8ed0bdda90fba410a7dd50b5", size = 344503, upload-time = "2026-01-02T13:55:08.723Z" }, + { url = "https://files.pythonhosted.org/packages/76/df/0b76fe3b93558ae71b856940e384909c4c2c7a1c330423003191e4ba7782/aiokafka-0.13.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:256807326831b7eee253ea1017bd2b19ab1c2298ce6b20a87fde97c253c572bc", size = 347621, upload-time = "2026-01-02T13:55:10.147Z" }, + { url = "https://files.pythonhosted.org/packages/34/1a/d59932f98fd3c106e2a7c8d4d5ebd8df25403436dfc27b3031918a37385e/aiokafka-0.13.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:64d90f91291da265d7f25296ba68fc6275684eebd6d1cf05a1b2abe6c2ba3543", size = 1111410, upload-time = "2026-01-02T13:55:11.763Z" }, + { url = "https://files.pythonhosted.org/packages/7e/04/fbf3e34ab3bc21e6e760c3fcd089375052fccc04eb8745459a82a58a647b/aiokafka-0.13.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b5a33cc043c8d199bcf101359d86f2d31fd54f4b157ac12028bdc34e3e1cf74a", size = 1094799, upload-time = "2026-01-02T13:55:13.795Z" }, + { url = "https://files.pythonhosted.org/packages/85/10/509f709fd3b7c3e568a5b8044be0e80a1504f8da6ddc72c128b21e270913/aiokafka-0.13.0-cp314-cp314-win32.whl", hash = "sha256:538950384b539ba2333d35a853f09214c0409e818e5d5f366ef759eea50bae9c", size = 311553, upload-time = "2026-01-02T13:55:15.928Z" }, + { url = "https://files.pythonhosted.org/packages/2b/18/424d6a4eb6f4835a371c1e2cfafce800540b33d957c6638795d911f98973/aiokafka-0.13.0-cp314-cp314-win_amd64.whl", hash = "sha256:c906dd42daadd14b4506a2e6c62dfef3d4919b5953d32ae5e5f0d99efd103c89", size = 330648, upload-time = "2026-01-02T13:55:17.421Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -70,6 +108,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, ] +[[package]] +name = "async-timeout" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274, upload-time = "2024-11-06T16:41:39.6Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233, upload-time = "2024-11-06T16:41:37.9Z" }, +] + [[package]] name = "attrs" version = "25.4.0"