diff --git a/hud/environment/connection.py b/hud/environment/connection.py index 2fde077bd..060eeec01 100644 --- a/hud/environment/connection.py +++ b/hud/environment/connection.py @@ -159,6 +159,9 @@ async def connect(self) -> None: "transport": self._transport, "auth": self._auth, } + client_timeout = getattr(self._transport, "_hud_client_timeout", None) + if client_timeout is not None: + client_kwargs["timeout"] = client_timeout if self._elicitation_handler is not None: client_kwargs["elicitation_handler"] = self._elicitation_handler diff --git a/hud/environment/connectors/mcp_config.py b/hud/environment/connectors/mcp_config.py index 048186742..4d73c71c0 100644 --- a/hud/environment/connectors/mcp_config.py +++ b/hud/environment/connectors/mcp_config.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from hud.environment.connectors.base import BaseConnectorMixin @@ -66,8 +66,7 @@ def connect_mcp( if settings.client_timeout > 0 else min(request_timeout, settings.__class__.model_fields["client_timeout"].default) ) - server_config.setdefault("sse_read_timeout", timeout) - transport = _build_transport(server_config) + transport = _build_transport(server_config, timeout=timeout) return self._add_connection( name, @@ -121,17 +120,29 @@ def connect_mcp_config( return self -def _build_transport(server_config: dict[str, Any]) -> Any: +def _build_transport(server_config: dict[str, Any], *, timeout: float | None = None) -> Any: from fastmcp.client.transports import SSETransport, StreamableHttpTransport from fastmcp.mcp_config import infer_transport_type_from_url url = server_config["url"] transport_type = server_config.get("transport") or infer_transport_type_from_url(url) - transport_cls = SSETransport if transport_type == "sse" else StreamableHttpTransport - - return transport_cls( - url=url, - headers=server_config.get("headers"), - auth=server_config.get("auth"), - sse_read_timeout=server_config.get("sse_read_timeout"), - ) + transport_timeout = timeout if timeout is not None else server_config.get("sse_read_timeout") + transport_kwargs = { + "url": url, + "headers": server_config.get("headers"), + "auth": server_config.get("auth"), + "httpx_client_factory": server_config.get("httpx_client_factory"), + } + + if transport_type == "sse": + return SSETransport( + **transport_kwargs, + sse_read_timeout=transport_timeout, + ) + + transport = StreamableHttpTransport(**transport_kwargs) + if transport_timeout is not None: + # FastMCP 3.x wants streamable HTTP timeouts on the client/session, + # not on the transport constructor. + cast("Any", transport)._hud_client_timeout = transport_timeout + return transport diff --git a/hud/environment/tests/test_connection.py b/hud/environment/tests/test_connection.py index 9ecd9114c..139759043 100644 --- a/hud/environment/tests/test_connection.py +++ b/hud/environment/tests/test_connection.py @@ -140,6 +140,35 @@ async def test_connect_creates_client(self) -> None: # Client is now set assert connector.client is mock_client + @pytest.mark.asyncio + async def test_connect_passes_transport_timeout_to_client(self) -> None: + """connect() forwards transport timeout to FastMCP client session kwargs.""" + + class Transport: + _hud_client_timeout = 300 + + transport = Transport() + connector = Connector( + transport=transport, + config=ConnectionConfig(), + name="test", + connection_type=ConnectionType.REMOTE, + auth="test-token", + ) + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.is_connected = MagicMock(return_value=True) + + with patch("fastmcp.client.Client", return_value=mock_client) as mock_cls: + await connector.connect() + + mock_cls.assert_called_once_with( + transport=transport, + auth="test-token", + timeout=300, + ) + @pytest.mark.asyncio async def test_disconnect_clears_client(self) -> None: """disconnect() closes client and clears state.""" diff --git a/hud/environment/tests/test_connectors.py b/hud/environment/tests/test_connectors.py index a0e018cef..6ecabc144 100644 --- a/hud/environment/tests/test_connectors.py +++ b/hud/environment/tests/test_connectors.py @@ -197,7 +197,7 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: env = TestEnv() with patch("hud.settings.settings", spec=Settings) as mock_settings: mock_settings.hud_mcp_url = "https://mcp.hud.ai" - mock_settings.client_timeout = 300 # Used in connect_mcp for sse_read_timeout + mock_settings.client_timeout = 300 # Used in connect_mcp transport timeout logic env.connect_hub("browser") @@ -205,3 +205,45 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None: assert "hud" in env._connections # Verify hub config is stored for serialization assert env._hub_config == {"name": "browser"} + + def test_connect_mcp_streamable_transport_uses_client_timeout(self) -> None: + """Streamable HTTP uses FastMCP client timeout instead of deprecated transport arg.""" + from fastmcp.client.transports import StreamableHttpTransport + + from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin + from hud.settings import Settings + + class TestEnv(MCPConfigConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + env = TestEnv() + with patch("hud.settings.settings", spec=Settings) as mock_settings: + mock_settings.client_timeout = 300 + env.connect_mcp({"browser": {"url": "https://mcp.hud.ai/browser"}}) + + transport = env._connections["browser"]._transport + assert isinstance(transport, StreamableHttpTransport) + assert transport.sse_read_timeout is None + assert getattr(transport, "_hud_client_timeout", None) == 300 + + def test_connect_mcp_sse_transport_keeps_sse_timeout(self) -> None: + """SSE transports should continue to receive sse_read_timeout directly.""" + from fastmcp.client.transports import SSETransport + + from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin + from hud.settings import Settings + + class TestEnv(MCPConfigConnectorMixin): + def __init__(self) -> None: + self._connections: dict[str, Connector] = {} + + env = TestEnv() + with patch("hud.settings.settings", spec=Settings) as mock_settings: + mock_settings.client_timeout = 300 + env.connect_mcp({"browser": {"url": "https://mcp.hud.ai/browser", "transport": "sse"}}) + + transport = env._connections["browser"]._transport + assert isinstance(transport, SSETransport) + assert transport.sse_read_timeout is not None + assert transport.sse_read_timeout.total_seconds() == 300