From 6d6468ec0a341d28db71ef1622a62eb14d8d85a1 Mon Sep 17 00:00:00 2001 From: Saeed Seyfi Date: Wed, 18 Mar 2026 16:50:21 +0100 Subject: [PATCH 1/2] feat: add contrib.activity_cache for activity memoization Activities aren't idempotent by default. When workflows retry or re-run, every activity executes again even when inputs haven't changed. This module provides content-addressed caching with shared remote storage via fsspec, as both a decorator and an interceptor. Closes #1374 --- pyproject.toml | 1 + temporalio/contrib/activity_cache/README.md | 110 +++++++++++++++++ temporalio/contrib/activity_cache/__init__.py | 23 ++++ .../contrib/activity_cache/_decorator.py | 104 ++++++++++++++++ .../contrib/activity_cache/_interceptor.py | 97 +++++++++++++++ temporalio/contrib/activity_cache/_keys.py | 38 ++++++ .../contrib/activity_cache/_serializers.py | 86 ++++++++++++++ temporalio/contrib/activity_cache/_store.py | 111 ++++++++++++++++++ tests/contrib/activity_cache/__init__.py | 0 .../contrib/activity_cache/test_decorator.py | 97 +++++++++++++++ tests/contrib/activity_cache/test_keys.py | 75 ++++++++++++ 11 files changed, 742 insertions(+) create mode 100644 temporalio/contrib/activity_cache/README.md create mode 100644 temporalio/contrib/activity_cache/__init__.py create mode 100644 temporalio/contrib/activity_cache/_decorator.py create mode 100644 temporalio/contrib/activity_cache/_interceptor.py create mode 100644 temporalio/contrib/activity_cache/_keys.py create mode 100644 temporalio/contrib/activity_cache/_serializers.py create mode 100644 temporalio/contrib/activity_cache/_store.py create mode 100644 tests/contrib/activity_cache/__init__.py create mode 100644 tests/contrib/activity_cache/test_decorator.py create mode 100644 tests/contrib/activity_cache/test_keys.py diff --git a/pyproject.toml b/pyproject.toml index 7a2df7ea8..717685aab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"] pydantic = ["pydantic>=2.0.0,<3"] openai-agents = ["openai-agents>=0.3,<0.7", "mcp>=1.9.4, <2"] google-adk = ["google-adk>=1.27.0,<2"] +activity-cache = ["fsspec>=2024.1.0"] [project.urls] Homepage = "https://github.com/temporalio/sdk-python" diff --git a/temporalio/contrib/activity_cache/README.md b/temporalio/contrib/activity_cache/README.md new file mode 100644 index 000000000..e1ba36f75 --- /dev/null +++ b/temporalio/contrib/activity_cache/README.md @@ -0,0 +1,110 @@ +# Activity Cache for Temporal + +Content-addressed activity memoization with shared remote storage. Same inputs = skip execution, return cached result. Makes workflows idempotent across retries, re-runs, and distributed workers. + +## Problem + +Temporal activities are not idempotent by default. If a workflow retries or re-runs, activities execute again — even when their inputs haven't changed. For expensive activities (AI inference, API calls, data processing), this wastes time and money. This module caches activity results so repeated calls with the same inputs return instantly. + +## Install + +```bash +pip install temporalio[activity-cache] + +# With a specific cloud backend: +pip install temporalio[activity-cache] gcsfs # Google Cloud Storage +pip install temporalio[activity-cache] s3fs # Amazon S3 +``` + +## Usage + +### As a decorator (explicit, per-activity) + +```python +from temporalio import activity +from temporalio.contrib.activity_cache import cached + +@cached("gs://my-bucket/cache", ttl=timedelta(days=90)) +@activity.defn +async def extract(input: ExtractInput) -> ExtractOutput: + ... # Only runs on cache miss +``` + +### With custom key function + +By default, all arguments are included in the cache key. Use `key_fn` to select specific args: + +```python +@cached( + "gs://my-bucket/cache", + key_fn=lambda input: { + "component": input.component_name, + "content_hash": input.content_hash, + }, +) +@activity.defn +async def extract(input: ExtractInput) -> ExtractOutput: + # Cache key only considers component + content_hash + # Changes to input.timestamp or input.run_id won't bust the cache + ... +``` + +### As an interceptor (transparent, all activities) + +```python +from temporalio.contrib.activity_cache import CachingInterceptor, no_cache + +worker = Worker( + client, + task_queue="my-queue", + activities=[extract, register, verify, commit], + interceptors=[CachingInterceptor("gs://my-bucket/cache", ttl=timedelta(days=90))], +) + +# Opt out specific activities: +@no_cache +@activity.defn +async def commit(input: CommitInput) -> CommitOutput: + ... # Always executes, never cached +``` + +## How It Works + +1. **Key computation**: SHA256 hash of function name + serialized arguments +2. **Cache check**: Look up `{base_url}/{fn_name}/{key}.pkl` in remote store +3. **Cache hit**: Unpickle and return the stored result — activity body never runs +4. **Cache miss**: Execute activity, pickle result, upload to remote store + +## Serialization + +Arguments are serialized deterministically for cache key computation: + +| Type | Serialization | +|------|--------------| +| `str`, `int`, `float`, `bool`, `None` | Pass-through | +| `bytes` | SHA256 hash (first 16 chars) | +| `Path` (file) | SHA256 of file content (first 16 chars) | +| Pydantic `BaseModel` | `.model_dump()` (recursive) | +| `dataclass` | `dataclasses.asdict()` (recursive) | +| `dict` | Sorted keys, recursive values | +| `list`, `tuple` | Recursive elements | + +Register custom serializers for domain types: + +```python +from temporalio.contrib.activity_cache import register_serializer + +register_serializer(MyType, lambda obj: {"id": obj.id, "version": obj.version}) +``` + +## Storage Backends + +Any backend supported by [fsspec](https://filesystem-spec.readthedocs.io/): + +| Scheme | Backend | Extra package | +|--------|---------|--------------| +| `gs://` | Google Cloud Storage | `gcsfs` | +| `s3://` | Amazon S3 | `s3fs` | +| `az://` | Azure Blob Storage | `adlfs` | +| `file://` | Local filesystem | (none) | +| `memory://` | In-memory (testing) | (none) | diff --git a/temporalio/contrib/activity_cache/__init__.py b/temporalio/contrib/activity_cache/__init__.py new file mode 100644 index 000000000..9ab050aef --- /dev/null +++ b/temporalio/contrib/activity_cache/__init__.py @@ -0,0 +1,23 @@ +"""Content-addressed activity memoization with shared remote storage. + +This package provides activity-level caching for Temporal workflows. Cached +activities return stored results on repeated calls with the same inputs, +making workflows idempotent across retries, re-runs, and multi-worker +deployments. + +The cache is stored remotely (GCS, S3, Azure, local, etc.) via `fsspec`_, +so it is shared across all workers. + +.. _fsspec: https://filesystem-spec.readthedocs.io/ +""" + +from temporalio.contrib.activity_cache._decorator import cached, no_cache +from temporalio.contrib.activity_cache._interceptor import CachingInterceptor +from temporalio.contrib.activity_cache._serializers import register_serializer + +__all__ = [ + "CachingInterceptor", + "cached", + "no_cache", + "register_serializer", +] diff --git a/temporalio/contrib/activity_cache/_decorator.py b/temporalio/contrib/activity_cache/_decorator.py new file mode 100644 index 000000000..5ff1caace --- /dev/null +++ b/temporalio/contrib/activity_cache/_decorator.py @@ -0,0 +1,104 @@ +"""Cached decorator for activity memoization.""" + +from __future__ import annotations + +import functools +import logging +from collections.abc import Callable +from datetime import timedelta +from typing import Any, TypeVar + +from temporalio.contrib.activity_cache._keys import compute_cache_key +from temporalio.contrib.activity_cache._store import CacheStore + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + +# Marker attribute to opt out of caching via interceptor +NO_CACHE_ATTR = "__temporal_no_cache__" + + +def no_cache(fn: F) -> F: + """Mark an activity to be excluded from interceptor-based caching. + + Use this when using :class:`CachingInterceptor` but you want specific + activities to always execute without caching. + + Example:: + + @no_cache + @activity.defn + async def always_runs(input: Input) -> Output: + ... + """ + setattr(fn, NO_CACHE_ATTR, True) + return fn + + +def cached( + store_url: str, + ttl: timedelta | None = None, + key_fn: Callable[..., dict[str, Any]] | None = None, + **storage_options: object, +) -> Callable[[F], F]: + """Decorator for content-addressed activity memoization. + + Wraps an async function so that repeated calls with the same inputs + return the cached result without re-executing. The cache is stored + remotely (GCS, S3, local, etc.) so it is shared across distributed + workers. + + Args: + store_url: Base URL for the cache store. The scheme determines the + backend (``gs://`` for GCS, ``s3://`` for S3, etc.). + ttl: Optional time-to-live for cached entries. If None, entries + never expire. + key_fn: Optional function that receives the activity arguments and + returns a dict of values to hash for the cache key. If not + provided, all arguments are included in the key. + **storage_options: Extra keyword arguments passed to the fsspec + filesystem (e.g., ``project="my-gcp-project"``). + + Example:: + + @cached("gs://my-bucket/cache", ttl=timedelta(days=90)) + @activity.defn + async def extract(input: ExtractInput) -> ExtractOutput: + ... # Only runs on cache miss + + # With key_fn to select which args matter: + @cached( + "gs://my-bucket/cache", + key_fn=lambda input: { + "component": input.component, + "content_hash": input.content_hash, + }, + ) + @activity.defn + async def extract(input: ExtractInput) -> ExtractOutput: + ... + """ + store = CacheStore(store_url, **storage_options) + + def decorator(fn: F) -> F: + fn_name = getattr(fn, "__name__", str(fn)) + + @functools.wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + key = compute_cache_key(fn_name, args, key_fn) + + hit, value = await store.get(fn_name, key) + if hit: + logger.debug("Cache hit for %s (key=%s)", fn_name, key[:8]) + return value + + logger.debug("Cache miss for %s (key=%s)", fn_name, key[:8]) + result = await fn(*args, **kwargs) + + await store.set(fn_name, key, result, ttl) + return result + + return wrapper # type: ignore[return-value] + + return decorator diff --git a/temporalio/contrib/activity_cache/_interceptor.py b/temporalio/contrib/activity_cache/_interceptor.py new file mode 100644 index 000000000..e632afaf2 --- /dev/null +++ b/temporalio/contrib/activity_cache/_interceptor.py @@ -0,0 +1,97 @@ +"""Temporal ActivityInboundInterceptor for transparent activity caching.""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from datetime import timedelta +from typing import Any + +import temporalio.worker + +from temporalio.contrib.activity_cache._decorator import NO_CACHE_ATTR +from temporalio.contrib.activity_cache._keys import compute_cache_key +from temporalio.contrib.activity_cache._store import CacheStore + +logger = logging.getLogger(__name__) + + +class CachingInterceptor(temporalio.worker.Interceptor): + """Temporal interceptor that caches all activity results. + + When added to a worker's interceptor list, all activities are + automatically cached. Use :func:`~temporalio.contrib.activity_cache.no_cache` + to opt out specific activities. + + Args: + store_url: Base URL for the cache store. + ttl: Optional default TTL for all cached entries. + key_fn: Optional function to compute cache keys from activity args. + If not provided, all arguments are included in the key. + **storage_options: Extra keyword arguments passed to the fsspec + filesystem. + + Example:: + + from temporalio.contrib.activity_cache import CachingInterceptor + + worker = Worker( + client, + task_queue="my-queue", + activities=[extract, register, verify], + interceptors=[CachingInterceptor("gs://bucket/cache")], + ) + """ + + def __init__( + self, + store_url: str, + ttl: timedelta | None = None, + key_fn: Callable[..., dict[str, Any]] | None = None, + **storage_options: object, + ) -> None: + self._store = CacheStore(store_url, **storage_options) + self._ttl = ttl + self._key_fn = key_fn + + def intercept_activity( + self, + next: temporalio.worker.ActivityInboundInterceptor, + ) -> temporalio.worker.ActivityInboundInterceptor: + """Wrap the activity inbound interceptor with caching.""" + return _CachingActivityInboundInterceptor(next, self) + + +class _CachingActivityInboundInterceptor( + temporalio.worker.ActivityInboundInterceptor, +): + def __init__( + self, + next: temporalio.worker.ActivityInboundInterceptor, + root: CachingInterceptor, + ) -> None: + super().__init__(next) + self._root = root + + async def execute_activity( + self, + input: temporalio.worker.ExecuteActivityInput, + ) -> Any: + """Execute the activity with caching.""" + # Check opt-out marker + if getattr(input.fn, NO_CACHE_ATTR, False): + return await super().execute_activity(input) + + fn_name = getattr(input.fn, "__name__", str(input.fn)) + key = compute_cache_key(fn_name, input.args, self._root._key_fn) + + hit, value = await self._root._store.get(fn_name, key) + if hit: + logger.debug("Cache hit for %s (key=%s)", fn_name, key[:8]) + return value + + logger.debug("Cache miss for %s (key=%s)", fn_name, key[:8]) + result = await super().execute_activity(input) + + await self._root._store.set(fn_name, key, result, self._root._ttl) + return result diff --git a/temporalio/contrib/activity_cache/_keys.py b/temporalio/contrib/activity_cache/_keys.py new file mode 100644 index 000000000..e5f24c65d --- /dev/null +++ b/temporalio/contrib/activity_cache/_keys.py @@ -0,0 +1,38 @@ +"""Cache key computation for activity memoization.""" + +from __future__ import annotations + +import hashlib +import inspect +import json +from typing import Any, Callable + +from temporalio.contrib.activity_cache._serializers import serialize_for_hash + + +def compute_cache_key( + fn_name: str, + args: tuple[Any, ...], + key_fn: Callable[..., dict[str, Any]] | None = None, +) -> str: + """Compute a deterministic cache key from a function name and arguments. + + Args: + fn_name: The function/activity name. + args: Positional arguments passed to the function. + key_fn: Optional function that receives the activity arguments and + returns a dict of values to include in the cache key. If provided, + only the returned dict is hashed (not the full args). + + Returns: + A 32-character hex string (SHA256 prefix). + """ + if key_fn is not None: + # Resolve args to kwargs for key_fn + key_data = key_fn(*args) + serialized = serialize_for_hash(key_data) + else: + serialized = serialize_for_hash(args) + + payload = json.dumps({"fn": fn_name, "args": serialized}, sort_keys=True) + return hashlib.sha256(payload.encode()).hexdigest()[:32] diff --git a/temporalio/contrib/activity_cache/_serializers.py b/temporalio/contrib/activity_cache/_serializers.py new file mode 100644 index 000000000..1ff8593d0 --- /dev/null +++ b/temporalio/contrib/activity_cache/_serializers.py @@ -0,0 +1,86 @@ +"""Input serialization for deterministic cache key computation.""" + +from __future__ import annotations + +import dataclasses +import hashlib +from pathlib import Path +from typing import Any, Callable + + +# Registry of custom serializers: type → serializer function +_custom_serializers: dict[type, Callable[[Any], Any]] = {} + + +def register_serializer(type_: type, serializer: Callable[[Any], Any]) -> None: + """Register a custom serializer for a type. + + The serializer function should return a JSON-serializable value that + deterministically represents the input for cache key computation. + + Args: + type_: The type to register the serializer for. + serializer: A function that takes an instance of ``type_`` and returns + a JSON-serializable representation. + """ + _custom_serializers[type_] = serializer + + +def serialize_for_hash(value: Any) -> Any: + """Convert a value to a JSON-serializable form for cache key computation. + + Handles common types deterministically: + + - Primitives (str, int, float, bool, None) pass through + - bytes → SHA256 hash (first 16 chars) + - Path → SHA256 of file content (first 16 chars) + - Pydantic BaseModel → ``.model_dump()`` (recursive) + - dataclass → ``dataclasses.asdict()`` (recursive) + - dict → sorted keys, recursive values + - list/tuple → recursive elements + - Custom registered types → custom serializer + + Args: + value: The value to serialize. + + Returns: + A JSON-serializable representation suitable for hashing. + + Raises: + TypeError: If the value type is not supported and no custom + serializer is registered. + """ + if value is None or isinstance(value, (str, int, float, bool)): + return value + + # Check custom serializers first (allows overriding built-in behavior) + for type_, serializer in _custom_serializers.items(): + if isinstance(value, type_): + return serialize_for_hash(serializer(value)) + + if isinstance(value, bytes): + return {"__bytes__": hashlib.sha256(value).hexdigest()[:16]} + + if isinstance(value, Path): + if value.is_file(): + content = value.read_bytes() + return {"__path__": hashlib.sha256(content).hexdigest()[:16]} + return {"__path__": str(value)} + + if isinstance(value, dict): + return {str(k): serialize_for_hash(v) for k, v in sorted(value.items())} + + if isinstance(value, (list, tuple)): + return [serialize_for_hash(item) for item in value] + + # Pydantic BaseModel (check without importing pydantic) + if hasattr(value, "model_dump"): + return serialize_for_hash(value.model_dump()) + + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return serialize_for_hash(dataclasses.asdict(value)) + + raise TypeError( + f"Cannot serialize type {type(value).__name__} for cache key. " + f"Register a custom serializer with register_serializer()." + ) diff --git a/temporalio/contrib/activity_cache/_store.py b/temporalio/contrib/activity_cache/_store.py new file mode 100644 index 000000000..f0fe851b5 --- /dev/null +++ b/temporalio/contrib/activity_cache/_store.py @@ -0,0 +1,111 @@ +"""Cache store backed by fsspec for remote/local storage.""" + +from __future__ import annotations + +import json +import pickle +from datetime import datetime, timedelta, timezone +from typing import Any +from urllib.parse import urlparse + +import fsspec + + +class CacheStore: + """A key-value cache store backed by any fsspec-compatible filesystem. + + Cache entries are stored as two files per key: + + - ``{prefix}/{fn_name}/{key}.pkl`` — pickled return value + - ``{prefix}/{fn_name}/{key}.meta.json`` — metadata (expiration) + + Args: + base_url: Base URL for the cache store. The scheme determines the + fsspec backend (``gs://`` for GCS, ``s3://`` for S3, etc.). + **storage_options: Extra keyword arguments passed to + ``fsspec.filesystem()``. + """ + + def __init__(self, base_url: str, **storage_options: object) -> None: + self._base_url = base_url.rstrip("/") + parsed = urlparse(self._base_url) + self._protocol = parsed.scheme or "file" + self._base_path = ( + parsed.netloc + parsed.path if parsed.netloc else parsed.path + ) + self._fs = fsspec.filesystem(self._protocol, **storage_options) + + def _value_path(self, fn_name: str, key: str) -> str: + return f"{self._base_path}/{fn_name}/{key}.pkl" + + def _meta_path(self, fn_name: str, key: str) -> str: + return f"{self._base_path}/{fn_name}/{key}.meta.json" + + async def get(self, fn_name: str, key: str) -> tuple[bool, Any]: + """Retrieve a cached value. + + Args: + fn_name: The function/activity name (used as namespace). + key: The cache key. + + Returns: + A tuple of ``(hit, value)``. If ``hit`` is False, ``value`` + is None. + """ + meta_path = self._meta_path(fn_name, key) + + if not self._fs.exists(meta_path): + return False, None + + # Check TTL + meta = json.loads(self._fs.cat_file(meta_path)) + expires_at = meta.get("expires_at") + if expires_at is not None: + if datetime.fromisoformat(expires_at) < datetime.now(timezone.utc): + # Expired — clean up lazily + self._delete_entry(fn_name, key) + return False, None + + value_path = self._value_path(fn_name, key) + if not self._fs.exists(value_path): + return False, None + + data = self._fs.cat_file(value_path) + return True, pickle.loads(data) # noqa: S301 + + async def set( + self, fn_name: str, key: str, value: Any, ttl: timedelta | None = None + ) -> None: + """Store a value in the cache. + + Args: + fn_name: The function/activity name (used as namespace). + key: The cache key. + value: The value to cache (must be picklable). + ttl: Optional time-to-live. If None, the entry never expires. + """ + meta: dict[str, Any] = {} + if ttl is not None: + expires_at = datetime.now(timezone.utc) + ttl + meta["expires_at"] = expires_at.isoformat() + + value_path = self._value_path(fn_name, key) + meta_path = self._meta_path(fn_name, key) + + # Write value first, then metadata (metadata signals "entry exists") + self._fs.pipe_file(value_path, pickle.dumps(value)) + self._fs.pipe_file(meta_path, json.dumps(meta).encode()) + + async def delete(self, fn_name: str, key: str) -> None: + """Delete a cache entry. + + Args: + fn_name: The function/activity name (used as namespace). + key: The cache key. + """ + self._delete_entry(fn_name, key) + + def _delete_entry(self, fn_name: str, key: str) -> None: + for path in [self._value_path(fn_name, key), self._meta_path(fn_name, key)]: + if self._fs.exists(path): + self._fs.rm(path) diff --git a/tests/contrib/activity_cache/__init__.py b/tests/contrib/activity_cache/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/contrib/activity_cache/test_decorator.py b/tests/contrib/activity_cache/test_decorator.py new file mode 100644 index 000000000..c3b69636f --- /dev/null +++ b/tests/contrib/activity_cache/test_decorator.py @@ -0,0 +1,97 @@ +"""Tests for the @cached decorator.""" + +import uuid +from datetime import timedelta + +import pytest + +from temporalio.contrib.activity_cache import cached + + +class TestCachedDecorator: + """Tests for the @cached decorator with memory:// backend.""" + + @pytest.fixture + def store_url(self) -> str: + """Unique memory:// URL per test.""" + return f"memory://cache-test/{uuid.uuid4()}" + + async def test_cache_hit(self, store_url: str) -> None: + """Second call with same args returns cached result.""" + calls = 0 + + @cached(store_url) + async def fn(x: int) -> int: + nonlocal calls + calls += 1 + return x * 2 + + assert await fn(5) == 10 + assert calls == 1 + assert await fn(5) == 10 + assert calls == 1 # Not called again + + async def test_cache_miss_on_different_args(self, store_url: str) -> None: + """Different args cause a cache miss.""" + calls = 0 + + @cached(store_url) + async def fn(x: int) -> int: + nonlocal calls + calls += 1 + return x * 2 + + assert await fn(5) == 10 + assert await fn(6) == 12 + assert calls == 2 + + async def test_ttl_expiry(self, store_url: str) -> None: + """Expired entries are treated as misses.""" + calls = 0 + + @cached(store_url, ttl=timedelta(seconds=-1)) # Already expired + async def fn(x: int) -> int: + nonlocal calls + calls += 1 + return x + + assert await fn(1) == 1 + assert calls == 1 + assert await fn(1) == 1 + assert calls == 2 # Called again because first entry expired + + async def test_key_fn(self, store_url: str) -> None: + """key_fn controls which args are in the cache key.""" + calls = 0 + + @cached(store_url, key_fn=lambda x, _ignored: {"x": x}) + async def fn(x: int, ignored: str) -> int: + nonlocal calls + calls += 1 + return x + + assert await fn(1, "a") == 1 + assert calls == 1 + assert await fn(1, "b") == 1 # Different ignored arg, same key + assert calls == 1 + + async def test_preserves_function_name(self, store_url: str) -> None: + """Wrapper preserves the original function name.""" + + @cached(store_url) + async def my_function(x: int) -> int: + return x + + assert my_function.__name__ == "my_function" + + async def test_complex_return_values(self, store_url: str) -> None: + """Complex return types survive pickle round-trip.""" + + @cached(store_url) + async def fn() -> dict: + return {"nested": [1, 2, {"deep": True}], "set_like": [1, 2, 3]} + + result1 = await fn() + result2 = await fn() + assert result1 == result2 + assert result1 == {"nested": [1, 2, {"deep": True}], "set_like": [1, 2, 3]} diff --git a/tests/contrib/activity_cache/test_keys.py b/tests/contrib/activity_cache/test_keys.py new file mode 100644 index 000000000..487375621 --- /dev/null +++ b/tests/contrib/activity_cache/test_keys.py @@ -0,0 +1,75 @@ +"""Tests for cache key computation.""" + +from dataclasses import dataclass + +from temporalio.contrib.activity_cache._keys import compute_cache_key + + +class TestComputeCacheKey: + """Tests for deterministic cache key generation.""" + + def test_same_inputs_same_key(self) -> None: + """Identical inputs produce identical keys.""" + key1 = compute_cache_key("fn", ("a", "b")) + key2 = compute_cache_key("fn", ("a", "b")) + assert key1 == key2 + + def test_different_inputs_different_key(self) -> None: + """Different inputs produce different keys.""" + key1 = compute_cache_key("fn", ("a", "b")) + key2 = compute_cache_key("fn", ("a", "c")) + assert key1 != key2 + + def test_different_fn_name_different_key(self) -> None: + """Different function names produce different keys.""" + key1 = compute_cache_key("fn_a", ("x",)) + key2 = compute_cache_key("fn_b", ("x",)) + assert key1 != key2 + + def test_key_is_32_hex_chars(self) -> None: + """Keys are 32-character hex strings.""" + key = compute_cache_key("fn", ("input",)) + assert len(key) == 32 + assert all(c in "0123456789abcdef" for c in key) + + def test_key_fn_selects_args(self) -> None: + """key_fn controls which arguments are included.""" + key1 = compute_cache_key( + "fn", ("a", "b"), key_fn=lambda x, y: {"x": x} + ) + key2 = compute_cache_key( + "fn", ("a", "DIFFERENT"), key_fn=lambda x, y: {"x": x} + ) + assert key1 == key2 + + def test_key_fn_different_selected_args(self) -> None: + """key_fn with different selected values produces different keys.""" + key1 = compute_cache_key( + "fn", ("a", "b"), key_fn=lambda x, y: {"x": x} + ) + key2 = compute_cache_key( + "fn", ("DIFFERENT", "b"), key_fn=lambda x, y: {"x": x} + ) + assert key1 != key2 + + def test_dataclass_inputs(self) -> None: + """Dataclass inputs are serialized deterministically.""" + + @dataclass + class Input: + name: str + value: int + + key1 = compute_cache_key("fn", (Input("test", 42),)) + key2 = compute_cache_key("fn", (Input("test", 42),)) + key3 = compute_cache_key("fn", (Input("test", 99),)) + assert key1 == key2 + assert key1 != key3 + + def test_bytes_content_addressed(self) -> None: + """Bytes inputs use content hash, not identity.""" + key1 = compute_cache_key("fn", (b"hello",)) + key2 = compute_cache_key("fn", (b"hello",)) + key3 = compute_cache_key("fn", (b"world",)) + assert key1 == key2 + assert key1 != key3 From 8823262e2a5a189ac896585143747e4af2475844 Mon Sep 17 00:00:00 2001 From: Saeed Seyfi Date: Wed, 18 Mar 2026 17:02:40 +0100 Subject: [PATCH 2/2] test: add CachingInterceptor integration tests Tests interceptor with real Temporal worker: - Verifies second call with same args is served from cache - Verifies @no_cache activities always execute --- .../activity_cache/test_interceptor.py | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 tests/contrib/activity_cache/test_interceptor.py diff --git a/tests/contrib/activity_cache/test_interceptor.py b/tests/contrib/activity_cache/test_interceptor.py new file mode 100644 index 000000000..f15715fed --- /dev/null +++ b/tests/contrib/activity_cache/test_interceptor.py @@ -0,0 +1,132 @@ +"""Tests for the CachingInterceptor with real Temporal workers.""" + +import uuid +from datetime import timedelta + +import pytest + +from temporalio import activity, workflow +from temporalio.contrib.activity_cache import CachingInterceptor, no_cache +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker + +# Track call counts across activities +_call_counts: dict[str, int] = {} + + +@activity.defn +async def add(x: int, y: int) -> int: + """Activity that tracks how many times it's called.""" + _call_counts["add"] = _call_counts.get("add", 0) + 1 + return x + y + + +@no_cache +@activity.defn +async def add_no_cache(x: int, y: int) -> int: + """Activity opted out of caching.""" + _call_counts["add_no_cache"] = _call_counts.get("add_no_cache", 0) + 1 + return x + y + + +@workflow.defn +class CallTwiceWorkflow: + """Calls the same activity twice with the same args.""" + + @workflow.run + async def run(self, x: int, y: int) -> list[int]: + r1 = await workflow.execute_activity( + add, + args=[x, y], + start_to_close_timeout=timedelta(seconds=30), + ) + r2 = await workflow.execute_activity( + add, + args=[x, y], + start_to_close_timeout=timedelta(seconds=30), + ) + return [r1, r2] + + +@workflow.defn +class CallNoCacheWorkflow: + """Calls a @no_cache activity twice.""" + + @workflow.run + async def run(self, x: int, y: int) -> list[int]: + r1 = await workflow.execute_activity( + add_no_cache, + args=[x, y], + start_to_close_timeout=timedelta(seconds=30), + ) + r2 = await workflow.execute_activity( + add_no_cache, + args=[x, y], + start_to_close_timeout=timedelta(seconds=30), + ) + return [r1, r2] + + +class TestCachingInterceptor: + """Tests for CachingInterceptor with real Temporal workers.""" + + @pytest.fixture(autouse=True) + def reset_counts(self) -> None: + """Reset call counts before each test.""" + _call_counts.clear() + + @pytest.fixture + def store_url(self) -> str: + """Unique memory:// URL per test.""" + return f"memory://interceptor-test/{uuid.uuid4()}" + + @pytest.fixture + async def env(self) -> WorkflowEnvironment: + """Start a local Temporal test environment.""" + return await WorkflowEnvironment.start_local() + + async def test_interceptor_caches_activity( + self, env: WorkflowEnvironment, store_url: str + ) -> None: + """Second call with same args is served from cache.""" + task_queue = str(uuid.uuid4()) + + async with Worker( + env.client, + task_queue=task_queue, + workflows=[CallTwiceWorkflow], + activities=[add], + interceptors=[CachingInterceptor(store_url)], + ): + result = await env.client.execute_workflow( + CallTwiceWorkflow.run, + args=[3, 4], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + + assert result == [7, 7] + assert _call_counts["add"] == 1 # Only called once, second was cached + + async def test_no_cache_skips_caching( + self, env: WorkflowEnvironment, store_url: str + ) -> None: + """@no_cache activities always execute.""" + task_queue = str(uuid.uuid4()) + + async with Worker( + env.client, + task_queue=task_queue, + workflows=[CallNoCacheWorkflow], + activities=[add_no_cache], + interceptors=[CachingInterceptor(store_url)], + ): + result = await env.client.execute_workflow( + CallNoCacheWorkflow.run, + args=[3, 4], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + + assert result == [7, 7] + assert _call_counts["add_no_cache"] == 2 # Called both times