diff --git a/mellea/plugins/manager.py b/mellea/plugins/manager.py index 978b7ae99..f29b7eb2a 100644 --- a/mellea/plugins/manager.py +++ b/mellea/plugins/manager.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast from mellea.plugins.base import MelleaBasePayload, PluginViolationError from mellea.plugins.context import build_global_context @@ -175,13 +175,17 @@ def deregister_session_plugins(session_id: str) -> None: logger.debug("Plugin %s already unregistered", name, exc_info=True) +# Hooks return the same payload they received. Use this to accurately reflect that typing. +_MelleaBasePayload = TypeVar("_MelleaBasePayload", bound=MelleaBasePayload) + + async def invoke_hook( hook_type: HookType, - payload: MelleaBasePayload, + payload: _MelleaBasePayload, *, backend: Backend | None = None, **context_fields: Any, -) -> tuple[Any | None, MelleaBasePayload]: +) -> tuple[Any | None, _MelleaBasePayload]: """Invoke a hook if plugins are configured. Returns ``(result, possibly-modified-payload)``. @@ -241,7 +245,12 @@ async def invoke_hook( plugin_name=v.plugin_name or "", ) - modified = ( - result.modified_payload if result and result.modified_payload else payload + # `result` doesn't type the returned payload correctly. + # If the modified payload exists, cast it as the correct type here, + # else return the original payload. + modified: _MelleaBasePayload = ( + cast(_MelleaBasePayload, result.modified_payload) + if result and result.modified_payload + else payload ) return result, modified