Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hud/environment/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ async def connect(self) -> None:
"transport": self._transport,
"auth": self._auth,
}
client_timeout = getattr(self._transport, "_hud_client_timeout", None)
if client_timeout is not None:
client_kwargs["timeout"] = client_timeout
if self._elicitation_handler is not None:
client_kwargs["elicitation_handler"] = self._elicitation_handler

Expand Down
35 changes: 23 additions & 12 deletions hud/environment/connectors/mcp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

from hud.environment.connectors.base import BaseConnectorMixin

Expand Down Expand Up @@ -66,8 +66,7 @@ def connect_mcp(
if settings.client_timeout > 0
else min(request_timeout, settings.__class__.model_fields["client_timeout"].default)
)
server_config.setdefault("sse_read_timeout", timeout)
transport = _build_transport(server_config)
transport = _build_transport(server_config, timeout=timeout)

return self._add_connection(
name,
Expand Down Expand Up @@ -121,17 +120,29 @@ def connect_mcp_config(
return self


def _build_transport(server_config: dict[str, Any]) -> Any:
def _build_transport(server_config: dict[str, Any], *, timeout: float | None = None) -> Any:
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
from fastmcp.mcp_config import infer_transport_type_from_url

url = server_config["url"]
transport_type = server_config.get("transport") or infer_transport_type_from_url(url)
transport_cls = SSETransport if transport_type == "sse" else StreamableHttpTransport

return transport_cls(
url=url,
headers=server_config.get("headers"),
auth=server_config.get("auth"),
sse_read_timeout=server_config.get("sse_read_timeout"),
)
transport_timeout = timeout if timeout is not None else server_config.get("sse_read_timeout")
transport_kwargs = {
"url": url,
"headers": server_config.get("headers"),
"auth": server_config.get("auth"),
"httpx_client_factory": server_config.get("httpx_client_factory"),
}

if transport_type == "sse":
return SSETransport(
**transport_kwargs,
sse_read_timeout=transport_timeout,
)

transport = StreamableHttpTransport(**transport_kwargs)
if transport_timeout is not None:
# FastMCP 3.x wants streamable HTTP timeouts on the client/session,
# not on the transport constructor.
cast("Any", transport)._hud_client_timeout = transport_timeout
return transport
29 changes: 29 additions & 0 deletions hud/environment/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,35 @@ async def test_connect_creates_client(self) -> None:
# Client is now set
assert connector.client is mock_client

@pytest.mark.asyncio
async def test_connect_passes_transport_timeout_to_client(self) -> None:
"""connect() forwards transport timeout to FastMCP client session kwargs."""

class Transport:
_hud_client_timeout = 300

transport = Transport()
connector = Connector(
transport=transport,
config=ConnectionConfig(),
name="test",
connection_type=ConnectionType.REMOTE,
auth="test-token",
)

mock_client = MagicMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.is_connected = MagicMock(return_value=True)

with patch("fastmcp.client.Client", return_value=mock_client) as mock_cls:
await connector.connect()

mock_cls.assert_called_once_with(
transport=transport,
auth="test-token",
timeout=300,
)

@pytest.mark.asyncio
async def test_disconnect_clears_client(self) -> None:
"""disconnect() closes client and clears state."""
Expand Down
44 changes: 43 additions & 1 deletion hud/environment/tests/test_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,53 @@ def mount(self, server: Any, *, prefix: str | None = None) -> None:
env = TestEnv()
with patch("hud.settings.settings", spec=Settings) as mock_settings:
mock_settings.hud_mcp_url = "https://mcp.hud.ai"
mock_settings.client_timeout = 300 # Used in connect_mcp for sse_read_timeout
mock_settings.client_timeout = 300 # Used in connect_mcp transport timeout logic

env.connect_hub("browser")

# connect_hub creates a connection named "hud" (from mcp_config key)
assert "hud" in env._connections
# Verify hub config is stored for serialization
assert env._hub_config == {"name": "browser"}

def test_connect_mcp_streamable_transport_uses_client_timeout(self) -> None:
"""Streamable HTTP uses FastMCP client timeout instead of deprecated transport arg."""
from fastmcp.client.transports import StreamableHttpTransport

from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin
from hud.settings import Settings

class TestEnv(MCPConfigConnectorMixin):
def __init__(self) -> None:
self._connections: dict[str, Connector] = {}

env = TestEnv()
with patch("hud.settings.settings", spec=Settings) as mock_settings:
mock_settings.client_timeout = 300
env.connect_mcp({"browser": {"url": "https://mcp.hud.ai/browser"}})

transport = env._connections["browser"]._transport
assert isinstance(transport, StreamableHttpTransport)
assert transport.sse_read_timeout is None
assert getattr(transport, "_hud_client_timeout", None) == 300

def test_connect_mcp_sse_transport_keeps_sse_timeout(self) -> None:
"""SSE transports should continue to receive sse_read_timeout directly."""
from fastmcp.client.transports import SSETransport

from hud.environment.connectors.mcp_config import MCPConfigConnectorMixin
from hud.settings import Settings

class TestEnv(MCPConfigConnectorMixin):
def __init__(self) -> None:
self._connections: dict[str, Connector] = {}

env = TestEnv()
with patch("hud.settings.settings", spec=Settings) as mock_settings:
mock_settings.client_timeout = 300
env.connect_mcp({"browser": {"url": "https://mcp.hud.ai/browser", "transport": "sse"}})

transport = env._connections["browser"]._transport
assert isinstance(transport, SSETransport)
assert transport.sse_read_timeout is not None
assert transport.sse_read_timeout.total_seconds() == 300
Loading