Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"]
pydantic = ["pydantic>=2.0.0,<3"]
openai-agents = ["openai-agents>=0.3,<0.7", "mcp>=1.9.4, <2"]
google-adk = ["google-adk>=1.27.0,<2"]
activity-cache = ["fsspec>=2024.1.0"]

[project.urls]
Homepage = "https://github.com/temporalio/sdk-python"
Expand Down
110 changes: 110 additions & 0 deletions temporalio/contrib/activity_cache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Activity Cache for Temporal

Content-addressed activity memoization with shared remote storage. Same inputs = skip execution, return cached result. Makes workflows idempotent across retries, re-runs, and distributed workers.

## Problem

Temporal activities are not idempotent by default. If a workflow retries or re-runs, activities execute again — even when their inputs haven't changed. For expensive activities (AI inference, API calls, data processing), this wastes time and money. This module caches activity results so repeated calls with the same inputs return instantly.

## Install

```bash
pip install temporalio[activity-cache]

# With a specific cloud backend:
pip install temporalio[activity-cache] gcsfs # Google Cloud Storage
pip install temporalio[activity-cache] s3fs # Amazon S3
```

## Usage

### As a decorator (explicit, per-activity)

```python
from temporalio import activity
from temporalio.contrib.activity_cache import cached

@cached("gs://my-bucket/cache", ttl=timedelta(days=90))
@activity.defn
async def extract(input: ExtractInput) -> ExtractOutput:
... # Only runs on cache miss
```

### With custom key function

By default, all arguments are included in the cache key. Use `key_fn` to select specific args:

```python
@cached(
"gs://my-bucket/cache",
key_fn=lambda input: {
"component": input.component_name,
"content_hash": input.content_hash,
},
)
@activity.defn
async def extract(input: ExtractInput) -> ExtractOutput:
# Cache key only considers component + content_hash
# Changes to input.timestamp or input.run_id won't bust the cache
...
```

### As an interceptor (transparent, all activities)

```python
from temporalio.contrib.activity_cache import CachingInterceptor, no_cache

worker = Worker(
client,
task_queue="my-queue",
activities=[extract, register, verify, commit],
interceptors=[CachingInterceptor("gs://my-bucket/cache", ttl=timedelta(days=90))],
)

# Opt out specific activities:
@no_cache
@activity.defn
async def commit(input: CommitInput) -> CommitOutput:
... # Always executes, never cached
```

## How It Works

1. **Key computation**: SHA256 hash of function name + serialized arguments
2. **Cache check**: Look up `{base_url}/{fn_name}/{key}.pkl` in remote store
3. **Cache hit**: Unpickle and return the stored result — activity body never runs
4. **Cache miss**: Execute activity, pickle result, upload to remote store

## Serialization

Arguments are serialized deterministically for cache key computation:

| Type | Serialization |
|------|--------------|
| `str`, `int`, `float`, `bool`, `None` | Pass-through |
| `bytes` | SHA256 hash (first 16 chars) |
| `Path` (file) | SHA256 of file content (first 16 chars) |
| Pydantic `BaseModel` | `.model_dump()` (recursive) |
| `dataclass` | `dataclasses.asdict()` (recursive) |
| `dict` | Sorted keys, recursive values |
| `list`, `tuple` | Recursive elements |

Register custom serializers for domain types:

```python
from temporalio.contrib.activity_cache import register_serializer

register_serializer(MyType, lambda obj: {"id": obj.id, "version": obj.version})
```

## Storage Backends

Any backend supported by [fsspec](https://filesystem-spec.readthedocs.io/):

| Scheme | Backend | Extra package |
|--------|---------|--------------|
| `gs://` | Google Cloud Storage | `gcsfs` |
| `s3://` | Amazon S3 | `s3fs` |
| `az://` | Azure Blob Storage | `adlfs` |
| `file://` | Local filesystem | (none) |
| `memory://` | In-memory (testing) | (none) |
23 changes: 23 additions & 0 deletions temporalio/contrib/activity_cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Content-addressed activity memoization with shared remote storage.

This package provides activity-level caching for Temporal workflows. Cached
activities return stored results on repeated calls with the same inputs,
making workflows idempotent across retries, re-runs, and multi-worker
deployments.

The cache is stored remotely (GCS, S3, Azure, local, etc.) via `fsspec`_,
so it is shared across all workers.

.. _fsspec: https://filesystem-spec.readthedocs.io/
"""

from temporalio.contrib.activity_cache._decorator import cached, no_cache
from temporalio.contrib.activity_cache._interceptor import CachingInterceptor
from temporalio.contrib.activity_cache._serializers import register_serializer

__all__ = [
"CachingInterceptor",
"cached",
"no_cache",
"register_serializer",
]
104 changes: 104 additions & 0 deletions temporalio/contrib/activity_cache/_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Cached decorator for activity memoization."""

from __future__ import annotations

import functools
import logging
from collections.abc import Callable
from datetime import timedelta
from typing import Any, TypeVar

from temporalio.contrib.activity_cache._keys import compute_cache_key
from temporalio.contrib.activity_cache._store import CacheStore

logger = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable[..., Any])

# Marker attribute to opt out of caching via interceptor
NO_CACHE_ATTR = "__temporal_no_cache__"


def no_cache(fn: F) -> F:
"""Mark an activity to be excluded from interceptor-based caching.

Use this when using :class:`CachingInterceptor` but you want specific
activities to always execute without caching.

Example::

@no_cache
@activity.defn
async def always_runs(input: Input) -> Output:
...
"""
setattr(fn, NO_CACHE_ATTR, True)
return fn


def cached(
store_url: str,
ttl: timedelta | None = None,
key_fn: Callable[..., dict[str, Any]] | None = None,
**storage_options: object,
) -> Callable[[F], F]:
"""Decorator for content-addressed activity memoization.

Wraps an async function so that repeated calls with the same inputs
return the cached result without re-executing. The cache is stored
remotely (GCS, S3, local, etc.) so it is shared across distributed
workers.

Args:
store_url: Base URL for the cache store. The scheme determines the
backend (``gs://`` for GCS, ``s3://`` for S3, etc.).
ttl: Optional time-to-live for cached entries. If None, entries
never expire.
key_fn: Optional function that receives the activity arguments and
returns a dict of values to hash for the cache key. If not
provided, all arguments are included in the key.
**storage_options: Extra keyword arguments passed to the fsspec
filesystem (e.g., ``project="my-gcp-project"``).

Example::

@cached("gs://my-bucket/cache", ttl=timedelta(days=90))
@activity.defn
async def extract(input: ExtractInput) -> ExtractOutput:
... # Only runs on cache miss

# With key_fn to select which args matter:
@cached(
"gs://my-bucket/cache",
key_fn=lambda input: {
"component": input.component,
"content_hash": input.content_hash,
},
)
@activity.defn
async def extract(input: ExtractInput) -> ExtractOutput:
...
"""
store = CacheStore(store_url, **storage_options)

def decorator(fn: F) -> F:
fn_name = getattr(fn, "__name__", str(fn))

@functools.wraps(fn)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
key = compute_cache_key(fn_name, args, key_fn)

hit, value = await store.get(fn_name, key)
if hit:
logger.debug("Cache hit for %s (key=%s)", fn_name, key[:8])
return value

logger.debug("Cache miss for %s (key=%s)", fn_name, key[:8])
result = await fn(*args, **kwargs)

await store.set(fn_name, key, result, ttl)
return result

return wrapper # type: ignore[return-value]

return decorator
97 changes: 97 additions & 0 deletions temporalio/contrib/activity_cache/_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Temporal ActivityInboundInterceptor for transparent activity caching."""

from __future__ import annotations

import logging
from collections.abc import Callable
from datetime import timedelta
from typing import Any

import temporalio.worker

from temporalio.contrib.activity_cache._decorator import NO_CACHE_ATTR
from temporalio.contrib.activity_cache._keys import compute_cache_key
from temporalio.contrib.activity_cache._store import CacheStore

logger = logging.getLogger(__name__)


class CachingInterceptor(temporalio.worker.Interceptor):
"""Temporal interceptor that caches all activity results.

When added to a worker's interceptor list, all activities are
automatically cached. Use :func:`~temporalio.contrib.activity_cache.no_cache`
to opt out specific activities.

Args:
store_url: Base URL for the cache store.
ttl: Optional default TTL for all cached entries.
key_fn: Optional function to compute cache keys from activity args.
If not provided, all arguments are included in the key.
**storage_options: Extra keyword arguments passed to the fsspec
filesystem.

Example::

from temporalio.contrib.activity_cache import CachingInterceptor

worker = Worker(
client,
task_queue="my-queue",
activities=[extract, register, verify],
interceptors=[CachingInterceptor("gs://bucket/cache")],
)
"""

def __init__(
self,
store_url: str,
ttl: timedelta | None = None,
key_fn: Callable[..., dict[str, Any]] | None = None,
**storage_options: object,
) -> None:
self._store = CacheStore(store_url, **storage_options)
self._ttl = ttl
self._key_fn = key_fn

def intercept_activity(
self,
next: temporalio.worker.ActivityInboundInterceptor,
) -> temporalio.worker.ActivityInboundInterceptor:
"""Wrap the activity inbound interceptor with caching."""
return _CachingActivityInboundInterceptor(next, self)


class _CachingActivityInboundInterceptor(
temporalio.worker.ActivityInboundInterceptor,
):
def __init__(
self,
next: temporalio.worker.ActivityInboundInterceptor,
root: CachingInterceptor,
) -> None:
super().__init__(next)
self._root = root

async def execute_activity(
self,
input: temporalio.worker.ExecuteActivityInput,
) -> Any:
"""Execute the activity with caching."""
# Check opt-out marker
if getattr(input.fn, NO_CACHE_ATTR, False):
return await super().execute_activity(input)

fn_name = getattr(input.fn, "__name__", str(input.fn))
key = compute_cache_key(fn_name, input.args, self._root._key_fn)

hit, value = await self._root._store.get(fn_name, key)
if hit:
logger.debug("Cache hit for %s (key=%s)", fn_name, key[:8])
return value

logger.debug("Cache miss for %s (key=%s)", fn_name, key[:8])
result = await super().execute_activity(input)

await self._root._store.set(fn_name, key, result, self._root._ttl)
return result
38 changes: 38 additions & 0 deletions temporalio/contrib/activity_cache/_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Cache key computation for activity memoization."""

from __future__ import annotations

import hashlib
import inspect
import json
from typing import Any, Callable

from temporalio.contrib.activity_cache._serializers import serialize_for_hash


def compute_cache_key(
fn_name: str,
args: tuple[Any, ...],
key_fn: Callable[..., dict[str, Any]] | None = None,
) -> str:
"""Compute a deterministic cache key from a function name and arguments.
Args:
fn_name: The function/activity name.
args: Positional arguments passed to the function.
key_fn: Optional function that receives the activity arguments and
returns a dict of values to include in the cache key. If provided,
only the returned dict is hashed (not the full args).
Returns:
A 32-character hex string (SHA256 prefix).
"""
if key_fn is not None:
# Resolve args to kwargs for key_fn
key_data = key_fn(*args)
serialized = serialize_for_hash(key_data)
else:
serialized = serialize_for_hash(args)

payload = json.dumps({"fn": fn_name, "args": serialized}, sort_keys=True)
return hashlib.sha256(payload.encode()).hexdigest()[:32]
Loading
Loading