From 6e2ae62eb4dbf299e7f61f60f8525af91cebd04b Mon Sep 17 00:00:00 2001 From: Jason Steving Date: Tue, 17 Mar 2026 18:46:01 -0700 Subject: [PATCH 1/6] Initial Generally Working Implementation More to be done to improve quality of the integration. This currently contains a dumping ground of in progress files that will need to be cleaned up before merged into main. --- pyproject.toml | 3 + .../contrib/google_gemini_sdk/__init__.py | 57 +++ .../google_gemini_sdk/_gemini_plugin.py | 103 ++++++ .../google_gemini_sdk/_heartbeat_decorator.py | 40 +++ .../_invoke_model_activity.py | 182 ++++++++++ .../_model_activity_parameters.py | 52 +++ .../first_class_example/start_workflow.py | 32 ++ .../first_class_example/worker.py | 153 ++++++++ temporalio/contrib/google_gemini_sdk/justfile | 11 + .../contrib/google_gemini_sdk/testing.py | 161 +++++++++ .../contrib/google_gemini_sdk/workflow.py | 332 ++++++++++++++++++ uv.lock | 12 +- 12 files changed, 1131 insertions(+), 7 deletions(-) create mode 100644 temporalio/contrib/google_gemini_sdk/__init__.py create mode 100644 temporalio/contrib/google_gemini_sdk/_gemini_plugin.py create mode 100644 temporalio/contrib/google_gemini_sdk/_heartbeat_decorator.py create mode 100644 temporalio/contrib/google_gemini_sdk/_invoke_model_activity.py create mode 100644 temporalio/contrib/google_gemini_sdk/_model_activity_parameters.py create mode 100644 temporalio/contrib/google_gemini_sdk/first_class_example/start_workflow.py create mode 100644 temporalio/contrib/google_gemini_sdk/first_class_example/worker.py create mode 100644 temporalio/contrib/google_gemini_sdk/justfile create mode 100644 temporalio/contrib/google_gemini_sdk/testing.py create mode 100644 temporalio/contrib/google_gemini_sdk/workflow.py diff --git a/pyproject.toml b/pyproject.toml index 7a2df7ea8..909153304 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,9 @@ opentelemetry = ["opentelemetry-api>=1.11.1,<2", "opentelemetry-sdk>=1.11.1,<2"] pydantic = ["pydantic>=2.0.0,<3"] openai-agents = ["openai-agents>=0.3,<0.7", "mcp>=1.9.4, <2"] google-adk = ["google-adk>=1.27.0,<2"] +google-gemini = [ + "google-genai>=1.66.0", +] [project.urls] Homepage = "https://github.com/temporalio/sdk-python" diff --git a/temporalio/contrib/google_gemini_sdk/__init__.py b/temporalio/contrib/google_gemini_sdk/__init__.py new file mode 100644 index 000000000..a9a9d3b88 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/__init__.py @@ -0,0 +1,57 @@ +"""First-class Temporal integration for the Google Gemini SDK. + +.. warning:: + This module is experimental and may change in future versions. + Use with caution in production environments. + +Quickstart:: + + from temporalio.contrib.google_gemini_sdk import ( + GeminiAgent, + GeminiPlugin, + activity_as_tool, + run_agent, + ) + + @workflow.defn + class MyAgentWorkflow: + @workflow.run + async def run(self, query: str) -> str: + return await run_agent( + GeminiAgent( + model="gemini-2.5-flash", + system_instruction="You are a helpful assistant.", + tools=[ + activity_as_tool(my_tool, start_to_close_timeout=timedelta(seconds=30)), + ], + ), + query, + ) +""" + +from temporalio.contrib.google_gemini_sdk._gemini_plugin import GeminiPlugin +from temporalio.contrib.google_gemini_sdk._model_activity_parameters import ( + ModelActivityParameters, +) +from temporalio.contrib.google_gemini_sdk.workflow import ( + ActivityTool, + GeminiAgent, + GeminiAgentWorkflowError, + GeminiToolSerializationError, + activity_as_tool, + run_agent, +) +from temporalio.contrib.google_gemini_sdk import testing, workflow + +__all__ = [ + "ActivityTool", + "GeminiAgent", + "GeminiAgentWorkflowError", + "GeminiPlugin", + "GeminiToolSerializationError", + "ModelActivityParameters", + "activity_as_tool", + "run_agent", + "testing", + "workflow", +] diff --git a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py new file mode 100644 index 000000000..bf8e0a979 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py @@ -0,0 +1,103 @@ +"""Temporal plugin for Google Gemini SDK integration.""" + +from __future__ import annotations + +import dataclasses +from collections.abc import Callable + +from google import genai + +from temporalio.contrib.google_gemini_sdk._invoke_model_activity import ( + GeminiModelActivity, +) +from temporalio.contrib.google_gemini_sdk._model_activity_parameters import ( + ModelActivityParameters, +) +from temporalio.contrib.google_gemini_sdk.workflow import GeminiAgentWorkflowError +from temporalio.contrib.pydantic import ( + PydanticPayloadConverter as _DefaultPydanticPayloadConverter, +) +from temporalio.converter import DataConverter, DefaultPayloadConverter +from temporalio.plugin import SimplePlugin +from temporalio.worker import WorkflowRunner +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + + +class GeminiPlugin(SimplePlugin): + """A Temporal Worker Plugin configured for the Google Gemini SDK. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + This plugin configures: + - Pydantic Payload Converter (required for Gemini SDK types). + - Sandbox passthrough for ``google.genai`` and ``google.api_core`` modules. + - The ``generate_content_activity`` model invocation activity. + - ``GeminiAgentWorkflowError`` as a workflow failure exception type. + + Example: + >>> plugin = GeminiPlugin() + >>> client = await Client.connect("localhost:7233", plugins=[plugin]) + >>> async with Worker( + ... client, + ... task_queue="my-queue", + ... workflows=[MyAgentWorkflow], + ... activities=[my_tool_activity], + ... ): + ... await asyncio.Event().wait() + """ + + def __init__( + self, + model_params: ModelActivityParameters | None = None, + client_factory: Callable[[], genai.Client] | None = None, + _model_activity: GeminiModelActivity | None = None, + ) -> None: + """Initialize the Gemini plugin. + + Args: + model_params: Optional default parameters for model activity execution. + Currently accepted but not applied automatically; pass ``model_params`` + directly to :func:`~temporalio.contrib.google_gemini_sdk.workflow.run_agent`. + client_factory: Optional factory function for creating the Gemini client. + Defaults to reading ``GOOGLE_API_KEY`` from the environment. + _model_activity: Internal override for testing. Prefer using + :class:`~temporalio.contrib.google_gemini_sdk.testing.GeminiEnvironment` + instead of setting this directly. + """ + model_activity = _model_activity or GeminiModelActivity(client_factory) + + def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: + if not runner: + raise ValueError("No WorkflowRunner provided to GeminiPlugin.") + if isinstance(runner, SandboxedWorkflowRunner): + return dataclasses.replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + "google.genai", "google.api_core" + ), + ) + return runner + + super().__init__( + name="GeminiPlugin", + data_converter=self._configure_data_converter, + activities=[model_activity.generate_content_activity], + workflow_runner=workflow_runner, + workflow_failure_exception_types=[GeminiAgentWorkflowError], + ) + + def _configure_data_converter( + self, converter: DataConverter | None + ) -> DataConverter: + if converter is None: + return DataConverter( + payload_converter_class=_DefaultPydanticPayloadConverter + ) + elif converter.payload_converter_class is DefaultPayloadConverter: + return dataclasses.replace( + converter, + payload_converter_class=_DefaultPydanticPayloadConverter, + ) + return converter diff --git a/temporalio/contrib/google_gemini_sdk/_heartbeat_decorator.py b/temporalio/contrib/google_gemini_sdk/_heartbeat_decorator.py new file mode 100644 index 000000000..4baff6706 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_heartbeat_decorator.py @@ -0,0 +1,40 @@ +import asyncio +from collections.abc import Awaitable, Callable +from functools import wraps +from typing import Any, TypeVar, cast + +from temporalio import activity + +F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) + + +def _auto_heartbeater(fn: F) -> F: # type:ignore[reportUnusedClass] + # Propagate type hints from the original callable. + @wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + heartbeat_timeout = activity.info().heartbeat_timeout + heartbeat_task = None + if heartbeat_timeout: + # Heartbeat twice as often as the timeout + heartbeat_task = asyncio.create_task( + heartbeat_every(heartbeat_timeout.total_seconds() / 2) + ) + try: + return await fn(*args, **kwargs) + finally: + if heartbeat_task: + heartbeat_task.cancel() + # Wait for heartbeat cancellation to complete + try: + await heartbeat_task + except asyncio.CancelledError: + pass + + return cast(F, wrapper) + + +async def heartbeat_every(delay: float, *details: Any) -> None: + """Heartbeat every so often while not cancelled""" + while True: + await asyncio.sleep(delay) + activity.heartbeat(*details) diff --git a/temporalio/contrib/google_gemini_sdk/_invoke_model_activity.py b/temporalio/contrib/google_gemini_sdk/_invoke_model_activity.py new file mode 100644 index 000000000..4bd224c3d --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_invoke_model_activity.py @@ -0,0 +1,182 @@ +"""Gemini model invocation activity for Temporal workflows.""" + +from __future__ import annotations + +import os +from collections.abc import Callable +from typing import Any + +from pydantic import BaseModel + +from google import genai +from google.genai import types + +from temporalio import activity +from temporalio.exceptions import ApplicationError + +from temporalio.contrib.google_gemini_sdk._heartbeat_decorator import _auto_heartbeater + + +class FunctionCallOutput(BaseModel): + """A single function call returned by the model.""" + + name: str + args: dict[str, Any] + + +class ActivityModelInput(BaseModel): + """Input for the Gemini model invocation activity.""" + + model: str + system_instruction: str | None = None + contents: list[types.Content] + function_declarations: list[types.FunctionDeclaration] = [] + generation_config: types.GenerateContentConfig | None = None + + +class ActivityModelOutput(BaseModel): + """Output from the Gemini model invocation activity.""" + + text: str | None + function_calls: list[FunctionCallOutput] + model_content: types.Content # Model turn to append to conversation history + + +def _default_client_factory() -> genai.Client: + try: + api_key = os.environ["GOOGLE_API_KEY"] + except KeyError: + raise ApplicationError( + "GOOGLE_API_KEY environment variable is not set", + non_retryable=True, + ) + return genai.Client( + api_key=api_key, + http_options=types.HttpOptions( + retry_options=types.HttpRetryOptions(attempts=1), + ), + ) + + +def _map_google_exception(exc: Exception) -> None: + """Map google.api_core exceptions to ApplicationError with correct retryability.""" + try: + from google.api_core import exceptions as google_exceptions + except ImportError: + return + + if isinstance(exc, google_exceptions.ResourceExhausted): + raise ApplicationError( + str(exc), + type="ResourceExhausted", + non_retryable=False, + ) from exc + elif isinstance( + exc, + ( + google_exceptions.DeadlineExceeded, + google_exceptions.ServiceUnavailable, + google_exceptions.InternalServerError, + ), + ): + raise ApplicationError( + str(exc), type=type(exc).__name__, non_retryable=False + ) from exc + elif isinstance( + exc, + ( + google_exceptions.InvalidArgument, + google_exceptions.PermissionDenied, + google_exceptions.NotFound, + ), + ): + raise ApplicationError( + str(exc), type=type(exc).__name__, non_retryable=True + ) from exc + + +class GeminiModelActivity: + """Temporal activity class for invoking the Gemini model. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + """ + + def __init__( + self, + client_factory: Callable[[], genai.Client] | None = None, + ) -> None: + self._client_factory = client_factory or _default_client_factory + + @activity.defn + @_auto_heartbeater + async def generate_content_activity( + self, input: ActivityModelInput + ) -> ActivityModelOutput: + """Invoke the Gemini model and return text and/or function calls.""" + try: + client = self._client_factory() + + # Merge user's generation_config with required agent settings. + # AFC must be disabled so Temporal owns tool execution. + base_config = ( + input.generation_config + if input.generation_config is not None + else types.GenerateContentConfig() + ) + config = base_config.model_copy( + update={ + "system_instruction": input.system_instruction, + "tools": ( + [types.Tool(function_declarations=input.function_declarations)] + if input.function_declarations + else None + ), + "automatic_function_calling": types.AutomaticFunctionCallingConfig( + disable=True + ), + } + ) + + response = await client.aio.models.generate_content( + model=input.model, + contents=input.contents, + config=config, + ) + + function_calls: list[FunctionCallOutput] = [] + text_parts: list[str] = [] + model_content: types.Content + + if response.candidates and response.candidates[0].content: + model_content = response.candidates[0].content + for part in model_content.parts: + if part.function_call: + function_calls.append( + FunctionCallOutput( + name=part.function_call.name, + args=( + dict(part.function_call.args) + if part.function_call.args + else {} + ), + ) + ) + elif part.text: + text_parts.append(part.text) + else: + model_content = types.Content(role="model", parts=[]) + + # Only include text if there are no function calls (avoids SDK warning) + text = "".join(text_parts) if text_parts and not function_calls else None + + return ActivityModelOutput( + text=text, + function_calls=function_calls, + model_content=model_content, + ) + + except Exception as exc: + _map_google_exception(exc) + raise diff --git a/temporalio/contrib/google_gemini_sdk/_model_activity_parameters.py b/temporalio/contrib/google_gemini_sdk/_model_activity_parameters.py new file mode 100644 index 000000000..7ef728510 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_model_activity_parameters.py @@ -0,0 +1,52 @@ +"""Parameters for configuring Temporal activity execution for Gemini model calls.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import timedelta + +from temporalio.common import Priority, RetryPolicy +from temporalio.workflow import ActivityCancellationType, VersioningIntent + + +@dataclass +class ModelActivityParameters: + """Parameters for configuring Temporal activity execution for Gemini model calls. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + This class encapsulates all the parameters that can be used to configure + how Temporal activities are executed when making Gemini model calls. + """ + + task_queue: str | None = None + """Specific task queue to use for model activities.""" + + schedule_to_close_timeout: timedelta | None = None + """Maximum time from scheduling to completion.""" + + schedule_to_start_timeout: timedelta | None = None + """Maximum time from scheduling to starting.""" + + start_to_close_timeout: timedelta | None = timedelta(seconds=60) + """Maximum time for the activity to complete.""" + + heartbeat_timeout: timedelta | None = None + """Maximum time between heartbeats.""" + + retry_policy: RetryPolicy | None = None + """Policy for retrying failed activities.""" + + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL + """How the activity handles cancellation.""" + + versioning_intent: VersioningIntent | None = None + """Versioning intent for the activity.""" + + priority: Priority = field(default_factory=lambda: Priority.default) + """Priority for the activity execution.""" + + use_local_activity: bool = False + """Whether to use a local activity. Changing mid-workflow breaks determinism.""" diff --git a/temporalio/contrib/google_gemini_sdk/first_class_example/start_workflow.py b/temporalio/contrib/google_gemini_sdk/first_class_example/start_workflow.py new file mode 100644 index 000000000..1a773a87e --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/first_class_example/start_workflow.py @@ -0,0 +1,32 @@ +# ABOUTME: Client script to start the first-class Gemini agent workflow. +# No GOOGLE_API_KEY needed here — only the worker requires it. + +import asyncio +import sys +import uuid + +from temporalio.client import Client +from temporalio.contrib.pydantic import pydantic_data_converter + +TASK_QUEUE = "gemini-first-class" + + +async def main() -> None: + client = await Client.connect( + "localhost:7233", + data_converter=pydantic_data_converter, + ) + + query = sys.argv[1] if len(sys.argv) > 1 else "What's the weather like right now?" + + result = await client.execute_workflow( + "WeatherAgentWorkflow", + query, + id=f"gemini-first-class-{uuid.uuid4()}", + task_queue=TASK_QUEUE, + ) + print(f"\nResult:\n{result}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py new file mode 100644 index 000000000..dfd2effb4 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py @@ -0,0 +1,153 @@ +# ABOUTME: First-class Temporal + Gemini SDK integration demo. +# Demonstrates the clean developer experience: 3-line workflow, no manual loop, +# no dynamic activities, no tool registry, and no inspect hackery. + +import asyncio +import json +from datetime import timedelta + +from dotenv import load_dotenv +from pydantic import BaseModel, Field +from temporalio import activity, workflow +from temporalio.client import Client +from temporalio.envconfig import ClientConfig +from temporalio.worker import Worker + +with workflow.unsafe.imports_passed_through(): + import httpx + +from temporalio.contrib.google_gemini_sdk import GeminiAgent, GeminiPlugin, activity_as_tool, run_agent + + +# ============================================================================= +# System Instructions +# ============================================================================= + +SYSTEM_INSTRUCTIONS = """ +You are a helpful agent that can use tools to help the user. +You will be given an input from the user and a list of tools to use. +You may or may not need to use the tools to satisfy the user ask. +If no tools are needed, respond in haikus. +""" + +# ============================================================================= +# Tool Definitions — plain @activity.defn functions, no registry required +# ============================================================================= + +NWS_API_BASE = "https://api.weather.gov" +USER_AGENT = "weather-app/1.0" + + +class GetWeatherAlertsRequest(BaseModel): + """Request model for getting weather alerts.""" + + state: str = Field(description="Two-letter US state code (e.g. CA, NY)") + + +@activity.defn +async def get_weather_alerts(request: GetWeatherAlertsRequest) -> str: + """Get weather alerts for a US state. + + Args: + request: The request object containing: + - state: Two-letter US state code (e.g. CA, NY) + """ + headers = {"User-Agent": USER_AGENT, "Accept": "application/geo+json"} + url = f"{NWS_API_BASE}/alerts/active/area/{request.state}" + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers, timeout=5.0) + response.raise_for_status() + return json.dumps(response.json()) + + +@activity.defn +async def get_ip_address() -> str: + """Get the public IP address of the current machine.""" + async with httpx.AsyncClient() as client: + response = await client.get("https://icanhazip.com") + response.raise_for_status() + return response.text.strip() + + +class GetLocationRequest(BaseModel): + """Request model for getting location info from an IP address.""" + + ipaddress: str = Field(description="An IP address") + + +@activity.defn +async def get_location_info(request: GetLocationRequest) -> str: + """Get the location information for an IP address including city, state, and country. + + Args: + request: The request object containing: + - ipaddress: An IP address to look up + """ + async with httpx.AsyncClient() as client: + response = await client.get(f"http://ip-api.com/json/{request.ipaddress}") + response.raise_for_status() + result = response.json() + return f"{result['city']}, {result['regionName']}, {result['country']}" + + +# ============================================================================= +# Workflow — 3 lines of real logic, no manual loop, no print() debugging +# ============================================================================= + +TASK_QUEUE = "gemini-first-class" + + +@workflow.defn +class WeatherAgentWorkflow: + """Durable agentic workflow powered by Gemini SDK and Temporal.""" + + @workflow.run + async def run(self, query: str) -> str: + return await run_agent( + GeminiAgent( + model="gemini-2.5-flash", + system_instruction=SYSTEM_INSTRUCTIONS, + tools=[ + activity_as_tool( + get_weather_alerts, + start_to_close_timeout=timedelta(seconds=30), + ), + activity_as_tool( + get_ip_address, + start_to_close_timeout=timedelta(seconds=10), + ), + activity_as_tool( + get_location_info, + start_to_close_timeout=timedelta(seconds=10), + ), + ], + ), + query, + ) + + +# ============================================================================= +# Worker — register plugin + user activities, nothing else required +# ============================================================================= + + +async def main() -> None: + load_dotenv() + + plugin = GeminiPlugin() + + config = ClientConfig.load_client_connect_config() + config.setdefault("target_host", "localhost:7233") + client = await Client.connect(**config, plugins=[plugin]) + + async with Worker( + client, + task_queue=TASK_QUEUE, + workflows=[WeatherAgentWorkflow], + activities=[get_weather_alerts, get_ip_address, get_location_info], + ): + await asyncio.Event().wait() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/temporalio/contrib/google_gemini_sdk/justfile b/temporalio/contrib/google_gemini_sdk/justfile new file mode 100644 index 000000000..b883c0379 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/justfile @@ -0,0 +1,11 @@ +set dotenv-filename := ".env.local" +set dotenv-load + +run: + uv run python test_gemini.py + +worker: + uv run python first_class_example/worker.py + +query q="What's the weather right now?": + uv run python first_class_example/start_workflow.py "{{q}}" diff --git a/temporalio/contrib/google_gemini_sdk/testing.py b/temporalio/contrib/google_gemini_sdk/testing.py new file mode 100644 index 000000000..4ae682ab8 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/testing.py @@ -0,0 +1,161 @@ +"""Testing utilities for the Google Gemini SDK Temporal integration. + +.. warning:: + This module is experimental and may change in future versions. + Use with caution in production environments. + +Example:: + + from temporalio.contrib.google_gemini_sdk.testing import ( + GeminiEnvironment, + MockGeminiResponse, + ) + + async def test_weather_agent(): + responses = [ + MockGeminiResponse.tool_call("get_weather_alerts", {"request": {"state": "CA"}}), + MockGeminiResponse.text("No active weather alerts in California."), + ] + async with GeminiEnvironment(responses=responses) as env: + client = await Client.connect("localhost:7233", plugins=[env.plugin]) + # ... run your workflow test +""" + +from __future__ import annotations + +from google.genai import types + +from temporalio import activity +from temporalio.contrib.google_gemini_sdk._invoke_model_activity import ( + ActivityModelInput, + ActivityModelOutput, + FunctionCallOutput, +) +from temporalio.contrib.google_gemini_sdk._gemini_plugin import GeminiPlugin + + +class MockGeminiResponse: + """Factory for constructing :class:`ActivityModelOutput` test fixtures. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + """ + + __test__ = False + + @staticmethod + def text(text: str) -> ActivityModelOutput: + """Return an output that simulates a plain-text model response.""" + return ActivityModelOutput( + text=text, + function_calls=[], + model_content=types.Content( + role="model", + parts=[types.Part.from_text(text=text)], + ), + ) + + @staticmethod + def tool_call(name: str, args: dict) -> ActivityModelOutput: + """Return an output that simulates a model requesting a tool call.""" + return ActivityModelOutput( + text=None, + function_calls=[FunctionCallOutput(name=name, args=args)], + model_content=types.Content( + role="model", + parts=[ + types.Part( + function_call=types.FunctionCall(name=name, args=args) + ) + ], + ), + ) + + +class TestGeminiModelActivity: + """A mock replacement for :class:`GeminiModelActivity` that returns pre-configured responses. + + Responses are consumed in FIFO order. If no responses remain, raises ``IndexError``. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + Example:: + + responses = [ + MockGeminiResponse.tool_call("lookup", {"lookup": {"query": "CA"}}), + MockGeminiResponse.text("Here is the answer."), + ] + activity_instance = TestGeminiModelActivity(responses) + plugin = GeminiPlugin(_model_activity=activity_instance) + """ + + __test__ = False + + def __init__(self, responses: list[ActivityModelOutput]) -> None: + self._responses = list(responses) + + @activity.defn + async def generate_content_activity( + self, input: ActivityModelInput + ) -> ActivityModelOutput: + """Return the next pre-configured response, ignoring the actual input.""" + if not self._responses: + raise IndexError( + "TestGeminiModelActivity has no more responses. " + "Add more responses to the list passed to the constructor." + ) + return self._responses.pop(0) + + +class GeminiEnvironment: + """A test environment that wires up a mock Gemini model activity. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + Example:: + + responses = [ + MockGeminiResponse.text("Hello!"), + ] + async with GeminiEnvironment(responses=responses) as env: + client = await Client.connect("localhost:7233", plugins=[env.plugin]) + async with Worker( + client, + task_queue="test-queue", + workflows=[MyWorkflow], + activities=[my_tool], + ): + result = await client.execute_workflow(...) + """ + + __test__ = False + + def __init__( + self, + responses: list[ActivityModelOutput] | None = None, + ) -> None: + test_activity = TestGeminiModelActivity(list(responses or [])) + self._plugin = GeminiPlugin(_model_activity=test_activity) + + async def __aenter__(self) -> GeminiEnvironment: + return self + + async def __aexit__(self, *args: object) -> None: + pass + + def applied_on_client(self, client: object) -> object: + """Return the client unchanged (plugin is applied at connect time via ``plugins=``). + + This method exists for API symmetry with other test environment helpers. + """ + return client + + @property + def plugin(self) -> GeminiPlugin: + """The :class:`GeminiPlugin` configured with the mock model activity.""" + return self._plugin diff --git a/temporalio/contrib/google_gemini_sdk/workflow.py b/temporalio/contrib/google_gemini_sdk/workflow.py new file mode 100644 index 000000000..c3a6ebd8a --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/workflow.py @@ -0,0 +1,332 @@ +"""Workflow utilities for Gemini SDK integration with Temporal.""" + +from __future__ import annotations + +import functools +import inspect +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any +from collections.abc import Callable + +from google.genai import types + +from temporalio import activity +from temporalio import workflow as temporal_workflow +from temporalio.common import Priority, RetryPolicy +from temporalio.exceptions import ApplicationError, TemporalError +from temporalio.workflow import ActivityCancellationType, VersioningIntent + +from temporalio.contrib.google_gemini_sdk._invoke_model_activity import ( + ActivityModelInput, + GeminiModelActivity, +) +from temporalio.contrib.google_gemini_sdk._model_activity_parameters import ( + ModelActivityParameters, +) + + +@dataclass +class ActivityTool: + """A Temporal activity wrapped as a Gemini tool. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + """ + + function_declaration: types.FunctionDeclaration + activity_name: str + schedule_kwargs: dict[str, Any] + + +def activity_as_tool( + fn: Callable, + *, + task_queue: str | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = timedelta(seconds=30), + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + versioning_intent: VersioningIntent | None = None, + priority: Priority = Priority.default, + api_option: str = "GEMINI_API", +) -> ActivityTool: + """Convert a Temporal activity function into a Gemini tool for use with :func:`run_agent`. + + .. warning:: + This API is experimental and may change in future versions. + Use with caution in production environments. + + Args: + fn: A Temporal activity function decorated with ``@activity.defn``. + task_queue: Specific task queue to use for this tool's activity. + schedule_to_close_timeout: Maximum time from scheduling to completion. + schedule_to_start_timeout: Maximum time from scheduling to starting. + start_to_close_timeout: Maximum time for the activity to complete. + heartbeat_timeout: Maximum time between heartbeats. + retry_policy: Policy for retrying failed activities. + cancellation_type: How the activity handles cancellation. + versioning_intent: Versioning intent for the activity. + priority: Priority for the activity execution. + api_option: Gemini API option for schema generation. Defaults to ``"GEMINI_API"``. + + Returns: + An :class:`ActivityTool` wrapping the activity with its Gemini schema. + + Raises: + ApplicationError: If the function is not properly decorated as a Temporal activity. + + Example: + >>> @activity.defn + ... async def get_weather(request: WeatherRequest) -> str: ... + >>> + >>> tool = activity_as_tool( + ... get_weather, + ... start_to_close_timeout=timedelta(seconds=30), + ... ) + """ + ret = activity._Definition.from_callable(fn) + if not ret: + raise ApplicationError( + "Bare function without @activity.defn decorator is not supported", + "invalid_tool", + ) + if ret.name is None: + raise ApplicationError( + "Activity must have a name to be used as a tool", + "invalid_tool", + ) + + # If the callable has a 'self' parameter (class-based activity), partially apply it + # so that FunctionDeclaration schema generation ignores the self param. + # The actual instance is resolved at execution time by the worker. + params = list(inspect.signature(fn).parameters.keys()) + schema_fn = fn + if len(params) > 0 and params[0] == "self": + partial = functools.partial(fn, None) + setattr(partial, "__name__", fn.__name__) + partial.__annotations__ = getattr(fn, "__annotations__", {}) + setattr( + partial, + "__temporal_activity_definition", + getattr(fn, "__temporal_activity_definition", None), + ) + partial.__doc__ = fn.__doc__ + schema_fn = partial + + function_declaration = types.FunctionDeclaration.from_callable_with_api_option( + callable=schema_fn, + api_option=api_option, + ) + + schedule_kwargs: dict[str, Any] = { + "task_queue": task_queue, + "schedule_to_close_timeout": schedule_to_close_timeout, + "schedule_to_start_timeout": schedule_to_start_timeout, + "start_to_close_timeout": start_to_close_timeout, + "heartbeat_timeout": heartbeat_timeout, + "retry_policy": retry_policy, + "cancellation_type": cancellation_type, + "versioning_intent": versioning_intent, + "priority": priority, + } + + return ActivityTool( + function_declaration=function_declaration, + activity_name=ret.name, + schedule_kwargs=schedule_kwargs, + ) + + +@dataclass +class GeminiAgent: + """Configuration for a Gemini-powered agentic loop. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + Example: + >>> agent = GeminiAgent( + ... model="gemini-2.5-flash", + ... system_instruction="You are a helpful assistant.", + ... tools=[ + ... activity_as_tool(get_weather, start_to_close_timeout=timedelta(seconds=30)), + ... ], + ... ) + """ + + model: str + system_instruction: str | None = None + tools: list[ActivityTool] = field(default_factory=list) + generation_config: types.GenerateContentConfig | None = None + api_option: str = "GEMINI_API" + + +async def run_agent( + agent: GeminiAgent, + initial_message: str, + *, + model_params: ModelActivityParameters | None = None, + max_turns: int = 10, +) -> str: + """Run the Gemini agentic loop inside a Temporal workflow. + + .. warning:: + This API is experimental and may change in future versions. + Use with caution in production environments. + + Each model call and each tool invocation becomes a separate Temporal activity, + giving full workflow history visibility and crash recovery. + + Args: + agent: The :class:`GeminiAgent` configuration. + initial_message: The user's initial query. + model_params: Optional parameters for configuring model activity execution. + max_turns: Maximum number of model call + tool execution rounds before raising. + + Returns: + The model's final text response. + + Raises: + GeminiAgentWorkflowError: If ``max_turns`` is exceeded or an unknown tool is called. + """ + if model_params is None: + model_params = ModelActivityParameters() + + history: list[types.Content] = [ + types.Content(role="user", parts=[types.Part.from_text(text=initial_message)]) + ] + + function_declarations = [t.function_declaration for t in agent.tools] + activity_tools: dict[str, ActivityTool] = { + t.function_declaration.name: t for t in agent.tools + } + + # Build kwargs for the model activity execution + model_kwargs: dict[str, Any] = { + "task_queue": model_params.task_queue, + "schedule_to_close_timeout": model_params.schedule_to_close_timeout, + "schedule_to_start_timeout": model_params.schedule_to_start_timeout, + "start_to_close_timeout": model_params.start_to_close_timeout, + "heartbeat_timeout": model_params.heartbeat_timeout, + "retry_policy": model_params.retry_policy, + "cancellation_type": model_params.cancellation_type, + "versioning_intent": model_params.versioning_intent, + "priority": model_params.priority, + } + + for _ in range(max_turns): + model_input = ActivityModelInput( + model=agent.model, + system_instruction=agent.system_instruction, + contents=history, + function_declarations=function_declarations, + generation_config=agent.generation_config, + ) + + if model_params.use_local_activity: + result = await temporal_workflow.execute_local_activity_method( + GeminiModelActivity.generate_content_activity, + model_input, + **{ + k: v + for k, v in model_kwargs.items() + if k + not in ( + "task_queue", + "schedule_to_start_timeout", + "versioning_intent", + ) + }, + ) + else: + result = await temporal_workflow.execute_activity_method( + GeminiModelActivity.generate_content_activity, + model_input, + **model_kwargs, + ) + + if result.function_calls: + history.append(result.model_content) + + for fc in result.function_calls: + tool = activity_tools.get(fc.name) + if tool is None: + raise GeminiAgentWorkflowError( + f"Model called unknown tool '{fc.name}'. " + f"Available tools: {list(activity_tools.keys())}" + ) + + # Extract positional args from the dict Gemini returns. + # Gemini wraps each parameter under its name, e.g.: + # get_weather_alerts(request: WeatherRequest) → {"request": {"state": "CA"}} + # list(args.values()) unwraps to [{"state": "CA"}], which Temporal + # then deserializes as the activity's Pydantic parameter type. + # For no-arg activities, args={} → dispatch_args=[] (no args passed). + dispatch_args = list(fc.args.values()) + + try: + tool_result = await temporal_workflow.execute_activity( + tool.activity_name, + args=dispatch_args, + **tool.schedule_kwargs, + ) + except Exception as exc: + raise GeminiAgentWorkflowError( + f"Tool '{fc.name}' raised an error: {exc}" + ) from exc + + try: + result_str = str(tool_result) + except Exception as exc: + raise GeminiToolSerializationError( + f"Tool '{fc.name}' returned a value that could not be converted to str" + ) from exc + + history.append( + types.Content( + role="user", + parts=[ + types.Part.from_function_response( + name=fc.name, + response={"result": result_str}, + ) + ], + ) + ) + else: + return result.text or "" + + raise GeminiAgentWorkflowError( + f"Agent exceeded maximum turns ({max_turns}) without producing a final response." + ) + + +class GeminiAgentWorkflowError(TemporalError): + """Raised when the Gemini agent loop cannot complete normally. + + .. warning:: + This exception is experimental and may change in future versions. + Use with caution in production environments. + + This is raised when: + - The agent exceeds ``max_turns`` without returning a text response. + - The model calls a tool that was not registered. + - A tool activity raises an unexpected error. + """ + + +class GeminiToolSerializationError(TemporalError): + """Raised when a tool's return value cannot be converted to a string. + + .. warning:: + This exception is experimental and may change in future versions. + Use with caution in production environments. + + All tool outputs are converted to strings before being sent back to the model. + If ``str(result)`` raises, this exception is raised instead. + """ diff --git a/uv.lock b/uv.lock index 9921726a0..c75bfe154 100644 --- a/uv.lock +++ b/uv.lock @@ -1556,7 +1556,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/3f/9859f655d11901e7b2996c6e3d33e0caa9a1d4572c3bc61ed0faa64b2f4c/greenlet-3.3.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:9bc885b89709d901859cf95179ec9f6bb67a3d2bb1f0e88456461bd4b7f8fd0d", size = 277747, upload-time = "2026-02-20T20:16:21.325Z" }, { url = "https://files.pythonhosted.org/packages/fb/07/cb284a8b5c6498dbd7cba35d31380bb123d7dceaa7907f606c8ff5993cbf/greenlet-3.3.2-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b568183cf65b94919be4438dc28416b234b678c608cafac8874dfeeb2a9bbe13", size = 579202, upload-time = "2026-02-20T20:47:28.955Z" }, { url = "https://files.pythonhosted.org/packages/ed/45/67922992b3a152f726163b19f890a85129a992f39607a2a53155de3448b8/greenlet-3.3.2-cp310-cp310-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:527fec58dc9f90efd594b9b700662ed3fb2493c2122067ac9c740d98080a620e", size = 590620, upload-time = "2026-02-20T20:55:55.581Z" }, - { url = "https://files.pythonhosted.org/packages/03/5f/6e2a7d80c353587751ef3d44bb947f0565ec008a2e0927821c007e96d3a7/greenlet-3.3.2-cp310-cp310-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:508c7f01f1791fbc8e011bd508f6794cb95397fdb198a46cb6635eb5b78d85a7", size = 602132, upload-time = "2026-02-20T21:02:43.261Z" }, { url = "https://files.pythonhosted.org/packages/ad/55/9f1ebb5a825215fadcc0f7d5073f6e79e3007e3282b14b22d6aba7ca6cb8/greenlet-3.3.2-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad0c8917dd42a819fe77e6bdfcb84e3379c0de956469301d9fd36427a1ca501f", size = 591729, upload-time = "2026-02-20T20:20:58.395Z" }, { url = "https://files.pythonhosted.org/packages/24/b4/21f5455773d37f94b866eb3cf5caed88d6cea6dd2c6e1f9c34f463cba3ec/greenlet-3.3.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:97245cc10e5515dbc8c3104b2928f7f02b6813002770cfaffaf9a6e0fc2b94ef", size = 1551946, upload-time = "2026-02-20T20:49:31.102Z" }, { url = "https://files.pythonhosted.org/packages/00/68/91f061a926abead128fe1a87f0b453ccf07368666bd59ffa46016627a930/greenlet-3.3.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8c1fdd7d1b309ff0da81d60a9688a8bd044ac4e18b250320a96fc68d31c209ca", size = 1618494, upload-time = "2026-02-20T20:21:06.541Z" }, @@ -1564,7 +1563,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f3/47/16400cb42d18d7a6bb46f0626852c1718612e35dcb0dffa16bbaffdf5dd2/greenlet-3.3.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:c56692189a7d1c7606cb794be0a8381470d95c57ce5be03fb3d0ef57c7853b86", size = 278890, upload-time = "2026-02-20T20:19:39.263Z" }, { url = "https://files.pythonhosted.org/packages/a3/90/42762b77a5b6aa96cd8c0e80612663d39211e8ae8a6cd47c7f1249a66262/greenlet-3.3.2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1ebd458fa8285960f382841da585e02201b53a5ec2bac6b156fc623b5ce4499f", size = 581120, upload-time = "2026-02-20T20:47:30.161Z" }, { url = "https://files.pythonhosted.org/packages/bf/6f/f3d64f4fa0a9c7b5c5b3c810ff1df614540d5aa7d519261b53fba55d4df9/greenlet-3.3.2-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a443358b33c4ec7b05b79a7c8b466f5d275025e750298be7340f8fc63dff2a55", size = 594363, upload-time = "2026-02-20T20:55:56.965Z" }, - { url = "https://files.pythonhosted.org/packages/9c/8b/1430a04657735a3f23116c2e0d5eb10220928846e4537a938a41b350bed6/greenlet-3.3.2-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4375a58e49522698d3e70cc0b801c19433021b5c37686f7ce9c65b0d5c8677d2", size = 605046, upload-time = "2026-02-20T21:02:45.234Z" }, { url = "https://files.pythonhosted.org/packages/72/83/3e06a52aca8128bdd4dcd67e932b809e76a96ab8c232a8b025b2850264c5/greenlet-3.3.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e2cd90d413acbf5e77ae41e5d3c9b3ac1d011a756d7284d7f3f2b806bbd6358", size = 594156, upload-time = "2026-02-20T20:20:59.955Z" }, { url = "https://files.pythonhosted.org/packages/70/79/0de5e62b873e08fe3cef7dbe84e5c4bc0e8ed0c7ff131bccb8405cd107c8/greenlet-3.3.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:442b6057453c8cb29b4fb36a2ac689382fc71112273726e2423f7f17dc73bf99", size = 1554649, upload-time = "2026-02-20T20:49:32.293Z" }, { url = "https://files.pythonhosted.org/packages/5a/00/32d30dee8389dc36d42170a9c66217757289e2afb0de59a3565260f38373/greenlet-3.3.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:45abe8eb6339518180d5a7fa47fa01945414d7cca5ecb745346fc6a87d2750be", size = 1619472, upload-time = "2026-02-20T20:21:07.966Z" }, @@ -1573,7 +1571,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ea/ab/1608e5a7578e62113506740b88066bf09888322a311cff602105e619bd87/greenlet-3.3.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:ac8d61d4343b799d1e526db579833d72f23759c71e07181c2d2944e429eb09cd", size = 280358, upload-time = "2026-02-20T20:17:43.971Z" }, { url = "https://files.pythonhosted.org/packages/a5/23/0eae412a4ade4e6623ff7626e38998cb9b11e9ff1ebacaa021e4e108ec15/greenlet-3.3.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ceec72030dae6ac0c8ed7591b96b70410a8be370b6a477b1dbc072856ad02bd", size = 601217, upload-time = "2026-02-20T20:47:31.462Z" }, { url = "https://files.pythonhosted.org/packages/f8/16/5b1678a9c07098ecb9ab2dd159fafaf12e963293e61ee8d10ecb55273e5e/greenlet-3.3.2-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a2a5be83a45ce6188c045bcc44b0ee037d6a518978de9a5d97438548b953a1ac", size = 611792, upload-time = "2026-02-20T20:55:58.423Z" }, - { url = "https://files.pythonhosted.org/packages/5c/c5/cc09412a29e43406eba18d61c70baa936e299bc27e074e2be3806ed29098/greenlet-3.3.2-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ae9e21c84035c490506c17002f5c8ab25f980205c3e61ddb3a2a2a2e6c411fcb", size = 626250, upload-time = "2026-02-20T21:02:46.596Z" }, { url = "https://files.pythonhosted.org/packages/50/1f/5155f55bd71cabd03765a4aac9ac446be129895271f73872c36ebd4b04b6/greenlet-3.3.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43e99d1749147ac21dde49b99c9abffcbc1e2d55c67501465ef0930d6e78e070", size = 613875, upload-time = "2026-02-20T20:21:01.102Z" }, { url = "https://files.pythonhosted.org/packages/fc/dd/845f249c3fcd69e32df80cdab059b4be8b766ef5830a3d0aa9d6cad55beb/greenlet-3.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4c956a19350e2c37f2c48b336a3afb4bff120b36076d9d7fb68cb44e05d95b79", size = 1571467, upload-time = "2026-02-20T20:49:33.495Z" }, { url = "https://files.pythonhosted.org/packages/2a/50/2649fe21fcc2b56659a452868e695634722a6655ba245d9f77f5656010bf/greenlet-3.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6c6f8ba97d17a1e7d664151284cb3315fc5f8353e75221ed4324f84eb162b395", size = 1640001, upload-time = "2026-02-20T20:21:09.154Z" }, @@ -1582,7 +1579,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ac/48/f8b875fa7dea7dd9b33245e37f065af59df6a25af2f9561efa8d822fde51/greenlet-3.3.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:aa6ac98bdfd716a749b84d4034486863fd81c3abde9aa3cf8eff9127981a4ae4", size = 279120, upload-time = "2026-02-20T20:19:01.9Z" }, { url = "https://files.pythonhosted.org/packages/49/8d/9771d03e7a8b1ee456511961e1b97a6d77ae1dea4a34a5b98eee706689d3/greenlet-3.3.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab0c7e7901a00bc0a7284907273dc165b32e0d109a6713babd04471327ff7986", size = 603238, upload-time = "2026-02-20T20:47:32.873Z" }, { url = "https://files.pythonhosted.org/packages/59/0e/4223c2bbb63cd5c97f28ffb2a8aee71bdfb30b323c35d409450f51b91e3e/greenlet-3.3.2-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d248d8c23c67d2291ffd47af766e2a3aa9fa1c6703155c099feb11f526c63a92", size = 614219, upload-time = "2026-02-20T20:55:59.817Z" }, - { url = "https://files.pythonhosted.org/packages/94/2b/4d012a69759ac9d77210b8bfb128bc621125f5b20fc398bce3940d036b1c/greenlet-3.3.2-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ccd21bb86944ca9be6d967cf7691e658e43417782bce90b5d2faeda0ff78a7dd", size = 628268, upload-time = "2026-02-20T21:02:48.024Z" }, { url = "https://files.pythonhosted.org/packages/7a/34/259b28ea7a2a0c904b11cd36c79b8cef8019b26ee5dbe24e73b469dea347/greenlet-3.3.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b6997d360a4e6a4e936c0f9625b1c20416b8a0ea18a8e19cabbefc712e7397ab", size = 616774, upload-time = "2026-02-20T20:21:02.454Z" }, { url = "https://files.pythonhosted.org/packages/0a/03/996c2d1689d486a6e199cb0f1cf9e4aa940c500e01bdf201299d7d61fa69/greenlet-3.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64970c33a50551c7c50491671265d8954046cb6e8e2999aacdd60e439b70418a", size = 1571277, upload-time = "2026-02-20T20:49:34.795Z" }, { url = "https://files.pythonhosted.org/packages/d9/c4/2570fc07f34a39f2caf0bf9f24b0a1a0a47bc2e8e465b2c2424821389dfc/greenlet-3.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1a9172f5bf6bd88e6ba5a84e0a68afeac9dc7b6b412b245dd64f52d83c81e55b", size = 1640455, upload-time = "2026-02-20T20:21:10.261Z" }, @@ -1591,7 +1587,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/ae/8bffcbd373b57a5992cd077cbe8858fff39110480a9d50697091faea6f39/greenlet-3.3.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:8d1658d7291f9859beed69a776c10822a0a799bc4bfe1bd4272bb60e62507dab", size = 279650, upload-time = "2026-02-20T20:18:00.783Z" }, { url = "https://files.pythonhosted.org/packages/d1/c0/45f93f348fa49abf32ac8439938726c480bd96b2a3c6f4d949ec0124b69f/greenlet-3.3.2-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18cb1b7337bca281915b3c5d5ae19f4e76d35e1df80f4ad3c1a7be91fadf1082", size = 650295, upload-time = "2026-02-20T20:47:34.036Z" }, { url = "https://files.pythonhosted.org/packages/b3/de/dd7589b3f2b8372069ab3e4763ea5329940fc7ad9dcd3e272a37516d7c9b/greenlet-3.3.2-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2e47408e8ce1c6f1ceea0dffcdf6ebb85cc09e55c7af407c99f1112016e45e9", size = 662163, upload-time = "2026-02-20T20:56:01.295Z" }, - { url = "https://files.pythonhosted.org/packages/cd/ac/85804f74f1ccea31ba518dcc8ee6f14c79f73fe36fa1beba38930806df09/greenlet-3.3.2-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e3cb43ce200f59483eb82949bf1835a99cf43d7571e900d7c8d5c62cdf25d2f9", size = 675371, upload-time = "2026-02-20T21:02:49.664Z" }, { url = "https://files.pythonhosted.org/packages/d2/d8/09bfa816572a4d83bccd6750df1926f79158b1c36c5f73786e26dbe4ee38/greenlet-3.3.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63d10328839d1973e5ba35e98cccbca71b232b14051fd957b6f8b6e8e80d0506", size = 664160, upload-time = "2026-02-20T20:21:04.015Z" }, { url = "https://files.pythonhosted.org/packages/48/cf/56832f0c8255d27f6c35d41b5ec91168d74ec721d85f01a12131eec6b93c/greenlet-3.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e4ab3cfb02993c8cc248ea73d7dae6cec0253e9afa311c9b37e603ca9fad2ce", size = 1619181, upload-time = "2026-02-20T20:49:36.052Z" }, { url = "https://files.pythonhosted.org/packages/0a/23/b90b60a4aabb4cec0796e55f25ffbfb579a907c3898cd2905c8918acaa16/greenlet-3.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:94ad81f0fd3c0c0681a018a976e5c2bd2ca2d9d94895f23e7bb1af4e8af4e2d5", size = 1687713, upload-time = "2026-02-20T20:21:11.684Z" }, @@ -1600,7 +1595,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/98/6d/8f2ef704e614bcf58ed43cfb8d87afa1c285e98194ab2cfad351bf04f81e/greenlet-3.3.2-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:e26e72bec7ab387ac80caa7496e0f908ff954f31065b0ffc1f8ecb1338b11b54", size = 286617, upload-time = "2026-02-20T20:19:29.856Z" }, { url = "https://files.pythonhosted.org/packages/5e/0d/93894161d307c6ea237a43988f27eba0947b360b99ac5239ad3fe09f0b47/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b466dff7a4ffda6ca975979bab80bdadde979e29fc947ac3be4451428d8b0e4", size = 655189, upload-time = "2026-02-20T20:47:35.742Z" }, { url = "https://files.pythonhosted.org/packages/f5/2c/d2d506ebd8abcb57386ec4f7ba20f4030cbe56eae541bc6fd6ef399c0b41/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b8bddc5b73c9720bea487b3bffdb1840fe4e3656fba3bd40aa1489e9f37877ff", size = 658225, upload-time = "2026-02-20T20:56:02.527Z" }, - { url = "https://files.pythonhosted.org/packages/d1/67/8197b7e7e602150938049d8e7f30de1660cfb87e4c8ee349b42b67bdb2e1/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:59b3e2c40f6706b05a9cd299c836c6aa2378cabe25d021acd80f13abf81181cf", size = 666581, upload-time = "2026-02-20T21:02:51.526Z" }, { url = "https://files.pythonhosted.org/packages/8e/30/3a09155fbf728673a1dea713572d2d31159f824a37c22da82127056c44e4/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b26b0f4428b871a751968285a1ac9648944cea09807177ac639b030bddebcea4", size = 657907, upload-time = "2026-02-20T20:21:05.259Z" }, { url = "https://files.pythonhosted.org/packages/f3/fd/d05a4b7acd0154ed758797f0a43b4c0962a843bedfe980115e842c5b2d08/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1fb39a11ee2e4d94be9a76671482be9398560955c9e568550de0224e41104727", size = 1618857, upload-time = "2026-02-20T20:49:37.309Z" }, { url = "https://files.pythonhosted.org/packages/6f/e1/50ee92a5db521de8f35075b5eff060dd43d39ebd46c2181a2042f7070385/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:20154044d9085151bc309e7689d6f7ba10027f8f5a8c0676ad398b951913d89e", size = 1680010, upload-time = "2026-02-20T20:21:13.427Z" }, @@ -4282,6 +4276,9 @@ dependencies = [ google-adk = [ { name = "google-adk" }, ] +google-gemini = [ + { name = "google-genai" }, +] grpc = [ { name = "grpcio" }, ] @@ -4329,6 +4326,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "google-adk", marker = "extra == 'google-adk'", specifier = ">=1.27.0,<2" }, + { name = "google-genai", marker = "extra == 'google-gemini'", specifier = ">=1.66.0" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, { name = "mcp", marker = "extra == 'openai-agents'", specifier = ">=1.9.4,<2" }, { name = "nexus-rpc", specifier = "==1.4.0" }, @@ -4341,7 +4339,7 @@ requires-dist = [ { name = "types-protobuf", specifier = ">=3.20" }, { name = "typing-extensions", specifier = ">=4.2.0,<5" }, ] -provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk"] +provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "google-gemini"] [package.metadata.requires-dev] dev = [ From 8f5e77a6829e7a3bf322f9f62741d87511c4f6c6 Mon Sep 17 00:00:00 2001 From: Jason Steving Date: Wed, 18 Mar 2026 12:01:01 -0700 Subject: [PATCH 2/6] Rewrite Gemini SDK integration to intercept at the HTTP transport layer Replace the previous approach (wrapping generate_content in a Temporal activity with a manual agentic loop) with HTTP-level interception that lets the Gemini SDK's native automatic function calling (AFC) drive the conversation. TemporalHttpxClient overrides httpx.AsyncClient.send() so every HTTP call the Gemini SDK makes becomes a durable Temporal activity, recorded in the workflow event history and subject to Temporal retry/timeout semantics. activity_as_tool() converts @activity.defn functions into Gemini-compatible callables dispatched via workflow.execute_activity, making each tool invocation independently durable. Credentials (x-goog-api-key) are stripped from serialized requests before they reach event history and re-injected from os.environ inside the activity. OAuth/authorization headers are intentionally left in place since they are short-lived and cannot be reconstructed. The package __init__.py uses lazy imports for all httpx-dependent symbols (GeminiPlugin, TemporalHttpxClient, temporal_http_options) so that sandbox-safe imports like activity_as_tool never trigger an httpx load. GeminiPlugin stores the pre-built genai.Client in a passthrough'd module so workflows can retrieve it without os.environ access. --- .../contrib/google_gemini_sdk/__init__.py | 101 +++++-- .../google_gemini_sdk/_client_store.py | 43 +++ .../google_gemini_sdk/_gemini_plugin.py | 80 ++--- .../google_gemini_sdk/_heartbeat_decorator.py | 40 --- .../google_gemini_sdk/_http_activity.py | 94 ++++++ .../_invoke_model_activity.py | 182 ----------- .../_model_activity_parameters.py | 52 ---- .../_temporal_httpx_client.py | 207 +++++++++++++ .../first_class_example/worker.py | 55 +++- .../contrib/google_gemini_sdk/workflow.py | 284 +++--------------- 10 files changed, 552 insertions(+), 586 deletions(-) create mode 100644 temporalio/contrib/google_gemini_sdk/_client_store.py delete mode 100644 temporalio/contrib/google_gemini_sdk/_heartbeat_decorator.py create mode 100644 temporalio/contrib/google_gemini_sdk/_http_activity.py delete mode 100644 temporalio/contrib/google_gemini_sdk/_invoke_model_activity.py delete mode 100644 temporalio/contrib/google_gemini_sdk/_model_activity_parameters.py create mode 100644 temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py diff --git a/temporalio/contrib/google_gemini_sdk/__init__.py b/temporalio/contrib/google_gemini_sdk/__init__.py index a9a9d3b88..af04a991a 100644 --- a/temporalio/contrib/google_gemini_sdk/__init__.py +++ b/temporalio/contrib/google_gemini_sdk/__init__.py @@ -4,54 +4,103 @@ This module is experimental and may change in future versions. Use with caution in production environments. +This integration lets you use the Gemini SDK **exactly as you normally would** +while making every network call and every tool invocation **durable Temporal +activities**. + +- :func:`temporal_http_options` — wrap ``genai.Client`` so all HTTP calls + (model calls, streaming responses) are routed through a Temporal activity. +- :func:`activity_as_tool` — convert any ``@activity.defn`` function into a + Gemini tool callable; Gemini's AFC invokes it as a Temporal activity. +- :class:`GeminiPlugin` — stores the pre-built client, registers the HTTP + transport activity, and configures the worker. +- :func:`get_gemini_client` — retrieve the client inside a workflow. + Quickstart:: - from temporalio.contrib.google_gemini_sdk import ( - GeminiAgent, - GeminiPlugin, - activity_as_tool, - run_agent, + # ---- worker setup (outside sandbox) ---- + gemini_client = genai.Client( + api_key=os.environ["GOOGLE_API_KEY"], + http_options=temporal_http_options( + start_to_close_timeout=timedelta(seconds=60), + ), ) + plugin = GeminiPlugin(gemini_client=gemini_client) + @activity.defn + async def get_weather(state: str) -> str: ... + + # ---- workflow (sandbox-safe) ---- @workflow.defn - class MyAgentWorkflow: + class AgentWorkflow: @workflow.run async def run(self, query: str) -> str: - return await run_agent( - GeminiAgent( - model="gemini-2.5-flash", - system_instruction="You are a helpful assistant.", + client = get_gemini_client() + response = await client.aio.models.generate_content( + model="gemini-2.5-flash", + contents=query, + config=types.GenerateContentConfig( tools=[ - activity_as_tool(my_tool, start_to_close_timeout=timedelta(seconds=30)), + activity_as_tool( + get_weather, + start_to_close_timeout=timedelta(seconds=30), + ), ], ), - query, ) + return response.text """ -from temporalio.contrib.google_gemini_sdk._gemini_plugin import GeminiPlugin -from temporalio.contrib.google_gemini_sdk._model_activity_parameters import ( - ModelActivityParameters, -) +# --- Sandbox-safe imports (loaded eagerly) --- +# These modules have NO httpx / google.genai imports and are safe to load +# inside the Temporal workflow sandbox. +from temporalio.contrib.google_gemini_sdk._client_store import get_gemini_client from temporalio.contrib.google_gemini_sdk.workflow import ( - ActivityTool, - GeminiAgent, GeminiAgentWorkflowError, GeminiToolSerializationError, activity_as_tool, - run_agent, ) -from temporalio.contrib.google_gemini_sdk import testing, workflow __all__ = [ - "ActivityTool", - "GeminiAgent", "GeminiAgentWorkflowError", "GeminiPlugin", "GeminiToolSerializationError", - "ModelActivityParameters", + "TemporalHttpxClient", "activity_as_tool", - "run_agent", - "testing", - "workflow", + "get_gemini_client", + "temporal_http_options", ] + + +# --- Lazy imports for httpx-dependent symbols --- +# GeminiPlugin, TemporalHttpxClient, and temporal_http_options all transitively +# import httpx (via _gemini_plugin → _http_activity → httpx, and via +# _temporal_httpx_client → httpx). They must NOT be loaded inside the workflow +# sandbox. They are imported lazily so that sandbox-safe imports like +# ``from temporalio.contrib.google_gemini_sdk import activity_as_tool`` +# never trigger an httpx import. +def __getattr__(name: str): # type: ignore[override] + _lazy = { + "GeminiPlugin": ( + "temporalio.contrib.google_gemini_sdk._gemini_plugin", + "GeminiPlugin", + ), + "TemporalHttpxClient": ( + "temporalio.contrib.google_gemini_sdk._temporal_httpx_client", + "TemporalHttpxClient", + ), + "temporal_http_options": ( + "temporalio.contrib.google_gemini_sdk._temporal_httpx_client", + "temporal_http_options", + ), + } + if name in _lazy: + import importlib + + module_path, attr = _lazy[name] + mod = importlib.import_module(module_path) + value = getattr(mod, attr) + # Cache on the module so __getattr__ is only called once per name. + globals()[name] = value + return value + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/temporalio/contrib/google_gemini_sdk/_client_store.py b/temporalio/contrib/google_gemini_sdk/_client_store.py new file mode 100644 index 000000000..c7120170c --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_client_store.py @@ -0,0 +1,43 @@ +"""Shared storage for the pre-built ``genai.Client`` instance. + +This module is added to the Temporal sandbox passthrough list by +:class:`~temporalio.contrib.google_gemini_sdk.GeminiPlugin`. Because it is +passthrough'd, the module-level ``_gemini_client`` variable is shared between +the real runtime (where the worker sets it) and the sandboxed workflow (where +:func:`get_gemini_client` reads it). + +This module intentionally has **no** ``httpx`` or ``google.genai`` imports so +that it can also be loaded safely by the sandbox's restricted importer if the +passthrough hasn't been configured yet. +""" + +from __future__ import annotations + +from typing import Any + +# Set by GeminiPlugin.__init__ before the worker starts. +_gemini_client: Any = None + + +def get_gemini_client() -> Any: + """Return the ``genai.Client`` stored by :class:`GeminiPlugin`. + + .. warning:: + This function is experimental and may change in future versions. + Use with caution in production environments. + + Call this inside a workflow to obtain the pre-built Gemini client that was + passed to :class:`~temporalio.contrib.google_gemini_sdk.GeminiPlugin` at + worker setup time. The client is created **once** outside the workflow + sandbox, so ``os.environ`` access, SSL cert loading, etc. happen at startup + — not during workflow execution. + + Raises: + RuntimeError: If no client has been configured (i.e. ``GeminiPlugin`` + was not initialised with a ``gemini_client``). + """ + if _gemini_client is None: + raise RuntimeError( + "No Gemini client configured. Pass gemini_client= to GeminiPlugin()." + ) + return _gemini_client diff --git a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py index bf8e0a979..1d29f48fa 100644 --- a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py +++ b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py @@ -3,16 +3,10 @@ from __future__ import annotations import dataclasses -from collections.abc import Callable +from typing import Any -from google import genai - -from temporalio.contrib.google_gemini_sdk._invoke_model_activity import ( - GeminiModelActivity, -) -from temporalio.contrib.google_gemini_sdk._model_activity_parameters import ( - ModelActivityParameters, -) +from temporalio.contrib.google_gemini_sdk import _client_store +from temporalio.contrib.google_gemini_sdk._http_activity import gemini_api_call from temporalio.contrib.google_gemini_sdk.workflow import GeminiAgentWorkflowError from temporalio.contrib.pydantic import ( PydanticPayloadConverter as _DefaultPydanticPayloadConverter, @@ -30,43 +24,41 @@ class GeminiPlugin(SimplePlugin): This class is experimental and may change in future versions. Use with caution in production environments. - This plugin configures: - - Pydantic Payload Converter (required for Gemini SDK types). - - Sandbox passthrough for ``google.genai`` and ``google.api_core`` modules. - - The ``generate_content_activity`` model invocation activity. - - ``GeminiAgentWorkflowError`` as a workflow failure exception type. + This plugin: - Example: - >>> plugin = GeminiPlugin() - >>> client = await Client.connect("localhost:7233", plugins=[plugin]) - >>> async with Worker( - ... client, - ... task_queue="my-queue", - ... workflows=[MyAgentWorkflow], - ... activities=[my_tool_activity], - ... ): - ... await asyncio.Event().wait() + - Stores the pre-built ``genai.Client`` so that workflows can retrieve it + via :func:`~temporalio.contrib.google_gemini_sdk.get_gemini_client` + without accessing ``os.environ`` or creating heavy objects in the sandbox. + - Registers ``gemini_api_call`` — the durable HTTP transport invoked + by :class:`~temporalio.contrib.google_gemini_sdk.TemporalHttpxClient`. + - Configures the Pydantic data converter and sandbox passthrough modules. + + Example:: + + gemini_client = genai.Client( + api_key=os.environ["GOOGLE_API_KEY"], + http_options=temporal_http_options( + start_to_close_timeout=timedelta(seconds=60), + ), + ) + plugin = GeminiPlugin(gemini_client=gemini_client) + client = await Client.connect("localhost:7233", plugins=[plugin]) """ def __init__( self, - model_params: ModelActivityParameters | None = None, - client_factory: Callable[[], genai.Client] | None = None, - _model_activity: GeminiModelActivity | None = None, + gemini_client: Any, ) -> None: """Initialize the Gemini plugin. Args: - model_params: Optional default parameters for model activity execution. - Currently accepted but not applied automatically; pass ``model_params`` - directly to :func:`~temporalio.contrib.google_gemini_sdk.workflow.run_agent`. - client_factory: Optional factory function for creating the Gemini client. - Defaults to reading ``GOOGLE_API_KEY`` from the environment. - _model_activity: Internal override for testing. Prefer using - :class:`~temporalio.contrib.google_gemini_sdk.testing.GeminiEnvironment` - instead of setting this directly. + gemini_client: A pre-built ``genai.Client`` instance. Create it at + worker startup (where ``os.environ`` is available) with + ``http_options=temporal_http_options(...)`` so that its HTTP + calls are routed through Temporal activities. """ - model_activity = _model_activity or GeminiModelActivity(client_factory) + # Store the client in the passthrough'd module so the sandbox can see it. + _client_store._gemini_client = gemini_client def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: if not runner: @@ -75,7 +67,19 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: return dataclasses.replace( runner, restrictions=runner.restrictions.with_passthrough_modules( - "google.genai", "google.api_core" + # google.genai — so the workflow can call methods on the + # pre-built genai.Client and use types.GenerateContentConfig + "google.genai", + "google.api_core", + # pydantic internals — avoids "imported after initial + # workflow load" warnings when google.genai types + # trigger lazy pydantic schema compilation. + "pydantic_core", + "pydantic", + "annotated_types", + # The client store module — passthrough'd so the sandbox + # sees the same _gemini_client reference the worker set. + "temporalio.contrib.google_gemini_sdk._client_store", ), ) return runner @@ -83,7 +87,7 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: super().__init__( name="GeminiPlugin", data_converter=self._configure_data_converter, - activities=[model_activity.generate_content_activity], + activities=[gemini_api_call], workflow_runner=workflow_runner, workflow_failure_exception_types=[GeminiAgentWorkflowError], ) diff --git a/temporalio/contrib/google_gemini_sdk/_heartbeat_decorator.py b/temporalio/contrib/google_gemini_sdk/_heartbeat_decorator.py deleted file mode 100644 index 4baff6706..000000000 --- a/temporalio/contrib/google_gemini_sdk/_heartbeat_decorator.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -from collections.abc import Awaitable, Callable -from functools import wraps -from typing import Any, TypeVar, cast - -from temporalio import activity - -F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) - - -def _auto_heartbeater(fn: F) -> F: # type:ignore[reportUnusedClass] - # Propagate type hints from the original callable. - @wraps(fn) - async def wrapper(*args: Any, **kwargs: Any) -> Any: - heartbeat_timeout = activity.info().heartbeat_timeout - heartbeat_task = None - if heartbeat_timeout: - # Heartbeat twice as often as the timeout - heartbeat_task = asyncio.create_task( - heartbeat_every(heartbeat_timeout.total_seconds() / 2) - ) - try: - return await fn(*args, **kwargs) - finally: - if heartbeat_task: - heartbeat_task.cancel() - # Wait for heartbeat cancellation to complete - try: - await heartbeat_task - except asyncio.CancelledError: - pass - - return cast(F, wrapper) - - -async def heartbeat_every(delay: float, *details: Any) -> None: - """Heartbeat every so often while not cancelled""" - while True: - await asyncio.sleep(delay) - activity.heartbeat(*details) diff --git a/temporalio/contrib/google_gemini_sdk/_http_activity.py b/temporalio/contrib/google_gemini_sdk/_http_activity.py new file mode 100644 index 000000000..9a094df1e --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_http_activity.py @@ -0,0 +1,94 @@ +"""Temporal activity for executing HTTP requests on behalf of the Gemini SDK. + +The Gemini SDK makes HTTP calls internally through an httpx.AsyncClient. +:class:`~temporalio.contrib.google_gemini_sdk.workflow.TemporalHttpxClient` +overrides that client's ``send()`` method to dispatch through this activity +instead of making direct network calls. This ensures every Gemini model call +is durably recorded in Temporal's workflow event history and benefits from +Temporal's retry/timeout semantics. +""" + +from __future__ import annotations + +import httpx +from pydantic import BaseModel + +from temporalio import activity + + +class HttpRequestData(BaseModel): + """Serialized form of an httpx.Request for transport across the activity boundary.""" + + method: str + url: str + headers: dict[str, str] + content: bytes + timeout: float | None = None + + +class HttpResponseData(BaseModel): + """Serialized form of an httpx.Response for transport across the activity boundary.""" + + status_code: int + headers: dict[str, str] + content: bytes + + +@activity.defn +async def gemini_api_call(req: HttpRequestData) -> HttpResponseData: + """Execute an HTTP request and return the serialized response. + + .. warning:: + This activity is experimental and may change in future versions. + Use with caution in production environments. + + This activity is registered automatically by + :class:`~temporalio.contrib.google_gemini_sdk.GeminiPlugin` and is invoked + by :class:`~temporalio.contrib.google_gemini_sdk.workflow.TemporalHttpxClient` + whenever the Gemini SDK needs to make a network call from within a workflow. + + Do not call this activity directly. + """ + import os + + # ── Credential re-injection ────────────────────────────────────────── + # TemporalHttpxClient.send() (in _temporal_httpx_client.py) strips the + # "x-goog-api-key" header from the serialized request so that the API + # key never appears in Temporal's event history. Here — inside the + # activity, which runs on the worker outside the workflow sandbox — we + # read the API key from the worker's environment and put it back before + # making the real HTTP call. + # + # NOTE: The "authorization" header (used by Vertex AI OAuth flows) is + # NOT stripped because those tokens are short-lived and refreshed + # per-request by the SDK — we cannot reconstruct them here. + # + # This means the API key must be set in the worker's environment + # (GOOGLE_API_KEY or GEMINI_API_KEY) for API-key auth to work. + # ───────────────────────────────────────────────────────────────────── + headers = dict(req.headers) + api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") + if api_key and "x-goog-api-key" not in headers: + headers["x-goog-api-key"] = api_key + + async with httpx.AsyncClient() as client: + response = await client.request( + method=req.method, + url=req.url, + headers=headers, + content=req.content, + timeout=req.timeout, + ) + # response.content is already decoded (decompressed) by httpx. + # Strip content-encoding so the reconstructed Response on the + # workflow side doesn't try to decompress it a second time. + headers = { + k: v + for k, v in response.headers.items() + if k.lower() not in ("content-encoding", "content-length") + } + return HttpResponseData( + status_code=response.status_code, + headers=headers, + content=response.content, + ) diff --git a/temporalio/contrib/google_gemini_sdk/_invoke_model_activity.py b/temporalio/contrib/google_gemini_sdk/_invoke_model_activity.py deleted file mode 100644 index 4bd224c3d..000000000 --- a/temporalio/contrib/google_gemini_sdk/_invoke_model_activity.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Gemini model invocation activity for Temporal workflows.""" - -from __future__ import annotations - -import os -from collections.abc import Callable -from typing import Any - -from pydantic import BaseModel - -from google import genai -from google.genai import types - -from temporalio import activity -from temporalio.exceptions import ApplicationError - -from temporalio.contrib.google_gemini_sdk._heartbeat_decorator import _auto_heartbeater - - -class FunctionCallOutput(BaseModel): - """A single function call returned by the model.""" - - name: str - args: dict[str, Any] - - -class ActivityModelInput(BaseModel): - """Input for the Gemini model invocation activity.""" - - model: str - system_instruction: str | None = None - contents: list[types.Content] - function_declarations: list[types.FunctionDeclaration] = [] - generation_config: types.GenerateContentConfig | None = None - - -class ActivityModelOutput(BaseModel): - """Output from the Gemini model invocation activity.""" - - text: str | None - function_calls: list[FunctionCallOutput] - model_content: types.Content # Model turn to append to conversation history - - -def _default_client_factory() -> genai.Client: - try: - api_key = os.environ["GOOGLE_API_KEY"] - except KeyError: - raise ApplicationError( - "GOOGLE_API_KEY environment variable is not set", - non_retryable=True, - ) - return genai.Client( - api_key=api_key, - http_options=types.HttpOptions( - retry_options=types.HttpRetryOptions(attempts=1), - ), - ) - - -def _map_google_exception(exc: Exception) -> None: - """Map google.api_core exceptions to ApplicationError with correct retryability.""" - try: - from google.api_core import exceptions as google_exceptions - except ImportError: - return - - if isinstance(exc, google_exceptions.ResourceExhausted): - raise ApplicationError( - str(exc), - type="ResourceExhausted", - non_retryable=False, - ) from exc - elif isinstance( - exc, - ( - google_exceptions.DeadlineExceeded, - google_exceptions.ServiceUnavailable, - google_exceptions.InternalServerError, - ), - ): - raise ApplicationError( - str(exc), type=type(exc).__name__, non_retryable=False - ) from exc - elif isinstance( - exc, - ( - google_exceptions.InvalidArgument, - google_exceptions.PermissionDenied, - google_exceptions.NotFound, - ), - ): - raise ApplicationError( - str(exc), type=type(exc).__name__, non_retryable=True - ) from exc - - -class GeminiModelActivity: - """Temporal activity class for invoking the Gemini model. - - .. warning:: - This class is experimental and may change in future versions. - Use with caution in production environments. - """ - - def __init__( - self, - client_factory: Callable[[], genai.Client] | None = None, - ) -> None: - self._client_factory = client_factory or _default_client_factory - - @activity.defn - @_auto_heartbeater - async def generate_content_activity( - self, input: ActivityModelInput - ) -> ActivityModelOutput: - """Invoke the Gemini model and return text and/or function calls.""" - try: - client = self._client_factory() - - # Merge user's generation_config with required agent settings. - # AFC must be disabled so Temporal owns tool execution. - base_config = ( - input.generation_config - if input.generation_config is not None - else types.GenerateContentConfig() - ) - config = base_config.model_copy( - update={ - "system_instruction": input.system_instruction, - "tools": ( - [types.Tool(function_declarations=input.function_declarations)] - if input.function_declarations - else None - ), - "automatic_function_calling": types.AutomaticFunctionCallingConfig( - disable=True - ), - } - ) - - response = await client.aio.models.generate_content( - model=input.model, - contents=input.contents, - config=config, - ) - - function_calls: list[FunctionCallOutput] = [] - text_parts: list[str] = [] - model_content: types.Content - - if response.candidates and response.candidates[0].content: - model_content = response.candidates[0].content - for part in model_content.parts: - if part.function_call: - function_calls.append( - FunctionCallOutput( - name=part.function_call.name, - args=( - dict(part.function_call.args) - if part.function_call.args - else {} - ), - ) - ) - elif part.text: - text_parts.append(part.text) - else: - model_content = types.Content(role="model", parts=[]) - - # Only include text if there are no function calls (avoids SDK warning) - text = "".join(text_parts) if text_parts and not function_calls else None - - return ActivityModelOutput( - text=text, - function_calls=function_calls, - model_content=model_content, - ) - - except Exception as exc: - _map_google_exception(exc) - raise diff --git a/temporalio/contrib/google_gemini_sdk/_model_activity_parameters.py b/temporalio/contrib/google_gemini_sdk/_model_activity_parameters.py deleted file mode 100644 index 7ef728510..000000000 --- a/temporalio/contrib/google_gemini_sdk/_model_activity_parameters.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Parameters for configuring Temporal activity execution for Gemini model calls.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from datetime import timedelta - -from temporalio.common import Priority, RetryPolicy -from temporalio.workflow import ActivityCancellationType, VersioningIntent - - -@dataclass -class ModelActivityParameters: - """Parameters for configuring Temporal activity execution for Gemini model calls. - - .. warning:: - This class is experimental and may change in future versions. - Use with caution in production environments. - - This class encapsulates all the parameters that can be used to configure - how Temporal activities are executed when making Gemini model calls. - """ - - task_queue: str | None = None - """Specific task queue to use for model activities.""" - - schedule_to_close_timeout: timedelta | None = None - """Maximum time from scheduling to completion.""" - - schedule_to_start_timeout: timedelta | None = None - """Maximum time from scheduling to starting.""" - - start_to_close_timeout: timedelta | None = timedelta(seconds=60) - """Maximum time for the activity to complete.""" - - heartbeat_timeout: timedelta | None = None - """Maximum time between heartbeats.""" - - retry_policy: RetryPolicy | None = None - """Policy for retrying failed activities.""" - - cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL - """How the activity handles cancellation.""" - - versioning_intent: VersioningIntent | None = None - """Versioning intent for the activity.""" - - priority: Priority = field(default_factory=lambda: Priority.default) - """Priority for the activity execution.""" - - use_local_activity: bool = False - """Whether to use a local activity. Changing mid-workflow breaks determinism.""" diff --git a/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py b/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py new file mode 100644 index 000000000..019834d66 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py @@ -0,0 +1,207 @@ +"""Temporal-aware httpx client and helpers. + +This module imports ``httpx`` and must **never** be loaded inside the Temporal +workflow sandbox. It is imported lazily by ``__init__.py`` and at module level +in worker code (outside the sandbox). +""" + +from __future__ import annotations + +from datetime import timedelta +from typing import Any + +import httpx +from google.genai import types + +from temporalio import workflow as temporal_workflow +from temporalio.common import RetryPolicy + +from temporalio.contrib.google_gemini_sdk._http_activity import ( + HttpRequestData, + gemini_api_call, +) + + +class _NoOpAsyncTransport(httpx.AsyncBaseTransport): + """Placeholder transport; never called because TemporalHttpxClient.send() intercepts all requests.""" + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + raise RuntimeError( + "_NoOpAsyncTransport.handle_async_request() should never be " + "called. All requests are routed through " + "workflow.execute_activity via TemporalHttpxClient.send()." + ) + + +class TemporalHttpxClient(httpx.AsyncClient): + """An ``httpx.AsyncClient`` that routes all HTTP calls through a Temporal activity. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + Pass an instance to ``genai.Client(http_options=types.HttpOptions(httpx_async_client=...))``. + Every HTTP request the Gemini SDK makes (model calls, streaming responses, etc.) + is dispatched as + :func:`~temporalio.contrib.google_gemini_sdk._http_activity.gemini_api_call`, + making it durable and visible in the workflow event history. + + The :func:`temporal_http_options` helper constructs this for you. + + Args: + start_to_close_timeout: Maximum time for a single HTTP request activity. + schedule_to_close_timeout: Maximum time from scheduling to completion. + heartbeat_timeout: Maximum time between activity heartbeats. + retry_policy: Retry policy for failed HTTP request activities. + """ + + def __init__( + self, + *, + start_to_close_timeout: timedelta | None = timedelta(seconds=60), + schedule_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + ) -> None: + # Use a no-op transport to avoid SSL cert file I/O at construction time. + # The transport is never invoked because send() is fully overridden. + super().__init__(transport=_NoOpAsyncTransport()) + self._activity_kwargs: dict[str, Any] = { + "start_to_close_timeout": start_to_close_timeout, + "schedule_to_close_timeout": schedule_to_close_timeout, + "heartbeat_timeout": heartbeat_timeout, + "retry_policy": retry_policy, + } + + async def send( + self, + request: httpx.Request, + *, + stream: bool = False, + **kwargs: Any, + ) -> httpx.Response: + """Dispatch the HTTP request as a Temporal activity instead of making a direct call.""" + try: + content = request.content + except httpx.RequestNotRead: + content = b"" + + # Extract the timeout the SDK set on this request. httpx stores it + # as request.extensions["timeout"] (a dict with pool/connect/read/write + # keys). We pass the read timeout through so the activity's httpx + # client doesn't use the default 5 s. + timeout: float | None = None + timeout_ext = request.extensions.get("timeout", {}) + if isinstance(timeout_ext, dict): + # Prefer read timeout (waiting for response), fall back to pool + timeout = timeout_ext.get("read") or timeout_ext.get("pool") + + # ── Credential stripping ────────────────────────────────────────── + # The Gemini SDK injects "x-goog-api-key" into every outgoing request + # at genai.Client construction time. Because the serialized + # HttpRequestData is stored in Temporal's event history (visible in the + # Temporal UI and accessible to anyone with namespace read access), we + # strip this header so the API key is never persisted. + # + # The matching activity (gemini_api_call) re-injects the API key + # from os.environ on the worker side before making the real HTTP call. + # See _http_activity.py for the other half. + # + # NOTE: We intentionally do NOT strip the "authorization" header + # (used by Vertex AI with OAuth / service-account credentials). + # OAuth tokens are short-lived and refreshed per-request by the SDK; + # we cannot re-inject them in the activity. Vertex AI users should + # be aware that bearer tokens will appear in event history — use + # Temporal's payload codec / encryption if this is a concern. + # ───────────────────────────────────────────────────────────────────── + headers = { + k: v + for k, v in request.headers.items() + if k.lower() != "x-goog-api-key" + } + + req_data = HttpRequestData( + method=request.method, + url=str(request.url), + headers=headers, + content=content, + timeout=timeout, + ) + + resp_data = await temporal_workflow.execute_activity( + gemini_api_call, + req_data, + **self._activity_kwargs, + ) + + return httpx.Response( + status_code=resp_data.status_code, + headers=resp_data.headers, + content=resp_data.content, + request=request, + ) + + +def temporal_http_options( + *, + start_to_close_timeout: timedelta = timedelta(seconds=60), + schedule_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, +) -> types.HttpOptions: + """Create ``HttpOptions`` that route all Gemini SDK HTTP calls through Temporal. + + .. warning:: + This API is experimental and may change in future versions. + Use with caution in production environments. + + Pass the result to ``genai.Client(http_options=...)`` so that every model + call made from within a Temporal workflow is durably recorded as a Temporal + activity. + + Create the ``genai.Client`` at **module level** (outside the workflow + sandbox) so that ``os.environ`` is read at import time, not inside the + workflow. The workflow then references the pre-built client. + + Gemini SDK retries are disabled (``attempts=1``) because Temporal owns the + retry policy instead. + + Args: + start_to_close_timeout: Maximum time for a single HTTP request activity. + Defaults to 60 seconds. + schedule_to_close_timeout: Maximum time from scheduling to completion. + heartbeat_timeout: Maximum time between heartbeats. + retry_policy: Retry policy for failed HTTP request activities. + + Returns: + A :class:`~google.genai.types.HttpOptions` instance with a + :class:`TemporalHttpxClient` set as the async HTTP backend. + + Example:: + + # ---- module level (outside workflow sandbox) ---- + gemini_client = genai.Client( + api_key=os.environ["GOOGLE_API_KEY"], + http_options=temporal_http_options( + start_to_close_timeout=timedelta(seconds=60), + ), + ) + + # ---- inside the workflow ---- + @workflow.defn + class AgentWorkflow: + @workflow.run + async def run(self, query: str) -> str: + response = await gemini_client.aio.models.generate_content(...) + return response.text + """ + return types.HttpOptions( + httpx_async_client=TemporalHttpxClient( + start_to_close_timeout=start_to_close_timeout, + schedule_to_close_timeout=schedule_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + retry_policy=retry_policy, + ), + # Temporal owns retries; disable SDK-level retries to avoid interference. + retry_options=types.HttpRetryOptions(attempts=1), + ) diff --git a/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py index dfd2effb4..d927d3134 100644 --- a/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py +++ b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py @@ -1,9 +1,17 @@ # ABOUTME: First-class Temporal + Gemini SDK integration demo. -# Demonstrates the clean developer experience: 3-line workflow, no manual loop, -# no dynamic activities, no tool registry, and no inspect hackery. +# +# Key differences from example/durable_agent_worker.py: +# - AFC is ENABLED: Gemini's SDK owns the agentic loop, no manual while-True +# - Tools are plain @activity.defn functions; no registry, no dynamic activities +# - activity_as_tool() makes each tool call a durable Temporal activity +# - temporal_http_options() makes each model HTTP call a durable Temporal activity +# - genai.Client is created once in main(), stored via GeminiPlugin +# - Workflow retrieves it with get_gemini_client() — no os.environ in sandbox +# - No run_agent(), no inspect hackery, no print() logging import asyncio import json +import os from datetime import timedelta from dotenv import load_dotenv @@ -15,8 +23,13 @@ with workflow.unsafe.imports_passed_through(): import httpx + from google.genai import types -from temporalio.contrib.google_gemini_sdk import GeminiAgent, GeminiPlugin, activity_as_tool, run_agent +from temporalio.contrib.google_gemini_sdk import ( + GeminiPlugin, + activity_as_tool, + get_gemini_client, +) # ============================================================================= @@ -91,7 +104,7 @@ async def get_location_info(request: GetLocationRequest) -> str: # ============================================================================= -# Workflow — 3 lines of real logic, no manual loop, no print() debugging +# Workflow — natural Gemini SDK usage; AFC drives the loop; all calls are durable # ============================================================================= TASK_QUEUE = "gemini-first-class" @@ -99,13 +112,22 @@ async def get_location_info(request: GetLocationRequest) -> str: @workflow.defn class WeatherAgentWorkflow: - """Durable agentic workflow powered by Gemini SDK and Temporal.""" + """Durable agentic workflow powered by Gemini SDK and Temporal. + + The Gemini SDK's automatic function calling (AFC) drives the multi-turn + agentic loop. temporal_http_options() ensures every model HTTP call is a + durable Temporal activity. activity_as_tool() ensures every tool invocation + is also a durable Temporal activity. Together, every step of the agentic + loop is visible in the workflow event history and recoverable after a crash. + """ @workflow.run async def run(self, query: str) -> str: - return await run_agent( - GeminiAgent( - model="gemini-2.5-flash", + client = get_gemini_client() + response = await client.aio.models.generate_content( + model="gemini-2.5-flash", + contents=query, + config=types.GenerateContentConfig( system_instruction=SYSTEM_INSTRUCTIONS, tools=[ activity_as_tool( @@ -122,19 +144,30 @@ async def run(self, query: str) -> str: ), ], ), - query, ) + return response.text # ============================================================================= -# Worker — register plugin + user activities, nothing else required +# Worker — create client once, pass to plugin, start worker # ============================================================================= async def main() -> None: load_dotenv() - plugin = GeminiPlugin() + # Import here so it's outside the sandbox (only main() runs outside). + from google.genai import Client as GeminiClient + from temporalio.contrib.google_gemini_sdk import temporal_http_options + + gemini_client = GeminiClient( + api_key=os.environ["GOOGLE_API_KEY"], + http_options=temporal_http_options( + start_to_close_timeout=timedelta(seconds=60), + ), + ) + + plugin = GeminiPlugin(gemini_client=gemini_client) config = ClientConfig.load_client_connect_config() config.setdefault("target_host", "localhost:7233") diff --git a/temporalio/contrib/google_gemini_sdk/workflow.py b/temporalio/contrib/google_gemini_sdk/workflow.py index c3a6ebd8a..669401a4b 100644 --- a/temporalio/contrib/google_gemini_sdk/workflow.py +++ b/temporalio/contrib/google_gemini_sdk/workflow.py @@ -1,44 +1,25 @@ -"""Workflow utilities for Gemini SDK integration with Temporal.""" +"""Workflow utilities for Google Gemini SDK integration with Temporal. + +This module is loaded inside the Temporal workflow sandbox and therefore must +**not** import ``httpx``, ``google.genai``, or any other module that the +sandbox cannot handle. HTTP-transport helpers live in +``_temporal_httpx_client.py`` which is loaded lazily outside the sandbox. +""" from __future__ import annotations import functools import inspect -from dataclasses import dataclass, field from datetime import timedelta from typing import Any from collections.abc import Callable -from google.genai import types - from temporalio import activity from temporalio import workflow as temporal_workflow from temporalio.common import Priority, RetryPolicy from temporalio.exceptions import ApplicationError, TemporalError from temporalio.workflow import ActivityCancellationType, VersioningIntent -from temporalio.contrib.google_gemini_sdk._invoke_model_activity import ( - ActivityModelInput, - GeminiModelActivity, -) -from temporalio.contrib.google_gemini_sdk._model_activity_parameters import ( - ModelActivityParameters, -) - - -@dataclass -class ActivityTool: - """A Temporal activity wrapped as a Gemini tool. - - .. warning:: - This class is experimental and may change in future versions. - Use with caution in production environments. - """ - - function_declaration: types.FunctionDeclaration - activity_name: str - schedule_kwargs: dict[str, Any] - def activity_as_tool( fn: Callable, @@ -52,41 +33,34 @@ def activity_as_tool( cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, versioning_intent: VersioningIntent | None = None, priority: Priority = Priority.default, - api_option: str = "GEMINI_API", -) -> ActivityTool: - """Convert a Temporal activity function into a Gemini tool for use with :func:`run_agent`. +) -> Callable: + """Convert a Temporal activity into a Gemini-compatible async tool callable. .. warning:: This API is experimental and may change in future versions. Use with caution in production environments. + Returns an async callable with the same name, docstring, and type signature as + ``fn``. When Gemini's automatic function calling (AFC) invokes the returned + callable from within a Temporal workflow, the call is executed as a Temporal + activity via :func:`workflow.execute_activity`. Each tool invocation therefore + appears as a separate, durable entry in the workflow event history. + + Because AFC is left **enabled**, the Gemini SDK owns the agentic loop — no + manual ``while`` loop or ``run_agent()`` helper is required. Pass the returned + callable directly to ``GenerateContentConfig(tools=[...])``. + + For undocumented arguments see :py:meth:`workflow.execute_activity`. + Args: fn: A Temporal activity function decorated with ``@activity.defn``. - task_queue: Specific task queue to use for this tool's activity. - schedule_to_close_timeout: Maximum time from scheduling to completion. - schedule_to_start_timeout: Maximum time from scheduling to starting. - start_to_close_timeout: Maximum time for the activity to complete. - heartbeat_timeout: Maximum time between heartbeats. - retry_policy: Policy for retrying failed activities. - cancellation_type: How the activity handles cancellation. - versioning_intent: Versioning intent for the activity. - priority: Priority for the activity execution. - api_option: Gemini API option for schema generation. Defaults to ``"GEMINI_API"``. Returns: - An :class:`ActivityTool` wrapping the activity with its Gemini schema. + An async callable suitable for use as a Gemini tool. Raises: - ApplicationError: If the function is not properly decorated as a Temporal activity. - - Example: - >>> @activity.defn - ... async def get_weather(request: WeatherRequest) -> str: ... - >>> - >>> tool = activity_as_tool( - ... get_weather, - ... start_to_close_timeout=timedelta(seconds=30), - ... ) + ApplicationError: If ``fn`` is not decorated with ``@activity.defn`` or + has no activity name. """ ret = activity._Definition.from_callable(fn) if not ret: @@ -96,16 +70,17 @@ def activity_as_tool( ) if ret.name is None: raise ApplicationError( - "Activity must have a name to be used as a tool", + "Activity must have a name to be used as a Gemini tool", "invalid_tool", ) - # If the callable has a 'self' parameter (class-based activity), partially apply it - # so that FunctionDeclaration schema generation ignores the self param. - # The actual instance is resolved at execution time by the worker. + # For class-based activities the first parameter is 'self'. Partially apply + # it so that Gemini inspects only the user-facing parameters when building + # the function-call schema, while the worker resolves the real instance at + # execution time. params = list(inspect.signature(fn).parameters.keys()) - schema_fn = fn - if len(params) > 0 and params[0] == "self": + schema_fn: Callable = fn + if params and params[0] == "self": partial = functools.partial(fn, None) setattr(partial, "__name__", fn.__name__) partial.__annotations__ = getattr(fn, "__annotations__", {}) @@ -117,11 +92,7 @@ def activity_as_tool( partial.__doc__ = fn.__doc__ schema_fn = partial - function_declaration = types.FunctionDeclaration.from_callable_with_api_option( - callable=schema_fn, - api_option=api_option, - ) - + activity_name: str = ret.name schedule_kwargs: dict[str, Any] = { "task_queue": task_queue, "schedule_to_close_timeout": schedule_to_close_timeout, @@ -134,189 +105,31 @@ def activity_as_tool( "priority": priority, } - return ActivityTool( - function_declaration=function_declaration, - activity_name=ret.name, - schedule_kwargs=schedule_kwargs, - ) - - -@dataclass -class GeminiAgent: - """Configuration for a Gemini-powered agentic loop. - - .. warning:: - This class is experimental and may change in future versions. - Use with caution in production environments. - - Example: - >>> agent = GeminiAgent( - ... model="gemini-2.5-flash", - ... system_instruction="You are a helpful assistant.", - ... tools=[ - ... activity_as_tool(get_weather, start_to_close_timeout=timedelta(seconds=30)), - ... ], - ... ) - """ - - model: str - system_instruction: str | None = None - tools: list[ActivityTool] = field(default_factory=list) - generation_config: types.GenerateContentConfig | None = None - api_option: str = "GEMINI_API" - - -async def run_agent( - agent: GeminiAgent, - initial_message: str, - *, - model_params: ModelActivityParameters | None = None, - max_turns: int = 10, -) -> str: - """Run the Gemini agentic loop inside a Temporal workflow. - - .. warning:: - This API is experimental and may change in future versions. - Use with caution in production environments. - - Each model call and each tool invocation becomes a separate Temporal activity, - giving full workflow history visibility and crash recovery. - - Args: - agent: The :class:`GeminiAgent` configuration. - initial_message: The user's initial query. - model_params: Optional parameters for configuring model activity execution. - max_turns: Maximum number of model call + tool execution rounds before raising. - - Returns: - The model's final text response. - - Raises: - GeminiAgentWorkflowError: If ``max_turns`` is exceeded or an unknown tool is called. - """ - if model_params is None: - model_params = ModelActivityParameters() - - history: list[types.Content] = [ - types.Content(role="user", parts=[types.Part.from_text(text=initial_message)]) - ] - - function_declarations = [t.function_declaration for t in agent.tools] - activity_tools: dict[str, ActivityTool] = { - t.function_declaration.name: t for t in agent.tools - } - - # Build kwargs for the model activity execution - model_kwargs: dict[str, Any] = { - "task_queue": model_params.task_queue, - "schedule_to_close_timeout": model_params.schedule_to_close_timeout, - "schedule_to_start_timeout": model_params.schedule_to_start_timeout, - "start_to_close_timeout": model_params.start_to_close_timeout, - "heartbeat_timeout": model_params.heartbeat_timeout, - "retry_policy": model_params.retry_policy, - "cancellation_type": model_params.cancellation_type, - "versioning_intent": model_params.versioning_intent, - "priority": model_params.priority, - } - - for _ in range(max_turns): - model_input = ActivityModelInput( - model=agent.model, - system_instruction=agent.system_instruction, - contents=history, - function_declarations=function_declarations, - generation_config=agent.generation_config, + async def wrapper(*args: Any, **kwargs: Any) -> Any: + sig = inspect.signature(schema_fn) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + activity_args = list(bound.arguments.values()) + return await temporal_workflow.execute_activity( + activity_name, + args=activity_args, + **schedule_kwargs, ) - if model_params.use_local_activity: - result = await temporal_workflow.execute_local_activity_method( - GeminiModelActivity.generate_content_activity, - model_input, - **{ - k: v - for k, v in model_kwargs.items() - if k - not in ( - "task_queue", - "schedule_to_start_timeout", - "versioning_intent", - ) - }, - ) - else: - result = await temporal_workflow.execute_activity_method( - GeminiModelActivity.generate_content_activity, - model_input, - **model_kwargs, - ) - - if result.function_calls: - history.append(result.model_content) - - for fc in result.function_calls: - tool = activity_tools.get(fc.name) - if tool is None: - raise GeminiAgentWorkflowError( - f"Model called unknown tool '{fc.name}'. " - f"Available tools: {list(activity_tools.keys())}" - ) + wrapper.__name__ = schema_fn.__name__ + wrapper.__doc__ = schema_fn.__doc__ + setattr(wrapper, "__signature__", inspect.signature(schema_fn)) + wrapper.__annotations__ = getattr(schema_fn, "__annotations__", {}) - # Extract positional args from the dict Gemini returns. - # Gemini wraps each parameter under its name, e.g.: - # get_weather_alerts(request: WeatherRequest) → {"request": {"state": "CA"}} - # list(args.values()) unwraps to [{"state": "CA"}], which Temporal - # then deserializes as the activity's Pydantic parameter type. - # For no-arg activities, args={} → dispatch_args=[] (no args passed). - dispatch_args = list(fc.args.values()) - - try: - tool_result = await temporal_workflow.execute_activity( - tool.activity_name, - args=dispatch_args, - **tool.schedule_kwargs, - ) - except Exception as exc: - raise GeminiAgentWorkflowError( - f"Tool '{fc.name}' raised an error: {exc}" - ) from exc - - try: - result_str = str(tool_result) - except Exception as exc: - raise GeminiToolSerializationError( - f"Tool '{fc.name}' returned a value that could not be converted to str" - ) from exc - - history.append( - types.Content( - role="user", - parts=[ - types.Part.from_function_response( - name=fc.name, - response={"result": result_str}, - ) - ], - ) - ) - else: - return result.text or "" - - raise GeminiAgentWorkflowError( - f"Agent exceeded maximum turns ({max_turns}) without producing a final response." - ) + return wrapper class GeminiAgentWorkflowError(TemporalError): - """Raised when the Gemini agent loop cannot complete normally. + """Raised when a Gemini-driven agentic workflow cannot complete normally. .. warning:: This exception is experimental and may change in future versions. Use with caution in production environments. - - This is raised when: - - The agent exceeds ``max_turns`` without returning a text response. - - The model calls a tool that was not registered. - - A tool activity raises an unexpected error. """ @@ -326,7 +139,4 @@ class GeminiToolSerializationError(TemporalError): .. warning:: This exception is experimental and may change in future versions. Use with caution in production environments. - - All tool outputs are converted to strings before being sent back to the model. - If ``str(result)`` raises, this exception is raised instead. """ From ebbaeee74171b18ac92fd6ed161fab0432794efe Mon Sep 17 00:00:00 2001 From: Jason Steving Date: Wed, 18 Mar 2026 13:54:47 -0700 Subject: [PATCH 3/6] Move genai.Client creation into GeminiPlugin to simplify user-facing API GeminiPlugin now accepts genai.Client kwargs (api_key, vertexai, project, etc.) directly and creates the client internally with temporal_http_options(), eliminating the need for users to manually wire up the HTTP transport. Activity timeout/retry configuration is exposed as explicit constructor parameters. Also adds TYPE_CHECKING imports for better IDE support in __init__.py and _client_store.py. --- .../contrib/google_gemini_sdk/__init__.py | 34 ++++--- .../google_gemini_sdk/_client_store.py | 15 +-- .../google_gemini_sdk/_gemini_plugin.py | 95 +++++++++++++++---- .../first_class_example/worker.py | 29 +++--- 4 files changed, 126 insertions(+), 47 deletions(-) diff --git a/temporalio/contrib/google_gemini_sdk/__init__.py b/temporalio/contrib/google_gemini_sdk/__init__.py index af04a991a..1d905b33a 100644 --- a/temporalio/contrib/google_gemini_sdk/__init__.py +++ b/temporalio/contrib/google_gemini_sdk/__init__.py @@ -8,24 +8,18 @@ while making every network call and every tool invocation **durable Temporal activities**. -- :func:`temporal_http_options` — wrap ``genai.Client`` so all HTTP calls - (model calls, streaming responses) are routed through a Temporal activity. +- :class:`GeminiPlugin` — creates and owns the ``genai.Client``, registers + the HTTP transport activity, and configures the worker. Pass the same args + you would pass to ``genai.Client()`` — the plugin handles ``http_options`` + internally. - :func:`activity_as_tool` — convert any ``@activity.defn`` function into a Gemini tool callable; Gemini's AFC invokes it as a Temporal activity. -- :class:`GeminiPlugin` — stores the pre-built client, registers the HTTP - transport activity, and configures the worker. - :func:`get_gemini_client` — retrieve the client inside a workflow. Quickstart:: # ---- worker setup (outside sandbox) ---- - gemini_client = genai.Client( - api_key=os.environ["GOOGLE_API_KEY"], - http_options=temporal_http_options( - start_to_close_timeout=timedelta(seconds=60), - ), - ) - plugin = GeminiPlugin(gemini_client=gemini_client) + plugin = GeminiPlugin(api_key=os.environ["GOOGLE_API_KEY"]) @activity.defn async def get_weather(state: str) -> str: ... @@ -51,7 +45,23 @@ async def run(self, query: str) -> str: return response.text """ -# --- Sandbox-safe imports (loaded eagerly) --- +from __future__ import annotations + +from typing import TYPE_CHECKING + +# --- Type-checking imports (never executed at runtime) --- +# These give IDEs and type checkers full visibility into the lazy-loaded +# symbols so that autocomplete, go-to-definition, and hover docs work. +if TYPE_CHECKING: + from temporalio.contrib.google_gemini_sdk._gemini_plugin import ( + GeminiPlugin as GeminiPlugin, + ) + from temporalio.contrib.google_gemini_sdk._temporal_httpx_client import ( + TemporalHttpxClient as TemporalHttpxClient, + temporal_http_options as temporal_http_options, + ) + +# --- Sandbox-safe imports (loaded eagerly at runtime) --- # These modules have NO httpx / google.genai imports and are safe to load # inside the Temporal workflow sandbox. from temporalio.contrib.google_gemini_sdk._client_store import get_gemini_client diff --git a/temporalio/contrib/google_gemini_sdk/_client_store.py b/temporalio/contrib/google_gemini_sdk/_client_store.py index c7120170c..ef4045d09 100644 --- a/temporalio/contrib/google_gemini_sdk/_client_store.py +++ b/temporalio/contrib/google_gemini_sdk/_client_store.py @@ -6,20 +6,23 @@ the real runtime (where the worker sets it) and the sandboxed workflow (where :func:`get_gemini_client` reads it). -This module intentionally has **no** ``httpx`` or ``google.genai`` imports so -that it can also be loaded safely by the sandbox's restricted importer if the -passthrough hasn't been configured yet. +This module intentionally has **no** runtime ``httpx`` or ``google.genai`` +imports so that it can also be loaded safely by the sandbox's restricted +importer if the passthrough hasn't been configured yet. """ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from google.genai import Client as _GeminiClient # Set by GeminiPlugin.__init__ before the worker starts. -_gemini_client: Any = None +_gemini_client: _GeminiClient | None = None -def get_gemini_client() -> Any: +def get_gemini_client() -> _GeminiClient: """Return the ``genai.Client`` stored by :class:`GeminiPlugin`. .. warning:: diff --git a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py index 1d29f48fa..7e4d9759b 100644 --- a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py +++ b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py @@ -3,8 +3,10 @@ from __future__ import annotations import dataclasses +from datetime import timedelta from typing import Any +from temporalio.common import RetryPolicy from temporalio.contrib.google_gemini_sdk import _client_store from temporalio.contrib.google_gemini_sdk._http_activity import gemini_api_call from temporalio.contrib.google_gemini_sdk.workflow import GeminiAgentWorkflowError @@ -24,39 +26,96 @@ class GeminiPlugin(SimplePlugin): This class is experimental and may change in future versions. Use with caution in production environments. - This plugin: + This plugin creates and owns the ``genai.Client`` instance, ensuring it + is always configured with :func:`temporal_http_options` so that every HTTP + call (LLM model calls, streaming, etc.) is routed through a Temporal + activity. Workflows retrieve the client via :func:`get_gemini_client`. - - Stores the pre-built ``genai.Client`` so that workflows can retrieve it - via :func:`~temporalio.contrib.google_gemini_sdk.get_gemini_client` - without accessing ``os.environ`` or creating heavy objects in the sandbox. - - Registers ``gemini_api_call`` — the durable HTTP transport invoked - by :class:`~temporalio.contrib.google_gemini_sdk.TemporalHttpxClient`. + It also: + + - Registers the ``gemini_api_call`` activity (the durable HTTP transport). - Configures the Pydantic data converter and sandbox passthrough modules. + All ``genai.Client`` constructor arguments (``api_key``, ``vertexai``, + ``project``, ``credentials``, etc.) are forwarded via ``**kwargs`` — the + plugin automatically handles ``http_options``. If the Gemini SDK adds new + constructor parameters in a future release, they are forwarded without + any changes to this plugin. + Example:: - gemini_client = genai.Client( - api_key=os.environ["GOOGLE_API_KEY"], - http_options=temporal_http_options( - start_to_close_timeout=timedelta(seconds=60), - ), - ) - plugin = GeminiPlugin(gemini_client=gemini_client) + plugin = GeminiPlugin(api_key=os.environ["GOOGLE_API_KEY"]) client = await Client.connect("localhost:7233", plugins=[plugin]) + async with Worker(client, task_queue="q", workflows=[MyWorkflow], + activities=[my_tool]): + ... + + Vertex AI example:: + + plugin = GeminiPlugin(vertexai=True, project="my-project", location="us-central1") """ def __init__( self, - gemini_client: Any, + *, + # ── Temporal activity configuration for model HTTP calls ───────── + start_to_close_timeout: timedelta = timedelta(seconds=60), + schedule_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + # ── genai.Client constructor args ──────────────────────────────── + # Forwarded directly to genai.Client(). The plugin adds + # http_options automatically — do NOT pass it here. + # See genai.Client docs for available kwargs: api_key, vertexai, + # project, location, credentials, etc. + **gemini_client_kwargs: Any, ) -> None: """Initialize the Gemini plugin. Args: - gemini_client: A pre-built ``genai.Client`` instance. Create it at - worker startup (where ``os.environ`` is available) with - ``http_options=temporal_http_options(...)`` so that its HTTP - calls are routed through Temporal activities. + start_to_close_timeout: Maximum time for a single model HTTP call + activity. Defaults to 60 seconds. + schedule_to_close_timeout: Maximum time from scheduling to + completion for model HTTP call activities. + heartbeat_timeout: Maximum time between heartbeats for model HTTP + call activities. + retry_policy: Retry policy for failed model HTTP call activities. + **gemini_client_kwargs: Forwarded to ``genai.Client()``. Do NOT + pass ``http_options`` — the plugin manages it internally. + See ``genai.Client`` for available options (``api_key``, + ``vertexai``, ``project``, ``location``, ``credentials``, …). """ + if "http_options" in gemini_client_kwargs: + raise ValueError( + "Do not pass http_options to GeminiPlugin — it configures " + "temporal_http_options() internally. Pass other genai.Client " + "args (api_key, vertexai, project, etc.) directly." + ) + + # Create the genai.Client with temporal_http_options() so that all + # HTTP calls go through a Temporal activity. This happens at worker + # startup (outside the sandbox) where os.environ is available. + # + # When no kwargs are provided (e.g. in test environments), skip client + # creation — get_gemini_client() will raise at workflow runtime. + gemini_client = None + if gemini_client_kwargs: + from google.genai import Client as GeminiClient + + from temporalio.contrib.google_gemini_sdk._temporal_httpx_client import ( + temporal_http_options, + ) + + gemini_client = GeminiClient( + http_options=temporal_http_options( + start_to_close_timeout=start_to_close_timeout, + schedule_to_close_timeout=schedule_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + retry_policy=retry_policy, + ), + **gemini_client_kwargs, + ) + # Store the client in the passthrough'd module so the sandbox can see it. _client_store._gemini_client = gemini_client diff --git a/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py index d927d3134..8d1b0f9e8 100644 --- a/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py +++ b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py @@ -123,6 +123,12 @@ class WeatherAgentWorkflow: @workflow.run async def run(self, query: str) -> str: + # Retrieve the pre-built genai.Client that was created at worker + # startup and stored via GeminiPlugin. We cannot instantiate + # genai.Client here because its constructor always reads os.environ + # (for the API key, project ID, etc.), which Temporal's workflow + # sandbox forbids. get_gemini_client() reads from a passthrough'd + # module, so the sandbox sees the real, pre-configured client object. client = get_gemini_client() response = await client.aio.models.generate_content( model="gemini-2.5-flash", @@ -149,26 +155,27 @@ async def run(self, query: str) -> str: # ============================================================================= -# Worker — create client once, pass to plugin, start worker +# Worker — plugin owns client creation, start worker # ============================================================================= async def main() -> None: load_dotenv() - # Import here so it's outside the sandbox (only main() runs outside). - from google.genai import Client as GeminiClient - from temporalio.contrib.google_gemini_sdk import temporal_http_options - - gemini_client = GeminiClient( + # GeminiPlugin creates the genai.Client internally, ensuring it is always + # wired with temporal_http_options() so every LLM HTTP call runs as a + # durable Temporal activity. Pass the same args you'd pass to + # genai.Client() — the plugin handles http_options for you. + # + # The client is created here (at worker startup, outside the sandbox) + # because genai.Client() reads os.environ internally, which Temporal's + # workflow sandbox forbids. Workflows retrieve the pre-built client + # via get_gemini_client(). + plugin = GeminiPlugin( api_key=os.environ["GOOGLE_API_KEY"], - http_options=temporal_http_options( - start_to_close_timeout=timedelta(seconds=60), - ), + start_to_close_timeout=timedelta(seconds=60), ) - plugin = GeminiPlugin(gemini_client=gemini_client) - config = ClientConfig.load_client_connect_config() config.setdefault("target_host", "localhost:7233") client = await Client.connect(**config, plugins=[plugin]) From 7ff46ac13fd171eeeb56e55ea689d1597ce47546 Mon Sep 17 00:00:00 2001 From: Jason Steving Date: Wed, 18 Mar 2026 16:36:06 -0700 Subject: [PATCH 4/6] Add field-level encryption for sensitive headers in Temporal event history MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the manual credential stripping/re-injection pattern (strip x-goog-api-key before serialization, re-inject from env vars in the activity) with a proper PayloadCodec that encrypts sensitive header values within HttpRequestData payloads using Fernet. This approach is better because: - All headers (including authorization) are now protected, not just x-goog-api-key - The Temporal UI still shows the full request structure with all non-sensitive fields readable — no Codec Server needed - No env var dependency in the activity for credential re-injection New module _sensitive_fields_codec.py provides three components: - TypeTaggingPydanticJSONConverter: writes the Python type name into payload metadata so the codec can identify which model produced it - SensitiveFieldsCodec: encrypts/decrypts specific keys within specific dict fields of registered Pydantic models - make_sensitive_fields_data_converter: factory wiring both into a single DataConverter GeminiPlugin gains two new parameters: - sensitive_activity_fields: which header keys to encrypt (defaults to x-goog-api-key and authorization) - sensitive_activity_fields_encryption_key: the Fernet key Encryption is on by default; pass sensitive_activity_fields=None to disable for local dev. --- .../contrib/google_gemini_sdk/__init__.py | 5 + .../google_gemini_sdk/_gemini_plugin.py | 173 ++++++++++--- .../google_gemini_sdk/_http_activity.py | 37 +-- .../_sensitive_fields_codec.py | 242 ++++++++++++++++++ .../_temporal_httpx_client.py | 29 +-- .../first_class_example/worker.py | 45 +++- 6 files changed, 440 insertions(+), 91 deletions(-) create mode 100644 temporalio/contrib/google_gemini_sdk/_sensitive_fields_codec.py diff --git a/temporalio/contrib/google_gemini_sdk/__init__.py b/temporalio/contrib/google_gemini_sdk/__init__.py index 1d905b33a..3f3711aff 100644 --- a/temporalio/contrib/google_gemini_sdk/__init__.py +++ b/temporalio/contrib/google_gemini_sdk/__init__.py @@ -72,6 +72,7 @@ async def run(self, query: str) -> str: ) __all__ = [ + "DEFAULT_SENSITIVE_HEADER_KEYS", "GeminiAgentWorkflowError", "GeminiPlugin", "GeminiToolSerializationError", @@ -91,6 +92,10 @@ async def run(self, query: str) -> str: # never trigger an httpx import. def __getattr__(name: str): # type: ignore[override] _lazy = { + "DEFAULT_SENSITIVE_HEADER_KEYS": ( + "temporalio.contrib.google_gemini_sdk._gemini_plugin", + "DEFAULT_SENSITIVE_HEADER_KEYS", + ), "GeminiPlugin": ( "temporalio.contrib.google_gemini_sdk._gemini_plugin", "GeminiPlugin", diff --git a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py index 7e4d9759b..b005fdc65 100644 --- a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py +++ b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py @@ -8,16 +8,55 @@ from temporalio.common import RetryPolicy from temporalio.contrib.google_gemini_sdk import _client_store -from temporalio.contrib.google_gemini_sdk._http_activity import gemini_api_call +from temporalio.contrib.google_gemini_sdk._http_activity import ( + HttpRequestData, + gemini_api_call, +) from temporalio.contrib.google_gemini_sdk.workflow import GeminiAgentWorkflowError from temporalio.contrib.pydantic import ( PydanticPayloadConverter as _DefaultPydanticPayloadConverter, ) -from temporalio.converter import DataConverter, DefaultPayloadConverter +from typing import Sequence + +import temporalio.api.common.v1 +from temporalio.converter import DataConverter, DefaultPayloadConverter, PayloadCodec from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +#: Default set of HTTP header keys that contain credentials and should be +#: encrypted in Temporal's event history. Pass to ``sensitive_activity_fields`` +#: to use these defaults, or provide your own set. +DEFAULT_SENSITIVE_HEADER_KEYS: set[str] = {"x-goog-api-key", "authorization"} + +_ENCRYPTION_KEY_HELP = """\ +sensitive_activity_fields_encryption_key must be a Fernet key (44 URL-safe base64 bytes). + +To generate one: + + Local dev / quick start: + python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())" + + Production: + Store the key in a secret manager (e.g. Google Secret Manager, AWS Secrets + Manager, HashiCorp Vault) and load it at worker startup. The same key must + be used by every worker and client that reads/writes this Temporal namespace. + +Then pass it to GeminiPlugin: + + plugin = GeminiPlugin( + api_key=os.environ["GOOGLE_API_KEY"], + sensitive_activity_fields_encryption_key=os.environ["SENSITIVE_ACTIVITY_FIELDS_ENCRYPTION_KEY"].encode(), + ) + +If you don't need credential encryption (e.g. local dev with a private Temporal +server), pass sensitive_activity_fields=None to skip it entirely: + + plugin = GeminiPlugin( + api_key=os.environ["GOOGLE_API_KEY"], + sensitive_activity_fields=None, + )""" + class GeminiPlugin(SimplePlugin): """A Temporal Worker Plugin configured for the Google Gemini SDK. @@ -34,25 +73,36 @@ class GeminiPlugin(SimplePlugin): It also: - Registers the ``gemini_api_call`` activity (the durable HTTP transport). - - Configures the Pydantic data converter and sandbox passthrough modules. + - Optionally configures a **field-level encryption codec** that encrypts + credential headers (``x-goog-api-key``, ``authorization``) within + ``HttpRequestData`` payloads so they never appear in plaintext in + Temporal's event history. All other fields remain human-readable. + - Configures sandbox passthrough modules. All ``genai.Client`` constructor arguments (``api_key``, ``vertexai``, ``project``, ``credentials``, etc.) are forwarded via ``**kwargs`` — the - plugin automatically handles ``http_options``. If the Gemini SDK adds new - constructor parameters in a future release, they are forwarded without - any changes to this plugin. + plugin automatically handles ``http_options``. - Example:: + Example with encryption (recommended for production):: - plugin = GeminiPlugin(api_key=os.environ["GOOGLE_API_KEY"]) - client = await Client.connect("localhost:7233", plugins=[plugin]) - async with Worker(client, task_queue="q", workflows=[MyWorkflow], - activities=[my_tool]): - ... + plugin = GeminiPlugin( + api_key=os.environ["GOOGLE_API_KEY"], + sensitive_activity_fields_encryption_key=os.environ["SENSITIVE_ACTIVITY_FIELDS_ENCRYPTION_KEY"].encode(), + ) + + Example without encryption (local dev with private Temporal server):: + + plugin = GeminiPlugin( + api_key=os.environ["GOOGLE_API_KEY"], + sensitive_activity_fields=None, + ) Vertex AI example:: - plugin = GeminiPlugin(vertexai=True, project="my-project", location="us-central1") + plugin = GeminiPlugin( + vertexai=True, project="my-project", location="us-central1", + sensitive_activity_fields_encryption_key=os.environ["SENSITIVE_ACTIVITY_FIELDS_ENCRYPTION_KEY"].encode(), + ) """ def __init__( @@ -63,6 +113,9 @@ def __init__( schedule_to_close_timeout: timedelta | None = None, heartbeat_timeout: timedelta | None = None, retry_policy: RetryPolicy | None = None, + # ── Sensitive activity field encryption ────────────────────────── + sensitive_activity_fields: set[str] | None = DEFAULT_SENSITIVE_HEADER_KEYS, + sensitive_activity_fields_encryption_key: bytes | None = None, # ── genai.Client constructor args ──────────────────────────────── # Forwarded directly to genai.Client(). The plugin adds # http_options automatically — do NOT pass it here. @@ -80,6 +133,16 @@ def __init__( heartbeat_timeout: Maximum time between heartbeats for model HTTP call activities. retry_policy: Retry policy for failed model HTTP call activities. + sensitive_activity_fields: Set of HTTP header keys to encrypt in + Temporal's event history. Defaults to + ``DEFAULT_SENSITIVE_HEADER_KEYS`` (``{"x-goog-api-key", + "authorization"}``). Pass ``None`` to disable encryption + entirely (e.g. local dev with a private Temporal server). + sensitive_activity_fields_encryption_key: A Fernet key for + encrypting the fields specified above. Generate one with + ``cryptography.fernet.Fernet.generate_key()``. Must be the + same across all workers and clients for a given namespace. + Required when ``sensitive_activity_fields`` is not ``None``. **gemini_client_kwargs: Forwarded to ``genai.Client()``. Do NOT pass ``http_options`` — the plugin manages it internally. See ``genai.Client`` for available options (``api_key``, @@ -92,11 +155,37 @@ def __init__( "args (api_key, vertexai, project, etc.) directly." ) - # Create the genai.Client with temporal_http_options() so that all - # HTTP calls go through a Temporal activity. This happens at worker - # startup (outside the sandbox) where os.environ is available. + # ── Build the DataConverter ────────────────────────────────────── + if sensitive_activity_fields is not None: + if sensitive_activity_fields_encryption_key is None: + raise ValueError( + "sensitive_activity_fields_encryption_key is required when " + "sensitive_activity_fields is set.\n\n" + _ENCRYPTION_KEY_HELP + ) + + from temporalio.contrib.google_gemini_sdk._sensitive_fields_codec import ( + make_sensitive_fields_data_converter, + ) + + self._data_converter = make_sensitive_fields_data_converter( + model_configs={ + HttpRequestData: {"headers": sensitive_activity_fields}, + }, + encryption_key=sensitive_activity_fields_encryption_key, + ) + else: + # No encryption — use plain Pydantic converter. + # Credential headers will be visible in Temporal's event history. + self._data_converter = DataConverter( + payload_converter_class=_DefaultPydanticPayloadConverter + ) + + # ── Create the genai.Client ────────────────────────────────────── + # Uses temporal_http_options() so all HTTP calls go through a Temporal + # activity. Created at worker startup (outside the sandbox) where + # os.environ is available. # - # When no kwargs are provided (e.g. in test environments), skip client + # When no kwargs are provided (e.g. test environments), skip client # creation — get_gemini_client() will raise at workflow runtime. gemini_client = None if gemini_client_kwargs: @@ -126,18 +215,11 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: return dataclasses.replace( runner, restrictions=runner.restrictions.with_passthrough_modules( - # google.genai — so the workflow can call methods on the - # pre-built genai.Client and use types.GenerateContentConfig "google.genai", "google.api_core", - # pydantic internals — avoids "imported after initial - # workflow load" warnings when google.genai types - # trigger lazy pydantic schema compilation. "pydantic_core", "pydantic", "annotated_types", - # The client store module — passthrough'd so the sandbox - # sees the same _gemini_client reference the worker set. "temporalio.contrib.google_gemini_sdk._client_store", ), ) @@ -154,13 +236,36 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: def _configure_data_converter( self, converter: DataConverter | None ) -> DataConverter: - if converter is None: - return DataConverter( - payload_converter_class=_DefaultPydanticPayloadConverter - ) - elif converter.payload_converter_class is DefaultPayloadConverter: - return dataclasses.replace( - converter, - payload_converter_class=_DefaultPydanticPayloadConverter, - ) - return converter + if converter is not None and converter.payload_codec is not None: + # Caller has their own codec — chain ours first, then theirs. + if self._data_converter.payload_codec is not None: + return dataclasses.replace( + self._data_converter, + payload_codec=_CompositeCodec( + [self._data_converter.payload_codec, converter.payload_codec] + ), + ) + return self._data_converter + + +class _CompositeCodec(PayloadCodec): + """Chains multiple codecs in order (encode: left→right, decode: right→left).""" + + def __init__(self, codecs: list[PayloadCodec]) -> None: + self._codecs = codecs + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + result = list(payloads) + for codec in self._codecs: + result = await codec.encode(result) + return result + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + result = list(payloads) + for codec in reversed(self._codecs): + result = await codec.decode(result) + return result diff --git a/temporalio/contrib/google_gemini_sdk/_http_activity.py b/temporalio/contrib/google_gemini_sdk/_http_activity.py index 9a094df1e..fcbfd89c2 100644 --- a/temporalio/contrib/google_gemini_sdk/_http_activity.py +++ b/temporalio/contrib/google_gemini_sdk/_http_activity.py @@ -1,11 +1,17 @@ """Temporal activity for executing HTTP requests on behalf of the Gemini SDK. The Gemini SDK makes HTTP calls internally through an httpx.AsyncClient. -:class:`~temporalio.contrib.google_gemini_sdk.workflow.TemporalHttpxClient` +:class:`~temporalio.contrib.google_gemini_sdk._temporal_httpx_client.TemporalHttpxClient` overrides that client's ``send()`` method to dispatch through this activity instead of making direct network calls. This ensures every Gemini model call is durably recorded in Temporal's workflow event history and benefits from Temporal's retry/timeout semantics. + +Credential headers (``x-goog-api-key``, ``authorization``) are **not** stripped +or re-injected here. They are encrypted transparently by +:class:`~temporalio.contrib.google_gemini_sdk._sensitive_fields_codec.SensitiveFieldsCodec` +before the payload reaches Temporal's event history, and decrypted before the +activity receives it. """ from __future__ import annotations @@ -44,38 +50,19 @@ async def gemini_api_call(req: HttpRequestData) -> HttpResponseData: This activity is registered automatically by :class:`~temporalio.contrib.google_gemini_sdk.GeminiPlugin` and is invoked - by :class:`~temporalio.contrib.google_gemini_sdk.workflow.TemporalHttpxClient` + by :class:`~temporalio.contrib.google_gemini_sdk._temporal_httpx_client.TemporalHttpxClient` whenever the Gemini SDK needs to make a network call from within a workflow. + Credential headers arrive fully intact (decrypted by the codec before the + activity receives the payload). No manual re-injection is needed. + Do not call this activity directly. """ - import os - - # ── Credential re-injection ────────────────────────────────────────── - # TemporalHttpxClient.send() (in _temporal_httpx_client.py) strips the - # "x-goog-api-key" header from the serialized request so that the API - # key never appears in Temporal's event history. Here — inside the - # activity, which runs on the worker outside the workflow sandbox — we - # read the API key from the worker's environment and put it back before - # making the real HTTP call. - # - # NOTE: The "authorization" header (used by Vertex AI OAuth flows) is - # NOT stripped because those tokens are short-lived and refreshed - # per-request by the SDK — we cannot reconstruct them here. - # - # This means the API key must be set in the worker's environment - # (GOOGLE_API_KEY or GEMINI_API_KEY) for API-key auth to work. - # ───────────────────────────────────────────────────────────────────── - headers = dict(req.headers) - api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("GEMINI_API_KEY") - if api_key and "x-goog-api-key" not in headers: - headers["x-goog-api-key"] = api_key - async with httpx.AsyncClient() as client: response = await client.request( method=req.method, url=req.url, - headers=headers, + headers=req.headers, content=req.content, timeout=req.timeout, ) diff --git a/temporalio/contrib/google_gemini_sdk/_sensitive_fields_codec.py b/temporalio/contrib/google_gemini_sdk/_sensitive_fields_codec.py new file mode 100644 index 000000000..d1f18311d --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_sensitive_fields_codec.py @@ -0,0 +1,242 @@ +"""Targeted field-level encryption for Temporal payloads. + +Encrypts specific keys within specific fields of specific Pydantic models, +leaving everything else fully readable in the Temporal UI without a Codec Server. + +This module provides three components that work together: + +1. **TypeTaggingPydanticJSONConverter** — a Pydantic JSON converter that writes + the Python type name into payload metadata so the codec can identify which + model produced a payload. + +2. **SensitiveFieldsCodec** — a PayloadCodec that reads the type tag, looks up + which dict-field keys are sensitive, and encrypts/decrypts just those values. + +3. **make_sensitive_fields_data_converter** — a factory that wires both into a + single ``DataConverter`` ready to pass to ``Client.connect()`` and ``Worker()``. + +See ``plans/sensitive_fields_codec.md`` for the full design document. +""" + +from __future__ import annotations + +import dataclasses +import json +from typing import Any, Sequence + +from cryptography.fernet import Fernet +from pydantic import BaseModel + +import temporalio.api.common.v1 +from temporalio.contrib.pydantic import ( + PydanticJSONPlainPayloadConverter, + PydanticPayloadConverter, + pydantic_data_converter, +) +from temporalio.converter import ( + DataConverter, + PayloadCodec, +) + +# ── Constants ──────────────────────────────────────────────────────────────── + +SENSITIVE_FIELDS_ENCODING = b"json/sensitive-fields" +"""Encoding metadata value written to payloads that have had fields encrypted.""" + +SENTINEL_PREFIX = "__enc__" +"""Prefix on encrypted values — a defensive guard against double-decryption.""" + +TYPE_TAG_METADATA_KEY = "x-python-type" +"""Metadata key (str) carrying the Python ``__qualname__`` of the serialized value.""" + + +# ── Converter ──────────────────────────────────────────────────────────────── + + +class TypeTaggingPydanticJSONConverter(PydanticJSONPlainPayloadConverter): + """Pydantic JSON converter that tags payloads with the Python type name. + + For registered (watched) types, writes ``type(value).__qualname__`` into + ``payload.metadata[x-python-type]`` after standard Pydantic serialization. + This tag is read by :class:`SensitiveFieldsCodec` to route encryption. + """ + + def __init__( + self, + watched_types: frozenset[type], + to_json_options: Any = None, + ) -> None: + super().__init__(to_json_options) + self._watched = watched_types + + def to_payload( + self, value: Any + ) -> temporalio.api.common.v1.Payload | None: + payload = super().to_payload(value) + if payload is not None and isinstance(value, tuple(self._watched)): + payload.metadata[TYPE_TAG_METADATA_KEY] = ( + type(value).__qualname__.encode("utf-8") + ) + return payload + + +# ── Codec ──────────────────────────────────────────────────────────────────── + + +class SensitiveFieldsCodec(PayloadCodec): + """PayloadCodec that encrypts specific keys within dict fields of registered models. + + For each registered Pydantic model type, the config specifies which dict + fields contain sensitive keys and which keys within those dicts to encrypt. + + Example config:: + + model_configs = { + HttpRequestData: {"headers": {"x-goog-api-key", "authorization"}}, + } + + This encrypts ``headers["x-goog-api-key"]`` and ``headers["authorization"]`` + within ``HttpRequestData`` payloads while leaving all other headers, the URL, + method, body, etc. fully readable. + """ + + def __init__( + self, + model_configs: dict[type[BaseModel], dict[str, set[str]]], + encryption_key: bytes, + ) -> None: + """Initialize the codec. + + Args: + model_configs: Mapping of model type → dict field name → set of + sensitive keys within that dict. + encryption_key: A Fernet key (32 bytes, URL-safe base64). Generate + one with ``cryptography.fernet.Fernet.generate_key()``. + """ + self._configs: dict[str, dict[str, set[str]]] = { + t.__qualname__: fields_config + for t, fields_config in model_configs.items() + } + self._fernet = Fernet(encryption_key) + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + """Encrypt sensitive dict-field keys in matching payloads.""" + return [self._encode_one(p) for p in payloads] + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + """Decrypt sensitive dict-field keys in matching payloads.""" + return [self._decode_one(p) for p in payloads] + + def _encode_one( + self, p: temporalio.api.common.v1.Payload + ) -> temporalio.api.common.v1.Payload: + # Gate: type tag must be present and registered. + type_name = p.metadata.get(TYPE_TAG_METADATA_KEY, b"").decode("utf-8") + if not type_name or type_name not in self._configs: + return p + + config = self._configs[type_name] + data: dict[str, Any] = json.loads(p.data) + changed = False + + for dict_field, sensitive_keys in config.items(): + field_val = data.get(dict_field) + if not isinstance(field_val, dict): + continue + for key in sensitive_keys: + if key in field_val and isinstance(field_val[key], str): + plaintext = field_val[key].encode("utf-8") + ciphertext = self._fernet.encrypt(plaintext).decode("utf-8") + field_val[key] = SENTINEL_PREFIX + ciphertext + changed = True + + if not changed: + return p + + new_metadata = dict(p.metadata) + new_metadata["encoding"] = SENSITIVE_FIELDS_ENCODING # bytes + return temporalio.api.common.v1.Payload( + metadata=new_metadata, + data=json.dumps(data).encode("utf-8"), + ) + + def _decode_one( + self, p: temporalio.api.common.v1.Payload + ) -> temporalio.api.common.v1.Payload: + # Gate 1: encoding must match. + if p.metadata.get("encoding") != SENSITIVE_FIELDS_ENCODING: + return p + + # Gate 2: type tag must be present. + type_name = p.metadata.get(TYPE_TAG_METADATA_KEY, b"").decode("utf-8") + if not type_name: + return p + + # Gate 3: type must be registered. + config = self._configs.get(type_name) + if config is None: + return p + + data: dict[str, Any] = json.loads(p.data) + + for dict_field, sensitive_keys in config.items(): + field_val = data.get(dict_field) + if not isinstance(field_val, dict): + continue + for key in sensitive_keys: + val = field_val.get(key) + if isinstance(val, str) and val.startswith(SENTINEL_PREFIX): + ciphertext = val[len(SENTINEL_PREFIX) :].encode("utf-8") + field_val[key] = self._fernet.decrypt(ciphertext).decode("utf-8") + + new_metadata = dict(p.metadata) + new_metadata["encoding"] = b"json/plain" + return temporalio.api.common.v1.Payload( + metadata=new_metadata, + data=json.dumps(data).encode("utf-8"), + ) + + +# ── Factory ────────────────────────────────────────────────────────────────── + + +def make_sensitive_fields_data_converter( + model_configs: dict[type[BaseModel], dict[str, set[str]]], + encryption_key: bytes, +) -> DataConverter: + """Create a ``DataConverter`` with targeted field-level encryption. + + Args: + model_configs: Mapping of model type → dict field name → set of + sensitive keys within that dict. + encryption_key: A Fernet key. + + Returns: + A ``DataConverter`` drop-in replacement for ``pydantic_data_converter``. + """ + watched_types = frozenset(model_configs.keys()) + + # Create a zero-arg converter class that closes over watched_types. + # DataConverter instantiates this class with no arguments in __post_init__. + class _TaggingConverter(PydanticPayloadConverter): + def __init__(self, to_json_options: Any = None) -> None: + # Let PydanticPayloadConverter set up all builtin converters normally, + # then swap the JSON converter with our type-tagging subclass. + super().__init__(to_json_options) + tagging_json = TypeTaggingPydanticJSONConverter( + watched_types, to_json_options + ) + self.converters = { + k: tagging_json if k == b"json/plain" else v + for k, v in self.converters.items() + } + + return dataclasses.replace( + pydantic_data_converter, + payload_converter_class=_TaggingConverter, + payload_codec=SensitiveFieldsCodec(model_configs, encryption_key), + ) diff --git a/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py b/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py index 019834d66..38493832e 100644 --- a/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py +++ b/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py @@ -96,34 +96,15 @@ async def send( # Prefer read timeout (waiting for response), fall back to pool timeout = timeout_ext.get("read") or timeout_ext.get("pool") - # ── Credential stripping ────────────────────────────────────────── - # The Gemini SDK injects "x-goog-api-key" into every outgoing request - # at genai.Client construction time. Because the serialized - # HttpRequestData is stored in Temporal's event history (visible in the - # Temporal UI and accessible to anyone with namespace read access), we - # strip this header so the API key is never persisted. - # - # The matching activity (gemini_api_call) re-injects the API key - # from os.environ on the worker side before making the real HTTP call. - # See _http_activity.py for the other half. - # - # NOTE: We intentionally do NOT strip the "authorization" header - # (used by Vertex AI with OAuth / service-account credentials). - # OAuth tokens are short-lived and refreshed per-request by the SDK; - # we cannot re-inject them in the activity. Vertex AI users should - # be aware that bearer tokens will appear in event history — use - # Temporal's payload codec / encryption if this is a concern. - # ───────────────────────────────────────────────────────────────────── - headers = { - k: v - for k, v in request.headers.items() - if k.lower() != "x-goog-api-key" - } + # Headers are passed through as-is. Sensitive headers (x-goog-api-key, + # authorization) are encrypted transparently by SensitiveFieldsCodec + # before the payload reaches Temporal's event history, and decrypted + # before the activity receives it. No manual stripping needed. req_data = HttpRequestData( method=request.method, url=str(request.url), - headers=headers, + headers=dict(request.headers), content=content, timeout=timeout, ) diff --git a/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py index 8d1b0f9e8..10a37e56e 100644 --- a/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py +++ b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py @@ -167,14 +167,43 @@ async def main() -> None: # durable Temporal activity. Pass the same args you'd pass to # genai.Client() — the plugin handles http_options for you. # - # The client is created here (at worker startup, outside the sandbox) - # because genai.Client() reads os.environ internally, which Temporal's - # workflow sandbox forbids. Workflows retrieve the pre-built client - # via get_gemini_client(). - plugin = GeminiPlugin( - api_key=os.environ["GOOGLE_API_KEY"], - start_to_close_timeout=timedelta(seconds=60), - ) + # The client is created at worker startup (outside the sandbox) because + # genai.Client() reads os.environ internally, which Temporal's workflow + # sandbox forbids. Workflows retrieve the pre-built client via + # get_gemini_client(). + # + # ── Sensitive activity field encryption ──────────────────────────── + # By default, credential headers (x-goog-api-key, authorization) in + # activity inputs are encrypted in Temporal's event history so they're + # never visible in plaintext. Set SENSITIVE_ACTIVITY_FIELDS_ENCRYPTION_KEY in + # your environment to enable this: + # + # Generate a key (once, store in secret manager): + # python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())" + # + # If you don't care about encrypting credentials (e.g. local dev with + # a private Temporal server), pass sensitive_activity_fields=None: + # + # plugin = GeminiPlugin( + # api_key=os.environ["GOOGLE_API_KEY"], + # sensitive_activity_fields=None, + # ) + # ───────────────────────────────────────────────────────────────── + sensitive_fields_key = os.environ.get("SENSITIVE_ACTIVITY_FIELDS_ENCRYPTION_KEY") + if sensitive_fields_key: + # Encrypt credential headers in Temporal's event history. + plugin = GeminiPlugin( + api_key=os.environ["GOOGLE_API_KEY"], + sensitive_activity_fields_encryption_key=sensitive_fields_key.encode(), + start_to_close_timeout=timedelta(seconds=60), + ) + else: + # No encryption — credentials will be visible in the Temporal UI. + plugin = GeminiPlugin( + api_key=os.environ["GOOGLE_API_KEY"], + sensitive_activity_fields=None, + start_to_close_timeout=timedelta(seconds=60), + ) config = ClientConfig.load_client_connect_config() config.setdefault("target_host", "localhost:7233") From 82a6c8e99e2d9fc83bd461c5a6668dd22e5fd4ad Mon Sep 17 00:00:00 2001 From: Jason Steving Date: Wed, 18 Mar 2026 16:59:09 -0700 Subject: [PATCH 5/6] Remove testing.py from branch before PR review --- .../contrib/google_gemini_sdk/testing.py | 161 ------------------ 1 file changed, 161 deletions(-) delete mode 100644 temporalio/contrib/google_gemini_sdk/testing.py diff --git a/temporalio/contrib/google_gemini_sdk/testing.py b/temporalio/contrib/google_gemini_sdk/testing.py deleted file mode 100644 index 4ae682ab8..000000000 --- a/temporalio/contrib/google_gemini_sdk/testing.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Testing utilities for the Google Gemini SDK Temporal integration. - -.. warning:: - This module is experimental and may change in future versions. - Use with caution in production environments. - -Example:: - - from temporalio.contrib.google_gemini_sdk.testing import ( - GeminiEnvironment, - MockGeminiResponse, - ) - - async def test_weather_agent(): - responses = [ - MockGeminiResponse.tool_call("get_weather_alerts", {"request": {"state": "CA"}}), - MockGeminiResponse.text("No active weather alerts in California."), - ] - async with GeminiEnvironment(responses=responses) as env: - client = await Client.connect("localhost:7233", plugins=[env.plugin]) - # ... run your workflow test -""" - -from __future__ import annotations - -from google.genai import types - -from temporalio import activity -from temporalio.contrib.google_gemini_sdk._invoke_model_activity import ( - ActivityModelInput, - ActivityModelOutput, - FunctionCallOutput, -) -from temporalio.contrib.google_gemini_sdk._gemini_plugin import GeminiPlugin - - -class MockGeminiResponse: - """Factory for constructing :class:`ActivityModelOutput` test fixtures. - - .. warning:: - This class is experimental and may change in future versions. - Use with caution in production environments. - """ - - __test__ = False - - @staticmethod - def text(text: str) -> ActivityModelOutput: - """Return an output that simulates a plain-text model response.""" - return ActivityModelOutput( - text=text, - function_calls=[], - model_content=types.Content( - role="model", - parts=[types.Part.from_text(text=text)], - ), - ) - - @staticmethod - def tool_call(name: str, args: dict) -> ActivityModelOutput: - """Return an output that simulates a model requesting a tool call.""" - return ActivityModelOutput( - text=None, - function_calls=[FunctionCallOutput(name=name, args=args)], - model_content=types.Content( - role="model", - parts=[ - types.Part( - function_call=types.FunctionCall(name=name, args=args) - ) - ], - ), - ) - - -class TestGeminiModelActivity: - """A mock replacement for :class:`GeminiModelActivity` that returns pre-configured responses. - - Responses are consumed in FIFO order. If no responses remain, raises ``IndexError``. - - .. warning:: - This class is experimental and may change in future versions. - Use with caution in production environments. - - Example:: - - responses = [ - MockGeminiResponse.tool_call("lookup", {"lookup": {"query": "CA"}}), - MockGeminiResponse.text("Here is the answer."), - ] - activity_instance = TestGeminiModelActivity(responses) - plugin = GeminiPlugin(_model_activity=activity_instance) - """ - - __test__ = False - - def __init__(self, responses: list[ActivityModelOutput]) -> None: - self._responses = list(responses) - - @activity.defn - async def generate_content_activity( - self, input: ActivityModelInput - ) -> ActivityModelOutput: - """Return the next pre-configured response, ignoring the actual input.""" - if not self._responses: - raise IndexError( - "TestGeminiModelActivity has no more responses. " - "Add more responses to the list passed to the constructor." - ) - return self._responses.pop(0) - - -class GeminiEnvironment: - """A test environment that wires up a mock Gemini model activity. - - .. warning:: - This class is experimental and may change in future versions. - Use with caution in production environments. - - Example:: - - responses = [ - MockGeminiResponse.text("Hello!"), - ] - async with GeminiEnvironment(responses=responses) as env: - client = await Client.connect("localhost:7233", plugins=[env.plugin]) - async with Worker( - client, - task_queue="test-queue", - workflows=[MyWorkflow], - activities=[my_tool], - ): - result = await client.execute_workflow(...) - """ - - __test__ = False - - def __init__( - self, - responses: list[ActivityModelOutput] | None = None, - ) -> None: - test_activity = TestGeminiModelActivity(list(responses or [])) - self._plugin = GeminiPlugin(_model_activity=test_activity) - - async def __aenter__(self) -> GeminiEnvironment: - return self - - async def __aexit__(self, *args: object) -> None: - pass - - def applied_on_client(self, client: object) -> object: - """Return the client unchanged (plugin is applied at connect time via ``plugins=``). - - This method exists for API symmetry with other test environment helpers. - """ - return client - - @property - def plugin(self) -> GeminiPlugin: - """The :class:`GeminiPlugin` configured with the mock model activity.""" - return self._plugin From b3b45536dc253c55909dcb5430c55de6b1c453f0 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 19 Mar 2026 14:25:55 -0700 Subject: [PATCH 6/6] Some PR changes --- .../contrib/google_gemini_sdk/__init__.py | 59 +------------------ .../google_gemini_sdk/_client_store.py | 46 --------------- .../google_gemini_sdk/_gemini_plugin.py | 44 ++------------ .../_sensitive_fields_codec.py | 11 ++-- .../_temporal_httpx_client.py | 38 +++++------- .../first_class_example/worker.py | 15 ++--- .../contrib/google_gemini_sdk/workflow.py | 31 +++++++++- 7 files changed, 61 insertions(+), 183 deletions(-) delete mode 100644 temporalio/contrib/google_gemini_sdk/_client_store.py diff --git a/temporalio/contrib/google_gemini_sdk/__init__.py b/temporalio/contrib/google_gemini_sdk/__init__.py index 3f3711aff..48fa7e7e7 100644 --- a/temporalio/contrib/google_gemini_sdk/__init__.py +++ b/temporalio/contrib/google_gemini_sdk/__init__.py @@ -49,22 +49,7 @@ async def run(self, query: str) -> str: from typing import TYPE_CHECKING -# --- Type-checking imports (never executed at runtime) --- -# These give IDEs and type checkers full visibility into the lazy-loaded -# symbols so that autocomplete, go-to-definition, and hover docs work. -if TYPE_CHECKING: - from temporalio.contrib.google_gemini_sdk._gemini_plugin import ( - GeminiPlugin as GeminiPlugin, - ) - from temporalio.contrib.google_gemini_sdk._temporal_httpx_client import ( - TemporalHttpxClient as TemporalHttpxClient, - temporal_http_options as temporal_http_options, - ) - -# --- Sandbox-safe imports (loaded eagerly at runtime) --- -# These modules have NO httpx / google.genai imports and are safe to load -# inside the Temporal workflow sandbox. -from temporalio.contrib.google_gemini_sdk._client_store import get_gemini_client +from temporalio.contrib.google_gemini_sdk._gemini_plugin import GeminiPlugin from temporalio.contrib.google_gemini_sdk.workflow import ( GeminiAgentWorkflowError, GeminiToolSerializationError, @@ -72,50 +57,8 @@ async def run(self, query: str) -> str: ) __all__ = [ - "DEFAULT_SENSITIVE_HEADER_KEYS", "GeminiAgentWorkflowError", "GeminiPlugin", "GeminiToolSerializationError", - "TemporalHttpxClient", "activity_as_tool", - "get_gemini_client", - "temporal_http_options", ] - - -# --- Lazy imports for httpx-dependent symbols --- -# GeminiPlugin, TemporalHttpxClient, and temporal_http_options all transitively -# import httpx (via _gemini_plugin → _http_activity → httpx, and via -# _temporal_httpx_client → httpx). They must NOT be loaded inside the workflow -# sandbox. They are imported lazily so that sandbox-safe imports like -# ``from temporalio.contrib.google_gemini_sdk import activity_as_tool`` -# never trigger an httpx import. -def __getattr__(name: str): # type: ignore[override] - _lazy = { - "DEFAULT_SENSITIVE_HEADER_KEYS": ( - "temporalio.contrib.google_gemini_sdk._gemini_plugin", - "DEFAULT_SENSITIVE_HEADER_KEYS", - ), - "GeminiPlugin": ( - "temporalio.contrib.google_gemini_sdk._gemini_plugin", - "GeminiPlugin", - ), - "TemporalHttpxClient": ( - "temporalio.contrib.google_gemini_sdk._temporal_httpx_client", - "TemporalHttpxClient", - ), - "temporal_http_options": ( - "temporalio.contrib.google_gemini_sdk._temporal_httpx_client", - "temporal_http_options", - ), - } - if name in _lazy: - import importlib - - module_path, attr = _lazy[name] - mod = importlib.import_module(module_path) - value = getattr(mod, attr) - # Cache on the module so __getattr__ is only called once per name. - globals()[name] = value - return value - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/temporalio/contrib/google_gemini_sdk/_client_store.py b/temporalio/contrib/google_gemini_sdk/_client_store.py deleted file mode 100644 index ef4045d09..000000000 --- a/temporalio/contrib/google_gemini_sdk/_client_store.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Shared storage for the pre-built ``genai.Client`` instance. - -This module is added to the Temporal sandbox passthrough list by -:class:`~temporalio.contrib.google_gemini_sdk.GeminiPlugin`. Because it is -passthrough'd, the module-level ``_gemini_client`` variable is shared between -the real runtime (where the worker sets it) and the sandboxed workflow (where -:func:`get_gemini_client` reads it). - -This module intentionally has **no** runtime ``httpx`` or ``google.genai`` -imports so that it can also be loaded safely by the sandbox's restricted -importer if the passthrough hasn't been configured yet. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from google.genai import Client as _GeminiClient - -# Set by GeminiPlugin.__init__ before the worker starts. -_gemini_client: _GeminiClient | None = None - - -def get_gemini_client() -> _GeminiClient: - """Return the ``genai.Client`` stored by :class:`GeminiPlugin`. - - .. warning:: - This function is experimental and may change in future versions. - Use with caution in production environments. - - Call this inside a workflow to obtain the pre-built Gemini client that was - passed to :class:`~temporalio.contrib.google_gemini_sdk.GeminiPlugin` at - worker setup time. The client is created **once** outside the workflow - sandbox, so ``os.environ`` access, SSL cert loading, etc. happen at startup - — not during workflow execution. - - Raises: - RuntimeError: If no client has been configured (i.e. ``GeminiPlugin`` - was not initialised with a ``gemini_client``). - """ - if _gemini_client is None: - raise RuntimeError( - "No Gemini client configured. Pass gemini_client= to GeminiPlugin()." - ) - return _gemini_client diff --git a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py index b005fdc65..163dfc1d5 100644 --- a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py +++ b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py @@ -4,10 +4,10 @@ import dataclasses from datetime import timedelta -from typing import Any +from typing import Any, Sequence +import temporalio.api.common.v1 from temporalio.common import RetryPolicy -from temporalio.contrib.google_gemini_sdk import _client_store from temporalio.contrib.google_gemini_sdk._http_activity import ( HttpRequestData, gemini_api_call, @@ -16,10 +16,7 @@ from temporalio.contrib.pydantic import ( PydanticPayloadConverter as _DefaultPydanticPayloadConverter, ) -from typing import Sequence - -import temporalio.api.common.v1 -from temporalio.converter import DataConverter, DefaultPayloadConverter, PayloadCodec +from temporalio.converter import DataConverter, PayloadCodec from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -180,34 +177,6 @@ def __init__( payload_converter_class=_DefaultPydanticPayloadConverter ) - # ── Create the genai.Client ────────────────────────────────────── - # Uses temporal_http_options() so all HTTP calls go through a Temporal - # activity. Created at worker startup (outside the sandbox) where - # os.environ is available. - # - # When no kwargs are provided (e.g. test environments), skip client - # creation — get_gemini_client() will raise at workflow runtime. - gemini_client = None - if gemini_client_kwargs: - from google.genai import Client as GeminiClient - - from temporalio.contrib.google_gemini_sdk._temporal_httpx_client import ( - temporal_http_options, - ) - - gemini_client = GeminiClient( - http_options=temporal_http_options( - start_to_close_timeout=start_to_close_timeout, - schedule_to_close_timeout=schedule_to_close_timeout, - heartbeat_timeout=heartbeat_timeout, - retry_policy=retry_policy, - ), - **gemini_client_kwargs, - ) - - # Store the client in the passthrough'd module so the sandbox can see it. - _client_store._gemini_client = gemini_client - def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: if not runner: raise ValueError("No WorkflowRunner provided to GeminiPlugin.") @@ -215,12 +184,7 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: return dataclasses.replace( runner, restrictions=runner.restrictions.with_passthrough_modules( - "google.genai", - "google.api_core", - "pydantic_core", - "pydantic", - "annotated_types", - "temporalio.contrib.google_gemini_sdk._client_store", + "google.genai" ), ) return runner diff --git a/temporalio/contrib/google_gemini_sdk/_sensitive_fields_codec.py b/temporalio/contrib/google_gemini_sdk/_sensitive_fields_codec.py index d1f18311d..827bfb918 100644 --- a/temporalio/contrib/google_gemini_sdk/_sensitive_fields_codec.py +++ b/temporalio/contrib/google_gemini_sdk/_sensitive_fields_codec.py @@ -69,13 +69,11 @@ def __init__( super().__init__(to_json_options) self._watched = watched_types - def to_payload( - self, value: Any - ) -> temporalio.api.common.v1.Payload | None: + def to_payload(self, value: Any) -> temporalio.api.common.v1.Payload | None: payload = super().to_payload(value) if payload is not None and isinstance(value, tuple(self._watched)): - payload.metadata[TYPE_TAG_METADATA_KEY] = ( - type(value).__qualname__.encode("utf-8") + payload.metadata[TYPE_TAG_METADATA_KEY] = type(value).__qualname__.encode( + "utf-8" ) return payload @@ -114,8 +112,7 @@ def __init__( one with ``cryptography.fernet.Fernet.generate_key()``. """ self._configs: dict[str, dict[str, set[str]]] = { - t.__qualname__: fields_config - for t, fields_config in model_configs.items() + t.__qualname__: fields_config for t, fields_config in model_configs.items() } self._fernet = Fernet(encryption_key) diff --git a/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py b/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py index 38493832e..bb77ede46 100644 --- a/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py +++ b/temporalio/contrib/google_gemini_sdk/_temporal_httpx_client.py @@ -12,14 +12,15 @@ import httpx from google.genai import types +from google.genai.types import HttpOptions, HttpOptionsDict from temporalio import workflow as temporal_workflow from temporalio.common import RetryPolicy - from temporalio.contrib.google_gemini_sdk._http_activity import ( HttpRequestData, gemini_api_call, ) +from temporalio.workflow import ActivityConfig class _NoOpAsyncTransport(httpx.AsyncBaseTransport): @@ -58,20 +59,14 @@ class TemporalHttpxClient(httpx.AsyncClient): def __init__( self, *, - start_to_close_timeout: timedelta | None = timedelta(seconds=60), - schedule_to_close_timeout: timedelta | None = None, - heartbeat_timeout: timedelta | None = None, - retry_policy: RetryPolicy | None = None, + activity_config: ActivityConfig | None = None, ) -> None: # Use a no-op transport to avoid SSL cert file I/O at construction time. # The transport is never invoked because send() is fully overridden. super().__init__(transport=_NoOpAsyncTransport()) - self._activity_kwargs: dict[str, Any] = { - "start_to_close_timeout": start_to_close_timeout, - "schedule_to_close_timeout": schedule_to_close_timeout, - "heartbeat_timeout": heartbeat_timeout, - "retry_policy": retry_policy, - } + self._activity_config = activity_config or ActivityConfig( + start_to_close_timeout=timedelta(seconds=60) + ) # TODO do we want to merge if start to close timeout not set async def send( self, @@ -112,7 +107,7 @@ async def send( resp_data = await temporal_workflow.execute_activity( gemini_api_call, req_data, - **self._activity_kwargs, + **self._activity_config, ) return httpx.Response( @@ -123,12 +118,15 @@ async def send( ) +class SyncTemporalHttpxClient(httpx.Client): + def __init__(self): + pass + + def temporal_http_options( *, - start_to_close_timeout: timedelta = timedelta(seconds=60), - schedule_to_close_timeout: timedelta | None = None, - heartbeat_timeout: timedelta | None = None, - retry_policy: RetryPolicy | None = None, + http_options: HttpOptions | HttpOptionsDict | None = None, + activity_config: ActivityConfig | None = None, ) -> types.HttpOptions: """Create ``HttpOptions`` that route all Gemini SDK HTTP calls through Temporal. @@ -177,12 +175,8 @@ async def run(self, query: str) -> str: return response.text """ return types.HttpOptions( - httpx_async_client=TemporalHttpxClient( - start_to_close_timeout=start_to_close_timeout, - schedule_to_close_timeout=schedule_to_close_timeout, - heartbeat_timeout=heartbeat_timeout, - retry_policy=retry_policy, - ), + httpx_async_client=TemporalHttpxClient(activity_config=activity_config), + httpx_client=SyncTemporalHttpxClient(), # Temporal owns retries; disable SDK-level retries to avoid interference. retry_options=types.HttpRetryOptions(attempts=1), ) diff --git a/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py index 10a37e56e..def1f17c4 100644 --- a/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py +++ b/temporalio/contrib/google_gemini_sdk/first_class_example/worker.py @@ -15,7 +15,10 @@ from datetime import timedelta from dotenv import load_dotenv +from google.genai import types from pydantic import BaseModel, Field + +import temporalio.contrib.google_gemini_sdk.workflow from temporalio import activity, workflow from temporalio.client import Client from temporalio.envconfig import ClientConfig @@ -23,14 +26,8 @@ with workflow.unsafe.imports_passed_through(): import httpx - from google.genai import types - -from temporalio.contrib.google_gemini_sdk import ( - GeminiPlugin, - activity_as_tool, - get_gemini_client, -) +from temporalio.contrib.google_gemini_sdk import GeminiPlugin, activity_as_tool # ============================================================================= # System Instructions @@ -122,14 +119,14 @@ class WeatherAgentWorkflow: """ @workflow.run - async def run(self, query: str) -> str: + async def run(self, query: str) -> str | None: # Retrieve the pre-built genai.Client that was created at worker # startup and stored via GeminiPlugin. We cannot instantiate # genai.Client here because its constructor always reads os.environ # (for the API key, project ID, etc.), which Temporal's workflow # sandbox forbids. get_gemini_client() reads from a passthrough'd # module, so the sandbox sees the real, pre-configured client object. - client = get_gemini_client() + client = temporalio.contrib.google_gemini_sdk.workflow.gemini_client() response = await client.aio.models.generate_content( model="gemini-2.5-flash", contents=query, diff --git a/temporalio/contrib/google_gemini_sdk/workflow.py b/temporalio/contrib/google_gemini_sdk/workflow.py index 669401a4b..c3929ca8b 100644 --- a/temporalio/contrib/google_gemini_sdk/workflow.py +++ b/temporalio/contrib/google_gemini_sdk/workflow.py @@ -10,13 +10,21 @@ import functools import inspect +from collections.abc import Callable from datetime import timedelta from typing import Any -from collections.abc import Callable + +import google.auth.credentials +from google.genai import Client as GeminiClient +from google.genai.client import DebugConfig +from google.genai.types import HttpOptions, HttpOptionsDict from temporalio import activity from temporalio import workflow as temporal_workflow from temporalio.common import Priority, RetryPolicy +from temporalio.contrib.google_gemini_sdk._temporal_httpx_client import ( + temporal_http_options, +) from temporalio.exceptions import ApplicationError, TemporalError from temporalio.workflow import ActivityCancellationType, VersioningIntent @@ -124,6 +132,27 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper +def gemini_client( + *, + vertexai: bool | None = None, + api_key: str | None = None, + credentials: google.auth.credentials.Credentials | None = None, + project: str | None = None, + location: str | None = None, + debug_config: DebugConfig | None = None, + http_options: HttpOptions | HttpOptionsDict | None = None, +) -> GeminiClient: + return GeminiClient( + http_options=temporal_http_options(http_options=http_options), + vertexai=vertexai, + api_key=api_key, + credentials=credentials, + project=project, + location=location, + debug_config=debug_config, + ) + + class GeminiAgentWorkflowError(TemporalError): """Raised when a Gemini-driven agentic workflow cannot complete normally.