Skip to content

Commit da7f0c0

Browse files
authored
feat(a2a_client): Add httpx params to be passed to allow auth with A2A (#298)
1 parent c3690be commit da7f0c0

File tree

2 files changed

+156
-48
lines changed

2 files changed

+156
-48
lines changed

src/strands_tools/a2a_client.py

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,23 @@
77
- Agent discovery through agent cards from multiple URLs
88
- Message sending to specific A2A agents
99
- Push notification support for real-time task completion alerts
10+
- Custom authentication support via httpx client arguments
11+
12+
Usage Examples:
13+
14+
Basic usage without authentication:
15+
>>> provider = A2AClientToolProvider(
16+
... known_agent_urls=["http://agent1.example.com", "http://agent2.example.com"]
17+
... )
18+
19+
With OAuth/Bearer token authentication:
20+
>>> provider = A2AClientToolProvider(
21+
... known_agent_urls=["http://secure-agent.example.com"],
22+
... httpx_client_args={
23+
... "headers": {"Authorization": "Bearer your-token-here"},
24+
... "timeout": 300
25+
... }
26+
... )
1027
"""
1128

1229
import asyncio
@@ -34,6 +51,7 @@ def __init__(
3451
timeout: int = DEFAULT_TIMEOUT,
3552
webhook_url: str | None = None,
3653
webhook_token: str | None = None,
54+
httpx_client_args: dict[str, Any] | None = None,
3755
):
3856
"""
3957
Initialize A2A client tool provider.
@@ -43,12 +61,26 @@ def __init__(
4361
timeout: Timeout for HTTP operations in seconds (defaults to 300)
4462
webhook_url: Optional webhook URL for push notifications
4563
webhook_token: Optional authentication token for webhook notifications
64+
httpx_client_args: Optional dictionary of arguments to pass to httpx.AsyncClient
65+
constructor. This allows custom auth, headers, proxies, etc.
66+
Example: {"headers": {"Authorization": "Bearer token"}, "timeout": 60}
67+
68+
Note: To avoid event loop issues in multi-turn conversations,
69+
a fresh client is created for each async operation using these args.
70+
This prevents "Event loop is closed" errors when the provider is used
71+
across multiple asyncio.run() calls.
4672
"""
4773
self.timeout = timeout
4874
self._known_agent_urls: list[str] = known_agent_urls or []
4975
self._discovered_agents: dict[str, AgentCard] = {}
50-
self._httpx_client: httpx.AsyncClient | None = None
51-
self._client_factory: ClientFactory | None = None
76+
77+
# Store client args instead of client instance to avoid event loop issues
78+
self._httpx_client_args: dict[str, Any] = httpx_client_args or {}
79+
80+
# Set default timeout if not provided in client args
81+
if "timeout" not in self._httpx_client_args:
82+
self._httpx_client_args["timeout"] = self.timeout
83+
5284
self._initial_discovery_done: bool = False
5385

5486
# Push notification configuration
@@ -76,27 +108,39 @@ def tools(self) -> list[AgentTool]:
76108

77109
return tools
78110

79-
async def _ensure_httpx_client(self) -> httpx.AsyncClient:
80-
"""Ensure the shared HTTP client is initialized."""
81-
if self._httpx_client is None:
82-
self._httpx_client = httpx.AsyncClient(timeout=self.timeout)
83-
return self._httpx_client
84-
85-
async def _ensure_client_factory(self) -> ClientFactory:
86-
"""Ensure the ClientFactory is initialized."""
87-
if self._client_factory is None:
88-
httpx_client = await self._ensure_httpx_client()
89-
config = ClientConfig(
90-
httpx_client=httpx_client,
91-
streaming=False, # Use non-streaming mode for simpler response handling
92-
push_notification_configs=[self._push_config] if self._push_config else [],
93-
)
94-
self._client_factory = ClientFactory(config)
95-
return self._client_factory
111+
def _get_httpx_client(self) -> httpx.AsyncClient:
112+
"""
113+
Get a fresh httpx client for the current operation.
114+
115+
Creates a new client using the stored client args. This prevents event loop
116+
issues when the provider is used across multiple asyncio.run() calls.
117+
118+
Similar to the Gemini model provider fix in strands-agents/sdk-python#932,
119+
we create fresh clients per operation rather than reusing a single instance.
120+
"""
121+
return httpx.AsyncClient(**self._httpx_client_args)
122+
123+
def _get_client_factory(self) -> ClientFactory:
124+
"""
125+
Get a ClientFactory for the current operation.
126+
127+
Creates a fresh ClientFactory with a fresh httpx client for each call to avoid
128+
event loop issues when the provider is used across multiple asyncio.run() calls.
129+
130+
Note: We don't cache the ClientFactory because it contains the httpx client,
131+
which would cause "Event loop is closed" errors in multi-turn conversations.
132+
"""
133+
httpx_client = self._get_httpx_client()
134+
config = ClientConfig(
135+
httpx_client=httpx_client,
136+
streaming=False, # Use non-streaming mode for simpler response handling
137+
push_notification_configs=[self._push_config] if self._push_config else [],
138+
)
139+
return ClientFactory(config)
96140

97141
async def _create_a2a_card_resolver(self, url: str) -> A2ACardResolver:
98142
"""Create a new A2A card resolver for the given URL."""
99-
httpx_client = await self._ensure_httpx_client()
143+
httpx_client = self._get_httpx_client()
100144
logger.info(f"A2ACardResolver created for {url}")
101145
return A2ACardResolver(httpx_client=httpx_client, base_url=url)
102146

@@ -243,7 +287,7 @@ async def _send_message(
243287

244288
# Get the agent card and create client using factory
245289
agent_card = await self._discover_agent_card(target_agent_url)
246-
client_factory = await self._ensure_client_factory()
290+
client_factory = self._get_client_factory()
247291
client = client_factory.create(agent_card)
248292

249293
if message_id is None:

tests/test_a2a_client.py

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_init_default_parameters():
1212
assert provider.timeout == DEFAULT_TIMEOUT
1313
assert provider._known_agent_urls == []
1414
assert provider._discovered_agents == {}
15-
assert provider._httpx_client is None
15+
assert provider._httpx_client_args == {"timeout": DEFAULT_TIMEOUT}
1616

1717

1818
def test_init_custom_parameters():
@@ -26,6 +26,30 @@ def test_init_custom_parameters():
2626
assert provider._known_agent_urls == agent_urls
2727

2828

29+
def test_init_with_httpx_client_args():
30+
"""Test initialization with httpx client args."""
31+
client_args = {"headers": {"Authorization": "Bearer token"}, "timeout": 60}
32+
provider = A2AClientToolProvider(httpx_client_args=client_args)
33+
34+
assert provider._httpx_client_args["headers"] == {"Authorization": "Bearer token"}
35+
assert provider._httpx_client_args["timeout"] == 60
36+
37+
38+
def test_init_without_httpx_client_args():
39+
"""Test initialization without httpx client args uses default timeout."""
40+
provider = A2AClientToolProvider(timeout=45)
41+
42+
assert provider._httpx_client_args == {"timeout": 45}
43+
44+
45+
def test_init_httpx_client_args_overrides_timeout():
46+
"""Test that httpx_client_args timeout takes precedence."""
47+
client_args = {"timeout": 120}
48+
provider = A2AClientToolProvider(timeout=45, httpx_client_args=client_args)
49+
50+
assert provider._httpx_client_args["timeout"] == 120
51+
52+
2953
def test_tools_property():
3054
"""Test that tools property returns decorated methods."""
3155
provider = A2AClientToolProvider()
@@ -38,32 +62,51 @@ def test_tools_property():
3862
assert "a2a_send_message" in tool_names
3963

4064

41-
@pytest.mark.asyncio
42-
async def test_ensure_httpx_client_creates_new_client():
43-
"""Test _ensure_httpx_client creates new client when none exists."""
65+
def test_get_httpx_client_creates_new_client():
66+
"""Test _get_httpx_client creates new client with default args."""
4467
provider = A2AClientToolProvider(timeout=45)
4568

4669
with patch("httpx.AsyncClient") as mock_client_class:
4770
mock_client = Mock()
4871
mock_client_class.return_value = mock_client
4972

50-
result = await provider._ensure_httpx_client()
73+
result = provider._get_httpx_client()
5174

5275
mock_client_class.assert_called_once_with(timeout=45)
5376
assert result == mock_client
54-
assert provider._httpx_client == mock_client
5577

5678

57-
@pytest.mark.asyncio
58-
async def test_ensure_httpx_client_reuses_existing():
59-
"""Test _ensure_httpx_client reuses existing client."""
60-
provider = A2AClientToolProvider()
61-
existing_client = Mock()
62-
provider._httpx_client = existing_client
79+
def test_get_httpx_client_uses_custom_args():
80+
"""Test _get_httpx_client uses custom client args."""
81+
client_args = {"headers": {"Authorization": "Bearer token"}, "timeout": 120}
82+
provider = A2AClientToolProvider(httpx_client_args=client_args)
83+
84+
with patch("httpx.AsyncClient") as mock_client_class:
85+
mock_client = Mock()
86+
mock_client_class.return_value = mock_client
87+
88+
result = provider._get_httpx_client()
89+
90+
mock_client_class.assert_called_once_with(headers={"Authorization": "Bearer token"}, timeout=120)
91+
assert result == mock_client
92+
93+
94+
def test_get_httpx_client_creates_fresh_each_time():
95+
"""Test _get_httpx_client creates fresh client each time to avoid event loop issues."""
96+
provider = A2AClientToolProvider(timeout=60)
97+
98+
with patch("httpx.AsyncClient") as mock_client_class:
99+
mock_client1 = Mock()
100+
mock_client2 = Mock()
101+
mock_client_class.side_effect = [mock_client1, mock_client2]
63102

64-
result = await provider._ensure_httpx_client()
103+
result1 = provider._get_httpx_client()
104+
result2 = provider._get_httpx_client()
65105

66-
assert result == existing_client
106+
# Should create a new client each time
107+
assert mock_client_class.call_count == 2
108+
assert result1 == mock_client1
109+
assert result2 == mock_client2
67110

68111

69112
@pytest.mark.asyncio
@@ -310,7 +353,7 @@ async def test_send_message_without_message_id():
310353
@pytest.mark.asyncio
311354
@patch("strands_tools.a2a_client.uuid4")
312355
@patch.object(A2AClientToolProvider, "_discover_agent_card")
313-
@patch.object(A2AClientToolProvider, "_ensure_client_factory")
356+
@patch.object(A2AClientToolProvider, "_get_client_factory")
314357
@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents")
315358
async def test_send_message_success(mock_ensure, mock_factory, mock_discover, mock_uuid):
316359
"""Test _send_message successful message sending."""
@@ -374,12 +417,12 @@ async def test_send_message_error(mock_ensure, mock_discover):
374417

375418

376419
@pytest.mark.asyncio
377-
@patch.object(A2AClientToolProvider, "_ensure_httpx_client")
378-
async def test_create_a2a_card_resolver(mock_ensure_client):
420+
@patch.object(A2AClientToolProvider, "_get_httpx_client")
421+
async def test_create_a2a_card_resolver(mock_get_client):
379422
"""Test _create_a2a_card_resolver creates resolver with correct parameters."""
380423
provider = A2AClientToolProvider()
381424
mock_client = Mock()
382-
mock_ensure_client.return_value = mock_client
425+
mock_get_client.return_value = mock_client
383426

384427
with patch("strands_tools.a2a_client.A2ACardResolver") as mock_resolver_class:
385428
mock_resolver = Mock()
@@ -391,13 +434,12 @@ async def test_create_a2a_card_resolver(mock_ensure_client):
391434
assert result == mock_resolver
392435

393436

394-
@pytest.mark.asyncio
395-
@patch.object(A2AClientToolProvider, "_ensure_httpx_client")
396-
async def test_ensure_client_factory(mock_ensure_client):
397-
"""Test _ensure_client_factory creates ClientFactory with correct parameters."""
437+
@patch.object(A2AClientToolProvider, "_get_httpx_client")
438+
def test_get_client_factory(mock_get_client):
439+
"""Test _get_client_factory creates ClientFactory with correct parameters."""
398440
provider = A2AClientToolProvider()
399441
mock_client = Mock()
400-
mock_ensure_client.return_value = mock_client
442+
mock_get_client.return_value = mock_client
401443

402444
with patch("strands_tools.a2a_client.ClientFactory") as mock_factory_class:
403445
with patch("strands_tools.a2a_client.ClientConfig") as mock_config_class:
@@ -406,18 +448,40 @@ async def test_ensure_client_factory(mock_ensure_client):
406448
mock_factory = Mock()
407449
mock_factory_class.return_value = mock_factory
408450

409-
result = await provider._ensure_client_factory()
451+
result = provider._get_client_factory()
410452

411453
mock_config_class.assert_called_once()
412454
mock_factory_class.assert_called_once_with(mock_config)
413455
assert result == mock_factory
414-
assert provider._client_factory == mock_factory
456+
457+
458+
def test_get_client_factory_creates_fresh_each_time():
459+
"""Test _get_client_factory creates fresh factory each time to avoid event loop issues."""
460+
provider = A2AClientToolProvider()
461+
462+
with patch.object(provider, "_get_httpx_client") as mock_get_client:
463+
with patch("strands_tools.a2a_client.ClientFactory") as mock_factory_class:
464+
mock_client1 = Mock()
465+
mock_client2 = Mock()
466+
mock_get_client.side_effect = [mock_client1, mock_client2]
467+
468+
mock_factory1 = Mock()
469+
mock_factory2 = Mock()
470+
mock_factory_class.side_effect = [mock_factory1, mock_factory2]
471+
472+
result1 = provider._get_client_factory()
473+
result2 = provider._get_client_factory()
474+
475+
# Should create a new factory each time
476+
assert mock_factory_class.call_count == 2
477+
assert result1 == mock_factory1
478+
assert result2 == mock_factory2
415479

416480

417481
@pytest.mark.asyncio
418482
@patch("strands_tools.a2a_client.uuid4")
419483
@patch.object(A2AClientToolProvider, "_discover_agent_card")
420-
@patch.object(A2AClientToolProvider, "_ensure_client_factory")
484+
@patch.object(A2AClientToolProvider, "_get_client_factory")
421485
@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents")
422486
async def test_send_message_task_response(mock_ensure, mock_factory, mock_discover, mock_uuid):
423487
"""Test _send_message handling task response from ClientFactory."""
@@ -466,7 +530,7 @@ async def mock_send_message_iter(message):
466530
@pytest.mark.asyncio
467531
@patch("strands_tools.a2a_client.uuid4")
468532
@patch.object(A2AClientToolProvider, "_discover_agent_card")
469-
@patch.object(A2AClientToolProvider, "_ensure_client_factory")
533+
@patch.object(A2AClientToolProvider, "_get_client_factory")
470534
@patch.object(A2AClientToolProvider, "_ensure_discovered_known_agents")
471535
async def test_send_message_task_response_no_update(mock_ensure, mock_factory, mock_discover, mock_uuid):
472536
"""Test _send_message handling task response with no update event."""

0 commit comments

Comments
 (0)