1- const LOCALHOST : & str = "localhost" ;
2-
31use anyhow:: Result ;
4- use encoderfile_core:: {
5- AppState ,
6- dev_utils:: embedding_state,
7- transport:: mcp,
8- } ;
9- use rmcp:: {
10- ServiceExt ,
11- model:: { CallToolRequestParam , ClientCapabilities , ClientInfo , Implementation } ,
12- transport:: StreamableHttpClientTransport ,
13- } ;
2+ use encoderfile_core:: { AppState , transport:: mcp} ;
143use tokio:: net:: TcpListener ;
15- use tower_http:: trace:: DefaultOnResponse ;
164use tokio:: sync:: oneshot;
5+ use tower_http:: trace:: DefaultOnResponse ;
176
18- async fn run_mcp ( addr : String , state : AppState , receiver : oneshot:: Receiver < ( ) > , done_sender : oneshot :: Sender < ( ) > ) -> Result < ( ) > {
7+ async fn run_mcp ( addr : String , state : AppState , receiver : oneshot:: Receiver < ( ) > ) -> Result < ( ) > {
198 let model_type = state. model_type . clone ( ) ;
209 let router = mcp:: make_router ( state) . layer (
2110 tower_http:: trace:: TraceLayer :: new_for_http ( )
@@ -26,45 +15,42 @@ async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>,
2615 tracing:: info!( "Running {:?} MCP server on {}" , model_type, & addr) ;
2716 let listener = TcpListener :: bind ( addr) . await ?;
2817 axum:: serve ( listener, router)
29- . with_graceful_shutdown (
30- async {
31- receiver. await ;
32- tracing:: info!( "Received shutdown signal, shutting down" ) ;
33- done_sender. send ( ( ) ) ;
34- ( )
35- } )
36- . await ;
18+ . with_graceful_shutdown ( async {
19+ receiver. await . ok ( ) ;
20+ tracing:: info!( "Received shutdown signal, shutting down" ) ;
21+ } )
22+ . await
23+ . expect ( "Error while shutting down server" ) ;
3724 Ok ( ( ) )
3825}
3926
4027macro_rules! test_mcp_server_impl {
4128 ( $mod_name: ident, $state_func: ident, $req_type: ident, $resp_type: ident) => {
42- mod $mod_name {
29+ pub mod $mod_name {
4330 use encoderfile_core:: {
4431 common:: { $req_type, $resp_type} ,
4532 dev_utils:: $state_func,
4633 } ;
4734 use rmcp:: {
4835 ServiceExt ,
49- transport:: StreamableHttpClientTransport ,
5036 model:: { CallToolRequestParam , ClientCapabilities , ClientInfo , Implementation } ,
37+ transport:: StreamableHttpClientTransport ,
5138 } ;
5239 use tokio:: sync:: oneshot;
5340
5441 const LOCALHOST : & str = "localhost" ;
5542 const PORT : i32 = 9100 ;
5643
57- #[ tokio:: test]
58- #[ test_log:: test]
59- async fn $mod_name( ) {
44+ pub async fn $mod_name( ) {
6045 let addr = format!( "{}:{}" , LOCALHOST , PORT ) ;
6146 let dummy_state = $state_func( ) ;
6247 let ( sender, receiver) = oneshot:: channel( ) ;
63- let ( done_sender, done_receiver) = oneshot:: channel( ) ;
64- let mcp_server = tokio:: spawn( super :: run_mcp( addr, dummy_state, receiver, done_sender) ) ;
48+ let _mcp_server = tokio:: spawn( super :: run_mcp( addr, dummy_state, receiver) ) ;
6549 // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs
66- let client_transport =
67- StreamableHttpClientTransport :: from_uri( format!( "http://{}:{}/mcp" , LOCALHOST , PORT ) ) ;
50+ let client_transport = StreamableHttpClientTransport :: from_uri( format!(
51+ "http://{}:{}/mcp" ,
52+ LOCALHOST , PORT
53+ ) ) ;
6854 let client_info = ClientInfo {
6955 protocol_version: Default :: default ( ) ,
7056 capabilities: ClientCapabilities :: default ( ) ,
@@ -119,15 +105,13 @@ macro_rules! test_mcp_server_impl {
119105 )
120106 . expect( "failed to parse tool result" ) ;
121107 assert_eq!( embeddings_response. results. len( ) , 2 ) ;
122- client. cancel( ) . await ;
123- sender. send( ( ) ) ;
124- done_receiver. await ;
108+ client. cancel( ) . await . expect( "Error cancelling the agent" ) ;
109+ sender. send( ( ) ) . expect( "Error sending end of test signal" ) ;
125110 }
126111 }
127- }
112+ } ;
128113}
129114
130-
131115test_mcp_server_impl ! (
132116 test_mcp_embedding,
133117 embedding_state,
@@ -148,9 +132,23 @@ test_mcp_server_impl!(
148132 TokenClassificationRequest ,
149133 TokenClassificationResponse
150134) ;
135+
151136test_mcp_server_impl ! (
152137 test_mcp_sequence_classification,
153138 sequence_classification_state,
154139 SequenceClassificationRequest ,
155140 SequenceClassificationResponse
156- ) ;
141+ ) ;
142+
143+ #[ tokio:: test]
144+ #[ test_log:: test]
145+ async fn test_mcp_servers ( ) {
146+ self :: test_mcp_embedding:: test_mcp_embedding ( ) . await ;
147+ tracing:: info!( "Testing embedding" ) ;
148+ self :: test_mcp_sentence_embedding:: test_mcp_sentence_embedding ( ) . await ;
149+ tracing:: info!( "Testing sentence embedding" ) ;
150+ self :: test_mcp_token_classification:: test_mcp_token_classification ( ) . await ;
151+ tracing:: info!( "Testing token classification" ) ;
152+ self :: test_mcp_sequence_classification:: test_mcp_sequence_classification ( ) . await ;
153+ tracing:: info!( "Testing sequence classification" ) ;
154+ }
0 commit comments