From c4edefbef82fec42c138e76a7eca51482537e757 Mon Sep 17 00:00:00 2001 From: Jake LoRocco Date: Thu, 19 Mar 2026 10:43:43 -0400 Subject: [PATCH] feat: add return types to invoke_hook --- mellea/plugins/manager.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/mellea/plugins/manager.py b/mellea/plugins/manager.py index 978b7ae99..f29b7eb2a 100644 --- a/mellea/plugins/manager.py +++ b/mellea/plugins/manager.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast from mellea.plugins.base import MelleaBasePayload, PluginViolationError from mellea.plugins.context import build_global_context @@ -175,13 +175,17 @@ def deregister_session_plugins(session_id: str) -> None: logger.debug("Plugin %s already unregistered", name, exc_info=True) +# Hooks return the same payload they received. Use this to accurately reflect that typing. +_MelleaBasePayload = TypeVar("_MelleaBasePayload", bound=MelleaBasePayload) + + async def invoke_hook( hook_type: HookType, - payload: MelleaBasePayload, + payload: _MelleaBasePayload, *, backend: Backend | None = None, **context_fields: Any, -) -> tuple[Any | None, MelleaBasePayload]: +) -> tuple[Any | None, _MelleaBasePayload]: """Invoke a hook if plugins are configured. Returns ``(result, possibly-modified-payload)``. @@ -241,7 +245,12 @@ async def invoke_hook( plugin_name=v.plugin_name or "", ) - modified = ( - result.modified_payload if result and result.modified_payload else payload + # `result` doesn't type the returned payload correctly. + # If the modified payload exists, cast it as the correct type here, + # else return the original payload. + modified: _MelleaBasePayload = ( + cast(_MelleaBasePayload, result.modified_payload) + if result and result.modified_payload + else payload ) return result, modified