@@ -12,7 +12,7 @@ def test_init_default_parameters():
1212 assert provider .timeout == DEFAULT_TIMEOUT
1313 assert provider ._known_agent_urls == []
1414 assert provider ._discovered_agents == {}
15- assert provider ._httpx_client is None
15+ assert provider ._httpx_client_args == { "timeout" : DEFAULT_TIMEOUT }
1616
1717
1818def test_init_custom_parameters ():
@@ -26,6 +26,30 @@ def test_init_custom_parameters():
2626 assert provider ._known_agent_urls == agent_urls
2727
2828
29+ def test_init_with_httpx_client_args ():
30+ """Test initialization with httpx client args."""
31+ client_args = {"headers" : {"Authorization" : "Bearer token" }, "timeout" : 60 }
32+ provider = A2AClientToolProvider (httpx_client_args = client_args )
33+
34+ assert provider ._httpx_client_args ["headers" ] == {"Authorization" : "Bearer token" }
35+ assert provider ._httpx_client_args ["timeout" ] == 60
36+
37+
38+ def test_init_without_httpx_client_args ():
39+ """Test initialization without httpx client args uses default timeout."""
40+ provider = A2AClientToolProvider (timeout = 45 )
41+
42+ assert provider ._httpx_client_args == {"timeout" : 45 }
43+
44+
45+ def test_init_httpx_client_args_overrides_timeout ():
46+ """Test that httpx_client_args timeout takes precedence."""
47+ client_args = {"timeout" : 120 }
48+ provider = A2AClientToolProvider (timeout = 45 , httpx_client_args = client_args )
49+
50+ assert provider ._httpx_client_args ["timeout" ] == 120
51+
52+
2953def test_tools_property ():
3054 """Test that tools property returns decorated methods."""
3155 provider = A2AClientToolProvider ()
@@ -38,32 +62,51 @@ def test_tools_property():
3862 assert "a2a_send_message" in tool_names
3963
4064
41- @pytest .mark .asyncio
42- async def test_ensure_httpx_client_creates_new_client ():
43- """Test _ensure_httpx_client creates new client when none exists."""
65+ def test_get_httpx_client_creates_new_client ():
66+ """Test _get_httpx_client creates new client with default args."""
4467 provider = A2AClientToolProvider (timeout = 45 )
4568
4669 with patch ("httpx.AsyncClient" ) as mock_client_class :
4770 mock_client = Mock ()
4871 mock_client_class .return_value = mock_client
4972
50- result = await provider ._ensure_httpx_client ()
73+ result = provider ._get_httpx_client ()
5174
5275 mock_client_class .assert_called_once_with (timeout = 45 )
5376 assert result == mock_client
54- assert provider ._httpx_client == mock_client
5577
5678
57- @pytest .mark .asyncio
58- async def test_ensure_httpx_client_reuses_existing ():
59- """Test _ensure_httpx_client reuses existing client."""
60- provider = A2AClientToolProvider ()
61- existing_client = Mock ()
62- provider ._httpx_client = existing_client
79+ def test_get_httpx_client_uses_custom_args ():
80+ """Test _get_httpx_client uses custom client args."""
81+ client_args = {"headers" : {"Authorization" : "Bearer token" }, "timeout" : 120 }
82+ provider = A2AClientToolProvider (httpx_client_args = client_args )
83+
84+ with patch ("httpx.AsyncClient" ) as mock_client_class :
85+ mock_client = Mock ()
86+ mock_client_class .return_value = mock_client
87+
88+ result = provider ._get_httpx_client ()
89+
90+ mock_client_class .assert_called_once_with (headers = {"Authorization" : "Bearer token" }, timeout = 120 )
91+ assert result == mock_client
92+
93+
94+ def test_get_httpx_client_creates_fresh_each_time ():
95+ """Test _get_httpx_client creates fresh client each time to avoid event loop issues."""
96+ provider = A2AClientToolProvider (timeout = 60 )
97+
98+ with patch ("httpx.AsyncClient" ) as mock_client_class :
99+ mock_client1 = Mock ()
100+ mock_client2 = Mock ()
101+ mock_client_class .side_effect = [mock_client1 , mock_client2 ]
63102
64- result = await provider ._ensure_httpx_client ()
103+ result1 = provider ._get_httpx_client ()
104+ result2 = provider ._get_httpx_client ()
65105
66- assert result == existing_client
106+ # Should create a new client each time
107+ assert mock_client_class .call_count == 2
108+ assert result1 == mock_client1
109+ assert result2 == mock_client2
67110
68111
69112@pytest .mark .asyncio
@@ -310,7 +353,7 @@ async def test_send_message_without_message_id():
310353@pytest .mark .asyncio
311354@patch ("strands_tools.a2a_client.uuid4" )
312355@patch .object (A2AClientToolProvider , "_discover_agent_card" )
313- @patch .object (A2AClientToolProvider , "_ensure_client_factory " )
356+ @patch .object (A2AClientToolProvider , "_get_client_factory " )
314357@patch .object (A2AClientToolProvider , "_ensure_discovered_known_agents" )
315358async def test_send_message_success (mock_ensure , mock_factory , mock_discover , mock_uuid ):
316359 """Test _send_message successful message sending."""
@@ -374,12 +417,12 @@ async def test_send_message_error(mock_ensure, mock_discover):
374417
375418
376419@pytest .mark .asyncio
377- @patch .object (A2AClientToolProvider , "_ensure_httpx_client " )
378- async def test_create_a2a_card_resolver (mock_ensure_client ):
420+ @patch .object (A2AClientToolProvider , "_get_httpx_client " )
421+ async def test_create_a2a_card_resolver (mock_get_client ):
379422 """Test _create_a2a_card_resolver creates resolver with correct parameters."""
380423 provider = A2AClientToolProvider ()
381424 mock_client = Mock ()
382- mock_ensure_client .return_value = mock_client
425+ mock_get_client .return_value = mock_client
383426
384427 with patch ("strands_tools.a2a_client.A2ACardResolver" ) as mock_resolver_class :
385428 mock_resolver = Mock ()
@@ -391,13 +434,12 @@ async def test_create_a2a_card_resolver(mock_ensure_client):
391434 assert result == mock_resolver
392435
393436
394- @pytest .mark .asyncio
395- @patch .object (A2AClientToolProvider , "_ensure_httpx_client" )
396- async def test_ensure_client_factory (mock_ensure_client ):
397- """Test _ensure_client_factory creates ClientFactory with correct parameters."""
437+ @patch .object (A2AClientToolProvider , "_get_httpx_client" )
438+ def test_get_client_factory (mock_get_client ):
439+ """Test _get_client_factory creates ClientFactory with correct parameters."""
398440 provider = A2AClientToolProvider ()
399441 mock_client = Mock ()
400- mock_ensure_client .return_value = mock_client
442+ mock_get_client .return_value = mock_client
401443
402444 with patch ("strands_tools.a2a_client.ClientFactory" ) as mock_factory_class :
403445 with patch ("strands_tools.a2a_client.ClientConfig" ) as mock_config_class :
@@ -406,18 +448,40 @@ async def test_ensure_client_factory(mock_ensure_client):
406448 mock_factory = Mock ()
407449 mock_factory_class .return_value = mock_factory
408450
409- result = await provider ._ensure_client_factory ()
451+ result = provider ._get_client_factory ()
410452
411453 mock_config_class .assert_called_once ()
412454 mock_factory_class .assert_called_once_with (mock_config )
413455 assert result == mock_factory
414- assert provider ._client_factory == mock_factory
456+
457+
458+ def test_get_client_factory_creates_fresh_each_time ():
459+ """Test _get_client_factory creates fresh factory each time to avoid event loop issues."""
460+ provider = A2AClientToolProvider ()
461+
462+ with patch .object (provider , "_get_httpx_client" ) as mock_get_client :
463+ with patch ("strands_tools.a2a_client.ClientFactory" ) as mock_factory_class :
464+ mock_client1 = Mock ()
465+ mock_client2 = Mock ()
466+ mock_get_client .side_effect = [mock_client1 , mock_client2 ]
467+
468+ mock_factory1 = Mock ()
469+ mock_factory2 = Mock ()
470+ mock_factory_class .side_effect = [mock_factory1 , mock_factory2 ]
471+
472+ result1 = provider ._get_client_factory ()
473+ result2 = provider ._get_client_factory ()
474+
475+ # Should create a new factory each time
476+ assert mock_factory_class .call_count == 2
477+ assert result1 == mock_factory1
478+ assert result2 == mock_factory2
415479
416480
417481@pytest .mark .asyncio
418482@patch ("strands_tools.a2a_client.uuid4" )
419483@patch .object (A2AClientToolProvider , "_discover_agent_card" )
420- @patch .object (A2AClientToolProvider , "_ensure_client_factory " )
484+ @patch .object (A2AClientToolProvider , "_get_client_factory " )
421485@patch .object (A2AClientToolProvider , "_ensure_discovered_known_agents" )
422486async def test_send_message_task_response (mock_ensure , mock_factory , mock_discover , mock_uuid ):
423487 """Test _send_message handling task response from ClientFactory."""
@@ -466,7 +530,7 @@ async def mock_send_message_iter(message):
466530@pytest .mark .asyncio
467531@patch ("strands_tools.a2a_client.uuid4" )
468532@patch .object (A2AClientToolProvider , "_discover_agent_card" )
469- @patch .object (A2AClientToolProvider , "_ensure_client_factory " )
533+ @patch .object (A2AClientToolProvider , "_get_client_factory " )
470534@patch .object (A2AClientToolProvider , "_ensure_discovered_known_agents" )
471535async def test_send_message_task_response_no_update (mock_ensure , mock_factory , mock_discover , mock_uuid ):
472536 """Test _send_message handling task response with no update event."""
0 commit comments