diff --git a/go/sdk/variants/backend.go b/go/sdk/variants/backend.go index bcc805b..2d3ffbd 100644 --- a/go/sdk/variants/backend.go +++ b/go/sdk/variants/backend.go @@ -6,6 +6,7 @@ package variants import ( "context" + "fmt" "github.com/modelcontextprotocol/go-sdk/mcp" ) @@ -46,11 +47,37 @@ type inMemoryBackend struct { server *mcp.Server } -// connect creates an in-memory transport pair, connects the inner server, -// and creates a proxy client to communicate with it. Notifications for -// progress and logging are forwarded directly to the front session with -// _meta variant ID injection. +// captureMCPMethodHandler captures and returns a reference to the inner +// server's handler chain. This is a workaround using AddReceivingMiddleware +// to gain a reference to mcp.Server.receivingMethodHandler_, since the SDK +// does not expose a public accessor for it. This can be replaced once the +// SDK exposes a public accessor for the receiving handler chain. +func captureMCPMethodHandler(server *mcp.Server) (mcp.MethodHandler, error) { + var handler mcp.MethodHandler + + // The middleware is identity (returns next unmodified), so the handler + // chain is unchanged, no extra hop introduced even if called multiple times. + server.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler { + handler = next + return next + }) + + if handler == nil { + return nil, fmt.Errorf("failed to capture backend MCP method handler") + } + return handler, nil +} + +// connect creates an in-memory transport pair and connects the inner server. +// Requests bypass the transport via serverHandler to preserve context values. +// The transport is kept alive only for notification forwarding (progress, +// logging) from the inner server to the front client. func (b *inMemoryBackend) connect(ctx context.Context, variant ServerVariant, frontSession *mcp.ServerSession) (*innerConnection, error) { + mcpMethodHandler, err := captureMCPMethodHandler(b.server) + if err != nil { + return nil, err + } + serverTransport, clientSideTransport := mcp.NewInMemoryTransports() serverSession, err := b.server.Connect(ctx, serverTransport, nil) @@ -87,8 +114,13 @@ func (b *inMemoryBackend) connect(ctx context.Context, variant ServerVariant, fr } return &innerConnection{ - clientSession: clientSession, + backendSession: &backendSession{ + variantID: variant.ID, + serverSession: serverSession, + mcpMethodHandler: mcpMethodHandler, + }, cleanupFn: func() { + clientSession.Close() serverSession.Close() }, }, nil diff --git a/go/sdk/variants/dispatch.go b/go/sdk/variants/dispatch.go index 89c5a28..ef1440e 100644 --- a/go/sdk/variants/dispatch.go +++ b/go/sdk/variants/dispatch.go @@ -86,10 +86,10 @@ func variantIDFromMeta(req mcp.Request) string { return id } -// getSession extracts the variant ID from request _meta and returns the -// corresponding client session for dispatching. Falls back to the +// getConnection extracts the variant ID from request _meta and returns the +// corresponding innerConnection for dispatching. Falls back to the // first-ranked variant when no variant is specified. -func (d *dispatcher) getSession(ctx context.Context, req mcp.Request) (string, *mcp.ClientSession, error) { +func (d *dispatcher) getConnection(ctx context.Context, req mcp.Request) (*innerConnection, error) { variantID := variantIDFromMeta(req) // If no variant specified, use first-ranked (default). @@ -102,17 +102,17 @@ func (d *dispatcher) getSession(ctx context.Context, req mcp.Request) (string, * if variantID == "" { ranked := d.server.RankedVariants(ctx, VariantHints{}) if len(ranked) == 0 { - return "", nil, errors.New("no variants available") + return nil, errors.New("no variants available") } variantID = ranked[0].ID } conn, ok := d.connections[variantID] if !ok { - return "", nil, d.createInvalidVariantError(ctx, variantID) + return nil, d.createInvalidVariantError(ctx, variantID) } - return variantID, conn.clientSession, nil + return conn, nil } // enrichError adds activeVariant to a jsonrpc.Error's Data field for @@ -163,24 +163,30 @@ func enrichError(err error, variantID string) error { // handleList handles list methods by forwarding to the appropriate variant. // Implements cursor scoping per SEP-2053: unwraps incoming cursors and wraps outgoing cursors. func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - variantID, session, err := d.getSession(ctx, req) + conn, err := d.getConnection(ctx, req) if err != nil { return nil, err } + backendSession := conn.backendSession + variantID := backendSession.variantID params := req.GetParams() + extra := req.GetExtra() switch method { case "tools/list": p, _ := params.(*mcp.ListToolsParams) - if p != nil && p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err + if p != nil { + injectVariantMeta(p, variantID) + if p.Cursor != "" { + innerCursor, err := unwrapCursor(p.Cursor, variantID) + if err != nil { + return nil, err + } + p.Cursor = innerCursor } - p.Cursor = innerCursor } - result, err := session.ListTools(ctx, p) + result, err := backendSession.ListTools(ctx, p, extra) if err != nil { return nil, enrichError(err, variantID) } @@ -191,14 +197,17 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ case "resources/list": p, _ := params.(*mcp.ListResourcesParams) - if p != nil && p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err + if p != nil { + injectVariantMeta(p, variantID) + if p.Cursor != "" { + innerCursor, err := unwrapCursor(p.Cursor, variantID) + if err != nil { + return nil, err + } + p.Cursor = innerCursor } - p.Cursor = innerCursor } - result, err := session.ListResources(ctx, p) + result, err := backendSession.ListResources(ctx, p, extra) if err != nil { return nil, enrichError(err, variantID) } @@ -209,14 +218,17 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ case "prompts/list": p, _ := params.(*mcp.ListPromptsParams) - if p != nil && p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err + if p != nil { + injectVariantMeta(p, variantID) + if p.Cursor != "" { + innerCursor, err := unwrapCursor(p.Cursor, variantID) + if err != nil { + return nil, err + } + p.Cursor = innerCursor } - p.Cursor = innerCursor } - result, err := session.ListPrompts(ctx, p) + result, err := backendSession.ListPrompts(ctx, p, extra) if err != nil { return nil, enrichError(err, variantID) } @@ -227,14 +239,17 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ case "resources/templates/list": p, _ := params.(*mcp.ListResourceTemplatesParams) - if p != nil && p.Cursor != "" { - innerCursor, err := unwrapCursor(p.Cursor, variantID) - if err != nil { - return nil, err + if p != nil { + injectVariantMeta(p, variantID) + if p.Cursor != "" { + innerCursor, err := unwrapCursor(p.Cursor, variantID) + if err != nil { + return nil, err + } + p.Cursor = innerCursor } - p.Cursor = innerCursor } - result, err := session.ListResourceTemplates(ctx, p) + result, err := backendSession.ListResourceTemplates(ctx, p, extra) if err != nil { return nil, enrichError(err, variantID) } @@ -254,18 +269,19 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ // handleCall handles call methods (tools/call, resources/read, prompts/get). func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) { - variantID, session, err := d.getSession(ctx, req) + conn, err := d.getConnection(ctx, req) if err != nil { return nil, err } + backendSession := conn.backendSession + variantID := backendSession.variantID params := req.GetParams() + extra := req.GetExtra() var result mcp.Result + switch method { case "tools/call": - // The server SDK unmarshals tools/call params as *CallToolParamsRaw - // (with json.RawMessage arguments). Convert to *CallToolParams for - // the client-side call. raw, _ := params.(*mcp.CallToolParamsRaw) if raw == nil { return nil, &jsonrpc.Error{ @@ -273,11 +289,8 @@ func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Requ Message: "missing or invalid tools/call params", } } - result, err = session.CallTool(ctx, &mcp.CallToolParams{ - Meta: raw.Meta, - Name: raw.Name, - Arguments: raw.Arguments, - }) + injectVariantMeta(raw, variantID) + result, err = backendSession.CallTool(ctx, raw, extra) case "resources/read": p, _ := params.(*mcp.ReadResourceParams) if p == nil { @@ -286,7 +299,8 @@ func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Requ Message: "missing or invalid resources/read params", } } - result, err = session.ReadResource(ctx, p) + injectVariantMeta(p, variantID) + result, err = backendSession.ReadResource(ctx, p, extra) case "prompts/get": p, _ := params.(*mcp.GetPromptParams) if p == nil { @@ -295,7 +309,8 @@ func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Requ Message: "missing or invalid prompts/get params", } } - result, err = session.GetPrompt(ctx, p) + injectVariantMeta(p, variantID) + result, err = backendSession.GetPrompt(ctx, p, extra) default: return nil, errors.New("unsupported call method: " + method) } @@ -312,11 +327,12 @@ func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Requ // handleSubscribe handles resources/subscribe. func (d *dispatcher) handleSubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) { - variantID, session, err := d.getSession(ctx, req) + conn, err := d.getConnection(ctx, req) if err != nil { return nil, err } + backendSession := conn.backendSession params, _ := req.GetParams().(*mcp.SubscribeParams) if params == nil { return nil, &jsonrpc.Error{ @@ -324,8 +340,9 @@ func (d *dispatcher) handleSubscribe(ctx context.Context, req mcp.Request) (mcp. Message: "missing or invalid resources/subscribe params", } } - if err := session.Subscribe(ctx, params); err != nil { - return nil, enrichError(err, variantID) + injectVariantMeta(params, backendSession.variantID) + if err := backendSession.Subscribe(ctx, params, req.GetExtra()); err != nil { + return nil, enrichError(err, backendSession.variantID) } return nil, nil } @@ -334,11 +351,12 @@ func (d *dispatcher) handleSubscribe(ctx context.Context, req mcp.Request) (mcp. // Per SEP-2053: "Servers MUST continue to accept resources/unsubscribe for // existing subscription ids even if the underlying resource is no longer available." func (d *dispatcher) handleUnsubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) { - variantID, session, err := d.getSession(ctx, req) + conn, err := d.getConnection(ctx, req) if err != nil { return nil, err } + backendSession := conn.backendSession params, _ := req.GetParams().(*mcp.UnsubscribeParams) if params == nil { return nil, &jsonrpc.Error{ @@ -346,8 +364,9 @@ func (d *dispatcher) handleUnsubscribe(ctx context.Context, req mcp.Request) (mc Message: "missing or invalid resources/unsubscribe params", } } - if err := session.Unsubscribe(ctx, params); err != nil { - return nil, enrichError(err, variantID) + injectVariantMeta(params, backendSession.variantID) + if err := backendSession.Unsubscribe(ctx, params, req.GetExtra()); err != nil { + return nil, enrichError(err, backendSession.variantID) } return nil, nil } @@ -358,11 +377,12 @@ func (d *dispatcher) handleUnsubscribe(ctx context.Context, req mcp.Request) (mc // handleCompletion handles completion/complete. func (d *dispatcher) handleCompletion(ctx context.Context, req mcp.Request) (mcp.Result, error) { - variantID, session, err := d.getSession(ctx, req) + conn, err := d.getConnection(ctx, req) if err != nil { return nil, err } + backendSession := conn.backendSession params, _ := req.GetParams().(*mcp.CompleteParams) if params == nil { return nil, &jsonrpc.Error{ @@ -370,9 +390,10 @@ func (d *dispatcher) handleCompletion(ctx context.Context, req mcp.Request) (mcp Message: "missing or invalid completion/complete params", } } - result, err := session.Complete(ctx, params) + injectVariantMeta(params, backendSession.variantID) + result, err := backendSession.Complete(ctx, params, req.GetExtra()) if err != nil { - return nil, enrichError(err, variantID) + return nil, enrichError(err, backendSession.variantID) } return result, nil } diff --git a/go/sdk/variants/session.go b/go/sdk/variants/session.go index c3e4d22..12b869f 100644 --- a/go/sdk/variants/session.go +++ b/go/sdk/variants/session.go @@ -6,6 +6,7 @@ package variants import ( "context" + "fmt" "sync" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -19,19 +20,134 @@ import ( // server. In stateful mode one is created per client session; in // stateless mode a single instance is shared across all requests. type innerConnection struct { - clientSession *mcp.ClientSession - cleanupFn func() + backendSession *backendSession + cleanupFn func() } -// close tears down the client session and invokes the backend-specific -// cleanup function. +// close invokes the backend-specific cleanup function which tears down +// both the client and server sessions. func (c *innerConnection) close() { - c.clientSession.Close() if c.cleanupFn != nil { c.cleanupFn() } } +// backendSession bypasses the in-memory transport and calls the inner +// server's handler chain directly, preserving the caller's +// context.Context values. It mirrors the mcp.ClientSession API so that +// dispatch code reads naturally. +type backendSession struct { + variantID string + serverSession *mcp.ServerSession + mcpMethodHandler mcp.MethodHandler +} + +func (s *backendSession) ListTools(ctx context.Context, p *mcp.ListToolsParams, extra *mcp.RequestExtra) (*mcp.ListToolsResult, error) { + result, err := s.mcpMethodHandler(ctx, "tools/list", &mcp.ListToolsRequest{Session: s.serverSession, Params: p, Extra: extra}) + if err != nil { + return nil, err + } + r, ok := result.(*mcp.ListToolsResult) + if !ok && result != nil { + return nil, fmt.Errorf("unexpected result type %T for tools/list", result) + } + return r, nil +} + +func (s *backendSession) ListResources(ctx context.Context, p *mcp.ListResourcesParams, extra *mcp.RequestExtra) (*mcp.ListResourcesResult, error) { + result, err := s.mcpMethodHandler(ctx, "resources/list", &mcp.ListResourcesRequest{Session: s.serverSession, Params: p, Extra: extra}) + if err != nil { + return nil, err + } + r, ok := result.(*mcp.ListResourcesResult) + if !ok && result != nil { + return nil, fmt.Errorf("unexpected result type %T for resources/list", result) + } + return r, nil +} + +func (s *backendSession) ListPrompts(ctx context.Context, p *mcp.ListPromptsParams, extra *mcp.RequestExtra) (*mcp.ListPromptsResult, error) { + result, err := s.mcpMethodHandler(ctx, "prompts/list", &mcp.ListPromptsRequest{Session: s.serverSession, Params: p, Extra: extra}) + if err != nil { + return nil, err + } + r, ok := result.(*mcp.ListPromptsResult) + if !ok && result != nil { + return nil, fmt.Errorf("unexpected result type %T for prompts/list", result) + } + return r, nil +} + +func (s *backendSession) ListResourceTemplates(ctx context.Context, p *mcp.ListResourceTemplatesParams, extra *mcp.RequestExtra) (*mcp.ListResourceTemplatesResult, error) { + result, err := s.mcpMethodHandler(ctx, "resources/templates/list", &mcp.ListResourceTemplatesRequest{Session: s.serverSession, Params: p, Extra: extra}) + if err != nil { + return nil, err + } + r, ok := result.(*mcp.ListResourceTemplatesResult) + if !ok && result != nil { + return nil, fmt.Errorf("unexpected result type %T for resources/templates/list", result) + } + return r, nil +} + +func (s *backendSession) CallTool(ctx context.Context, p *mcp.CallToolParamsRaw, extra *mcp.RequestExtra) (*mcp.CallToolResult, error) { + result, err := s.mcpMethodHandler(ctx, "tools/call", &mcp.CallToolRequest{Session: s.serverSession, Params: p, Extra: extra}) + if err != nil { + return nil, err + } + r, ok := result.(*mcp.CallToolResult) + if !ok && result != nil { + return nil, fmt.Errorf("unexpected result type %T for tools/call", result) + } + return r, nil +} + +func (s *backendSession) ReadResource(ctx context.Context, p *mcp.ReadResourceParams, extra *mcp.RequestExtra) (*mcp.ReadResourceResult, error) { + result, err := s.mcpMethodHandler(ctx, "resources/read", &mcp.ReadResourceRequest{Session: s.serverSession, Params: p, Extra: extra}) + if err != nil { + return nil, err + } + r, ok := result.(*mcp.ReadResourceResult) + if !ok && result != nil { + return nil, fmt.Errorf("unexpected result type %T for resources/read", result) + } + return r, nil +} + +func (s *backendSession) GetPrompt(ctx context.Context, p *mcp.GetPromptParams, extra *mcp.RequestExtra) (*mcp.GetPromptResult, error) { + result, err := s.mcpMethodHandler(ctx, "prompts/get", &mcp.GetPromptRequest{Session: s.serverSession, Params: p, Extra: extra}) + if err != nil { + return nil, err + } + r, ok := result.(*mcp.GetPromptResult) + if !ok && result != nil { + return nil, fmt.Errorf("unexpected result type %T for prompts/get", result) + } + return r, nil +} + +func (s *backendSession) Subscribe(ctx context.Context, p *mcp.SubscribeParams, extra *mcp.RequestExtra) error { + _, err := s.mcpMethodHandler(ctx, "resources/subscribe", &mcp.SubscribeRequest{Session: s.serverSession, Params: p, Extra: extra}) + return err +} + +func (s *backendSession) Unsubscribe(ctx context.Context, p *mcp.UnsubscribeParams, extra *mcp.RequestExtra) error { + _, err := s.mcpMethodHandler(ctx, "resources/unsubscribe", &mcp.UnsubscribeRequest{Session: s.serverSession, Params: p, Extra: extra}) + return err +} + +func (s *backendSession) Complete(ctx context.Context, p *mcp.CompleteParams, extra *mcp.RequestExtra) (*mcp.CompleteResult, error) { + result, err := s.mcpMethodHandler(ctx, "completion/complete", &mcp.CompleteRequest{Session: s.serverSession, Params: p, Extra: extra}) + if err != nil { + return nil, err + } + r, ok := result.(*mcp.CompleteResult) + if !ok && result != nil { + return nil, fmt.Errorf("unexpected result type %T for completion/complete", result) + } + return r, nil +} + // sessionState holds all per-session state for one front client. type sessionState struct { dispatcher *dispatcher