Skip to content
Merged
42 changes: 37 additions & 5 deletions go/sdk/variants/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package variants

import (
"context"
"fmt"

"github.com/modelcontextprotocol/go-sdk/mcp"
)
Expand Down Expand Up @@ -46,11 +47,37 @@ type inMemoryBackend struct {
server *mcp.Server
}

// connect creates an in-memory transport pair, connects the inner server,
// and creates a proxy client to communicate with it. Notifications for
// progress and logging are forwarded directly to the front session with
// _meta variant ID injection.
// captureMCPMethodHandler captures and returns a reference to the inner
// server's handler chain. This is a workaround using AddReceivingMiddleware
// to gain a reference to mcp.Server.receivingMethodHandler_, since the SDK
// does not expose a public accessor for it. This can be replaced once the
// SDK exposes a public accessor for the receiving handler chain.
func captureMCPMethodHandler(server *mcp.Server) (mcp.MethodHandler, error) {
var handler mcp.MethodHandler

// The middleware is identity (returns next unmodified), so the handler
// chain is unchanged, no extra hop introduced even if called multiple times.
server.AddReceivingMiddleware(func(next mcp.MethodHandler) mcp.MethodHandler {
Comment thread
StevenRChen marked this conversation as resolved.
handler = next
return next
})

if handler == nil {
return nil, fmt.Errorf("failed to capture backend MCP method handler")
}
return handler, nil
}

// connect creates an in-memory transport pair and connects the inner server.
// Requests bypass the transport via serverHandler to preserve context values.
// The transport is kept alive only for notification forwarding (progress,
// logging) from the inner server to the front client.
func (b *inMemoryBackend) connect(ctx context.Context, variant ServerVariant, frontSession *mcp.ServerSession) (*innerConnection, error) {
mcpMethodHandler, err := captureMCPMethodHandler(b.server)
if err != nil {
return nil, err
}

serverTransport, clientSideTransport := mcp.NewInMemoryTransports()

serverSession, err := b.server.Connect(ctx, serverTransport, nil)
Expand Down Expand Up @@ -87,8 +114,13 @@ func (b *inMemoryBackend) connect(ctx context.Context, variant ServerVariant, fr
}

return &innerConnection{
clientSession: clientSession,
backendSession: &backendSession{
variantID: variant.ID,
serverSession: serverSession,
mcpMethodHandler: mcpMethodHandler,
},
cleanupFn: func() {
clientSession.Close()
serverSession.Close()
},
}, nil
Expand Down
123 changes: 72 additions & 51 deletions go/sdk/variants/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ func variantIDFromMeta(req mcp.Request) string {
return id
}

// getSession extracts the variant ID from request _meta and returns the
// corresponding client session for dispatching. Falls back to the
// getConnection extracts the variant ID from request _meta and returns the
// corresponding innerConnection for dispatching. Falls back to the
// first-ranked variant when no variant is specified.
func (d *dispatcher) getSession(ctx context.Context, req mcp.Request) (string, *mcp.ClientSession, error) {
func (d *dispatcher) getConnection(ctx context.Context, req mcp.Request) (*innerConnection, error) {
variantID := variantIDFromMeta(req)

// If no variant specified, use first-ranked (default).
Expand All @@ -102,17 +102,17 @@ func (d *dispatcher) getSession(ctx context.Context, req mcp.Request) (string, *
if variantID == "" {
ranked := d.server.RankedVariants(ctx, VariantHints{})
if len(ranked) == 0 {
return "", nil, errors.New("no variants available")
return nil, errors.New("no variants available")
}
variantID = ranked[0].ID
}

conn, ok := d.connections[variantID]
if !ok {
return "", nil, d.createInvalidVariantError(ctx, variantID)
return nil, d.createInvalidVariantError(ctx, variantID)
}

return variantID, conn.clientSession, nil
return conn, nil
}

// enrichError adds activeVariant to a jsonrpc.Error's Data field for
Expand Down Expand Up @@ -163,24 +163,30 @@ func enrichError(err error, variantID string) error {
// handleList handles list methods by forwarding to the appropriate variant.
// Implements cursor scoping per SEP-2053: unwraps incoming cursors and wraps outgoing cursors.
func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) {
variantID, session, err := d.getSession(ctx, req)
conn, err := d.getConnection(ctx, req)
if err != nil {
return nil, err
}

backendSession := conn.backendSession
variantID := backendSession.variantID
params := req.GetParams()
extra := req.GetExtra()

switch method {
case "tools/list":
p, _ := params.(*mcp.ListToolsParams)
if p != nil && p.Cursor != "" {
innerCursor, err := unwrapCursor(p.Cursor, variantID)
if err != nil {
return nil, err
if p != nil {
injectVariantMeta(p, variantID)
if p.Cursor != "" {
innerCursor, err := unwrapCursor(p.Cursor, variantID)
if err != nil {
return nil, err
}
p.Cursor = innerCursor
}
p.Cursor = innerCursor
}
result, err := session.ListTools(ctx, p)
result, err := backendSession.ListTools(ctx, p, extra)
if err != nil {
return nil, enrichError(err, variantID)
}
Expand All @@ -191,14 +197,17 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ

case "resources/list":
p, _ := params.(*mcp.ListResourcesParams)
if p != nil && p.Cursor != "" {
innerCursor, err := unwrapCursor(p.Cursor, variantID)
if err != nil {
return nil, err
if p != nil {
injectVariantMeta(p, variantID)
if p.Cursor != "" {
innerCursor, err := unwrapCursor(p.Cursor, variantID)
if err != nil {
return nil, err
}
p.Cursor = innerCursor
}
p.Cursor = innerCursor
}
result, err := session.ListResources(ctx, p)
result, err := backendSession.ListResources(ctx, p, extra)
if err != nil {
return nil, enrichError(err, variantID)
}
Expand All @@ -209,14 +218,17 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ

case "prompts/list":
p, _ := params.(*mcp.ListPromptsParams)
if p != nil && p.Cursor != "" {
innerCursor, err := unwrapCursor(p.Cursor, variantID)
if err != nil {
return nil, err
if p != nil {
injectVariantMeta(p, variantID)
if p.Cursor != "" {
innerCursor, err := unwrapCursor(p.Cursor, variantID)
if err != nil {
return nil, err
}
p.Cursor = innerCursor
}
p.Cursor = innerCursor
}
result, err := session.ListPrompts(ctx, p)
result, err := backendSession.ListPrompts(ctx, p, extra)
if err != nil {
return nil, enrichError(err, variantID)
}
Expand All @@ -227,14 +239,17 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ

case "resources/templates/list":
p, _ := params.(*mcp.ListResourceTemplatesParams)
if p != nil && p.Cursor != "" {
innerCursor, err := unwrapCursor(p.Cursor, variantID)
if err != nil {
return nil, err
if p != nil {
injectVariantMeta(p, variantID)
Comment thread
StevenRChen marked this conversation as resolved.
if p.Cursor != "" {
innerCursor, err := unwrapCursor(p.Cursor, variantID)
if err != nil {
return nil, err
}
p.Cursor = innerCursor
}
p.Cursor = innerCursor
}
result, err := session.ListResourceTemplates(ctx, p)
result, err := backendSession.ListResourceTemplates(ctx, p, extra)
if err != nil {
return nil, enrichError(err, variantID)
}
Expand All @@ -254,30 +269,28 @@ func (d *dispatcher) handleList(ctx context.Context, method string, req mcp.Requ

// handleCall handles call methods (tools/call, resources/read, prompts/get).
func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Request) (mcp.Result, error) {
variantID, session, err := d.getSession(ctx, req)
conn, err := d.getConnection(ctx, req)
if err != nil {
return nil, err
}

backendSession := conn.backendSession
variantID := backendSession.variantID
params := req.GetParams()
extra := req.GetExtra()
var result mcp.Result

switch method {
case "tools/call":
// The server SDK unmarshals tools/call params as *CallToolParamsRaw
// (with json.RawMessage arguments). Convert to *CallToolParams for
// the client-side call.
raw, _ := params.(*mcp.CallToolParamsRaw)
if raw == nil {
return nil, &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: "missing or invalid tools/call params",
}
}
result, err = session.CallTool(ctx, &mcp.CallToolParams{
Meta: raw.Meta,
Name: raw.Name,
Arguments: raw.Arguments,
})
injectVariantMeta(raw, variantID)
result, err = backendSession.CallTool(ctx, raw, extra)
case "resources/read":
p, _ := params.(*mcp.ReadResourceParams)
if p == nil {
Expand All @@ -286,7 +299,8 @@ func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Requ
Message: "missing or invalid resources/read params",
}
}
result, err = session.ReadResource(ctx, p)
injectVariantMeta(p, variantID)
result, err = backendSession.ReadResource(ctx, p, extra)
case "prompts/get":
p, _ := params.(*mcp.GetPromptParams)
if p == nil {
Expand All @@ -295,7 +309,8 @@ func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Requ
Message: "missing or invalid prompts/get params",
}
}
result, err = session.GetPrompt(ctx, p)
injectVariantMeta(p, variantID)
result, err = backendSession.GetPrompt(ctx, p, extra)
default:
return nil, errors.New("unsupported call method: " + method)
}
Expand All @@ -312,20 +327,22 @@ func (d *dispatcher) handleCall(ctx context.Context, method string, req mcp.Requ

// handleSubscribe handles resources/subscribe.
func (d *dispatcher) handleSubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) {
variantID, session, err := d.getSession(ctx, req)
conn, err := d.getConnection(ctx, req)
if err != nil {
return nil, err
}

backendSession := conn.backendSession
params, _ := req.GetParams().(*mcp.SubscribeParams)
if params == nil {
return nil, &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: "missing or invalid resources/subscribe params",
}
}
if err := session.Subscribe(ctx, params); err != nil {
return nil, enrichError(err, variantID)
injectVariantMeta(params, backendSession.variantID)
if err := backendSession.Subscribe(ctx, params, req.GetExtra()); err != nil {
return nil, enrichError(err, backendSession.variantID)
}
return nil, nil
}
Expand All @@ -334,20 +351,22 @@ func (d *dispatcher) handleSubscribe(ctx context.Context, req mcp.Request) (mcp.
// Per SEP-2053: "Servers MUST continue to accept resources/unsubscribe for
// existing subscription ids even if the underlying resource is no longer available."
func (d *dispatcher) handleUnsubscribe(ctx context.Context, req mcp.Request) (mcp.Result, error) {
variantID, session, err := d.getSession(ctx, req)
conn, err := d.getConnection(ctx, req)
if err != nil {
return nil, err
}

backendSession := conn.backendSession
params, _ := req.GetParams().(*mcp.UnsubscribeParams)
if params == nil {
return nil, &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: "missing or invalid resources/unsubscribe params",
}
}
if err := session.Unsubscribe(ctx, params); err != nil {
return nil, enrichError(err, variantID)
injectVariantMeta(params, backendSession.variantID)
if err := backendSession.Unsubscribe(ctx, params, req.GetExtra()); err != nil {
return nil, enrichError(err, backendSession.variantID)
}
return nil, nil
}
Expand All @@ -358,21 +377,23 @@ func (d *dispatcher) handleUnsubscribe(ctx context.Context, req mcp.Request) (mc

// handleCompletion handles completion/complete.
func (d *dispatcher) handleCompletion(ctx context.Context, req mcp.Request) (mcp.Result, error) {
variantID, session, err := d.getSession(ctx, req)
conn, err := d.getConnection(ctx, req)
if err != nil {
return nil, err
}

backendSession := conn.backendSession
params, _ := req.GetParams().(*mcp.CompleteParams)
if params == nil {
return nil, &jsonrpc.Error{
Code: jsonrpc.CodeInvalidParams,
Message: "missing or invalid completion/complete params",
}
}
result, err := session.Complete(ctx, params)
injectVariantMeta(params, backendSession.variantID)
result, err := backendSession.Complete(ctx, params, req.GetExtra())
if err != nil {
return nil, enrichError(err, variantID)
return nil, enrichError(err, backendSession.variantID)
}
return result, nil
}
Expand Down
Loading
Loading