diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 59057487c..516d65c47 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any from google.protobuf.json_format import MessageToDict, MessageToJson, Parse +from pydantic import ValidationError if TYPE_CHECKING: @@ -21,6 +22,7 @@ from a2a.types import ( AgentCard, GetTaskPushNotificationConfigParams, + InvalidParamsError, TaskIdParams, TaskNotFoundError, TaskQueryParams, @@ -257,8 +259,19 @@ async def on_get_task( """ task_id = request.path_params['id'] history_length_str = request.query_params.get('historyLength') - history_length = int(history_length_str) if history_length_str else None - params = TaskQueryParams(id=task_id, history_length=history_length) + try: + params = TaskQueryParams( + id=task_id, + history_length=history_length_str + if history_length_str + else None, + ) + except ValidationError: + raise ServerError( + error=InvalidParamsError( + message='historyLength must be a valid integer' + ) + ) from None task = await self.request_handler.on_get_task(params, context) if task: return MessageToDict(proto_utils.ToProto.task(task)) diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index 9ea8c9686..8bab3fded 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -399,5 +399,17 @@ async def test_send_message_rejected_task( assert expected_response == actual_response +@pytest.mark.anyio +async def test_get_task_invalid_history_length_returns_422( + client: AsyncClient, +) -> None: + """Non-numeric historyLength query param returns 422 InvalidParamsError.""" + response = await client.get('/v1/tasks/some-task-id?historyLength=abc') + assert response.status_code == 422 + data = response.json() + assert 'message' in data + assert 'historylength' in data['message'].lower() + + if __name__ == '__main__': pytest.main([__file__])