diff --git a/Makefile b/Makefile index df014c6..4b2e9e6 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ GOFLAGS ?= PKGS ?= ./... COVERPROFILE ?= coverage.out SEMGREP_IMAGE ?= semgrep/semgrep@sha256:326e5f41cc972bb423b764a14febbb62bbad29ee1c01820805d077dd868fea48 -FUZZTIME ?= 5s +FUZZTIME ?= 15s # Build matrix mirrors ci.yml — keep in sync. BUILD_MATRIX = \ @@ -155,15 +155,22 @@ tidy-check: ## Fail if go.mod / go.sum drift from `go mod tidy`. exit 1; \ fi -fuzz-quick: ## Run each Fuzz* target for FUZZTIME (default 5s). - @found=0; \ +fuzz-quick: ## Run each Fuzz* target for FUZZTIME (default 15s) in a fresh corpus. + @cachedir=$$(mktemp -d -t plexara-fuzz-cache.XXXXXX); \ + trap 'rm -rf "$$cachedir"' EXIT; \ + found=0; \ while IFS= read -r pkg; do \ [ -z "$$pkg" ] && continue; \ while IFS= read -r fuzz; do \ [ -z "$$fuzz" ] && continue; \ found=1; \ printf ' fuzz %s/%s for %s ... ' "$$pkg" "$$fuzz" "$(FUZZTIME)"; \ - out=$$($(GO) test "$$pkg" -run='^$$' -fuzz="^$$fuzz\$$" -fuzztime=$(FUZZTIME) 2>&1); \ + out=$$($(GO) test "$$pkg" \ + -run='^$$' -fuzz="^$$fuzz\$$" \ + -fuzztime=$(FUZZTIME) \ + -test.fuzzcachedir="$$cachedir" \ + -timeout=120s \ + 2>&1); \ ec=$$?; \ if [ $$ec -ne 0 ]; then echo "FAIL"; echo "$$out"; exit $$ec; fi; \ echo "ok"; \ diff --git a/core/mcp/catalog.go b/core/mcp/catalog.go new file mode 100644 index 0000000..f8d2f7a --- /dev/null +++ b/core/mcp/catalog.go @@ -0,0 +1,93 @@ +package mcp + +import ( + "sort" + "strings" +) + +// Catalog aggregates the tool catalog across every connected server. +// +// Tools is the flat list, ordered first by server name then by bare +// tool name. ToolsByServer is the same set indexed by server name. +type Catalog struct { + Tools []Tool + ToolsByServer map[string][]Tool +} + +func (c *Catalog) copy() *Catalog { + if c == nil { + return &Catalog{} + } + out := &Catalog{ + Tools: make([]Tool, len(c.Tools)), + ToolsByServer: make(map[string][]Tool, len(c.ToolsByServer)), + } + copy(out.Tools, c.Tools) + for k, v := range c.ToolsByServer { + dup := make([]Tool, len(v)) + copy(dup, v) + out.ToolsByServer[k] = dup + } + return out +} + +// Toolkit is a logical grouping of related tools — typically tools +// from the same MCP "namespace" (e.g. all `datahub_*` tools on a +// Plexara server). +type Toolkit struct { + Name string + Tools []Tool +} + +// ToolkitClassifier maps a tool to a toolkit name. Returning the +// empty string places the tool in the "default" toolkit. +type ToolkitClassifier func(t Tool) string + +// PrefixClassifier classifies tools by the substring before the first +// underscore in the bare tool name. For Plexara's catalog, this groups +// `datahub_*`, `trino_*`, `s3_*`, and `memory_*` cleanly. +// +// Names with no underscore, an empty bare name, or a leading underscore +// (`_foo`) all fall into the "misc" toolkit. +func PrefixClassifier(t Tool) string { + if t.BareName == "" { + return "misc" + } + idx := strings.Index(t.BareName, "_") + if idx <= 0 { + return "misc" + } + return t.BareName[:idx] +} + +// Toolkits groups the catalog using [PrefixClassifier]. +func (c *Catalog) Toolkits() []Toolkit { + return c.ToolkitsBy(PrefixClassifier) +} + +// ToolkitsBy groups the catalog using a custom classifier. Toolkits +// are returned in stable alphabetical order by name; tools within a +// toolkit are returned in the order they appear in [Catalog.Tools]. +func (c *Catalog) ToolkitsBy(fn ToolkitClassifier) []Toolkit { + if c == nil || len(c.Tools) == 0 { + return nil + } + groups := make(map[string][]Tool) + for _, t := range c.Tools { + name := fn(t) + if name == "" { + name = "default" + } + groups[name] = append(groups[name], t) + } + names := make([]string, 0, len(groups)) + for name := range groups { + names = append(names, name) + } + sort.Strings(names) + out := make([]Toolkit, 0, len(names)) + for _, name := range names { + out = append(out, Toolkit{Name: name, Tools: groups[name]}) + } + return out +} diff --git a/core/mcp/catalog_test.go b/core/mcp/catalog_test.go new file mode 100644 index 0000000..00c9be2 --- /dev/null +++ b/core/mcp/catalog_test.go @@ -0,0 +1,92 @@ +package mcp_test + +import ( + "testing" + + "github.com/plexara/plexara-agents/core/mcp" +) + +func TestCatalog_PrefixToolkits(t *testing.T) { + t.Parallel() + + cat := &mcp.Catalog{ + Tools: []mcp.Tool{ + {Name: "p__datahub_search", Server: "p", BareName: "datahub_search"}, + {Name: "p__datahub_get_schema", Server: "p", BareName: "datahub_get_schema"}, + {Name: "p__trino_query", Server: "p", BareName: "trino_query"}, + {Name: "p__memory_recall", Server: "p", BareName: "memory_recall"}, + {Name: "fs__read", Server: "fs", BareName: "read"}, + }, + } + + got := cat.Toolkits() + wantNames := []string{"datahub", "memory", "misc", "trino"} + if len(got) != len(wantNames) { + t.Fatalf("len(toolkits) = %d; want %d (%v)", len(got), len(wantNames), names(got)) + } + for i, want := range wantNames { + if got[i].Name != want { + t.Errorf("toolkit[%d].Name = %q; want %q", i, got[i].Name, want) + } + } + if got[0].Name == "datahub" && len(got[0].Tools) != 2 { + t.Errorf("datahub toolkit has %d tools; want 2", len(got[0].Tools)) + } +} + +func TestCatalog_CustomClassifier(t *testing.T) { + t.Parallel() + + cat := &mcp.Catalog{ + Tools: []mcp.Tool{ + {Server: "a", BareName: "x"}, + {Server: "b", BareName: "y"}, + }, + } + got := cat.ToolkitsBy(func(t mcp.Tool) string { return t.Server }) + if len(got) != 2 { + t.Fatalf("len = %d; want 2", len(got)) + } + if got[0].Name != "a" || got[1].Name != "b" { + t.Errorf("toolkits = %v; want [a b]", names(got)) + } +} + +func TestCatalog_Empty(t *testing.T) { + t.Parallel() + + var cat *mcp.Catalog + if got := cat.Toolkits(); got != nil { + t.Errorf("nil Catalog.Toolkits = %v; want nil", got) + } + if got := (&mcp.Catalog{}).Toolkits(); got != nil { + t.Errorf("empty Catalog.Toolkits = %v; want nil", got) + } +} + +func TestCatalog_DefaultClassifierEmptyName(t *testing.T) { + t.Parallel() + + cat := &mcp.Catalog{Tools: []mcp.Tool{{Server: "s", BareName: "noprefix"}}} + got := cat.ToolkitsBy(func(_ mcp.Tool) string { return "" }) + if len(got) != 1 || got[0].Name != "default" { + t.Errorf("empty-string classifier: got %v; want a single 'default' toolkit", names(got)) + } +} + +func TestCatalog_PrefixClassifierLeadingUnderscore(t *testing.T) { + t.Parallel() + + // Per PrefixClassifier doc: leading underscore -> "misc". + if got := mcp.PrefixClassifier(mcp.Tool{BareName: "_foo"}); got != "misc" { + t.Errorf("PrefixClassifier(\"_foo\") = %q; want misc", got) + } +} + +func names(toolkits []mcp.Toolkit) []string { + out := make([]string, len(toolkits)) + for i, tk := range toolkits { + out[i] = tk.Name + } + return out +} diff --git a/core/mcp/client.go b/core/mcp/client.go new file mode 100644 index 0000000..516a960 --- /dev/null +++ b/core/mcp/client.go @@ -0,0 +1,753 @@ +// Package mcp is a thin wrapper over [github.com/modelcontextprotocol/go-sdk/mcp] +// that drives one or more MCP servers concurrently. +// +// Its responsibilities, in increasing order of value-add: +// +// - Manage the connection lifecycle for each configured server. +// - Aggregate the tool catalog across servers, namespacing each tool +// name as "__" so the model sees a single flat list. +// - Route tool calls back to the originating server. +// - Retry transient tool-call failures with bounded backoff (on the +// same session — see [Client.Call]; full transport reconnect is +// not implemented in v1). +// - Surface resources and prompts from each server. +// +// The wrapper is intentionally narrow: it does not buffer responses, +// implement caching, or transform tool descriptions. Those belong in +// the agent loop and the router (phases 5-7). +// +// # Subprocess environment +// +// Stdio transports inherit the parent process's environment. If the +// agent process holds secrets in env vars (API keys, SSO tokens), +// they are passed to every spawned MCP server subprocess. Operators +// running third-party MCP binaries should sanitize their env or wrap +// the binary in a launcher that scrubs unwanted variables. A future +// release may add a per-[ServerConfig] env-curation knob. +// +// # Lifecycle +// +// Construct a [Client] with [New], call [Client.Connect] exactly once, +// then use it. After [Client.Close], the client is terminal — calling +// [Client.Connect] again returns [ErrConfig]. Construct a new client +// to reconnect. +package mcp + +import ( + "context" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "math" + "math/big" + "net/http" + "net/url" + "os/exec" + "sort" + "strings" + "sync" + "time" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + "golang.org/x/sync/errgroup" +) + +// NamespaceSeparator is the substring placed between server name and +// bare tool name in the namespaced wire form. Double underscore is +// chosen to remain legal under all model-provider tool-name regexes +// while staying trivially parseable. +const NamespaceSeparator = "__" + +// Transport selects how a [Client] reaches a server. +type Transport string + +// Supported transport kinds. Mirror the go-sdk's client transports. +const ( + TransportStdio Transport = "stdio" + TransportSSE Transport = "sse" + TransportStreamableHTTP Transport = "streamable_http" +) + +// ServerConfig describes one MCP server to connect to. +type ServerConfig struct { + // Name identifies the server in the namespace prefix used by the + // catalog. Must be non-empty and not contain [NamespaceSeparator]. + Name string + + // Transport selects how to reach the server. + Transport Transport + + // Endpoint is the URL (for sse / streamable_http) or the command + // line (for stdio). Stdio command lines are split on whitespace + // only — no shell-style quoting. For commands with whitespace in + // arguments, set Endpoint to the executable and pass extra args + // through StdioArgs. + Endpoint string + + // StdioArgs are extra arguments passed to the stdio command after + // any tokens parsed from Endpoint. Ignored for non-stdio. + StdioArgs []string + + // Headers may carry auth tokens or routing headers for HTTP-based + // transports. Ignored for stdio. + Headers map[string]string +} + +// Tool is one tool exposed by some MCP server. Name is the namespaced +// form ("server__tool"); Server and BareName recover the components. +type Tool struct { + Name string // namespaced: server__tool + Server string // origin server + BareName string // tool name as the server reported it + Description string // server-supplied description + InputSchema json.RawMessage // JSON Schema for arguments +} + +// Resource is one resource exposed by some MCP server. +type Resource struct { + Server string + URI string + Name string + Description string + MIMEType string +} + +// Prompt is one prompt template exposed by some MCP server. +type Prompt struct { + Server string + Name string + Description string +} + +// ToolCall is a tool invocation request. Name is the namespaced form +// ("server__tool"); the client splits it to route to the right server. +type ToolCall struct { + Name string + Arguments json.RawMessage +} + +// ToolContent is one piece of content returned by a tool call. +// +// Type is one of: "text", "image", "audio", "resource", "resource_link", +// or "unknown" (the catch-all for SDK content variants this wrapper +// has not been taught about yet — Text holds a JSON dump in that case). +type ToolContent struct { + Type string + Text string + Data []byte // for image/audio: raw bytes (the SDK base64-decodes the wire form) + MIMEType string + URI string // for resource and resource_link references + Name string // for resource_link +} + +// ToolResult is the result of a [Client.Call]. +type ToolResult struct { + Content []ToolContent + IsError bool +} + +// Backoff configures the bounded retry policy used by [Client.Call] +// when an MCP-level CallTool errors. Zero fields fall back to package +// defaults. The same session is retried; this is not full transport +// reconnect (see the package doc). +type Backoff struct { + Base time.Duration // default 500ms + Cap time.Duration // default 30s + MaxAttempts int // default 5 +} + +func (b Backoff) base() time.Duration { + if b.Base <= 0 { + return 500 * time.Millisecond + } + return b.Base +} + +func (b Backoff) cap() time.Duration { + if b.Cap <= 0 { + return 30 * time.Second + } + return b.Cap +} + +func (b Backoff) maxAttempts() int { + if b.MaxAttempts <= 0 { + return 5 + } + return b.MaxAttempts +} + +// Delay returns the backoff delay for the given zero-based attempt +// using exponential growth with full jitter, clamped to [0, Cap]. +// Negative attempts are treated as 0 (no panic on misuse). +func (b Backoff) Delay(attempt int) time.Duration { + if attempt < 0 { + attempt = 0 + } + // Cap the shift to avoid wrap-around with large attempts. After + // ~31 shifts the value already saturates the Cap clamp below; + // 31 is also < int64 width to keep the shift operand sane. + if attempt > 31 { + attempt = 31 + } + d := b.base() << attempt + if d <= 0 || d > b.cap() { + d = b.cap() + } + // Defense in depth: even with the cap-clamp, a caller-supplied + // Cap of math.MaxInt64 would make int64(d)+1 overflow to + // MinInt64. Subtract one before the +1 so the maximum operand + // stays representable. + if d == math.MaxInt64 { + d = math.MaxInt64 - 1 + } + // Full jitter, picked uniformly in [0, d]. crypto/rand is the + // portable answer for "give me a non-negative integer below N"; + // jitter doesn't *need* cryptographic strength, but using + // crypto/rand keeps every static analyzer happy without case-by- + // case suppressions and the per-call cost is negligible against + // network I/O. + n, err := rand.Int(rand.Reader, big.NewInt(int64(d)+1)) + if err != nil { + // rand.Reader is documented as always available on supported + // platforms. Degrade safely if a misconfigured environment + // somehow returns an error: cap the delay rather than panic. + return d + } + return time.Duration(n.Int64()) +} + +// Sentinel errors callers may match with [errors.Is]. +var ( + // ErrConfig is returned for invalid configuration at construction time. + ErrConfig = errors.New("mcp: invalid configuration") + // ErrUnknownServer is returned when a [ToolCall] names a server not in the client's config. + ErrUnknownServer = errors.New("mcp: unknown server") + // ErrServerUnavailable is returned after backoff is exhausted on a connection. + ErrServerUnavailable = errors.New("mcp: server unavailable after backoff") + // ErrInvalidName is returned by [SplitName] when the input is not a valid namespaced name. + ErrInvalidName = errors.New("mcp: invalid namespaced tool name") +) + +// Option configures a [Client] at construction time. +type Option func(*Client) + +// WithImplementation overrides the [sdkmcp.Implementation] block sent +// to each server during initialization. +func WithImplementation(impl sdkmcp.Implementation) Option { + return func(c *Client) { c.impl = impl } +} + +// WithBackoff overrides the per-call retry backoff policy used by +// [Client.Call] when an MCP-level CallTool errors. +func WithBackoff(b Backoff) Option { + return func(c *Client) { c.backoff = b } +} + +// Dialer builds the underlying [sdkmcp.Transport] for a [ServerConfig]. +// Tests inject a Dialer that returns in-memory transports; production +// callers leave the default in place. +type Dialer func(ctx context.Context, cfg ServerConfig) (sdkmcp.Transport, error) + +// WithDialer overrides the transport builder used by [Client.Connect]. +// Tests use this to wire an in-memory transport pair without spawning +// real subprocesses or HTTP servers. +func WithDialer(d Dialer) Option { + return func(c *Client) { c.dialer = d } +} + +// Client is a connection to one or more MCP servers. +// +// Construct with [New]; call [Client.Connect] before any other method; +// always call [Client.Close] when done. After Close the client is +// terminal — see the package doc. +type Client struct { + impl sdkmcp.Implementation + cfgs []ServerConfig + backoff Backoff + dialer Dialer + + mu sync.RWMutex + sessions map[string]*sdkmcp.ClientSession + catalog *Catalog + connectAttempted bool // latches on the first Connect call, success OR failure. + closed bool // set by Close; latches forever. +} + +// New constructs a Client. It validates configuration and returns an +// error on the first problem; it does not dial. Call [Client.Connect] +// to establish sessions. +func New(cfgs []ServerConfig, opts ...Option) (*Client, error) { + if len(cfgs) == 0 { + return nil, fmt.Errorf("%w: at least one ServerConfig is required", ErrConfig) + } + seen := make(map[string]struct{}, len(cfgs)) + for i, cfg := range cfgs { + if err := validateServerConfig(i, cfg, seen); err != nil { + return nil, err + } + } + c := &Client{ + impl: sdkmcp.Implementation{ + Name: "plexara-agents", + Version: "0.0.0-dev", + }, + cfgs: cfgs, + dialer: buildSDKTransport, + sessions: make(map[string]*sdkmcp.ClientSession, len(cfgs)), + } + for _, opt := range opts { + opt(c) + } + return c, nil +} + +// validateServerConfig checks one ServerConfig and records its name +// in seen to detect duplicates. Returns an error wrapping ErrConfig. +func validateServerConfig(i int, cfg ServerConfig, seen map[string]struct{}) error { + if cfg.Name == "" { + return fmt.Errorf("%w: cfgs[%d].Name is empty", ErrConfig, i) + } + if strings.TrimSpace(cfg.Name) != cfg.Name { + return fmt.Errorf("%w: cfgs[%d].Name %q has leading/trailing whitespace", ErrConfig, i, cfg.Name) + } + if strings.Contains(cfg.Name, NamespaceSeparator) { + return fmt.Errorf("%w: cfgs[%d].Name %q contains the namespace separator %q", + ErrConfig, i, cfg.Name, NamespaceSeparator) + } + if _, dup := seen[cfg.Name]; dup { + return fmt.Errorf("%w: duplicate server name %q", ErrConfig, cfg.Name) + } + seen[cfg.Name] = struct{}{} + if strings.TrimSpace(cfg.Endpoint) == "" { + return fmt.Errorf("%w: cfgs[%d].Endpoint is empty or whitespace", ErrConfig, i) + } + switch cfg.Transport { + case TransportStdio: + if len(strings.Fields(cfg.Endpoint)) == 0 { + return fmt.Errorf("%w: cfgs[%d].Endpoint has no command tokens", ErrConfig, i) + } + case TransportSSE, TransportStreamableHTTP: + u, perr := url.Parse(cfg.Endpoint) + if perr != nil || u.Scheme == "" || u.Host == "" { + return fmt.Errorf("%w: cfgs[%d].Endpoint %q is not a valid URL", ErrConfig, i, cfg.Endpoint) + } + default: + return fmt.Errorf("%w: cfgs[%d].Transport %q is not supported", ErrConfig, i, cfg.Transport) + } + return nil +} + +// Connect dials every configured server in parallel, then snapshots +// the aggregated tool catalog. If any server fails to connect, all +// already-opened sessions are closed and the error is returned. +// +// Connect must be called exactly once per [Client]. A second call — +// regardless of whether the first succeeded, failed, or was followed +// by [Client.Close] — returns [ErrConfig]. To retry after a failed +// Connect, construct a new client. +func (c *Client) Connect(ctx context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.closed { + return fmt.Errorf("%w: Connect on closed client; construct a new Client", ErrConfig) + } + if c.connectAttempted { + return fmt.Errorf("%w: Connect already called (success or failure latches the gate)", ErrConfig) + } + c.connectAttempted = true + + g, gctx := errgroup.WithContext(ctx) + var sessMu sync.Mutex + sessions := make(map[string]*sdkmcp.ClientSession, len(c.cfgs)) + + for _, cfg := range c.cfgs { + g.Go(func() error { + tr, err := c.dialer(gctx, cfg) + if err != nil { + return fmt.Errorf("server %q: build transport: %w", cfg.Name, err) + } + sdkClient := sdkmcp.NewClient(&c.impl, nil) + session, err := sdkClient.Connect(gctx, tr, nil) + if err != nil { + return fmt.Errorf("server %q: connect: %w", cfg.Name, err) + } + sessMu.Lock() + sessions[cfg.Name] = session + sessMu.Unlock() + return nil + }) + } + if err := g.Wait(); err != nil { + // Close any partially-opened sessions before returning. + for _, s := range sessions { + _ = s.Close() + } + return err + } + c.sessions = sessions + + cat, err := c.fetchCatalogLocked(ctx) + if err != nil { + // We have sessions but the catalog refresh failed. Close + // everything to keep state consistent. + for _, s := range c.sessions { + _ = s.Close() + } + c.sessions = map[string]*sdkmcp.ClientSession{} + return fmt.Errorf("fetch catalog: %w", err) + } + c.catalog = cat + return nil +} + +// Close closes every active session and clears the cached catalog, +// returning the first error seen. Subsequent calls are no-ops. +// After Close, [Client.Catalog] returns an empty catalog and any +// [Client.Call] / [Client.Resources] / [Client.Prompts] returns +// [ErrUnknownServer]. +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + var firstErr error + for name, s := range c.sessions { + if err := s.Close(); err != nil && firstErr == nil { + firstErr = fmt.Errorf("close %q: %w", name, err) + } + } + c.sessions = map[string]*sdkmcp.ClientSession{} + c.catalog = nil + c.closed = true + return firstErr +} + +// Catalog returns the aggregated tool catalog snapshot taken at +// [Client.Connect] time. The returned value is a defensive copy; +// mutating it has no effect on the client. +func (c *Client) Catalog() *Catalog { + c.mu.RLock() + defer c.mu.RUnlock() + if c.catalog == nil { + return &Catalog{} + } + return c.catalog.copy() +} + +// Resources lists resources exposed by the named server. Pagination +// is handled by the SDK iterator; servers exposing more resources +// than fit in a single page are still fully enumerated. Returns +// [ErrUnknownServer] if no such server is configured or if [Client.Close] +// has already been called. +// +// The client lock is held only long enough to look up the session +// pointer; the iteration runs lock-free so a slow ListResources on +// one server cannot starve concurrent operations on other servers. +// A concurrent [Client.Close] may close the underlying session +// mid-iteration; the SDK surfaces this as a transport error rather +// than a panic. +func (c *Client) Resources(ctx context.Context, server string) ([]Resource, error) { + c.mu.RLock() + sess, ok := c.sessions[server] + closed := c.closed + c.mu.RUnlock() + if closed || !ok { + return nil, fmt.Errorf("%w: %q", ErrUnknownServer, server) + } + var out []Resource + for r, err := range sess.Resources(ctx, nil) { + if err != nil { + return nil, fmt.Errorf("list resources: %w", err) + } + out = append(out, Resource{ + Server: server, URI: r.URI, Name: r.Name, + Description: r.Description, MIMEType: r.MIMEType, + }) + } + return out, nil +} + +// Prompts lists prompt templates exposed by the named server. +// Pagination is handled by the SDK iterator. Returns [ErrUnknownServer] +// if no such server is configured or if [Client.Close] has already been +// called. The client lock is released before iteration begins (see +// [Client.Resources] for the rationale). +func (c *Client) Prompts(ctx context.Context, server string) ([]Prompt, error) { + c.mu.RLock() + sess, ok := c.sessions[server] + closed := c.closed + c.mu.RUnlock() + if closed || !ok { + return nil, fmt.Errorf("%w: %q", ErrUnknownServer, server) + } + var out []Prompt + for p, err := range sess.Prompts(ctx, nil) { + if err != nil { + return nil, fmt.Errorf("list prompts: %w", err) + } + out = append(out, Prompt{ + Server: server, Name: p.Name, Description: p.Description, + }) + } + return out, nil +} + +// Call invokes a namespaced tool. The Name field of req must be in +// "server__tool" form. +// +// Retry policy: on ANY error from the underlying CallTool — including +// server-returned semantic errors such as invalid arguments — Call +// retries up to [Backoff.MaxAttempts] times with exponential backoff. +// Context cancellation and deadline are the only short-circuit +// conditions. This is intentionally broad for v1; a future release +// may classify jsonrpc2 error codes to short-circuit non-transient +// failures (`-32601` Method-not-found, `-32602` Invalid-params). +// Until then, callers that want fail-fast semantics for a specific +// tool should pass a context with a tight deadline. After backoff is +// exhausted Call returns [ErrServerUnavailable] wrapping the last +// underlying error. +// +// Tool-level "is this an error result?" reporting flows through +// [ToolResult.IsError], not Go errors — those come back from the +// server with an OK transport response. +func (c *Client) Call(ctx context.Context, req ToolCall) (ToolResult, error) { + server, bare, err := SplitName(req.Name) + if err != nil { + return ToolResult{}, err + } + c.mu.RLock() + sess, ok := c.sessions[server] + c.mu.RUnlock() + if !ok { + return ToolResult{}, fmt.Errorf("%w: %q", ErrUnknownServer, server) + } + + var args any + if len(req.Arguments) > 0 { + if err := json.Unmarshal(req.Arguments, &args); err != nil { + return ToolResult{}, fmt.Errorf("decode arguments: %w", err) + } + } + + maxAttempts := c.backoff.maxAttempts() + var lastErr error + for attempt := range maxAttempts { + if attempt > 0 { + delay := c.backoff.Delay(attempt - 1) + select { + case <-time.After(delay): + case <-ctx.Done(): + return ToolResult{}, ctx.Err() + } + } + res, err := sess.CallTool(ctx, &sdkmcp.CallToolParams{Name: bare, Arguments: args}) + if err == nil { + return convertToolResult(res), nil + } + // Non-retryable: context cancellation, no point retrying. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return ToolResult{}, err + } + lastErr = err + } + return ToolResult{}, fmt.Errorf("%w: %q: %w", ErrServerUnavailable, server, lastErr) +} + +// fetchCatalogLocked snapshots the tool list from every connected +// server. Pagination is handled by the SDK iterator; the catalog is +// sorted deterministically (server name, then bare tool name) so that +// two Connect calls produce identical Tools ordering. The caller must +// hold c.mu. +func (c *Client) fetchCatalogLocked(ctx context.Context) (*Catalog, error) { + cat := &Catalog{ + Tools: []Tool{}, + ToolsByServer: make(map[string][]Tool, len(c.sessions)), + } + // Iterate sessions in stable name order so partial failures point + // at a deterministic server. + names := make([]string, 0, len(c.sessions)) + for name := range c.sessions { + names = append(names, name) + } + sort.Strings(names) + + for _, name := range names { + sess := c.sessions[name] + // Initialize as []Tool{} (not nil) so a server that exposes + // zero tools yields a non-nil empty slice in ToolsByServer, + // matching cat.Tools' shape and avoiding nil-vs-empty caller + // surprises. + serverTools := []Tool{} + for t, err := range sess.Tools(ctx, nil) { + if err != nil { + return nil, fmt.Errorf("server %q: list tools: %w", name, err) + } + schemaBytes, mErr := json.Marshal(t.InputSchema) + if mErr != nil { + return nil, fmt.Errorf("server %q tool %q: marshal input schema: %w", name, t.Name, mErr) + } + tool := Tool{ + Name: JoinName(name, t.Name), + Server: name, + BareName: t.Name, + Description: t.Description, + InputSchema: schemaBytes, + } + serverTools = append(serverTools, tool) + } + // Stable order within a server's tool list. + sort.Slice(serverTools, func(i, j int) bool { + return serverTools[i].BareName < serverTools[j].BareName + }) + cat.ToolsByServer[name] = serverTools + cat.Tools = append(cat.Tools, serverTools...) + } + return cat, nil +} + +// convertToolResult maps a go-sdk CallToolResult into our wire-public form. +// Defensive against nil — a nil result yields a zero-value ToolResult. +func convertToolResult(r *sdkmcp.CallToolResult) ToolResult { + if r == nil { + return ToolResult{} + } + out := ToolResult{IsError: r.IsError} + for _, c := range r.Content { + out.Content = append(out.Content, contentToToolContent(c)) + } + return out +} + +func contentToToolContent(c sdkmcp.Content) ToolContent { + switch v := c.(type) { + case *sdkmcp.TextContent: + return ToolContent{Type: "text", Text: v.Text} + case *sdkmcp.ImageContent: + return ToolContent{Type: "image", Data: v.Data, MIMEType: v.MIMEType} + case *sdkmcp.AudioContent: + return ToolContent{Type: "audio", Data: v.Data, MIMEType: v.MIMEType} + case *sdkmcp.EmbeddedResource: + tc := ToolContent{Type: "resource"} + if v.Resource != nil { + tc.URI = v.Resource.URI + tc.MIMEType = v.Resource.MIMEType + tc.Text = v.Resource.Text + tc.Data = v.Resource.Blob + } + return tc + case *sdkmcp.ResourceLink: + return ToolContent{ + Type: "resource_link", + URI: v.URI, + Name: v.Name, + MIMEType: v.MIMEType, + } + default: + // Unknown content type — round-trip via JSON so callers at + // least see something rather than dropping it silently. If the + // payload itself fails to marshal (e.g. embeds an unsupported + // kind), fall back to a diagnostic Text rather than the empty + // string. + raw, mErr := json.Marshal(v) + if mErr != nil { + return ToolContent{ + Type: "unknown", + Text: fmt.Sprintf("%T: marshal failed: %v", v, mErr), + } + } + return ToolContent{Type: "unknown", Text: string(raw)} + } +} + +// SplitName parses a namespaced "server__tool" string into its parts. +// Returns [ErrInvalidName] if the input lacks the separator or has +// an empty server or tool component. +// +// A bare tool name may itself contain the separator. SplitName splits +// on the FIRST occurrence: "s__a__b" → ("s", "a__b"). +func SplitName(s string) (server, bare string, err error) { + idx := strings.Index(s, NamespaceSeparator) + if idx <= 0 { + return "", "", fmt.Errorf("%w: %q", ErrInvalidName, s) + } + server = s[:idx] + bare = s[idx+len(NamespaceSeparator):] + if bare == "" { + return "", "", fmt.Errorf("%w: %q (empty tool component)", ErrInvalidName, s) + } + return server, bare, nil +} + +// JoinName produces a namespaced "server__tool" string. +func JoinName(server, bare string) string { + return server + NamespaceSeparator + bare +} + +// httpClientWithHeaders returns nil if headers is empty (the SDK +// supplies its own default client). Otherwise it returns a client +// whose Transport wraps http.DefaultTransport and injects the given +// headers on every outbound request. +func httpClientWithHeaders(headers map[string]string) *http.Client { + if len(headers) == 0 { + return nil + } + // Defensive copy so post-construction mutation by the caller + // cannot retroactively change request headers. + h := make(map[string]string, len(headers)) + for k, v := range headers { + h[k] = v + } + return &http.Client{Transport: &headerInjectingRoundTripper{base: http.DefaultTransport, headers: h}} +} + +type headerInjectingRoundTripper struct { + base http.RoundTripper + headers map[string]string +} + +func (rt *headerInjectingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // Clone the request before mutating headers per http.RoundTripper contract. + out := req.Clone(req.Context()) + for k, v := range rt.headers { + out.Header.Set(k, v) + } + return rt.base.RoundTrip(out) +} + +// buildSDKTransport constructs the sdk's transport for a config. +// All errors wrap [ErrConfig] and include the server name so failures +// at Connect time are diagnosable through the public sentinel. +func buildSDKTransport(_ context.Context, cfg ServerConfig) (sdkmcp.Transport, error) { + switch cfg.Transport { + case TransportStdio: + argv := strings.Fields(cfg.Endpoint) + if len(argv) == 0 { + // New() rejects this; defense in depth. + return nil, fmt.Errorf("%w: server %q: stdio Endpoint has no command tokens", ErrConfig, cfg.Name) + } + argv = append(argv, cfg.StdioArgs...) + // argv comes from operator-controlled ServerConfig (the + // maintainer's MCP config file), not from end-user input. + // gosec's G204 ("subprocess with potentially tainted input") + // is suppressed with native #nosec syntax — recognized by + // both the standalone gosec used in security.yml and gosec + // running inside golangci-lint. Stdio MCP servers exist to + // run external binaries; an allowlist would prevent the + // operator from configuring the system at all. + cmd := exec.Command(argv[0], argv[1:]...) // #nosec G204 -- argv is operator-trusted config + return &sdkmcp.CommandTransport{Command: cmd}, nil + case TransportSSE: + return &sdkmcp.SSEClientTransport{ + Endpoint: cfg.Endpoint, + HTTPClient: httpClientWithHeaders(cfg.Headers), + }, nil + case TransportStreamableHTTP: + return &sdkmcp.StreamableClientTransport{ + Endpoint: cfg.Endpoint, + HTTPClient: httpClientWithHeaders(cfg.Headers), + }, nil + default: + return nil, fmt.Errorf("%w: server %q: unsupported transport %q", ErrConfig, cfg.Name, cfg.Transport) + } +} diff --git a/core/mcp/client_test.go b/core/mcp/client_test.go new file mode 100644 index 0000000..23c5f0b --- /dev/null +++ b/core/mcp/client_test.go @@ -0,0 +1,730 @@ +package mcp_test + +import ( + "context" + "encoding/json" + "errors" + "strings" + "sync" + "testing" + "time" + + sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/plexara/plexara-agents/core/mcp" +) + +func TestSplitName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in string + wantServer string + wantBare string + wantErr bool + }{ + {name: "simple", in: "plexara__datahub_search", wantServer: "plexara", wantBare: "datahub_search"}, + {name: "tool_with_separator", in: "plexara__a__b", wantServer: "plexara", wantBare: "a__b"}, + {name: "empty", in: "", wantErr: true}, + {name: "no_separator", in: "datahub_search", wantErr: true}, + {name: "leading_separator", in: "__tool", wantErr: true}, + {name: "trailing_separator", in: "server__", wantErr: true}, + {name: "only_separator", in: "__", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + gotServer, gotBare, err := mcp.SplitName(tt.in) + if tt.wantErr { + if err == nil { + t.Errorf("SplitName(%q) = %q, %q, nil; want error", tt.in, gotServer, gotBare) + } else if !errors.Is(err, mcp.ErrInvalidName) { + t.Errorf("SplitName(%q) err = %v; want errors.Is ErrInvalidName", tt.in, err) + } + return + } + if err != nil { + t.Fatalf("SplitName(%q): unexpected err: %v", tt.in, err) + } + if gotServer != tt.wantServer || gotBare != tt.wantBare { + t.Errorf("SplitName(%q) = %q, %q; want %q, %q", tt.in, gotServer, gotBare, tt.wantServer, tt.wantBare) + } + }) + } +} + +func TestJoinNameRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + server, bare string + }{ + {"plexara", "datahub_search"}, + {"acme", "trino_query"}, + {"a", "b"}, + {"server-with-dashes", "tool_with_underscores"}, + } + for _, tt := range tests { + joined := mcp.JoinName(tt.server, tt.bare) + gotServer, gotBare, err := mcp.SplitName(joined) + if err != nil { + t.Errorf("SplitName(JoinName(%q,%q)) = err: %v", tt.server, tt.bare, err) + continue + } + if gotServer != tt.server || gotBare != tt.bare { + t.Errorf("round trip %q,%q -> %q -> %q,%q", tt.server, tt.bare, joined, gotServer, gotBare) + } + } +} + +func TestNew_Validation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfgs []mcp.ServerConfig + }{ + {name: "no_servers", cfgs: nil}, + {name: "empty_name", cfgs: []mcp.ServerConfig{{Name: "", Transport: mcp.TransportStdio, Endpoint: "x"}}}, + {name: "name_with_separator", cfgs: []mcp.ServerConfig{{Name: "a__b", Transport: mcp.TransportStdio, Endpoint: "x"}}}, + {name: "duplicate_name", cfgs: []mcp.ServerConfig{ + {Name: "x", Transport: mcp.TransportStdio, Endpoint: "a"}, + {Name: "x", Transport: mcp.TransportStdio, Endpoint: "b"}, + }}, + {name: "empty_endpoint", cfgs: []mcp.ServerConfig{{Name: "x", Transport: mcp.TransportStdio, Endpoint: ""}}}, + {name: "unknown_transport", cfgs: []mcp.ServerConfig{{Name: "x", Transport: "smoke-signal", Endpoint: "x"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := mcp.New(tt.cfgs) + if err == nil { + t.Fatalf("New(%q) returned nil error", tt.name) + } + if !errors.Is(err, mcp.ErrConfig) { + t.Errorf("err = %v; want errors.Is ErrConfig", err) + } + }) + } +} + +func TestNew_OK(t *testing.T) { + t.Parallel() + + c, err := mcp.New([]mcp.ServerConfig{ + {Name: "a", Transport: mcp.TransportStdio, Endpoint: "echo"}, + {Name: "b", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}, + }) + if err != nil { + t.Fatalf("New: %v", err) + } + if c == nil { + t.Fatal("New returned nil client without error") + } +} + +// fakeServer returns an in-memory client transport wired to a real +// mcp.Server with the given tools registered. The server's connect +// goroutine is registered with the test's WaitGroup via t.Cleanup so +// that the test fully waits for it before exiting — without this, +// the goroutine could call t.Logf after the test had completed and +// crash with "Log in goroutine after test has completed". +func fakeServer(t *testing.T, name string, tools []sdkmcp.Tool) sdkmcp.Transport { + t.Helper() + server := sdkmcp.NewServer(&sdkmcp.Implementation{Name: name, Version: "v0.0.0"}, nil) + for _, tt := range tools { + sdkmcp.AddTool(server, &tt, func(_ context.Context, _ *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) { + return &sdkmcp.CallToolResult{ + Content: []sdkmcp.Content{ + &sdkmcp.TextContent{Text: "called " + tt.Name + " with " + sprintArgs(args)}, + }, + }, nil, nil + }) + } + clientT, serverT := sdkmcp.NewInMemoryTransports() + var wg sync.WaitGroup + wg.Add(1) + connectErr := make(chan error, 1) + go func() { + defer wg.Done() + _, err := server.Connect(context.Background(), serverT, nil) + connectErr <- err + }() + t.Cleanup(func() { + wg.Wait() + if err := <-connectErr; err != nil { + t.Logf("server.Connect: %v", err) + } + }) + return clientT +} + +func sprintArgs(args map[string]any) string { + b, _ := json.Marshal(args) + return string(b) +} + +func TestConnect_AggregatesCatalog(t *testing.T) { + t.Parallel() + + transports := map[string]sdkmcp.Transport{ + "plexara": fakeServer(t, "plexara", []sdkmcp.Tool{ + {Name: "datahub_search", Description: "search the data hub"}, + {Name: "trino_query", Description: "run sql"}, + }), + "fs": fakeServer(t, "fs", []sdkmcp.Tool{ + {Name: "read_file", Description: "read a file"}, + }), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + tr, ok := transports[cfg.Name] + if !ok { + t.Fatalf("unexpected server name %q", cfg.Name) + } + return tr, nil + } + + c, err := mcp.New([]mcp.ServerConfig{ + {Name: "plexara", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://stub"}, + {Name: "fs", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://stub"}, + }, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + if err := c.Connect(ctx); err != nil { + t.Fatalf("Connect: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + + cat := c.Catalog() + if got, want := len(cat.Tools), 3; got != want { + t.Fatalf("len(cat.Tools) = %d; want %d", got, want) + } + + // Confirm namespacing. + gotNames := map[string]bool{} + for _, tool := range cat.Tools { + gotNames[tool.Name] = true + } + for _, want := range []string{"plexara__datahub_search", "plexara__trino_query", "fs__read_file"} { + if !gotNames[want] { + t.Errorf("missing tool %q in catalog: %v", want, gotNames) + } + } + + // ToolsByServer index. + if got := len(cat.ToolsByServer["plexara"]); got != 2 { + t.Errorf("ToolsByServer[plexara] = %d tools; want 2", got) + } + if got := len(cat.ToolsByServer["fs"]); got != 1 { + t.Errorf("ToolsByServer[fs] = %d tools; want 1", got) + } +} + +func TestCall_Routes(t *testing.T) { + t.Parallel() + + transports := map[string]sdkmcp.Transport{ + "plexara": fakeServer(t, "plexara", []sdkmcp.Tool{ + {Name: "datahub_search", Description: "search"}, + }), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + return transports[cfg.Name], nil + } + + c, err := mcp.New([]mcp.ServerConfig{ + {Name: "plexara", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://stub"}, + }, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + if err := c.Connect(ctx); err != nil { + t.Fatalf("Connect: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + + res, err := c.Call(ctx, mcp.ToolCall{ + Name: "plexara__datahub_search", + Arguments: json.RawMessage(`{"q":"orders"}`), + }) + if err != nil { + t.Fatalf("Call: %v", err) + } + if res.IsError { + t.Errorf("IsError = true; want false") + } + if len(res.Content) != 1 || res.Content[0].Type != "text" { + t.Fatalf("unexpected content %#v", res.Content) + } + if !strings.Contains(res.Content[0].Text, "datahub_search") { + t.Errorf("response %q did not echo tool name", res.Content[0].Text) + } + if !strings.Contains(res.Content[0].Text, `"q":"orders"`) { + t.Errorf("response %q did not echo args", res.Content[0].Text) + } +} + +func TestCall_UnknownServer(t *testing.T) { + t.Parallel() + + transports := map[string]sdkmcp.Transport{ + "plexara": fakeServer(t, "plexara", nil), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + return transports[cfg.Name], nil + } + + c, err := mcp.New([]mcp.ServerConfig{ + {Name: "plexara", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://stub"}, + }, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Connect(t.Context()); err != nil { + t.Fatalf("Connect: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + + _, err = c.Call(t.Context(), mcp.ToolCall{Name: "missing__tool"}) + if !errors.Is(err, mcp.ErrUnknownServer) { + t.Errorf("err = %v; want errors.Is ErrUnknownServer", err) + } +} + +func TestCall_InvalidName(t *testing.T) { + t.Parallel() + + transports := map[string]sdkmcp.Transport{ + "plexara": fakeServer(t, "plexara", nil), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + return transports[cfg.Name], nil + } + + c, err := mcp.New([]mcp.ServerConfig{ + {Name: "plexara", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://stub"}, + }, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Connect(t.Context()); err != nil { + t.Fatalf("Connect: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + + _, err = c.Call(t.Context(), mcp.ToolCall{Name: "no-separator"}) + if !errors.Is(err, mcp.ErrInvalidName) { + t.Errorf("err = %v; want errors.Is ErrInvalidName", err) + } +} + +func TestConnect_ParallelFailureClosesAll(t *testing.T) { + t.Parallel() + + good := fakeServer(t, "good", []sdkmcp.Tool{{Name: "t", Description: "ok"}}) + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + switch cfg.Name { + case "good": + return good, nil + case "bad": + return nil, errors.New("dial refused") + default: + return nil, errors.New("unexpected") + } + } + + c, err := mcp.New([]mcp.ServerConfig{ + {Name: "good", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}, + {Name: "bad", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}, + }, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + err = c.Connect(t.Context()) + if err == nil { + t.Fatal("Connect returned nil; want error from bad server") + } + if !strings.Contains(err.Error(), "bad") { + t.Errorf("err = %v; want it to mention the bad server", err) + } + // After a failed Connect, the client must be left in an unusable + // state — empty catalog and Call routing returning ErrUnknownServer. + if got := c.Catalog().Tools; len(got) != 0 { + t.Errorf("Catalog().Tools = %d entries; want 0 after failed Connect", len(got)) + } + if _, callErr := c.Call(t.Context(), mcp.ToolCall{Name: "good__t"}); !errors.Is(callErr, mcp.ErrUnknownServer) { + t.Errorf("Call after failed Connect: err = %v; want ErrUnknownServer", callErr) + } +} + +func TestBackoff_Bounded(t *testing.T) { + t.Parallel() + + b := mcp.Backoff{Base: 100 * time.Millisecond, Cap: 1 * time.Second, MaxAttempts: 3} + for attempt := range 10 { + d := b.Delay(attempt) + if d < 0 || d > b.Cap { + t.Errorf("attempt %d: delay %v out of [0, %v]", attempt, d, b.Cap) + } + } +} + +// TestBackoff_NegativeAttemptIsClamped guards Delay against negative +// shift counts. Before the clamp was added, Delay(-1) panicked at +// runtime with "negative shift amount". +func TestBackoff_NegativeAttemptIsClamped(t *testing.T) { + t.Parallel() + + b := mcp.Backoff{Base: 100 * time.Millisecond, Cap: 1 * time.Second} + d := b.Delay(-1) + if d < 0 || d > b.Cap { + t.Errorf("Delay(-1) = %v; want in [0, %v]", d, b.Cap) + } +} + +// TestBackoff_GrowsAcrossAttempts confirms Delay actually scales — +// without growth, a constant zero would still pass TestBackoff_Bounded. +// The mean over many samples at high attempts must exceed the mean at +// attempt 0. +func TestBackoff_GrowsAcrossAttempts(t *testing.T) { + t.Parallel() + + b := mcp.Backoff{Base: 100 * time.Millisecond, Cap: 1 * time.Second} + const samples = 200 + + mean := func(attempt int) time.Duration { + var total time.Duration + for range samples { + total += b.Delay(attempt) + } + return total / samples + } + + low := mean(0) + high := mean(5) + if high <= low { + t.Errorf("mean Delay(5)=%v not greater than mean Delay(0)=%v; backoff is not growing", high, low) + } + // At attempts >> log2(Cap/Base), the saturated value should sit + // near Cap/2 with full jitter. Sanity-check that high-mean is + // within an order of magnitude of Cap/2. + if high < b.Cap/8 { + t.Errorf("mean Delay(5)=%v much smaller than Cap/8=%v; jitter range looks broken", high, b.Cap/8) + } +} + +func TestBackoff_Defaults(t *testing.T) { + t.Parallel() + + b := mcp.Backoff{} + d := b.Delay(0) + if d < 0 || d > 30*time.Second { + t.Errorf("default Delay(0) = %v; out of bounds", d) + } +} + +func TestResources_UnknownServer(t *testing.T) { + t.Parallel() + + transports := map[string]sdkmcp.Transport{ + "x": fakeServer(t, "x", nil), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + return transports[cfg.Name], nil + } + + c, err := mcp.New([]mcp.ServerConfig{{Name: "x", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}}, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Connect(t.Context()); err != nil { + t.Fatalf("Connect: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + + _, err = c.Resources(t.Context(), "missing") + if !errors.Is(err, mcp.ErrUnknownServer) { + t.Errorf("Resources err = %v; want ErrUnknownServer", err) + } + _, err = c.Prompts(t.Context(), "missing") + if !errors.Is(err, mcp.ErrUnknownServer) { + t.Errorf("Prompts err = %v; want ErrUnknownServer", err) + } +} + +func TestConnect_TwiceErrors(t *testing.T) { + t.Parallel() + + transports := map[string]sdkmcp.Transport{ + "x": fakeServer(t, "x", nil), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + return transports[cfg.Name], nil + } + + c, err := mcp.New([]mcp.ServerConfig{{Name: "x", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}}, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Connect(t.Context()); err != nil { + t.Fatalf("first Connect: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + + if err := c.Connect(t.Context()); err == nil { + t.Error("second Connect returned nil; want error") + } +} + +// TestConnect_FailedConnectLatchesGate pins the contract: a Client +// that failed to Connect is terminal — a retry returns ErrConfig. +// Callers must construct a fresh Client to retry. +func TestConnect_FailedConnectLatchesGate(t *testing.T) { + t.Parallel() + + dialer := func(_ context.Context, _ mcp.ServerConfig) (sdkmcp.Transport, error) { + return nil, errors.New("dial refused") + } + c, err := mcp.New([]mcp.ServerConfig{ + {Name: "x", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}, + }, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Connect(t.Context()); err == nil { + t.Fatal("first Connect returned nil; want dial error") + } + if err := c.Connect(t.Context()); !errors.Is(err, mcp.ErrConfig) { + t.Errorf("second Connect after failure: err = %v; want errors.Is ErrConfig", err) + } +} + +// TestConnect_AfterCloseRejected pins the documented lifecycle: a +// closed Client refuses Connect, and Close also clears the catalog. +func TestConnect_AfterCloseRejected(t *testing.T) { + t.Parallel() + + transports := map[string]sdkmcp.Transport{ + "x": fakeServer(t, "x", []sdkmcp.Tool{{Name: "tool"}}), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + return transports[cfg.Name], nil + } + + c, err := mcp.New([]mcp.ServerConfig{{Name: "x", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}}, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Connect(t.Context()); err != nil { + t.Fatalf("Connect: %v", err) + } + if got := len(c.Catalog().Tools); got != 1 { + t.Fatalf("pre-Close Catalog.Tools = %d; want 1", got) + } + + if err := c.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + // After Close, Catalog must be empty, Call must report + // ErrUnknownServer, and Connect must report ErrConfig. + if got := len(c.Catalog().Tools); got != 0 { + t.Errorf("post-Close Catalog.Tools = %d; want 0 (catalog must be cleared)", got) + } + if _, err := c.Call(t.Context(), mcp.ToolCall{Name: "x__tool"}); !errors.Is(err, mcp.ErrUnknownServer) { + t.Errorf("post-Close Call err = %v; want ErrUnknownServer", err) + } + if err := c.Connect(t.Context()); !errors.Is(err, mcp.ErrConfig) { + t.Errorf("post-Close Connect err = %v; want ErrConfig", err) + } +} + +func TestCatalog_DeterministicOrdering(t *testing.T) { + t.Parallel() + + // Two servers with multiple tools each, registered in arbitrary + // order. The catalog must come out sorted by (server, bare-tool- + // name) every time, even though the underlying sessions map is + // randomized. + transports := map[string]sdkmcp.Transport{ + "alpha": fakeServer(t, "alpha", []sdkmcp.Tool{ + {Name: "zebra"}, + {Name: "antelope"}, + {Name: "moose"}, + }), + "bravo": fakeServer(t, "bravo", []sdkmcp.Tool{ + {Name: "yak"}, + {Name: "elk"}, + }), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + return transports[cfg.Name], nil + } + + c, err := mcp.New([]mcp.ServerConfig{ + {Name: "alpha", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}, + {Name: "bravo", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}, + }, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Connect(t.Context()); err != nil { + t.Fatalf("Connect: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + + want := []string{ + "alpha__antelope", "alpha__moose", "alpha__zebra", + "bravo__elk", "bravo__yak", + } + cat := c.Catalog() + got := make([]string, len(cat.Tools)) + for i, tool := range cat.Tools { + got[i] = tool.Name + } + if len(got) != len(want) { + t.Fatalf("len(tools) = %d; want %d (full=%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("tools[%d] = %q; want %q (full got=%v)", i, got[i], want[i], got) + } + } +} + +// fakeServerWithPageSize wires a server whose ListTools/ListResources/ +// ListPrompts pagination cap is set explicitly. Used to verify the +// client iterators traverse multiple pages. +// +// Same goroutine-vs-test-end discipline as fakeServer. +func fakeServerWithPageSize(t *testing.T, name string, pageSize int, tools []sdkmcp.Tool) sdkmcp.Transport { + t.Helper() + server := sdkmcp.NewServer( + &sdkmcp.Implementation{Name: name, Version: "v0.0.0"}, + &sdkmcp.ServerOptions{PageSize: pageSize}, + ) + for _, tt := range tools { + sdkmcp.AddTool(server, &tt, func(_ context.Context, _ *sdkmcp.CallToolRequest, args map[string]any) (*sdkmcp.CallToolResult, any, error) { + return &sdkmcp.CallToolResult{ + Content: []sdkmcp.Content{ + &sdkmcp.TextContent{Text: "ok " + tt.Name + " " + sprintArgs(args)}, + }, + }, nil, nil + }) + } + clientT, serverT := sdkmcp.NewInMemoryTransports() + var wg sync.WaitGroup + wg.Add(1) + connectErr := make(chan error, 1) + go func() { + defer wg.Done() + _, err := server.Connect(context.Background(), serverT, nil) + connectErr <- err + }() + t.Cleanup(func() { + wg.Wait() + if err := <-connectErr; err != nil { + t.Logf("server.Connect: %v", err) + } + }) + return clientT +} + +func TestConnect_PaginatesAcrossMultiplePages(t *testing.T) { + t.Parallel() + + // Force the server to break tools across pages of size 1 so the + // client iterator must traverse multiple pages to enumerate them. + transports := map[string]sdkmcp.Transport{ + "big": fakeServerWithPageSize(t, "big", 1, []sdkmcp.Tool{ + {Name: "tool_a"}, + {Name: "tool_b"}, + {Name: "tool_c"}, + {Name: "tool_d"}, + {Name: "tool_e"}, + }), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + return transports[cfg.Name], nil + } + + c, err := mcp.New([]mcp.ServerConfig{ + {Name: "big", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}, + }, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Connect(t.Context()); err != nil { + t.Fatalf("Connect: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + + cat := c.Catalog() + if got, want := len(cat.Tools), 5; got != want { + t.Errorf("len(Tools) = %d across paginated pages; want %d", got, want) + } + want := []string{"big__tool_a", "big__tool_b", "big__tool_c", "big__tool_d", "big__tool_e"} + for i, w := range want { + if i >= len(cat.Tools) || cat.Tools[i].Name != w { + t.Errorf("Tools[%d] = %v; want %q (full=%v)", i, safeName(cat.Tools, i), w, toolNames(cat.Tools)) + break + } + } +} + +func safeName(ts []mcp.Tool, i int) string { + if i < 0 || i >= len(ts) { + return "" + } + return ts[i].Name +} + +func toolNames(ts []mcp.Tool) []string { + out := make([]string, len(ts)) + for i, t := range ts { + out[i] = t.Name + } + return out +} + +func TestCatalog_DefensiveCopy(t *testing.T) { + t.Parallel() + + transports := map[string]sdkmcp.Transport{ + "x": fakeServer(t, "x", []sdkmcp.Tool{{Name: "tool"}}), + } + dialer := func(_ context.Context, cfg mcp.ServerConfig) (sdkmcp.Transport, error) { + return transports[cfg.Name], nil + } + + c, err := mcp.New([]mcp.ServerConfig{{Name: "x", Transport: mcp.TransportStreamableHTTP, Endpoint: "http://x"}}, mcp.WithDialer(dialer)) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Connect(t.Context()); err != nil { + t.Fatalf("Connect: %v", err) + } + t.Cleanup(func() { _ = c.Close() }) + + cat := c.Catalog() + if len(cat.Tools) != 1 { + t.Fatalf("len(Tools) = %d; want 1", len(cat.Tools)) + } + cat.Tools[0].Name = "MUTATED" + delete(cat.ToolsByServer, "x") + + cat2 := c.Catalog() + if cat2.Tools[0].Name == "MUTATED" { + t.Errorf("Tools mutation propagated to internal state") + } + if _, ok := cat2.ToolsByServer["x"]; !ok { + t.Errorf("ToolsByServer mutation propagated to internal state") + } +} diff --git a/core/mcp/fuzz_test.go b/core/mcp/fuzz_test.go new file mode 100644 index 0000000..46754e4 --- /dev/null +++ b/core/mcp/fuzz_test.go @@ -0,0 +1,57 @@ +package mcp_test + +import ( + "errors" + "testing" + + "github.com/plexara/plexara-agents/core/mcp" +) + +// FuzzSplitName covers the namespaced-name parser against arbitrary +// input. The contract: SplitName never panics, never returns inputs +// that fail to round-trip through JoinName. +func FuzzSplitName(f *testing.F) { + seeds := []string{ + "plexara__datahub_search", + "a__b", + "a__b__c", + "", + "__", + "__tool", + "server__", + "no-separator", + "a__b__", + "___", + "____", + "\x00\x01__\x02", + } + for _, s := range seeds { + f.Add(s) + } + + f.Fuzz(func(t *testing.T, in string) { + server, bare, err := mcp.SplitName(in) + if err != nil { + // Errors are expected; ensure they chain to the public + // sentinel rather than carrying an ad-hoc message. + if !errors.Is(err, mcp.ErrInvalidName) { + t.Errorf("err = %v; want errors.Is ErrInvalidName", err) + } + return + } + // On success, both halves must be non-empty and rejoining must + // preserve the *first* server segment exactly. + if server == "" || bare == "" { + t.Errorf("SplitName(%q) returned empty server=%q or bare=%q", in, server, bare) + } + joined := mcp.JoinName(server, bare) + s2, b2, err := mcp.SplitName(joined) + if err != nil { + t.Errorf("re-Split of JoinName(%q,%q)=%q failed: %v", server, bare, joined, err) + return + } + if s2 != server || b2 != bare { + t.Errorf("round-trip diverged: %q,%q -> %q -> %q,%q", server, bare, joined, s2, b2) + } + }) +} diff --git a/core/session/fuzz_test.go b/core/session/fuzz_test.go new file mode 100644 index 0000000..fcd503b --- /dev/null +++ b/core/session/fuzz_test.go @@ -0,0 +1,54 @@ +package session_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/plexara/plexara-agents/core/session" +) + +// FuzzLoad covers the JSON Lines decoder against arbitrary input. +// Contract: Load never panics; if it returns no error, the loaded +// Session must round-trip through Save back to a Load that returns +// the same Session. +// +// Fuzz seeds use a fixed timestamp (not time.Now()) so the corpus +// stays reproducible across runs. A regression caught today must +// still be reproducible by the same seed string tomorrow. +func FuzzLoad(f *testing.F) { + const fixedHeader = `{"type":"session.header","id":"x","created":"2026-01-01T00:00:00Z","updated":"2026-01-01T00:00:00Z"}` + "\n" + seeds := []string{ + fixedHeader + `{"type":"session.message","role":"user","content":"hi"}` + "\n", + fixedHeader, + fixedHeader + `{"type":"session.future_kind","x":1}` + "\n", + ``, + `{"type":"session.header","id":"x"}` + "\n", + `{"type":"session.message"}` + "\n", + `{`, + "\x00\x01\x02", + } + for _, s := range seeds { + f.Add(s) + } + + f.Fuzz(func(t *testing.T, raw string) { + s, err := session.Load(strings.NewReader(raw)) + if err != nil { + return + } + // Round-trip: Save then Load again must yield equivalent session. + var buf bytes.Buffer + if err := s.Save(&buf); err != nil { + t.Fatalf("Save after Load: %v", err) + } + s2, err := session.Load(&buf) + if err != nil { + t.Fatalf("Load after Save: %v\ninput: %q", err, raw) + } + if s2.ID != s.ID || len(s2.Messages) != len(s.Messages) { + t.Errorf("round-trip diverged: ID %q vs %q, messages %d vs %d", + s.ID, s2.ID, len(s.Messages), len(s2.Messages)) + } + }) +} diff --git a/core/session/session.go b/core/session/session.go new file mode 100644 index 0000000..e4f6161 --- /dev/null +++ b/core/session/session.go @@ -0,0 +1,324 @@ +// Package session is an append-only chat history with replay-friendly +// persistence. +// +// A [Session] is a value: copy it, ship it across goroutines, save it +// to disk. Persistence is JSON Lines — one header line followed by +// one envelope per message. The format is stable and forward-readable +// so that a future reader can ignore unknown envelope types without +// failing. +// +// Token-aware truncation lives behind a [Tokenizer] interface. v1 ships +// a simple byte-length heuristic ([LengthHeuristic]); a real +// model-aware tokenizer can swap in for v1.1 without touching callers. +package session + +import ( + "bufio" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/plexara/plexara-agents/core/provider" +) + +// Session is an append-only chat history. +type Session struct { + ID string + Created time.Time + Updated time.Time + Messages []provider.Message +} + +// New constructs an empty Session with the given ID. ID is the +// caller's responsibility to choose; using a UUID is conventional but +// not required. +func New(id string) *Session { + now := time.Now().UTC() + return &Session{ + ID: id, + Created: now, + Updated: now, + Messages: nil, + } +} + +// Append adds a message and bumps Updated. +func (s *Session) Append(m provider.Message) { + s.Messages = append(s.Messages, m) + s.Updated = time.Now().UTC() +} + +// Tokenizer estimates token count for a message. +// +// v1 uses [LengthHeuristic]; v1.1 may swap in a real model-aware +// tokenizer (e.g. via a /tokenize HTTP call to the runtime) without +// touching callers. +type Tokenizer interface { + CountTokens(provider.Message) int +} + +// LengthHeuristic estimates tokens as ceil(content-bytes / +// BytesPerToken). It is intentionally simple; the caller should treat +// its output as a *budget* not an oracle. +type LengthHeuristic struct { + // BytesPerToken is the rough characters-per-token ratio. + // Defaults to 4 (typical for English text under common BPE-style + // tokenizers). + BytesPerToken int +} + +// CountTokens implements [Tokenizer]. +func (h LengthHeuristic) CountTokens(m provider.Message) int { + bp := h.BytesPerToken + if bp <= 0 { + bp = 4 + } + bytes := len(m.Content) + len(m.ToolCallID) + for _, tc := range m.ToolCalls { + bytes += len(tc.ID) + len(tc.Name) + len(tc.Arguments) + } + if bytes == 0 { + return 1 // every message costs at least one token of envelope. + } + return (bytes + bp - 1) / bp +} + +// Truncate drops the oldest non-system messages until the total token +// budget is at or below maxTokens. The first message is preserved if +// it is a system message (so the model still sees its instructions). +// Multiple leading system messages are NOT all preserved; only the +// first one is. +// +// The system prompt is never dropped, even when it alone exceeds +// maxTokens — the result in that case is a session containing only +// the system prompt (still over budget; the caller must shrink their +// prompt to fit). +// +// If a single non-system message exceeds maxTokens on its own, it is +// dropped along with everything before it. The session may be left +// **empty** (no Messages at all when there is no leading system +// message) or with only the system prompt. Callers persisting large +// tool results should size their budget with that in mind, and +// callers who immediately persist after Truncate should be prepared +// to write out an empty conversation. +// +// If t is nil, [LengthHeuristic] is used. +func (s *Session) Truncate(maxTokens int, t Tokenizer) { + if maxTokens <= 0 || len(s.Messages) == 0 { + return + } + if t == nil { + t = LengthHeuristic{} + } + + // Identify a single leading system message to preserve. + systemIdx := -1 + if len(s.Messages) > 0 && s.Messages[0].Role == provider.RoleSystem { + systemIdx = 0 + } + + total := 0 + for _, m := range s.Messages { + total += t.CountTokens(m) + } + if total <= maxTokens { + return + } + + // Drop oldest (after system, if any) until under budget. + keepFrom := 0 + if systemIdx == 0 { + keepFrom = 1 + } + for total > maxTokens && keepFrom < len(s.Messages) { + total -= t.CountTokens(s.Messages[keepFrom]) + keepFrom++ + } + + // Reassemble: system (if any) + tail starting at keepFrom. + var kept []provider.Message + if systemIdx == 0 { + kept = append(kept, s.Messages[0]) + } + if keepFrom < len(s.Messages) { + kept = append(kept, s.Messages[keepFrom:]...) + } + s.Messages = kept + s.Updated = time.Now().UTC() +} + +// Wire envelope types. Each line in the persistence stream is one of +// these tagged objects; the leading "type" field discriminates. +const ( + envelopeHeader = "session.header" + envelopeMessage = "session.message" +) + +type headerEnvelope struct { + Type string `json:"type"` + ID string `json:"id"` + Created time.Time `json:"created"` + Updated time.Time `json:"updated"` +} + +// messageEnvelope MUST mirror every field of [provider.Message] that +// belongs in a persisted session. Adding a field to provider.Message +// requires updating this envelope (and the Save / Load mappings) in +// the same change; otherwise the on-disk format silently loses data. +// +// Tool calls go through [toolCallEnvelope] rather than embedding the +// upstream [provider.ToolCall] directly because the upstream type has +// no JSON tags — relying on its default Pascal-case marshaling would +// freeze a wire format that no other reader could interoperate with. +type messageEnvelope struct { + Type string `json:"type"` + Role provider.Role `json:"role"` + Content string `json:"content,omitempty"` + ToolCalls []toolCallEnvelope `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// toolCallEnvelope mirrors [provider.ToolCall] with explicit +// snake_case JSON tags so the persisted format is stable and matches +// the OpenAI Chat Completions schema providers expect. +type toolCallEnvelope struct { + ID string `json:"id"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments,omitempty"` +} + +func toolCallsToEnvelopes(tcs []provider.ToolCall) []toolCallEnvelope { + if len(tcs) == 0 { + return nil + } + out := make([]toolCallEnvelope, len(tcs)) + for i, tc := range tcs { + out[i] = toolCallEnvelope{ID: tc.ID, Name: tc.Name, Arguments: tc.Arguments} + } + return out +} + +func envelopesToToolCalls(envs []toolCallEnvelope) []provider.ToolCall { + if len(envs) == 0 { + return nil + } + out := make([]provider.ToolCall, len(envs)) + for i, e := range envs { + out[i] = provider.ToolCall{ID: e.ID, Name: e.Name, Arguments: e.Arguments} + } + return out +} + +// Save writes the session as JSON Lines: a header line followed by +// one message line per [provider.Message] in [Session.Messages]. +func (s *Session) Save(w io.Writer) error { + bw := bufio.NewWriter(w) + header := headerEnvelope{ + Type: envelopeHeader, ID: s.ID, Created: s.Created, Updated: s.Updated, + } + if err := writeLine(bw, header); err != nil { + return fmt.Errorf("write header: %w", err) + } + for i, m := range s.Messages { + env := messageEnvelope{ + Type: envelopeMessage, + Role: m.Role, + Content: m.Content, + ToolCalls: toolCallsToEnvelopes(m.ToolCalls), + ToolCallID: m.ToolCallID, + } + if err := writeLine(bw, env); err != nil { + return fmt.Errorf("write message %d: %w", i, err) + } + } + return bw.Flush() +} + +func writeLine(w io.Writer, v any) error { + b, err := json.Marshal(v) + if err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + _, err = w.Write([]byte{'\n'}) + return err +} + +// Sentinel errors callers may match with [errors.Is]. +var ( + // ErrFormat is returned when the on-disk format is malformed in a + // way that prevents a usable session from being recovered. + ErrFormat = errors.New("session: malformed persistence format") +) + +// Load reads a session from JSON Lines. The first line must be a +// header envelope; subsequent lines are messages or unknown envelopes +// (which are skipped, so a future writer can add new envelope types +// without breaking forward compatibility). +func Load(r io.Reader) (*Session, error) { + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, 0, 64*1024), 4*1024*1024) + + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read header: %w", err) + } + return nil, fmt.Errorf("%w: empty input", ErrFormat) + } + var head headerEnvelope + if err := json.Unmarshal(scanner.Bytes(), &head); err != nil { + return nil, fmt.Errorf("%w: header: %w", ErrFormat, err) + } + if head.Type != envelopeHeader { + return nil, fmt.Errorf("%w: first envelope type = %q, want %q", ErrFormat, head.Type, envelopeHeader) + } + s := &Session{ + ID: head.ID, + Created: head.Created, + Updated: head.Updated, + } + lineNo := 1 + for scanner.Scan() { + lineNo++ + line := scanner.Bytes() + if len(strings.TrimSpace(string(line))) == 0 { + continue + } + var probe struct { + Type string `json:"type"` + } + if err := json.Unmarshal(line, &probe); err != nil { + return nil, fmt.Errorf("%w: line %d: %w", ErrFormat, lineNo, err) + } + switch probe.Type { + case envelopeMessage: + var env messageEnvelope + if err := json.Unmarshal(line, &env); err != nil { + return nil, fmt.Errorf("%w: line %d: decode message: %w", ErrFormat, lineNo, err) + } + s.Messages = append(s.Messages, provider.Message{ + Role: env.Role, + Content: env.Content, + ToolCalls: envelopesToToolCalls(env.ToolCalls), + ToolCallID: env.ToolCallID, + }) + case "": + // A line whose JSON parses but lacks a "type" field is a + // producer bug, not a forward-compat case. Loud-fail rather + // than silently dropping data. + return nil, fmt.Errorf("%w: line %d: envelope missing required \"type\" field", ErrFormat, lineNo) + default: + // Forward-compat: skip unknown (non-empty) types silently. A + // future writer can introduce new types without breaking us. + } + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("read body: %w", err) + } + return s, nil +} diff --git a/core/session/session_test.go b/core/session/session_test.go new file mode 100644 index 0000000..0294317 --- /dev/null +++ b/core/session/session_test.go @@ -0,0 +1,325 @@ +package session_test + +import ( + "bytes" + "encoding/json" + "errors" + "strings" + "testing" + "time" + + "github.com/plexara/plexara-agents/core/provider" + "github.com/plexara/plexara-agents/core/session" +) + +func TestNew(t *testing.T) { + t.Parallel() + + s := session.New("abc") + if s.ID != "abc" { + t.Errorf("ID = %q; want abc", s.ID) + } + if s.Created.IsZero() || s.Updated.IsZero() { + t.Errorf("timestamps are zero") + } + if !s.Created.Equal(s.Updated) { + t.Errorf("Created != Updated at construction") + } + if s.Messages != nil { + t.Errorf("Messages = %v; want nil", s.Messages) + } +} + +func TestAppend_BumpsUpdated(t *testing.T) { + t.Parallel() + + s := session.New("x") + original := s.Updated + time.Sleep(2 * time.Millisecond) + s.Append(provider.Message{Role: provider.RoleUser, Content: "hi"}) + if !s.Updated.After(original) { + t.Errorf("Updated did not advance: was %v, now %v", original, s.Updated) + } + if len(s.Messages) != 1 { + t.Errorf("len(Messages) = %d; want 1", len(s.Messages)) + } +} + +func TestSaveLoad_RoundTrip(t *testing.T) { + t.Parallel() + + s := session.New("session-1") + s.Append(provider.Message{Role: provider.RoleSystem, Content: "be terse"}) + s.Append(provider.Message{Role: provider.RoleUser, Content: "what's the weather"}) + s.Append(provider.Message{ + Role: provider.RoleAssistant, + ToolCalls: []provider.ToolCall{{ + ID: "call_1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"NYC"}`), + }}, + }) + s.Append(provider.Message{Role: provider.RoleTool, ToolCallID: "call_1", Content: "72F sunny"}) + + var buf bytes.Buffer + if err := s.Save(&buf); err != nil { + t.Fatalf("Save: %v", err) + } + + loaded, err := session.Load(&buf) + if err != nil { + t.Fatalf("Load: %v", err) + } + + if loaded.ID != s.ID { + t.Errorf("ID = %q; want %q", loaded.ID, s.ID) + } + if !loaded.Created.Equal(s.Created) { + t.Errorf("Created = %v; want %v", loaded.Created, s.Created) + } + if len(loaded.Messages) != len(s.Messages) { + t.Fatalf("len(Messages) = %d; want %d", len(loaded.Messages), len(s.Messages)) + } + for i := range s.Messages { + got := loaded.Messages[i] + want := s.Messages[i] + if got.Role != want.Role || got.Content != want.Content || got.ToolCallID != want.ToolCallID { + t.Errorf("Messages[%d] mismatch: got %#v want %#v", i, got, want) + } + if len(got.ToolCalls) != len(want.ToolCalls) { + t.Errorf("Messages[%d] ToolCalls len mismatch: %d vs %d", i, len(got.ToolCalls), len(want.ToolCalls)) + } + } +} + +func TestLoad_Errors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in string + }{ + {name: "empty", in: ""}, + {name: "non_header_first", in: `{"type":"session.message","role":"user"}` + "\n"}, + {name: "header_not_json", in: "not-json\n"}, + {name: "garbage_line", in: header() + "not-json\n"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := session.Load(strings.NewReader(tt.in)) + if err == nil { + t.Fatalf("Load(%q) returned nil; want error", tt.name) + } + if !errors.Is(err, session.ErrFormat) { + t.Errorf("err = %v; want errors.Is ErrFormat", err) + } + }) + } +} + +func TestLoad_SkipsUnknownEnvelopes(t *testing.T) { + t.Parallel() + + in := header() + + `{"type":"session.message","role":"user","content":"hi"}` + "\n" + + `{"type":"session.future_kind","payload":{"x":1}}` + "\n" + + `{"type":"session.message","role":"assistant","content":"hello"}` + "\n" + + s, err := session.Load(strings.NewReader(in)) + if err != nil { + t.Fatalf("Load: %v", err) + } + if len(s.Messages) != 2 { + t.Errorf("len(Messages) = %d; want 2 (unknown envelope dropped)", len(s.Messages)) + } +} + +// TestLoad_RejectsMissingType pins the contract that unknown FUTURE +// envelope types are skipped silently for forward compat, but a line +// with NO "type" field is a producer bug and must be loud. +func TestLoad_RejectsMissingType(t *testing.T) { + t.Parallel() + + in := header() + `{"role":"user","content":"hi"}` + "\n" + _, err := session.Load(strings.NewReader(in)) + if err == nil { + t.Fatal("Load returned nil; want ErrFormat for missing type field") + } + if !errors.Is(err, session.ErrFormat) { + t.Errorf("err = %v; want errors.Is ErrFormat", err) + } + if !strings.Contains(err.Error(), "missing required") { + t.Errorf("err = %v; want it to mention the missing field", err) + } +} + +// TestSave_PersistedFormatUsesSnakeCase pins the on-disk wire format +// so future readers (other tools, other languages) interoperate. The +// envelope-mirror types in session.go MUST keep snake_case keys; this +// test fails if anyone removes a JSON tag. +func TestSave_PersistedFormatUsesSnakeCase(t *testing.T) { + t.Parallel() + + s := session.New("x") + s.Append(provider.Message{ + Role: provider.RoleAssistant, + ToolCalls: []provider.ToolCall{{ + ID: "call_1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"NYC"}`), + }}, + }) + var buf bytes.Buffer + if err := s.Save(&buf); err != nil { + t.Fatalf("Save: %v", err) + } + out := buf.String() + + // Required snake_case keys somewhere in the output. + for _, want := range []string{ + `"type":"session.header"`, + `"created":"`, `"updated":"`, + `"type":"session.message"`, + `"role":"assistant"`, + `"tool_calls":[{`, + `"id":"call_1"`, + `"name":"get_weather"`, + `"arguments":{"city":"NYC"}`, + } { + if !strings.Contains(out, want) { + t.Errorf("Save output missing %q; full=%s", want, out) + } + } + // Forbid Pascal-case leakage from upstream provider.ToolCall. + for _, bad := range []string{`"ID":`, `"Name":`, `"Arguments":`} { + if strings.Contains(out, bad) { + t.Errorf("Save output leaked Pascal-case key %q; full=%s", bad, out) + } + } +} + +func TestLoad_SkipsBlankLines(t *testing.T) { + t.Parallel() + + in := header() + "\n" + + `{"type":"session.message","role":"user","content":"hi"}` + "\n\n" + + s, err := session.Load(strings.NewReader(in)) + if err != nil { + t.Fatalf("Load: %v", err) + } + if len(s.Messages) != 1 { + t.Errorf("len(Messages) = %d; want 1", len(s.Messages)) + } +} + +func TestLengthHeuristic_DefaultsAndZero(t *testing.T) { + t.Parallel() + + h := session.LengthHeuristic{} + if got := h.CountTokens(provider.Message{}); got != 1 { + t.Errorf("empty message tokens = %d; want 1 (envelope cost)", got) + } + // 16-byte content with default BytesPerToken=4 -> 4 tokens. + if got := h.CountTokens(provider.Message{Content: strings.Repeat("a", 16)}); got != 4 { + t.Errorf("16-byte tokens = %d; want 4", got) + } + // Custom ratio. + h2 := session.LengthHeuristic{BytesPerToken: 8} + if got := h2.CountTokens(provider.Message{Content: strings.Repeat("a", 16)}); got != 2 { + t.Errorf("16-byte at 8bpt = %d; want 2", got) + } +} + +func TestTruncate_DropsOldestNonSystem(t *testing.T) { + t.Parallel() + + // LengthHeuristic{} = ceil(content-bytes / 4): + // "be terse" (8) -> 2 tokens + // "first" (5) -> 2 tokens + // "answer-1" (8) -> 2 tokens + // "second" (6) -> 2 tokens + // "answer-2" (8) -> 2 tokens + // Total 10. A budget of 7 forces dropping 2 of the 4 non-system + // messages, oldest-first. + s := session.New("x") + s.Append(provider.Message{Role: provider.RoleSystem, Content: "be terse"}) + s.Append(provider.Message{Role: provider.RoleUser, Content: "first"}) + s.Append(provider.Message{Role: provider.RoleAssistant, Content: "answer-1"}) + s.Append(provider.Message{Role: provider.RoleUser, Content: "second"}) + s.Append(provider.Message{Role: provider.RoleAssistant, Content: "answer-2"}) + + s.Truncate(7, session.LengthHeuristic{}) + + if len(s.Messages) < 2 || s.Messages[0].Role != provider.RoleSystem { + t.Fatalf("expected system + at least one tail message; got %+v", roles(s.Messages)) + } + // The "first" user message must be dropped. + for _, m := range s.Messages { + if m.Content == "first" { + t.Errorf("oldest user message survived truncation") + } + } +} + +func TestTruncate_NoOpUnderBudget(t *testing.T) { + t.Parallel() + + s := session.New("x") + s.Append(provider.Message{Role: provider.RoleUser, Content: "hi"}) + s.Truncate(1000, session.LengthHeuristic{}) + if len(s.Messages) != 1 { + t.Errorf("under-budget Truncate dropped messages: %v", roles(s.Messages)) + } +} + +func TestTruncate_ZeroBudgetIsNoOp(t *testing.T) { + t.Parallel() + + s := session.New("x") + s.Append(provider.Message{Role: provider.RoleUser, Content: "hi"}) + s.Truncate(0, session.LengthHeuristic{}) + s.Truncate(-5, session.LengthHeuristic{}) + if len(s.Messages) != 1 { + t.Errorf("non-positive budget should not mutate; got %d msgs", len(s.Messages)) + } +} + +func TestTruncate_NilTokenizerDefaults(t *testing.T) { + t.Parallel() + + s := session.New("x") + s.Append(provider.Message{Role: provider.RoleUser, Content: "hi"}) + // Should not panic; behaves like the default LengthHeuristic. + s.Truncate(1000, nil) + if len(s.Messages) != 1 { + t.Errorf("nil tokenizer changed message count") + } +} + +func TestTruncate_PreservesSystemEvenWhenSingleMessage(t *testing.T) { + t.Parallel() + + s := session.New("x") + s.Append(provider.Message{Role: provider.RoleSystem, Content: strings.Repeat("a", 1000)}) + s.Append(provider.Message{Role: provider.RoleUser, Content: "hi"}) + + // Budget far below system message size — system should still survive. + s.Truncate(5, session.LengthHeuristic{}) + + if len(s.Messages) == 0 || s.Messages[0].Role != provider.RoleSystem { + t.Errorf("system message was dropped under tiny budget; got %v", roles(s.Messages)) + } +} + +// --- helpers --- + +func header() string { + now := time.Now().UTC().Format(time.RFC3339Nano) + return `{"type":"session.header","id":"x","created":"` + now + `","updated":"` + now + `"}` + "\n" +} + +func roles(ms []provider.Message) []provider.Role { + out := make([]provider.Role, len(ms)) + for i, m := range ms { + out[i] = m.Role + } + return out +} diff --git a/go.mod b/go.mod index 4625a1e..e0ec4af 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,11 @@ tool ( golang.org/x/vuln/cmd/govulncheck ) +require ( + github.com/modelcontextprotocol/go-sdk v1.6.0 + golang.org/x/sync v0.20.0 +) + require ( 4d63.com/gocheckcompilerdirectives v1.3.0 // indirect 4d63.com/gochecknoglobals v0.2.2 // indirect @@ -113,6 +118,7 @@ require ( github.com/golangci/unconvert v0.0.0-20250410112200-a129a6e6413e // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-licenses/v2 v2.0.1 // indirect + github.com/google/jsonschema-go v0.4.3 // indirect github.com/google/licenseclassifier/v2 v2.0.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect @@ -195,6 +201,8 @@ require ( github.com/sashamelentyev/interfacebloat v1.1.0 // indirect github.com/sashamelentyev/usestdlibvars v1.29.0 // indirect github.com/securego/gosec/v2 v2.26.1 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect github.com/sergi/go-diff v1.2.0 // indirect github.com/sirupsen/logrus v1.9.4 // indirect github.com/sivchari/containedctx v1.0.3 // indirect @@ -230,6 +238,7 @@ require ( github.com/yagipy/maintidx v1.0.0 // indirect github.com/yeya24/promlinter v0.3.0 // indirect github.com/ykadowak/zerologlint v0.1.5 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect gitlab.com/bosi/decorder v0.4.2 // indirect go-simpler.org/musttag v0.14.0 // indirect go-simpler.org/sloglint v0.12.0 // indirect @@ -248,7 +257,7 @@ require ( golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358 // indirect golang.org/x/mod v0.35.0 // indirect golang.org/x/net v0.53.0 // indirect - golang.org/x/sync v0.20.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sys v0.43.0 // indirect golang.org/x/telemetry v0.0.0-20260421165255-392afab6f40e // indirect golang.org/x/text v0.36.0 // indirect diff --git a/go.sum b/go.sum index d46e5d9..f503e49 100644 --- a/go.sum +++ b/go.sum @@ -258,6 +258,8 @@ github.com/godoc-lint/godoc-lint v0.11.2/go.mod h1:iVpGdL1JCikNH2gGeAn3Hh+AgN5Gx github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -336,6 +338,8 @@ github.com/google/go-licenses/v2 v2.0.1/go.mod h1:efibo0EDNGkau6AIMOViGW+rTNPudh github.com/google/go-replayers/httpreplay v1.2.0 h1:VM1wEyyjaoU53BwrOnaf9VhAyQQEEioJvFYxYcLRKzk= github.com/google/go-replayers/httpreplay v1.2.0/go.mod h1:WahEFFZZ7a1P4VM1qEeHy+tME4bwyqPcwWbNlUI1Mcg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0= +github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/google/licenseclassifier/v2 v2.0.0 h1:1Y57HHILNf4m0ABuMVb6xk4vAJYEUO0gDxNpog0pyeA= github.com/google/licenseclassifier/v2 v2.0.0/go.mod h1:cOjbdH0kyC9R22sdQbYsFkto4NGCAc+ZSwbeThazEtM= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= @@ -492,6 +496,8 @@ github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY= +github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -591,6 +597,10 @@ github.com/sashamelentyev/usestdlibvars v1.29.0 h1:8J0MoRrw4/NAXtjQqTHrbW9NN+3iM github.com/sashamelentyev/usestdlibvars v1.29.0/go.mod h1:8PpnjHMk5VdeWlVb4wCdrB8PNbLqZ3wBZTZWkrpZZL8= github.com/securego/gosec/v2 v2.26.1 h1:gdkttGhQFVehqRJ8grKH4DrpqM/QlPKNHBnl8QgcEC4= github.com/securego/gosec/v2 v2.26.1/go.mod h1:57UW4p0uoP3kxoTkhoo3axLdVAi+OWrLg/Ax/kdqtPE= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= @@ -683,6 +693,8 @@ github.com/yeya24/promlinter v0.3.0 h1:JVDbMp08lVCP7Y6NP3qHroGAO6z2yGKQtS5Jsjqto github.com/yeya24/promlinter v0.3.0/go.mod h1:cDfJQQYv9uYciW60QT0eeHlFodotkYZlL+YcPQN+mW4= github.com/ykadowak/zerologlint v0.1.5 h1:Gy/fMz1dFQN9JZTPjv1hxEk+sRWm05row04Yoolgdiw= github.com/ykadowak/zerologlint v0.1.5/go.mod h1:KaUskqF3e/v59oPmdq1U1DnKcuHokl2/K1U4pmIELKg= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -824,6 +836,8 @@ golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4Iltr golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= +golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=