Skip to content

Commit 03b62d2

Browse files
tonyxwzjlowin
andauthored
feat: handle error from the initialize middleware (#2531)
* feat: handle error from the initialize middleware In some situation, the initialize middleware can check the status of the server and decide to raise an error. Example use case: in a FastMCPProxy, an initialization middleware overrides the on_initialize method and connect to the underlying proxied client. When client respond with error, I want to pass this error to the client. * docs update * test: use McpError assertions now that exception propagation is fixed - Update tests to catch McpError specifically instead of generic Exception - Remove commented-out code in low_level.py --------- Co-authored-by: Jeremiah Lowin <[email protected]>
1 parent 95e58e8 commit 03b62d2

File tree

3 files changed

+132
-4
lines changed

3 files changed

+132
-4
lines changed

docs/servers/middleware.mdx

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,32 @@ This hierarchy allows you to target your middleware logic with the right level o
108108
The `on_initialize` hook receives the client's initialization request but **returns `None`** rather than a result. The initialization response is handled internally by the MCP protocol and cannot be modified by middleware. This hook is useful for client detection, logging connections, or initializing session state, but not for modifying the initialization handshake itself.
109109
</Note>
110110

111+
**Example:**
112+
113+
```python
114+
from fastmcp.server.middleware import Middleware, MiddlewareContext
115+
from mcp import McpError
116+
from mcp.types import ErrorData
117+
118+
class InitializationMiddleware(Middleware):
119+
async def on_initialize(self, context: MiddlewareContext, call_next):
120+
# Check client capabilities before initialization
121+
client_info = context.message.params.get("clientInfo", {})
122+
client_name = client_info.get("name", "unknown")
123+
124+
# Reject unsupported clients BEFORE call_next
125+
if client_name == "unsupported-client":
126+
raise McpError(ErrorData(code=-32000, message="This client is not supported"))
127+
128+
# Log successful initialization
129+
await call_next(context)
130+
print(f"Client {client_name} initialized successfully")
131+
```
132+
133+
<Warning>
134+
If you raise `McpError` in `on_initialize` **after** calling `call_next()`, the error will only be logged and will not be sent to the client. The initialization response has already been sent at that point. Always raise `McpError` **before** `call_next()` if you want to reject the initialization.
135+
</Warning>
136+
111137
### MCP Session Availability in Middleware
112138

113139
<VersionBadge version="2.13.1" />
@@ -787,4 +813,4 @@ class CustomHeaderMiddleware(Middleware):
787813
return result
788814

789815
mcp.add_middleware(CustomHeaderMiddleware())
790-
```
816+
```

src/fastmcp/server/low_level.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import anyio
88
import mcp.types
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10+
from mcp import McpError
1011
from mcp.server.lowlevel.server import (
1112
LifespanResultT,
1213
NotificationOptions,
@@ -104,9 +105,23 @@ async def call_original_handler(
104105
fastmcp_context=fastmcp_ctx,
105106
)
106107

107-
return await self.fastmcp._apply_middleware(
108-
mw_context, call_original_handler
109-
)
108+
try:
109+
return await self.fastmcp._apply_middleware(
110+
mw_context, call_original_handler
111+
)
112+
except McpError as e:
113+
# McpError can be thrown from middleware in `on_initialize`
114+
# send the error to responder.
115+
if not responder._completed:
116+
with responder:
117+
await responder.respond(e.error)
118+
else:
119+
# Don't re-raise: prevents responding to initialize request twice
120+
logger.warning(
121+
"Received McpError but responder is already completed. "
122+
"Cannot send error response as response was already sent.",
123+
exc_info=e,
124+
)
110125

111126
# Fall through to default handling (task methods now handled via registered handlers)
112127
return await super()._received_request(responder)

tests/server/middleware/test_initialization_middleware.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from typing import Any
44

55
import mcp.types as mt
6+
import pytest
7+
from mcp import McpError
8+
from mcp.types import ErrorData
69

710
from fastmcp import Client, FastMCP
811
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
@@ -286,3 +289,87 @@ async def on_initialize(
286289
assert middleware.initialize_result.serverInfo.name == "TestServer"
287290
assert middleware.initialize_result.protocolVersion is not None
288291
assert middleware.initialize_result.capabilities is not None
292+
293+
294+
async def test_middleware_mcp_error_during_initialization():
295+
"""Test that McpError raised in middleware during initialization is sent to client."""
296+
server = FastMCP("TestServer")
297+
298+
class ErrorThrowingMiddleware(Middleware):
299+
async def on_initialize(
300+
self,
301+
context: MiddlewareContext[mt.InitializeRequest],
302+
call_next: CallNext[mt.InitializeRequest, None],
303+
) -> None:
304+
raise McpError(
305+
ErrorData(
306+
code=mt.INVALID_PARAMS, message="Invalid initialization parameters"
307+
)
308+
)
309+
310+
server.add_middleware(ErrorThrowingMiddleware())
311+
312+
with pytest.raises(McpError) as exc_info:
313+
async with Client(server):
314+
pass
315+
316+
assert exc_info.value.error.message == "Invalid initialization parameters"
317+
assert exc_info.value.error.code == mt.INVALID_PARAMS
318+
319+
320+
async def test_middleware_mcp_error_before_call_next():
321+
"""Test McpError raised before calling next middleware."""
322+
server = FastMCP("TestServer")
323+
324+
class EarlyErrorMiddleware(Middleware):
325+
async def on_initialize(
326+
self,
327+
context: MiddlewareContext[mt.InitializeRequest],
328+
call_next: CallNext[mt.InitializeRequest, None],
329+
) -> None:
330+
raise McpError(
331+
ErrorData(code=mt.INVALID_REQUEST, message="Request validation failed")
332+
)
333+
334+
server.add_middleware(EarlyErrorMiddleware())
335+
336+
with pytest.raises(McpError) as exc_info:
337+
async with Client(server):
338+
pass
339+
340+
assert exc_info.value.error.message == "Request validation failed"
341+
assert exc_info.value.error.code == mt.INVALID_REQUEST
342+
343+
344+
async def test_middleware_mcp_error_after_call_next():
345+
"""Test that McpError raised after call_next doesn't break the connection.
346+
347+
When an error is raised after call_next, the responder has already completed,
348+
so the error is caught but not sent to the responder (checked via _completed flag).
349+
"""
350+
server = FastMCP("TestServer")
351+
352+
class PostProcessingErrorMiddleware(Middleware):
353+
def __init__(self):
354+
super().__init__()
355+
self.error_raised = False
356+
357+
async def on_initialize(
358+
self,
359+
context: MiddlewareContext[mt.InitializeRequest],
360+
call_next: CallNext[mt.InitializeRequest, mt.InitializeResult | None],
361+
) -> mt.InitializeResult | None:
362+
await call_next(context)
363+
self.error_raised = True
364+
raise McpError(
365+
ErrorData(code=mt.INTERNAL_ERROR, message="Post-processing failed")
366+
)
367+
368+
middleware = PostProcessingErrorMiddleware()
369+
server.add_middleware(middleware)
370+
371+
# Error is logged but not re-raised to prevent duplicate response
372+
async with Client(server):
373+
pass
374+
375+
assert middleware.error_raised is True

0 commit comments

Comments
 (0)