Skip to content

Commit 1bf3882

Browse files
committed
Serialize mcp tests
1 parent 050919d commit 1bf3882

File tree

1 file changed

+35
-37
lines changed

1 file changed

+35
-37
lines changed

encoderfile-core/tests/test_mcp.rs

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,10 @@
1-
const LOCALHOST: &str = "localhost";
2-
31
use anyhow::Result;
4-
use encoderfile_core::{
5-
AppState,
6-
dev_utils::embedding_state,
7-
transport::mcp,
8-
};
9-
use rmcp::{
10-
ServiceExt,
11-
model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation},
12-
transport::StreamableHttpClientTransport,
13-
};
2+
use encoderfile_core::{AppState, transport::mcp};
143
use tokio::net::TcpListener;
15-
use tower_http::trace::DefaultOnResponse;
164
use tokio::sync::oneshot;
5+
use tower_http::trace::DefaultOnResponse;
176

18-
async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>, done_sender: oneshot::Sender<()>) -> Result<()> {
7+
async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>) -> Result<()> {
198
let model_type = state.model_type.clone();
209
let router = mcp::make_router(state).layer(
2110
tower_http::trace::TraceLayer::new_for_http()
@@ -26,45 +15,42 @@ async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>,
2615
tracing::info!("Running {:?} MCP server on {}", model_type, &addr);
2716
let listener = TcpListener::bind(addr).await?;
2817
axum::serve(listener, router)
29-
.with_graceful_shutdown(
30-
async {
31-
receiver.await;
32-
tracing::info!("Received shutdown signal, shutting down");
33-
done_sender.send(());
34-
()
35-
})
36-
.await;
18+
.with_graceful_shutdown(async {
19+
receiver.await.ok();
20+
tracing::info!("Received shutdown signal, shutting down");
21+
})
22+
.await
23+
.expect("Error while shutting down server");
3724
Ok(())
3825
}
3926

4027
macro_rules! test_mcp_server_impl {
4128
($mod_name:ident, $state_func:ident, $req_type:ident, $resp_type:ident) => {
42-
mod $mod_name {
29+
pub mod $mod_name {
4330
use encoderfile_core::{
4431
common::{$req_type, $resp_type},
4532
dev_utils::$state_func,
4633
};
4734
use rmcp::{
4835
ServiceExt,
49-
transport::StreamableHttpClientTransport,
5036
model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation},
37+
transport::StreamableHttpClientTransport,
5138
};
5239
use tokio::sync::oneshot;
5340

5441
const LOCALHOST: &str = "localhost";
5542
const PORT: i32 = 9100;
5643

57-
#[tokio::test]
58-
#[test_log::test]
59-
async fn $mod_name() {
44+
pub async fn $mod_name() {
6045
let addr = format!("{}:{}", LOCALHOST, PORT);
6146
let dummy_state = $state_func();
6247
let (sender, receiver) = oneshot::channel();
63-
let (done_sender, done_receiver) = oneshot::channel();
64-
let mcp_server = tokio::spawn(super::run_mcp(addr, dummy_state, receiver, done_sender));
48+
let _mcp_server = tokio::spawn(super::run_mcp(addr, dummy_state, receiver));
6549
// Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs
66-
let client_transport =
67-
StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, PORT));
50+
let client_transport = StreamableHttpClientTransport::from_uri(format!(
51+
"http://{}:{}/mcp",
52+
LOCALHOST, PORT
53+
));
6854
let client_info = ClientInfo {
6955
protocol_version: Default::default(),
7056
capabilities: ClientCapabilities::default(),
@@ -119,15 +105,13 @@ macro_rules! test_mcp_server_impl {
119105
)
120106
.expect("failed to parse tool result");
121107
assert_eq!(embeddings_response.results.len(), 2);
122-
client.cancel().await;
123-
sender.send(());
124-
done_receiver.await;
108+
client.cancel().await.expect("Error cancelling the agent");
109+
sender.send(()).expect("Error sending end of test signal");
125110
}
126111
}
127-
}
112+
};
128113
}
129114

130-
131115
test_mcp_server_impl!(
132116
test_mcp_embedding,
133117
embedding_state,
@@ -148,9 +132,23 @@ test_mcp_server_impl!(
148132
TokenClassificationRequest,
149133
TokenClassificationResponse
150134
);
135+
151136
test_mcp_server_impl!(
152137
test_mcp_sequence_classification,
153138
sequence_classification_state,
154139
SequenceClassificationRequest,
155140
SequenceClassificationResponse
156-
);
141+
);
142+
143+
#[tokio::test]
144+
#[test_log::test]
145+
async fn test_mcp_servers() {
146+
self::test_mcp_embedding::test_mcp_embedding().await;
147+
tracing::info!("Testing embedding");
148+
self::test_mcp_sentence_embedding::test_mcp_sentence_embedding().await;
149+
tracing::info!("Testing sentence embedding");
150+
self::test_mcp_token_classification::test_mcp_token_classification().await;
151+
tracing::info!("Testing token classification");
152+
self::test_mcp_sequence_classification::test_mcp_sequence_classification().await;
153+
tracing::info!("Testing sequence classification");
154+
}

0 commit comments

Comments
 (0)