diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py index ace2c6ae3..dfd92d87c 100644 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ b/src/a2a/server/apps/jsonrpc/fastapi_app.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any @@ -72,9 +72,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, max_content_length: int | None = 10 * 1024 * 1024, # 10MB diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 3e7c2854b..27839cd35 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -4,7 +4,7 @@ import traceback from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import TYPE_CHECKING, Any from pydantic import ValidationError @@ -51,6 +51,7 @@ PREV_AGENT_CARD_WELL_KNOWN_PATH, ) from a2a.utils.errors import MethodNotImplementedError +from a2a.utils.helpers import maybe_await logger = logging.getLogger(__name__) @@ -178,9 +179,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, max_content_length: int | None = 10 * 1024 * 1024, # 10MB @@ -576,7 +578,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse: card_to_serve = self.agent_card if self.card_modifier: - card_to_serve = self.card_modifier(card_to_serve) + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return JSONResponse( card_to_serve.model_dump( @@ -605,7 +607,9 @@ async def _handle_get_authenticated_extended_agent_card( context = self._context_builder.build(request) # If no base extended card is provided, pass the public card to the modifier base_card = card_to_serve if card_to_serve else self.agent_card - card_to_serve = self.extended_card_modifier(base_card, context) + card_to_serve = await maybe_await( + self.extended_card_modifier(base_card, context) + ) if card_to_serve: return JSONResponse( diff --git a/src/a2a/server/apps/jsonrpc/starlette_app.py b/src/a2a/server/apps/jsonrpc/starlette_app.py index 1effa9d51..ceaf5ced1 100644 --- a/src/a2a/server/apps/jsonrpc/starlette_app.py +++ b/src/a2a/server/apps/jsonrpc/starlette_app.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any @@ -54,9 +54,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, max_content_length: int | None = 10 * 1024 * 1024, # 10MB diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index 3ae5ad6fe..12a03de84 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any @@ -49,9 +49,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, ): diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index cdf86ab14..719085604 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -4,6 +4,8 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from typing import TYPE_CHECKING, Any +from a2a.utils.helpers import maybe_await + if TYPE_CHECKING: from sse_starlette.sse import EventSourceResponse @@ -58,9 +60,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, ): @@ -150,7 +153,7 @@ async def handle_get_agent_card( """ card_to_serve = self.agent_card if self.card_modifier: - card_to_serve = self.card_modifier(card_to_serve) + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return card_to_serve.model_dump(mode='json', exclude_none=True) @@ -182,9 +185,11 @@ async def handle_authenticated_agent_card( if self.extended_card_modifier: context = self._context_builder.build(request) - card_to_serve = self.extended_card_modifier(card_to_serve, context) + card_to_serve = await maybe_await( + self.extended_card_modifier(card_to_serve, context) + ) elif self.card_modifier: - card_to_serve = self.card_modifier(card_to_serve) + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return card_to_serve.model_dump(mode='json', exclude_none=True) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index e2ec69a15..105b99471 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Sequence +from collections.abc import AsyncIterable, Awaitable, Sequence try: @@ -34,7 +34,7 @@ from a2a.types import AgentCard, TaskNotFoundError from a2a.utils import proto_utils from a2a.utils.errors import ServerError -from a2a.utils.helpers import validate, validate_async_generator +from a2a.utils.helpers import maybe_await, validate, validate_async_generator logger = logging.getLogger(__name__) @@ -89,7 +89,8 @@ def __init__( agent_card: AgentCard, request_handler: RequestHandler, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, ): """Initializes the GrpcHandler. @@ -339,7 +340,7 @@ async def GetAgentCard( """Get the agent card for the agent served.""" card_to_serve = self.agent_card if self.card_modifier: - card_to_serve = self.card_modifier(card_to_serve) + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return proto_utils.ToProto.agent_card(card_to_serve) async def abort_context( diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index 567c61484..6df872fca 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -1,6 +1,6 @@ import logging -from collections.abc import AsyncIterable, Callable +from collections.abc import AsyncIterable, Awaitable, Callable from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler @@ -46,7 +46,7 @@ TaskStatusUpdateEvent, ) from a2a.utils.errors import ServerError -from a2a.utils.helpers import validate +from a2a.utils.helpers import maybe_await, validate from a2a.utils.telemetry import SpanKind, trace_class @@ -63,10 +63,11 @@ def __init__( request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, ): """Initializes the JSONRPCHandler. @@ -450,9 +451,11 @@ async def get_authenticated_extended_card( card_to_serve = base_card if self.extended_card_modifier and context: - card_to_serve = self.extended_card_modifier(base_card, context) + card_to_serve = await maybe_await( + self.extended_card_modifier(base_card, context) + ) elif self.card_modifier: - card_to_serve = self.card_modifier(base_card) + card_to_serve = await maybe_await(self.card_modifier(base_card)) return GetAuthenticatedExtendedCardResponse( root=GetAuthenticatedExtendedCardSuccessResponse( diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 96acdc1e6..8164674e5 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -5,8 +5,8 @@ import json import logging -from collections.abc import Callable -from typing import Any +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar from uuid import uuid4 from a2a.types import ( @@ -24,6 +24,9 @@ from a2a.utils.telemetry import trace_function +T = TypeVar('T') + + logger = logging.getLogger(__name__) @@ -368,3 +371,10 @@ def canonicalize_agent_card(agent_card: AgentCard) -> str: # Recursively remove empty values cleaned_dict = _clean_empty(card_dict) return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) + + +async def maybe_await(value: T | Awaitable[T]) -> T: + """Awaits a value if it's awaitable, otherwise simply provides it back.""" + if inspect.isawaitable(value): + return await value + return value diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 26f923c14..647d9e86f 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -209,6 +209,34 @@ async def test_get_agent_card_with_modifier( ) -> None: """Test GetAgentCard call with a card_modifier.""" + async def modifier(card: types.AgentCard) -> types.AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Modified gRPC Agent' + return modified_card + + grpc_handler_modified = GrpcHandler( + agent_card=sample_agent_card, + request_handler=mock_request_handler, + card_modifier=modifier, + ) + + request_proto = a2a_pb2.GetAgentCardRequest() + response = await grpc_handler_modified.GetAgentCard( + request_proto, mock_grpc_context + ) + + assert response.name == 'Modified gRPC Agent' + assert response.version == sample_agent_card.version + + +@pytest.mark.asyncio +async def test_get_agent_card_with_modifier_sync( + mock_request_handler: AsyncMock, + sample_agent_card: types.AgentCard, + mock_grpc_context: AsyncMock, +) -> None: + """Test GetAgentCard call with a synchronous card_modifier.""" + def modifier(card: types.AgentCard) -> types.AgentCard: modified_card = card.model_copy(deep=True) modified_card.name = 'Modified gRPC Agent' diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index d1ead0211..4ed6e7025 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -1295,6 +1295,57 @@ async def test_get_authenticated_extended_card_with_modifier(self) -> None: skills=[], ) + async def modifier( + card: AgentCard, context: ServerCallContext + ) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Modified Card' + modified_card.description = ( + f'Modified for context: {context.state.get("foo")}' + ) + return modified_card + + handler = JSONRPCHandler( + self.mock_agent_card, + mock_request_handler, + extended_agent_card=mock_base_card, + extended_card_modifier=modifier, + ) + request = GetAuthenticatedExtendedCardRequest(id='ext-card-req-mod') + call_context = ServerCallContext(state={'foo': 'bar'}) + + # Act + response: GetAuthenticatedExtendedCardResponse = ( + await handler.get_authenticated_extended_card(request, call_context) + ) + + # Assert + self.assertIsInstance( + response.root, GetAuthenticatedExtendedCardSuccessResponse + ) + self.assertEqual(response.root.id, 'ext-card-req-mod') + modified_card = response.root.result + self.assertEqual(modified_card.name, 'Modified Card') + self.assertEqual(modified_card.description, 'Modified for context: bar') + self.assertEqual(modified_card.version, '1.0') + + async def test_get_authenticated_extended_card_with_modifier_sync( + self, + ) -> None: + """Test successful retrieval of a synchronously dynamically modified extended agent card.""" + # Arrange + mock_request_handler = AsyncMock(spec=DefaultRequestHandler) + mock_base_card = AgentCard( + name='Base Card', + description='Base details', + url='http://agent.example.com/api', + version='1.0', + capabilities=AgentCapabilities(), + default_input_modes=['text/plain'], + default_output_modes=['application/json'], + skills=[], + ) + def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard: modified_card = card.model_copy(deep=True) modified_card.name = 'Modified Card' diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index d65657dea..8080136c1 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -858,6 +858,30 @@ def test_dynamic_agent_card_modifier( ): """Test that the card_modifier dynamically alters the public agent card.""" + async def modifier(card: AgentCard) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Dynamically Modified Agent' + return modified_card + + app_instance = A2AStarletteApplication( + agent_card, handler, card_modifier=modifier + ) + client = TestClient(app_instance.build()) + + response = client.get(AGENT_CARD_WELL_KNOWN_PATH) + assert response.status_code == 200 + data = response.json() + assert data['name'] == 'Dynamically Modified Agent' + assert ( + data['version'] == agent_card.version + ) # Ensure other fields are intact + + +def test_dynamic_agent_card_modifier_sync( + agent_card: AgentCard, handler: mock.AsyncMock +): + """Test that a synchronous card_modifier dynamically alters the public agent card.""" + def modifier(card: AgentCard) -> AgentCard: modified_card = card.model_copy(deep=True) modified_card.name = 'Dynamically Modified Agent' @@ -885,6 +909,54 @@ def test_dynamic_extended_agent_card_modifier( """Test that the extended_card_modifier dynamically alters the extended agent card.""" agent_card.supports_authenticated_extended_card = True + async def modifier( + card: AgentCard, context: ServerCallContext + ) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.description = 'Dynamically Modified Extended Description' + return modified_card + + # Test with a base extended card + app_instance = A2AStarletteApplication( + agent_card, + handler, + extended_agent_card=extended_agent_card_fixture, + extended_card_modifier=modifier, + ) + client = TestClient(app_instance.build()) + + response = client.get(EXTENDED_AGENT_CARD_PATH) + assert response.status_code == 200 + data = response.json() + assert data['name'] == extended_agent_card_fixture.name + assert data['description'] == 'Dynamically Modified Extended Description' + + # Test without a base extended card (modifier should receive public card) + app_instance_no_base = A2AStarletteApplication( + agent_card, + handler, + extended_agent_card=None, + extended_card_modifier=modifier, + ) + client_no_base = TestClient(app_instance_no_base.build()) + response_no_base = client_no_base.get(EXTENDED_AGENT_CARD_PATH) + assert response_no_base.status_code == 200 + data_no_base = response_no_base.json() + assert data_no_base['name'] == agent_card.name + assert ( + data_no_base['description'] + == 'Dynamically Modified Extended Description' + ) + + +def test_dynamic_extended_agent_card_modifier_sync( + agent_card: AgentCard, + extended_agent_card_fixture: AgentCard, + handler: mock.AsyncMock, +): + """Test that a synchronous extended_card_modifier dynamically alters the extended agent card.""" + agent_card.supports_authenticated_extended_card = True + def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard: modified_card = card.model_copy(deep=True) modified_card.description = 'Dynamically Modified Extended Description' @@ -928,6 +1000,27 @@ def test_fastapi_dynamic_agent_card_modifier( ): """Test that the card_modifier dynamically alters the public agent card for FastAPI.""" + async def modifier(card: AgentCard) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Dynamically Modified Agent' + return modified_card + + app_instance = A2AFastAPIApplication( + agent_card, handler, card_modifier=modifier + ) + client = TestClient(app_instance.build()) + + response = client.get(AGENT_CARD_WELL_KNOWN_PATH) + assert response.status_code == 200 + data = response.json() + assert data['name'] == 'Dynamically Modified Agent' + + +def test_fastapi_dynamic_agent_card_modifier_sync( + agent_card: AgentCard, handler: mock.AsyncMock +): + """Test that a synchronous card_modifier dynamically alters the public agent card for FastAPI.""" + def modifier(card: AgentCard) -> AgentCard: modified_card = card.model_copy(deep=True) modified_card.name = 'Dynamically Modified Agent'