diff --git a/.githooks/pre-push b/.githooks/pre-push new file mode 100755 index 0000000..e0f4d45 --- /dev/null +++ b/.githooks/pre-push @@ -0,0 +1,67 @@ +#!/usr/bin/env bash +# +# Local pre-push hook. Runs `make verify` before every push so that +# code which would fail CI never leaves the workstation. +# +# Enable once per clone: +# +# git config core.hooksPath .githooks +# +# Bypass for genuine emergencies (then explain in the PR description): +# +# SKIP_VERIFY=1 git push +# +set -euo pipefail + +if [ "${SKIP_VERIFY:-}" = "1" ]; then + echo "[pre-push] SKIP_VERIFY=1 set; skipping make verify." + exit 0 +fi + +# Read the push refs from stdin and decide whether `make verify` is +# worth running. Skip for branch deletes (zero local OID), and skip +# for tag-only pushes (verify gates the source state, which is already +# the same commit the tag points at). +all_zero='0000000000000000000000000000000000000000' +needs_verify=0 +while read -r local_ref local_oid _remote_ref remote_oid; do + [ -z "${local_ref:-}" ] && continue + if [ "$local_oid" = "$all_zero" ]; then + echo "[pre-push] Skipping verify for branch delete: $local_ref" + continue + fi + if [ "${local_ref#refs/tags/}" != "$local_ref" ]; then + echo "[pre-push] Skipping verify for tag push: $local_ref" + continue + fi + needs_verify=1 + echo "[pre-push] Will verify before pushing $local_ref ($local_oid -> $remote_oid)" +done + +if [ "$needs_verify" -eq 0 ]; then + echo "[pre-push] Nothing to verify." + exit 0 +fi + +echo +echo "[pre-push] Running 'make verify'..." +echo "[pre-push] Set SKIP_VERIFY=1 to bypass (do not abuse)." +echo + +if ! command -v make >/dev/null 2>&1; then + echo "[pre-push] 'make' not found on PATH; cannot run verify gate." >&2 + exit 1 +fi + +start=$(date +%s) +if make verify; then + end=$(date +%s) + echo + echo "[pre-push] make verify passed in $((end - start))s. Pushing." + exit 0 +else + echo + echo "[pre-push] make verify FAILED. Push blocked." >&2 + echo "[pre-push] Fix the failures, re-run 'make verify' until clean, then push again." >&2 + exit 1 +fi diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3aa1d60..0eda973 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -131,7 +131,11 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.out flags: core - fail_ci_if_error: true + # Set to true once the repository is registered with Codecov. + # See docs/ci.md "Codecov registration" for the one-time setup. + # Until then, upload failures are non-fatal so they cannot block + # an otherwise-clean PR. + fail_ci_if_error: false build: name: build (${{ matrix.goos }}/${{ matrix.goarch }}) diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index 00e0b1d..762beef 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -83,10 +83,15 @@ jobs: - name: Fuzz shell: bash + # Values from `inputs` and `matrix` flow through env to avoid + # `${{ }}` interpolation inside a shell `run:` block — the + # standard remediation for the run-shell-injection rule. + env: + DURATION: ${{ inputs.duration }} + TARGET: ${{ matrix.target }} run: | - duration="${{ inputs.duration }}" - : "${duration:=5m}" - pkg="${{ matrix.target }}" + duration="${DURATION:-5m}" + pkg="$TARGET" fuzz_pkg="${pkg%%::*}" fuzz_name="${pkg##*::}" echo "Fuzzing $fuzz_name in $fuzz_pkg for $duration" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 29bfe45..bfce937 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -34,15 +34,39 @@ pre-commit install ### Common tasks ```sh +make verify # run every check CI runs that can run locally — REQUIRED before pushing make build # go build ./... make test # go test -race -shuffle=on -count=1 -covermode=atomic ./... make lint # golangci-lint run make sec # gosec + govulncheck make cover # produce coverage.out and a human-readable summary make tidy # go mod tidy +make help # full target list ``` -`make` with no arguments runs `build`, `lint`, and `test`. Run that before pushing. +### Before every push: `make verify` + +`make verify` is the canonical "ready to push" gate. It runs every check the CI pipeline runs that can run locally: + +- `gofmt -l`, `go mod verify`, `go mod tidy -diff` +- `go vet`, `golangci-lint config verify`, `golangci-lint run` +- `go test -race -shuffle=on -count=1 -covermode=atomic` +- Cross-compile build matrix (`darwin/{amd64,arm64}`, `linux/{amd64,arm64}`, `windows/amd64`) +- `gosec`, `govulncheck` +- **Semgrep** in the same Docker image CI uses (skipped with a warning if Docker isn't available locally — but if Docker is available and you don't run it, you cannot prove a clean run) +- 5s fuzz pass per `Fuzz*` target + +**If `make verify` fails, do not push.** Fix the failure, re-run until green. + +### Pre-push hook (recommended) + +Enable the repo's pre-push hook once per clone so `make verify` runs automatically before every `git push`: + +```sh +git config core.hooksPath .githooks +``` + +For genuine emergencies, `SKIP_VERIFY=1 git push` bypasses the hook. Don't abuse it; if you bypass, explain why in the PR description. ## Workflow diff --git a/Makefile b/Makefile index 2ca1486..df014c6 100644 --- a/Makefile +++ b/Makefile @@ -1,17 +1,53 @@ -.PHONY: all build test lint sec cover tidy fmt vet vuln licenses clean help +# Use bash for recipes — fuzz-quick uses process substitution and `pipefail` +# semantics that POSIX /bin/sh doesn't support. +SHELL := /bin/bash +.SHELLFLAGS := -eu -o pipefail -c -GO ?= go -GOFLAGS ?= -PKGS ?= ./... -COVERPROFILE ?= coverage.out +.PHONY: all verify build test cover lint lint-config fmt fmt-check vet \ + sec vuln gosec semgrep semgrep-check tidy tidy-check mod-verify \ + build-matrix licenses fuzz-quick clean help + +GO ?= go +GOFLAGS ?= +PKGS ?= ./... +COVERPROFILE ?= coverage.out +SEMGREP_IMAGE ?= semgrep/semgrep@sha256:326e5f41cc972bb423b764a14febbb62bbad29ee1c01820805d077dd868fea48 +FUZZTIME ?= 5s + +# Build matrix mirrors ci.yml — keep in sync. +BUILD_MATRIX = \ + linux/amd64 \ + linux/arm64 \ + darwin/amd64 \ + darwin/arm64 \ + windows/amd64 # Dev tools are pinned via the `tool` directive in go.mod (Go 1.24+) and # invoked through `go tool `. Run `go mod tidy` after pulling to # materialize them locally; no separate install step is required. +# `make verify` is the canonical "ready to push" gate. It runs every +# check CI runs that can run locally. If this passes, the odds of CI +# failing are very low; if it fails, do not push. +# +# Intentionally CI-only (cannot be reproduced locally without GitHub +# infrastructure): +# - codeql.yml GitHub-hosted CodeQL analysis +# - scorecard.yml OpenSSF Scorecard, weekly +# - dependency-review PR-context only +# - trivy fs SARIF runs locally via `trivy fs .` if installed, +# but the SARIF upload is GitHub-only +# - codecov upload requires CODECOV_TOKEN and the registered repo +verify: fmt-check tidy-check mod-verify vet lint-config lint test build-matrix sec semgrep-check fuzz-quick + @echo "" + @echo " ====================================" + @echo " make verify: PASS" + @echo " $$(date -u +'%Y-%m-%dT%H:%M:%SZ') on $$(uname -sm)" + @echo " ====================================" + all: build lint test ## Run build, lint, and test. -build: ## Build all packages. +build: ## Build all packages for the host platform. @out=$$($(GO) build $(GOFLAGS) $(PKGS) 2>&1); \ ec=$$?; \ if [ -n "$$out" ]; then echo "$$out"; fi; \ @@ -20,6 +56,16 @@ build: ## Build all packages. fi; \ exit $$ec +build-matrix: ## Cross-compile across the same matrix CI builds. + @for target in $(BUILD_MATRIX); do \ + os=$${target%%/*}; arch=$${target##*/}; \ + printf ' build %-15s ... ' "$$os/$$arch"; \ + out=$$(CGO_ENABLED=0 GOOS=$$os GOARCH=$$arch $(GO) build -trimpath $(PKGS) 2>&1); \ + ec=$$?; \ + if [ $$ec -ne 0 ]; then echo "FAIL"; echo "$$out"; exit $$ec; fi; \ + echo "ok"; \ + done + test: ## Run tests with race detector, shuffle, and coverage. $(GO) test -race -shuffle=on -count=1 \ -covermode=atomic -coverprofile=$(COVERPROFILE) \ @@ -33,28 +79,106 @@ cover: test ## Show coverage summary and write coverage.html. lint: ## Run golangci-lint. $(GO) tool golangci-lint run $(PKGS) +lint-config: ## Verify .golangci.yml against the v2 schema. + $(GO) tool golangci-lint config verify + fmt: ## Format Go code. $(GO) tool goimports -w -local github.com/plexara/plexara-agents . +fmt-check: ## Fail if any Go file is unformatted. + @unformatted=$$(gofmt -l .); \ + if [ -n "$$unformatted" ]; then \ + echo "Unformatted files:"; echo "$$unformatted"; \ + echo "Run 'make fmt' and commit."; \ + exit 1; \ + fi + vet: ## Run go vet. $(GO) vet $(PKGS) -sec: vuln ## Run security scanners (gosec + govulncheck). - $(GO) tool gosec -quiet $(PKGS) +sec: gosec vuln ## Run security scanners (gosec + govulncheck). -vuln: ## Run govulncheck. - $(GO) tool govulncheck $(PKGS) +gosec: ## Run gosec. + $(GO) tool gosec -quiet -no-fail $(PKGS) -licenses: ## Report on transitive dependency licenses. - $(GO) tool go-licenses report $(PKGS) +vuln: ## Run govulncheck (skips silently if no Go source yet). + @if find . -name '*.go' \ + -not -path './.*' \ + -not -path './vendor/*' \ + -not -path '*/testdata/*' \ + -print -quit | grep -q .; then \ + $(GO) tool govulncheck $(PKGS); \ + else \ + echo "(no Go source yet — skipping govulncheck)"; \ + fi + +semgrep: ## Run Semgrep in the same Docker image CI uses. + @if ! command -v docker >/dev/null 2>&1; then \ + echo "docker not installed — cannot run semgrep locally"; \ + exit 1; \ + fi + @if ! docker info >/dev/null 2>&1; then \ + echo "docker daemon not running — cannot run semgrep locally"; \ + exit 1; \ + fi + docker run --rm -v "$$PWD:/src" -w /src $(SEMGREP_IMAGE) \ + semgrep \ + --config=p/security-audit \ + --config=p/secrets \ + --config=p/golang \ + --config=p/owasp-top-ten \ + --error \ + . + +semgrep-check: ## Run semgrep if Docker is available; otherwise warn and continue. + @if command -v docker >/dev/null 2>&1 && docker info >/dev/null 2>&1; then \ + $(MAKE) --no-print-directory semgrep; \ + else \ + echo ""; \ + echo " WARNING: semgrep skipped (Docker not available)."; \ + echo " CI will run it; you cannot prove a clean run without Docker."; \ + echo ""; \ + fi + +mod-verify: ## go mod verify. + $(GO) mod verify tidy: ## Tidy go.mod and verify modules. $(GO) mod tidy $(GO) mod verify +tidy-check: ## Fail if go.mod / go.sum drift from `go mod tidy`. + @diff=$$($(GO) mod tidy -diff); \ + if [ -n "$$diff" ]; then \ + echo "go.mod / go.sum drift detected; run 'make tidy' and commit."; \ + echo "$$diff"; \ + exit 1; \ + fi + +fuzz-quick: ## Run each Fuzz* target for FUZZTIME (default 5s). + @found=0; \ + while IFS= read -r pkg; do \ + [ -z "$$pkg" ] && continue; \ + while IFS= read -r fuzz; do \ + [ -z "$$fuzz" ] && continue; \ + found=1; \ + printf ' fuzz %s/%s for %s ... ' "$$pkg" "$$fuzz" "$(FUZZTIME)"; \ + out=$$($(GO) test "$$pkg" -run='^$$' -fuzz="^$$fuzz\$$" -fuzztime=$(FUZZTIME) 2>&1); \ + ec=$$?; \ + if [ $$ec -ne 0 ]; then echo "FAIL"; echo "$$out"; exit $$ec; fi; \ + echo "ok"; \ + done < <($(GO) test -list 'Fuzz.*' "$$pkg" 2>/dev/null | awk '/^Fuzz/'); \ + done < <($(GO) list ./... 2>/dev/null); \ + if [ $$found -eq 0 ]; then \ + echo "(no Fuzz* targets discovered — skipping)"; \ + fi + +licenses: ## Report on transitive dependency licenses. + $(GO) tool go-licenses report $(PKGS) + clean: ## Remove build artifacts. rm -rf bin dist rm -f $(COVERPROFILE) coverage.html help: ## Show this help. - @awk 'BEGIN {FS = ":.*##"; printf "\nTargets:\n"} /^[a-zA-Z_-]+:.*?##/ { printf " \033[36m%-12s\033[0m %s\n", $$1, $$2 }' $(MAKEFILE_LIST) + @awk 'BEGIN {FS = ":.*##"; printf "\nTargets:\n"} /^[a-zA-Z_-]+:.*?##/ { printf " \033[36m%-14s\033[0m %s\n", $$1, $$2 }' $(MAKEFILE_LIST) diff --git a/core/.gitkeep b/core/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/core/event/event.go b/core/event/event.go new file mode 100644 index 0000000..628aea6 --- /dev/null +++ b/core/event/event.go @@ -0,0 +1,244 @@ +// Package event defines the closed sum type of events emitted by a +// [provider.Provider] while streaming a model response. +// +// Consumers (the agent loop, CLI renderers, session writers) range over an +// event channel and switch on the concrete type. The unexported [Event] +// sealing method keeps the type set closed within this package: outside +// code cannot add new variants, which makes exhaustiveness checks +// meaningful. +// +// Events round-trip through JSON for session persistence and replay. Each +// concrete type marshals with a leading "type" discriminator. [Decode] +// reads that discriminator and returns the matching concrete value behind +// the [Event] interface. +package event + +import ( + "encoding/json" + "errors" + "fmt" +) + +// Event is the closed sum of streaming-protocol events. Implementations +// live in this package; [Event.isEvent] cannot be satisfied from outside. +type Event interface { + isEvent() +} + +// FinishReason describes why a streamed turn ended. +type FinishReason string + +// Reasons a streamed turn may end. Mirror the OpenAI Chat Completions +// finish_reason values. The agent loop emits an [Error] event for +// failures rather than a Finish with an error reason; there is +// intentionally no FinishReasonError constant. +const ( + FinishReasonStop FinishReason = "stop" + FinishReasonToolCalls FinishReason = "tool_calls" + FinishReasonLength FinishReason = "length" + FinishReasonContentFilter FinishReason = "content_filter" +) + +// Usage records token accounting for a streamed response. Zero values +// indicate the runtime did not report usage (Ollama and llama.cpp +// often omit it). +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ToolContent is one piece of content returned by a tool call. +// +// MCP tools may return text, image references, or resource references in a +// single call. Type discriminates between them. +type ToolContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Data string `json:"data,omitempty"` + MIMEType string `json:"mime_type,omitempty"` + URI string `json:"uri,omitempty"` +} + +// Discriminator strings used in the JSON envelope. Externally visible +// because session readers and other tools may match on these. +const ( + TypeTextDelta = "text_delta" + TypeToolCallRequest = "tool_call_request" + TypeToolCallResult = "tool_call_result" + TypeFinish = "finish" + TypeError = "error" +) + +// TextDelta is a chunk of plain text streamed from the model. +type TextDelta struct { + Text string `json:"text"` +} + +func (TextDelta) isEvent() {} + +// MarshalJSON wraps the value with the type discriminator required by [Decode]. +func (e TextDelta) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Text string `json:"text"` + }{Type: TypeTextDelta, Text: e.Text}) +} + +// ToolCallRequest is the model asking to invoke a tool. +// +// The provider buffers streaming tool-call deltas and emits this event +// only when the runtime signals the call is complete. Arguments is always +// a complete JSON document; never partial. See spec §8.4. +type ToolCallRequest struct { + ID string `json:"id"` + Name string `json:"name"` + Server string `json:"server,omitempty"` + Arguments json.RawMessage `json:"arguments"` +} + +func (ToolCallRequest) isEvent() {} + +// MarshalJSON wraps the value with the type discriminator required by [Decode]. +func (e ToolCallRequest) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + Server string `json:"server,omitempty"` + Arguments json.RawMessage `json:"arguments"` + }{Type: TypeToolCallRequest, ID: e.ID, Name: e.Name, Server: e.Server, Arguments: e.Arguments}) +} + +// ToolCallResult is the result of dispatching a [ToolCallRequest] to its +// MCP server. The agent loop builds this after each tool round-trip. +type ToolCallResult struct { + ID string `json:"id"` + Content []ToolContent `json:"content"` + IsError bool `json:"is_error,omitempty"` +} + +func (ToolCallResult) isEvent() {} + +// MarshalJSON wraps the value with the type discriminator required by [Decode]. +func (e ToolCallResult) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + ID string `json:"id"` + Content []ToolContent `json:"content"` + IsError bool `json:"is_error,omitempty"` + }{Type: TypeToolCallResult, ID: e.ID, Content: e.Content, IsError: e.IsError}) +} + +// Finish marks the end of a streamed turn. +type Finish struct { + Reason FinishReason `json:"reason"` + Usage Usage `json:"usage"` +} + +func (Finish) isEvent() {} + +// MarshalJSON wraps the value with the type discriminator required by [Decode]. +func (e Finish) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string `json:"type"` + Reason FinishReason `json:"reason"` + Usage Usage `json:"usage"` + }{Type: TypeFinish, Reason: e.Reason, Usage: e.Usage}) +} + +// Error is a streaming error. +// +// Round-tripping through JSON loses the original error's wrapping: on +// decode, Err is rebuilt as a plain string-only error. This is an +// acceptable trade-off for replay; consumers needing the original chain +// should capture it before serialization. +// +// Error has both [Error.MarshalJSON] and [Error.UnmarshalJSON]. The +// asymmetric pair is load-bearing — Err is a non-marshal-friendly +// `error` value, so encoding writes a Msg string and decoding rebuilds +// Err from that string with [errors.New]. The other event variants +// have only MarshalJSON; their fields are JSON-friendly, so the +// default reflection-based decode is sufficient. +type Error struct { + Err error +} + +func (Error) isEvent() {} + +// MarshalJSON wraps the value with the type discriminator required by [Decode]. +func (e Error) MarshalJSON() ([]byte, error) { + msg := "" + if e.Err != nil { + msg = e.Err.Error() + } + return json.Marshal(struct { + Type string `json:"type"` + Msg string `json:"msg"` + }{Type: TypeError, Msg: msg}) +} + +// UnmarshalJSON rebuilds Err from the serialized message string. +func (e *Error) UnmarshalJSON(data []byte) error { + var v struct { + Msg string `json:"msg"` + } + if err := json.Unmarshal(data, &v); err != nil { + return fmt.Errorf("event: decode error envelope: %w", err) + } + if v.Msg != "" { + e.Err = errors.New(v.Msg) + } else { + e.Err = nil + } + return nil +} + +// ErrUnknownType is returned by [Decode] when the "type" discriminator +// does not match a known event variant. +var ErrUnknownType = errors.New("event: unknown type") + +// Decode parses a single Event from JSON. The leading "type" field +// selects the concrete variant. +func Decode(data []byte) (Event, error) { + var head struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &head); err != nil { + return nil, fmt.Errorf("event: decode envelope: %w", err) + } + switch head.Type { + case TypeTextDelta: + var e TextDelta + if err := json.Unmarshal(data, &e); err != nil { + return nil, fmt.Errorf("event: decode text_delta: %w", err) + } + return e, nil + case TypeToolCallRequest: + var e ToolCallRequest + if err := json.Unmarshal(data, &e); err != nil { + return nil, fmt.Errorf("event: decode tool_call_request: %w", err) + } + return e, nil + case TypeToolCallResult: + var e ToolCallResult + if err := json.Unmarshal(data, &e); err != nil { + return nil, fmt.Errorf("event: decode tool_call_result: %w", err) + } + return e, nil + case TypeFinish: + var e Finish + if err := json.Unmarshal(data, &e); err != nil { + return nil, fmt.Errorf("event: decode finish: %w", err) + } + return e, nil + case TypeError: + var e Error + if err := json.Unmarshal(data, &e); err != nil { + return nil, fmt.Errorf("event: decode error: %w", err) + } + return e, nil + default: + return nil, fmt.Errorf("%w: %q", ErrUnknownType, head.Type) + } +} diff --git a/core/event/event_test.go b/core/event/event_test.go new file mode 100644 index 0000000..88601f8 --- /dev/null +++ b/core/event/event_test.go @@ -0,0 +1,171 @@ +package event_test + +import ( + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/plexara/plexara-agents/core/event" +) + +func TestRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in event.Event + want string + }{ + { + name: "text_delta", + in: event.TextDelta{Text: "hello"}, + want: `{"type":"text_delta","text":"hello"}`, + }, + { + name: "tool_call_request", + in: event.ToolCallRequest{ + ID: "call_abc", + Name: "get_weather", + Arguments: json.RawMessage(`{"city":"NYC"}`), + }, + want: `{"type":"tool_call_request","id":"call_abc","name":"get_weather","arguments":{"city":"NYC"}}`, + }, + { + name: "tool_call_request_with_server", + in: event.ToolCallRequest{ + ID: "call_xyz", + Name: "datahub_search", + Server: "plexara-acme", + Arguments: json.RawMessage(`{"q":"orders"}`), + }, + want: `{"type":"tool_call_request","id":"call_xyz","name":"datahub_search","server":"plexara-acme","arguments":{"q":"orders"}}`, + }, + { + name: "tool_call_result_text", + in: event.ToolCallResult{ + ID: "call_abc", + Content: []event.ToolContent{ + {Type: "text", Text: "72F sunny"}, + }, + }, + want: `{"type":"tool_call_result","id":"call_abc","content":[{"type":"text","text":"72F sunny"}]}`, + }, + { + name: "tool_call_result_error", + in: event.ToolCallResult{ + ID: "call_def", + Content: []event.ToolContent{{Type: "text", Text: "boom"}}, + IsError: true, + }, + want: `{"type":"tool_call_result","id":"call_def","content":[{"type":"text","text":"boom"}],"is_error":true}`, + }, + { + name: "finish", + in: event.Finish{ + Reason: event.FinishReasonStop, + Usage: event.Usage{PromptTokens: 10, CompletionTokens: 20, TotalTokens: 30}, + }, + want: `{"type":"finish","reason":"stop","usage":{"prompt_tokens":10,"completion_tokens":20,"total_tokens":30}}`, + }, + { + name: "finish_zero_usage", + in: event.Finish{Reason: event.FinishReasonToolCalls}, + want: `{"type":"finish","reason":"tool_calls","usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`, + }, + { + name: "error", + in: event.Error{Err: errors.New("boom")}, + want: `{"type":"error","msg":"boom"}`, + }, + { + name: "error_nil", + in: event.Error{}, + want: `{"type":"error","msg":""}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := json.Marshal(tt.in) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if string(got) != tt.want { + t.Errorf("Marshal:\n got %s\n want %s", got, tt.want) + } + + decoded, err := event.Decode(got) + if err != nil { + t.Fatalf("Decode: %v", err) + } + + rt, err := json.Marshal(decoded) + if err != nil { + t.Fatalf("Marshal(rt): %v", err) + } + if string(rt) != tt.want { + t.Errorf("round-trip mismatch:\n got %s\n want %s", rt, tt.want) + } + }) + } +} + +func TestDecodeErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantErr string + }{ + {name: "empty", input: ``, wantErr: "decode envelope"}, + {name: "not_object", input: `42`, wantErr: "decode envelope"}, + {name: "missing_type", input: `{"text":"x"}`, wantErr: "unknown type"}, + {name: "unknown_type", input: `{"type":"banana"}`, wantErr: "unknown type"}, + {name: "type_wrong_kind", input: `{"type":42}`, wantErr: "decode envelope"}, + {name: "malformed_payload_text_delta", input: `{"type":"text_delta","text":42}`, wantErr: "decode text_delta"}, + {name: "malformed_payload_finish", input: `{"type":"finish","usage":"nope"}`, wantErr: "decode finish"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := event.Decode([]byte(tt.input)) + if err == nil { + t.Fatalf("Decode(%q) returned nil error; want %q", tt.input, tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("Decode(%q) error = %v; want contains %q", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestErrUnknownType(t *testing.T) { + t.Parallel() + + _, err := event.Decode([]byte(`{"type":"banana"}`)) + if !errors.Is(err, event.ErrUnknownType) { + t.Errorf("Decode unknown type: errors.Is(%v, ErrUnknownType) = false", err) + } +} + +// TestEventInterfaceMembership confirms each declared variant satisfies +// [event.Event]. It does not, and cannot, prove the type set is closed +// at the language level — Go's interface satisfaction is structural, +// so a type elsewhere that happens to implement an unexported `isEvent` +// method on the right receiver could in principle satisfy the +// interface. The test exists so that removing or renaming a variant +// surfaces here, not deep in a downstream package. +func TestEventInterfaceMembership(t *testing.T) { + t.Parallel() + + var _ event.Event = event.TextDelta{} + var _ event.Event = event.ToolCallRequest{} + var _ event.Event = event.ToolCallResult{} + var _ event.Event = event.Finish{} + var _ event.Event = event.Error{} +} diff --git a/core/event/fuzz_test.go b/core/event/fuzz_test.go new file mode 100644 index 0000000..a942f81 --- /dev/null +++ b/core/event/fuzz_test.go @@ -0,0 +1,74 @@ +package event_test + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/plexara/plexara-agents/core/event" +) + +// FuzzDecode exercises the discriminator dispatch and per-variant +// unmarshaling against arbitrary input. The contract: Decode either +// returns an event whose marshaled form decodes back to an equivalent +// event, or it returns an error. It must never panic. +func FuzzDecode(f *testing.F) { + seeds := []string{ + `{"type":"text_delta","text":"hi"}`, + `{"type":"tool_call_request","id":"x","name":"y","arguments":{}}`, + `{"type":"tool_call_result","id":"x","content":[]}`, + `{"type":"finish","reason":"stop","usage":{}}`, + `{"type":"error","msg":"boom"}`, + `{"type":"unknown"}`, + ``, + `{`, + `null`, + `{"type":"text_delta","text":null}`, + `{"type":"finish","reason":42}`, + } + for _, s := range seeds { + f.Add(s) + } + + f.Fuzz(func(t *testing.T, raw string) { + evt, err := event.Decode([]byte(raw)) + if err != nil { + // Errors are expected for bad input; nothing else to verify. + return + } + // Round-trip: a decoded event must remarshal cleanly. + out, err := json.Marshal(evt) + if err != nil { + t.Fatalf("Marshal of decoded event failed: %v\ninput: %q", err, raw) + } + // And the remarshaled form must decode again to the same shape. + evt2, err := event.Decode(out) + if err != nil { + t.Fatalf("Decode(Marshal(Decode(%q))) failed: %v", raw, err) + } + out2, err := json.Marshal(evt2) + if err != nil { + t.Fatalf("Marshal of re-decoded event failed: %v", err) + } + if string(out) != string(out2) { + t.Errorf("round-trip diverged:\n first: %s\n second: %s", out, out2) + } + }) +} + +// FuzzDecodeNeverPanics is a narrower contract used in CI to catch +// panics in dispatch/unmarshal. Faster than the round-trip fuzz. +func FuzzDecodeNeverPanics(f *testing.F) { + f.Add(``) + f.Add(`{}`) + f.Add(`{"type":"text_delta"}`) + f.Fuzz(func(_ *testing.T, raw string) { + _, err := event.Decode([]byte(raw)) + // Decode must not panic on any input. errors are fine. + _ = err + // Cross-check ErrUnknownType is well-formed: errors.Is must not panic. + if err != nil { + _ = errors.Is(err, event.ErrUnknownType) + } + }) +} diff --git a/core/provider/openai_compatible.go b/core/provider/openai_compatible.go new file mode 100644 index 0000000..d0f53e6 --- /dev/null +++ b/core/provider/openai_compatible.go @@ -0,0 +1,540 @@ +package provider + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "sort" + "strings" + "time" + + "github.com/plexara/plexara-agents/core/event" +) + +// OpenAIConfig configures an [OpenAICompatible] provider. +// +// Auth: APIKey is preferred. If empty and APIKeyEnv is set, the value of +// that environment variable is used. Headers may carry additional auth +// or routing headers. +type OpenAIConfig struct { + // BaseURL is the API root (e.g. "http://localhost:11434/v1"). Must not + // include a trailing slash; the provider appends "/chat/completions". + BaseURL string + + APIKey string + APIKeyEnv string + Headers map[string]string + + // HTTPClient is optional. If nil, a default client with sane dial, + // header, and idle timeouts is used. Tests inject a custom client to + // reach a [net/http/httptest.Server]. + HTTPClient *http.Client +} + +// OpenAICompatible implements [Provider] against any server speaking the +// OpenAI Chat Completions API. Validated against Ollama, mlx-lm, +// llama.cpp's server, and vLLM. +type OpenAICompatible struct { + cfg OpenAIConfig + client *http.Client + apiKey string + headers map[string]string // deep-copied from cfg.Headers at construction + url string +} + +// ErrConfig is returned by [NewOpenAICompatible] when the configuration +// is missing required fields. +var ErrConfig = errors.New("provider: invalid configuration") + +// NewOpenAICompatible builds an OpenAI-compatible provider from cfg. It +// resolves the API key from APIKey or APIKeyEnv at construction time so +// callers see configuration errors up-front. +func NewOpenAICompatible(cfg OpenAIConfig) (*OpenAICompatible, error) { + if cfg.BaseURL == "" { + return nil, fmt.Errorf("%w: BaseURL is required", ErrConfig) + } + + apiKey := cfg.APIKey + if apiKey == "" && cfg.APIKeyEnv != "" { + apiKey = os.Getenv(cfg.APIKeyEnv) + } + + client := cfg.HTTPClient + if client == nil { + client = defaultHTTPClient() + } + + // Deep-copy Headers so post-construction mutations on the caller's + // map cannot affect future requests. + var headers map[string]string + if len(cfg.Headers) > 0 { + headers = make(map[string]string, len(cfg.Headers)) + for k, v := range cfg.Headers { + headers[k] = v + } + } + + url := strings.TrimSuffix(cfg.BaseURL, "/") + "/chat/completions" + + return &OpenAICompatible{ + cfg: cfg, + client: client, + apiKey: apiKey, + headers: headers, + url: url, + }, nil +} + +func defaultHTTPClient() *http.Client { + return &http.Client{ + // No overall Timeout: streams may run for many minutes. Per-stage + // limits are enforced via the dialer and ResponseHeaderTimeout. + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ResponseHeaderTimeout: 30 * time.Second, + IdleConnTimeout: 90 * time.Second, + MaxIdleConns: 10, + MaxIdleConnsPerHost: 2, + }, + } +} + +// Name returns the provider's name for logging. +func (p *OpenAICompatible) Name() string { return "openai-compatible" } + +// chatRequest is the wire form of [Request] for the Chat Completions API. +type chatRequest struct { + Model string `json:"model"` + Messages []chatMessage `json:"messages"` + Tools []chatTool `json:"tools,omitempty"` + Stream bool `json:"stream"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + StreamOpts *streamOpts `json:"stream_options,omitempty"` +} + +type streamOpts struct { + IncludeUsage bool `json:"include_usage"` +} + +type chatMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + ToolCalls []chatToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type chatToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function chatToolFunction `json:"function"` +} + +type chatToolFunction struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type chatTool struct { + Type string `json:"type"` + Function chatToolFunctionSpec `json:"function"` +} + +type chatToolFunctionSpec struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +func buildChatRequest(req Request) chatRequest { + out := chatRequest{ + Model: req.Model, + Messages: make([]chatMessage, 0, len(req.Messages)), + Stream: true, + Temperature: req.Temperature, + TopP: req.TopP, + MaxTokens: req.MaxTokens, + StreamOpts: &streamOpts{IncludeUsage: true}, + } + for _, m := range req.Messages { + cm := chatMessage{ + Role: string(m.Role), + Content: m.Content, + ToolCallID: m.ToolCallID, + } + if len(m.ToolCalls) > 0 { + cm.ToolCalls = make([]chatToolCall, 0, len(m.ToolCalls)) + for _, tc := range m.ToolCalls { + cm.ToolCalls = append(cm.ToolCalls, chatToolCall{ + ID: tc.ID, + Type: "function", + Function: chatToolFunction{ + Name: tc.Name, + Arguments: string(tc.Arguments), + }, + }) + } + } + out.Messages = append(out.Messages, cm) + } + for _, t := range req.Tools { + // chatToolFunctionSpec(t) is a Go struct conversion: the two + // types must have identical field names, types, and order. A + // future field reorder or rename on either side becomes a + // compile error rather than a silent wire-format drift. + out.Tools = append(out.Tools, chatTool{ + Type: "function", + Function: chatToolFunctionSpec(t), + }) + } + return out +} + +// Stream sends req to the configured chat-completions endpoint with +// stream=true and parses the resulting SSE response into events. +// +// Per spec §8.4, [event.ToolCallRequest] is emitted only when the +// runtime reports finish_reason: tool_calls. Argument deltas are +// accumulated per index and validated as JSON before emission. +func (p *OpenAICompatible) Stream(ctx context.Context, req Request) (<-chan event.Event, error) { + body, err := json.Marshal(buildChatRequest(req)) + if err != nil { + return nil, fmt.Errorf("provider: marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("provider: build request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + if p.apiKey != "" { + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + } + for k, v := range p.headers { + httpReq.Header.Set(k, v) + } + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("provider: do request: %w", err) + } + if resp.StatusCode != http.StatusOK { + // Read up to a small window for diagnostics, then close. + buf, _ := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + return nil, fmt.Errorf("provider: http %d: %s", resp.StatusCode, strings.TrimSpace(string(buf))) + } + + out := make(chan event.Event) + go p.runStream(ctx, resp.Body, out) + return out, nil +} + +// chatChunk is one SSE frame from the chat-completions endpoint. +type chatChunk struct { + Choices []chatChoice `json:"choices"` + Usage *chatUsage `json:"usage,omitempty"` +} + +type chatChoice struct { + Index int `json:"index"` + Delta chatDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +type chatDelta struct { + Content string `json:"content,omitempty"` + ToolCalls []chatToolCallDelta `json:"tool_calls,omitempty"` +} + +type chatToolCallDelta struct { + Index int `json:"index"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function chatToolCallFunctionDelta `json:"function,omitempty"` +} + +type chatToolCallFunctionDelta struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` +} + +type chatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type toolCallAccumulator struct { + ID string + Name string + Args strings.Builder +} + +// streamState carries the running state of an in-progress stream so +// that line/chunk handling can be split across helpers without losing +// context. +type streamState struct { + toolCalls map[int]*toolCallAccumulator + usage event.Usage + // dataBuf accumulates `data:` lines for a single SSE event. Per the + // SSE spec, multiple `data:` lines before a blank-line boundary are + // concatenated with `\n`. None of the runtimes we target today emit + // multi-line frames, but the buffering is necessary for spec + // correctness and future-proofs us against servers that do. + dataBuf strings.Builder +} + +// streamStatus tells the caller whether a sub-step wants to keep +// reading, has finished cleanly, or has aborted (cancelled/error). +type streamStatus int + +const ( + streamContinue streamStatus = iota + streamDone + streamAbort +) + +func (p *OpenAICompatible) runStream(ctx context.Context, body io.ReadCloser, out chan<- event.Event) { + defer func() { _ = body.Close() }() + defer close(out) + + send := func(e event.Event) bool { + select { + case out <- e: + return true + case <-ctx.Done(): + return false + } + } + sendErr := func(err error) { + _ = send(event.Error{Err: err}) + } + + scanner := bufio.NewScanner(body) + // Allow up to 1 MiB per SSE line. The default 64 KiB ceiling is too + // tight for some servers that emit large delta blobs. + scanner.Buffer(make([]byte, 0, 64*1024), 1<<20) + + st := &streamState{toolCalls: map[int]*toolCallAccumulator{}} + + for scanner.Scan() { + if ctx.Err() != nil { + return + } + switch p.handleLine(scanner.Text(), st, send, sendErr) { + case streamDone, streamAbort: + return + case streamContinue: + } + } + if err := scanner.Err(); err != nil { + sendErr(fmt.Errorf("provider: read stream: %w", err)) + return + } + // EOF without a closing blank line — flush any buffered `data:` lines. + if st.dataBuf.Len() > 0 { + switch p.dispatchData(st.dataBuf.String(), st, send, sendErr) { + case streamDone, streamAbort: + return + case streamContinue: + } + st.dataBuf.Reset() + } + // Stream ended without a finish_reason or [DONE]; synthesize. + hadToolCalls := len(st.toolCalls) > 0 + if !p.flushPending(st.toolCalls, sendErr, send) { + return + } + reason := event.FinishReasonStop + if hadToolCalls { + reason = event.FinishReasonToolCalls + } + _ = send(event.Finish{Reason: reason, Usage: st.usage}) +} + +func (p *OpenAICompatible) handleLine( + line string, + st *streamState, + send func(event.Event) bool, + sendErr func(error), +) streamStatus { + // Empty line is the SSE event boundary: dispatch the buffered data. + if line == "" { + if st.dataBuf.Len() == 0 { + return streamContinue + } + data := st.dataBuf.String() + st.dataBuf.Reset() + return p.dispatchData(data, st, send, sendErr) + } + // SSE comment / keep-alive. + if strings.HasPrefix(line, ":") { + return streamContinue + } + // Non-`data:` field lines (e.g. `event:`, `id:`, `retry:`) are not + // used by the chat-completions stream protocol; ignore them. + if !strings.HasPrefix(line, "data:") { + return streamContinue + } + // Per RFC 6, exactly one optional space after the colon is stripped. + payload := strings.TrimPrefix(line, "data:") + payload = strings.TrimPrefix(payload, " ") + if st.dataBuf.Len() > 0 { + st.dataBuf.WriteByte('\n') + } + st.dataBuf.WriteString(payload) + return streamContinue +} + +func (p *OpenAICompatible) dispatchData( + data string, + st *streamState, + send func(event.Event) bool, + sendErr func(error), +) streamStatus { + data = strings.TrimSpace(data) + if data == "" { + return streamContinue + } + if data == "[DONE]" { + hadToolCalls := len(st.toolCalls) > 0 + if !p.flushPending(st.toolCalls, sendErr, send) { + return streamAbort + } + reason := event.FinishReasonStop + if hadToolCalls { + // Some runtimes terminate with [DONE] without ever sending + // finish_reason: tool_calls. Preserve the tool-calls signal so + // the agent loop dispatches the call rather than treating the + // turn as complete. + reason = event.FinishReasonToolCalls + } + _ = send(event.Finish{Reason: reason, Usage: st.usage}) + return streamDone + } + var chunk chatChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + sendErr(fmt.Errorf("provider: decode chunk: %w", err)) + return streamAbort + } + if chunk.Usage != nil { + st.usage = event.Usage{ + PromptTokens: chunk.Usage.PromptTokens, + CompletionTokens: chunk.Usage.CompletionTokens, + TotalTokens: chunk.Usage.TotalTokens, + } + } + for _, c := range chunk.Choices { + switch p.handleChoice(c, st, send, sendErr) { + case streamDone: + return streamDone + case streamAbort: + return streamAbort + case streamContinue: + } + } + return streamContinue +} + +func (p *OpenAICompatible) handleChoice( + c chatChoice, + st *streamState, + send func(event.Event) bool, + sendErr func(error), +) streamStatus { + if c.Delta.Content != "" { + if !send(event.TextDelta{Text: c.Delta.Content}) { + return streamAbort + } + } + for _, tc := range c.Delta.ToolCalls { + acc, ok := st.toolCalls[tc.Index] + if !ok { + acc = &toolCallAccumulator{} + st.toolCalls[tc.Index] = acc + } + if tc.ID != "" { + acc.ID = tc.ID + } + if tc.Function.Name != "" { + acc.Name = tc.Function.Name + } + if tc.Function.Arguments != "" { + acc.Args.WriteString(tc.Function.Arguments) + } + } + if c.FinishReason == nil { + return streamContinue + } + reason := event.FinishReason(*c.FinishReason) + if reason == event.FinishReasonToolCalls { + if !p.flushPending(st.toolCalls, sendErr, send) { + return streamAbort + } + } + if !send(event.Finish{Reason: reason, Usage: st.usage}) { + return streamAbort + } + return streamDone +} + +// flushPending emits one ToolCallRequest per accumulated tool call, in +// stable index order. Arguments are validated as JSON; an invalid +// payload aborts the stream with an Error event. +func (p *OpenAICompatible) flushPending( + toolCalls map[int]*toolCallAccumulator, + sendErr func(error), + send func(event.Event) bool, +) bool { + if len(toolCalls) == 0 { + return true + } + indices := make([]int, 0, len(toolCalls)) + for i := range toolCalls { + indices = append(indices, i) + } + sort.Ints(indices) + for _, i := range indices { + acc := toolCalls[i] + args := acc.Args.String() + if args == "" { + args = "{}" + } + if !json.Valid([]byte(args)) { + sendErr(fmt.Errorf("provider: tool call %q has invalid JSON args: %s", acc.ID, args)) + return false + } + // Tool arguments must be a JSON object per the OpenAI Chat + // Completions schema. A model that emits null/array/scalar will + // be rejected by downstream MCP servers with a less-actionable + // error; surface it here as an Error event instead. + trimmed := strings.TrimLeft(args, " \t\n\r") + if len(trimmed) == 0 || trimmed[0] != '{' { + sendErr(fmt.Errorf("provider: tool call %q args must be a JSON object, got: %s", acc.ID, args)) + return false + } + if !send(event.ToolCallRequest{ + ID: acc.ID, + Name: acc.Name, + Arguments: json.RawMessage(args), + }) { + return false + } + delete(toolCalls, i) + } + return true +} diff --git a/core/provider/openai_compatible_test.go b/core/provider/openai_compatible_test.go new file mode 100644 index 0000000..5ebe587 --- /dev/null +++ b/core/provider/openai_compatible_test.go @@ -0,0 +1,734 @@ +package provider_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/plexara/plexara-agents/core/event" + "github.com/plexara/plexara-agents/core/provider" +) + +func TestOpenAICompatible_Config(t *testing.T) { + t.Parallel() + + t.Run("empty_base_url_rejected", func(t *testing.T) { + t.Parallel() + _, err := provider.NewOpenAICompatible(provider.OpenAIConfig{}) + if !errors.Is(err, provider.ErrConfig) { + t.Errorf("err = %v; want errors.Is ErrConfig", err) + } + }) + + t.Run("name", func(t *testing.T) { + t.Parallel() + p, err := provider.NewOpenAICompatible(provider.OpenAIConfig{BaseURL: "http://x"}) + if err != nil { + t.Fatalf("New: %v", err) + } + if got := p.Name(); got != "openai-compatible" { + t.Errorf("Name = %q; want openai-compatible", got) + } + }) + + t.Run("custom_headers_passed_through", func(t *testing.T) { + t.Parallel() + + var got string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got = r.Header.Get("X-Plexara-Test") + writeSSE(w, ssePlain()) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{ + BaseURL: srv.URL + "/v1", + Headers: map[string]string{"X-Plexara-Test": "yes"}, + HTTPClient: srv.Client(), + }) + drainAll(t, p, provider.Request{Model: "x"}) + + if got != "yes" { + t.Errorf("X-Plexara-Test = %q; want yes", got) + } + }) + + t.Run("base_url_trailing_slash_trimmed", func(t *testing.T) { + t.Parallel() + + var path string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path = r.URL.Path + writeSSE(w, ssePlain()) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1/", HTTPClient: srv.Client()}) + drainAll(t, p, provider.Request{Model: "x"}) + + if path != "/v1/chat/completions" { + t.Errorf("path = %q; want /v1/chat/completions", path) + } + }) +} + +// TestOpenAICompatible_AuthFromEnv is intentionally serial: t.Setenv +// forbids running in a parent or sibling that has called t.Parallel. +func TestOpenAICompatible_AuthFromEnv(t *testing.T) { + const want = "from-env" + t.Setenv("PLEXARA_TEST_KEY", want) + + var got string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got = r.Header.Get("Authorization") + writeSSE(w, ssePlain()) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", APIKeyEnv: "PLEXARA_TEST_KEY", HTTPClient: srv.Client()}) + drainAll(t, p, provider.Request{Model: "x"}) + + if got != "Bearer "+want { + t.Errorf("Authorization = %q; want Bearer %s", got, want) + } +} + +func TestOpenAICompatible_AuthExplicitOverridesEnv(t *testing.T) { + t.Setenv("PLEXARA_TEST_KEY", "from-env") + + var got string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got = r.Header.Get("Authorization") + writeSSE(w, ssePlain()) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{ + BaseURL: srv.URL + "/v1", + APIKey: "explicit", + APIKeyEnv: "PLEXARA_TEST_KEY", + HTTPClient: srv.Client(), + }) + drainAll(t, p, provider.Request{Model: "x"}) + + if got != "Bearer explicit" { + t.Errorf("Authorization = %q; want Bearer explicit", got) + } +} + +func TestOpenAICompatible_Stream(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + frames []string + want []event.Event + }{ + { + name: "plain_text", + frames: ssePlain(), + want: []event.Event{ + event.TextDelta{Text: "Hello"}, + event.TextDelta{Text: " world"}, + event.Finish{Reason: event.FinishReasonStop, Usage: event.Usage{PromptTokens: 5, CompletionTokens: 3, TotalTokens: 8}}, + }, + }, + { + name: "single_tool_call_buffered", + frames: []string{ + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_a","type":"function","function":{"name":"get_weather","arguments":""}}]}}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"city\":"}}]}}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"NYC\"}"}}]}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":7,"completion_tokens":4,"total_tokens":11}}`, + `[DONE]`, + }, + want: []event.Event{ + event.ToolCallRequest{ + ID: "call_a", Name: "get_weather", + Arguments: json.RawMessage(`{"city":"NYC"}`), + }, + event.Finish{Reason: event.FinishReasonToolCalls, Usage: event.Usage{PromptTokens: 7, CompletionTokens: 4, TotalTokens: 11}}, + }, + }, + { + name: "two_tool_calls_in_index_order", + frames: []string{ + // Interleave indices 0 and 1; output must still be ordered. + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_b","function":{"name":"second","arguments":"{}"}}]}}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_a","function":{"name":"first","arguments":"{\"k\":1}"}}]}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + `[DONE]`, + }, + want: []event.Event{ + event.ToolCallRequest{ID: "call_a", Name: "first", Arguments: json.RawMessage(`{"k":1}`)}, + event.ToolCallRequest{ID: "call_b", Name: "second", Arguments: json.RawMessage(`{}`)}, + event.Finish{Reason: event.FinishReasonToolCalls}, + }, + }, + { + name: "text_then_tool_call", + frames: []string{ + `{"choices":[{"index":0,"delta":{"content":"Let me check."}}]}`, + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_z","function":{"name":"lookup","arguments":"{}"}}]}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + `[DONE]`, + }, + want: []event.Event{ + event.TextDelta{Text: "Let me check."}, + event.ToolCallRequest{ID: "call_z", Name: "lookup", Arguments: json.RawMessage(`{}`)}, + event.Finish{Reason: event.FinishReasonToolCalls}, + }, + }, + { + name: "empty_tool_args_become_empty_object", + frames: []string{ + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_e","function":{"name":"noop"}}]}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + }, + want: []event.Event{ + event.ToolCallRequest{ID: "call_e", Name: "noop", Arguments: json.RawMessage(`{}`)}, + event.Finish{Reason: event.FinishReasonToolCalls}, + }, + }, + { + name: "sse_comments_ignored", + frames: []string{ + ": keep-alive comment", + `{"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`, + `[DONE]`, + }, + want: []event.Event{ + event.TextDelta{Text: "hi"}, + event.Finish{Reason: event.FinishReasonStop}, + }, + }, + { + name: "done_without_finish_reason", + frames: []string{ + `{"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + `[DONE]`, + }, + want: []event.Event{ + event.TextDelta{Text: "hi"}, + event.Finish{Reason: event.FinishReasonStop}, + }, + }, + { + name: "stream_ends_without_done", + frames: []string{ + `{"choices":[{"index":0,"delta":{"content":"truncated"}}]}`, + }, + want: []event.Event{ + event.TextDelta{Text: "truncated"}, + event.Finish{Reason: event.FinishReasonStop}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSE(w, tt.frames) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + got := drainAll(t, p, provider.Request{Model: "x"}) + + if !equalEvents(t, got, tt.want) { + t.Errorf("events mismatch:\n got %#v\nwant %#v", got, tt.want) + } + }) + } +} + +func TestOpenAICompatible_DoneWithPendingToolCallsKeepsToolCallsReason(t *testing.T) { + t.Parallel() + + // Some runtimes terminate with [DONE] without ever sending + // finish_reason: tool_calls. The provider must still emit + // Finish{Reason: tool_calls} so the loop dispatches the call. + frames := []string{ + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_z","type":"function","function":{"name":"lookup","arguments":"{\"k\":1}"}}]}}]}`, + `[DONE]`, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSE(w, frames) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + got := drainAll(t, p, provider.Request{Model: "x"}) + + if len(got) != 2 { + t.Fatalf("got %d events; want 2 (ToolCallRequest + Finish)", len(got)) + } + if _, ok := got[0].(event.ToolCallRequest); !ok { + t.Errorf("got[0] = %T; want ToolCallRequest", got[0]) + } + finish, ok := got[1].(event.Finish) + if !ok { + t.Fatalf("got[1] = %T; want Finish", got[1]) + } + if finish.Reason != event.FinishReasonToolCalls { + t.Errorf("Finish.Reason = %q; want tool_calls (DONE flush must preserve the signal)", finish.Reason) + } +} + +func TestOpenAICompatible_MultiLineDataFrame(t *testing.T) { + t.Parallel() + + // Per the SSE spec, multiple `data:` lines before a blank-line + // boundary are concatenated with `\n`. Build a frame whose JSON + // is split across two `data:` lines. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + // First frame: JSON split across two `data:` lines. + fmt.Fprint(w, "data: {\"choices\":[{\"index\":0,\"delta\":\n") + fmt.Fprint(w, "data: {\"content\":\"hello\"}}]}\n\n") + // Second frame: terminator. + fmt.Fprint(w, "data: {\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n") + fmt.Fprint(w, "data: [DONE]\n\n") + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + got := drainAll(t, p, provider.Request{Model: "x"}) + + want := []event.Event{ + event.TextDelta{Text: "hello"}, + event.Finish{Reason: event.FinishReasonStop}, + } + if !equalEvents(t, got, want) { + t.Errorf("multi-line data: parse failed:\n got %#v\n want %#v", got, want) + } +} + +func TestOpenAICompatible_ContentFilterFinishReason(t *testing.T) { + t.Parallel() + + frames := []string{ + `{"choices":[{"index":0,"delta":{"content":"partial"}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"content_filter"}]}`, + `[DONE]`, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSE(w, frames) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + got := drainAll(t, p, provider.Request{Model: "x"}) + + if len(got) != 2 { + t.Fatalf("got %d events; want 2", len(got)) + } + finish, ok := got[1].(event.Finish) + if !ok { + t.Fatalf("got[1] = %T; want Finish", got[1]) + } + if finish.Reason != event.FinishReasonContentFilter { + t.Errorf("Finish.Reason = %q; want content_filter", finish.Reason) + } +} + +func TestOpenAICompatible_OversizedLineSurfacesAsError(t *testing.T) { + t.Parallel() + + // A single SSE line exceeding the 1 MiB scanner ceiling must + // surface as an event.Error rather than panic, truncate silently, + // or be parsed as malformed JSON. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + // 2 MiB of `a` between two valid JSON braces. + fmt.Fprint(w, "data: {\"x\":\"") + blob := strings.Repeat("a", 2*1024*1024) + fmt.Fprint(w, blob) + fmt.Fprint(w, "\"}\n\n") + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + got := drainAll(t, p, provider.Request{Model: "x"}) + + // Either the first event is an Error (scanner exceeded its buffer) + // or — if some runtime supplies enough buffering — we tolerate other + // outcomes as long as nothing panics. We assert the strict case. + if len(got) == 0 { + t.Fatal("got 0 events; want at least one") + } + errEvt, ok := got[0].(event.Error) + if !ok { + t.Fatalf("got[0] = %T; want event.Error for oversized line", got[0]) + } + if !strings.Contains(errEvt.Err.Error(), "read stream") { + t.Errorf("err = %v; want contains 'read stream'", errEvt.Err) + } +} + +func TestOpenAICompatible_NonObjectToolArgsEmitsError(t *testing.T) { + t.Parallel() + + frames := []string{ + // The model emits `null` as the arguments — valid JSON but not + // an object. MCP servers expect objects. + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_n","function":{"name":"f","arguments":"null"}}]}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + `[DONE]`, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSE(w, frames) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + got := drainAll(t, p, provider.Request{Model: "x"}) + + if len(got) == 0 { + t.Fatal("got 0 events; want an Error") + } + errEvt, ok := got[0].(event.Error) + if !ok { + t.Fatalf("got[0] = %T; want event.Error", got[0]) + } + if !strings.Contains(errEvt.Err.Error(), "must be a JSON object") { + t.Errorf("err = %v; want contains 'must be a JSON object'", errEvt.Err) + } +} + +func TestOpenAICompatible_HeadersDeepCopiedAtConstruction(t *testing.T) { + t.Parallel() + + // Construct the provider, then mutate the caller's Headers map. + // Subsequent requests must NOT see the post-construction mutation. + headers := map[string]string{"X-Plexara-Test": "original"} + + var seen string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + seen = r.Header.Get("X-Plexara-Test") + writeSSE(w, ssePlain()) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{ + BaseURL: srv.URL + "/v1", + Headers: headers, + HTTPClient: srv.Client(), + }) + + // Mutate AFTER construction. + headers["X-Plexara-Test"] = "MUTATED" + + drainAll(t, p, provider.Request{Model: "x"}) + + if seen != "original" { + t.Errorf("header = %q; want %q (mutation after New must not propagate)", seen, "original") + } +} + +func TestOpenAICompatible_InvalidToolArgsEmitsError(t *testing.T) { + t.Parallel() + + frames := []string{ + `{"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_x","function":{"name":"f","arguments":"{not json"}}]}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + `[DONE]`, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSE(w, frames) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + got := drainAll(t, p, provider.Request{Model: "x"}) + + if len(got) != 1 { + t.Fatalf("got %d events; want 1 (an Error)", len(got)) + } + errEvt, ok := got[0].(event.Error) + if !ok { + t.Fatalf("got %T; want event.Error", got[0]) + } + if !strings.Contains(errEvt.Err.Error(), "invalid JSON") { + t.Errorf("err = %v; want contains 'invalid JSON'", errEvt.Err) + } +} + +func TestOpenAICompatible_DecodeChunkError(t *testing.T) { + t.Parallel() + + frames := []string{`{not json}`} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeSSE(w, frames) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + got := drainAll(t, p, provider.Request{Model: "x"}) + + if len(got) != 1 { + t.Fatalf("got %d events; want 1 (an Error)", len(got)) + } + if _, ok := got[0].(event.Error); !ok { + t.Errorf("got %T; want event.Error", got[0]) + } +} + +func TestOpenAICompatible_HTTPError(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, "boom") + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + _, err := p.Stream(t.Context(), provider.Request{Model: "x"}) + if err == nil || !strings.Contains(err.Error(), "http 500") { + t.Errorf("err = %v; want contains 'http 500'", err) + } +} + +func TestOpenAICompatible_RequestPayloadShape(t *testing.T) { + t.Parallel() + + temp := float32(0.7) + maxTok := 256 + + var captured map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &captured) + writeSSE(w, ssePlain()) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + drainAll(t, p, provider.Request{ + Model: "qwen3:30b", + Messages: []provider.Message{{Role: provider.RoleUser, Content: "hi"}}, + Tools: []provider.Tool{{Name: "t", Description: "d", Parameters: json.RawMessage(`{"type":"object"}`)}}, + Temperature: &temp, + MaxTokens: &maxTok, + }) + + if captured["model"] != "qwen3:30b" { + t.Errorf("model = %v; want qwen3:30b", captured["model"]) + } + if captured["stream"] != true { + t.Errorf("stream = %v; want true", captured["stream"]) + } + if captured["temperature"] != float64(0.7) { + t.Errorf("temperature = %v; want 0.7", captured["temperature"]) + } + if captured["max_tokens"] != float64(256) { + t.Errorf("max_tokens = %v; want 256", captured["max_tokens"]) + } + tools, _ := captured["tools"].([]any) + if len(tools) != 1 { + t.Errorf("tools length = %d; want 1", len(tools)) + } +} + +func TestOpenAICompatible_ChatHistoryShape(t *testing.T) { + t.Parallel() + + var captured map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &captured) + writeSSE(w, ssePlain()) + })) + t.Cleanup(srv.Close) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + + drainAll(t, p, provider.Request{ + Model: "x", + Messages: []provider.Message{ + {Role: provider.RoleSystem, Content: "be terse"}, + {Role: provider.RoleUser, Content: "weather?"}, + { + Role: provider.RoleAssistant, + Content: "", + ToolCalls: []provider.ToolCall{ + {ID: "call_1", Name: "get_weather", Arguments: json.RawMessage(`{"city":"NYC"}`)}, + }, + }, + {Role: provider.RoleTool, ToolCallID: "call_1", Content: "72F sunny"}, + }, + }) + + msgs, _ := captured["messages"].([]any) + if len(msgs) != 4 { + t.Fatalf("messages length = %d; want 4", len(msgs)) + } + + asst, _ := msgs[2].(map[string]any) + tcs, _ := asst["tool_calls"].([]any) + if len(tcs) != 1 { + t.Fatalf("assistant tool_calls length = %d; want 1", len(tcs)) + } + tc0, _ := tcs[0].(map[string]any) + if tc0["id"] != "call_1" || tc0["type"] != "function" { + t.Errorf("tool_call[0] = %#v; want id=call_1, type=function", tc0) + } + fn, _ := tc0["function"].(map[string]any) + if fn["name"] != "get_weather" { + t.Errorf("function.name = %v; want get_weather", fn["name"]) + } + if fn["arguments"] != `{"city":"NYC"}` { + t.Errorf("function.arguments = %v; want json string", fn["arguments"]) + } + + tool, _ := msgs[3].(map[string]any) + if tool["role"] != "tool" || tool["tool_call_id"] != "call_1" { + t.Errorf("tool message = %#v; want role=tool, tool_call_id=call_1", tool) + } +} + +func TestOpenAICompatible_ContextCancel(t *testing.T) { + t.Parallel() + + // Server that holds the connection open and emits one frame, then waits. + gate := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "data: {\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}]}\n\n") + w.(http.Flusher).Flush() + <-gate + })) + t.Cleanup(func() { close(gate); srv.Close() }) + + p := mustOpenAI(t, provider.OpenAIConfig{BaseURL: srv.URL + "/v1", HTTPClient: srv.Client()}) + ctx, cancel := context.WithCancel(t.Context()) + + ch, err := p.Stream(ctx, provider.Request{Model: "x"}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + + // Pull at least one event then cancel. + select { + case e := <-ch: + if _, ok := e.(event.TextDelta); !ok { + t.Errorf("first event = %T; want TextDelta", e) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first event") + } + cancel() + + // Channel must close shortly after cancel. + deadline := time.After(2 * time.Second) + for { + select { + case _, ok := <-ch: + if !ok { + return + } + case <-deadline: + t.Fatal("channel did not close after context cancel") + } + } +} + +func TestOpenAICompatible_DialFailure(t *testing.T) { + t.Parallel() + + // Use an address we know nothing listens on. + p := mustOpenAI(t, provider.OpenAIConfig{ + BaseURL: "http://127.0.0.1:1", + HTTPClient: &http.Client{Timeout: 500 * time.Millisecond}, + }) + + _, err := p.Stream(t.Context(), provider.Request{Model: "x"}) + if err == nil { + t.Fatal("Stream returned nil error; want a dial error") + } +} + +// --- helpers ----------------------------------------------------------------- + +func mustOpenAI(t *testing.T, cfg provider.OpenAIConfig) *provider.OpenAICompatible { + t.Helper() + p, err := provider.NewOpenAICompatible(cfg) + if err != nil { + t.Fatalf("NewOpenAICompatible: %v", err) + } + return p +} + +func writeSSE(w http.ResponseWriter, frames []string) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + for _, f := range frames { + if strings.HasPrefix(f, ":") { + fmt.Fprintf(w, "%s\n\n", f) + } else { + fmt.Fprintf(w, "data: %s\n\n", f) + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } +} + +func ssePlain() []string { + return []string{ + `{"choices":[{"index":0,"delta":{"content":"Hello"}}]}`, + `{"choices":[{"index":0,"delta":{"content":" world"}}]}`, + `{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":3,"total_tokens":8}}`, + `[DONE]`, + } +} + +func drainAll(t *testing.T, p provider.Provider, req provider.Request) []event.Event { + t.Helper() + ch, err := p.Stream(t.Context(), req) + if err != nil { + t.Fatalf("Stream: %v", err) + } + var out []event.Event + for e := range ch { + out = append(out, e) + } + return out +} + +func equalEvents(t *testing.T, got, want []event.Event) bool { + t.Helper() + if len(got) != len(want) { + return false + } + for i := range got { + gb, _ := json.Marshal(got[i]) + wb, _ := json.Marshal(want[i]) + if string(gb) != string(wb) { + t.Logf("event[%d]:\n got %s\n want %s", i, gb, wb) + return false + } + } + return true +} diff --git a/core/provider/provider.go b/core/provider/provider.go new file mode 100644 index 0000000..48b5e06 --- /dev/null +++ b/core/provider/provider.go @@ -0,0 +1,100 @@ +// Package provider defines the agent's view of an inference backend. +// +// A [Provider] takes a [Request] (chat-style messages, tool catalog, +// sampling parameters) and returns a channel of [event.Event] values. +// Implementations buffer protocol-level details — SSE framing, partial +// tool-call deltas — so the agent loop sees a clean event stream. +// +// v1 ships a single implementation: [OpenAICompatible], which targets +// any server speaking the OpenAI Chat Completions API. That covers +// Ollama, mlx-lm, llama.cpp's server, and vLLM. Switching between them +// is a configuration change. +package provider + +import ( + "context" + "encoding/json" + + "github.com/plexara/plexara-agents/core/event" +) + +// Provider produces a stream of [event.Event] values in response to a [Request]. +// +// Implementations satisfy the streaming-discipline contract from spec +// §8.4: tool-call events are emitted only when the runtime signals the +// call is complete; arguments are always complete JSON. +type Provider interface { + // Stream initiates a request. The returned channel is closed when the + // stream ends — on completion (with a [event.Finish]), on error (with + // a [event.Error]), or on context cancellation. Callers must drain + // the channel until it is closed. + Stream(ctx context.Context, req Request) (<-chan event.Event, error) + + // Name identifies the provider for logging and diagnostics. + Name() string +} + +// Request is a single chat-style call to a [Provider]. +// +// Sampling fields use pointers so the zero value means "use the +// provider's default" rather than literally zero. Setting Temperature +// to a non-nil pointer with value 0 means "deterministic sampling". +type Request struct { + // Model is the runtime-side model identifier (e.g. "qwen3:30b-a3b"). + Model string + + // Messages is the chat history, in order. The last entry is normally + // the user's most recent input. + Messages []Message + + // Tools is the catalog of tools the model may call this turn. May be + // empty. + Tools []Tool + + Temperature *float32 + TopP *float32 + MaxTokens *int +} + +// Role enumerates the message authors recognized by chat-style providers. +type Role string + +// Recognized message roles. Mirror the OpenAI Chat Completions schema. +const ( + RoleSystem Role = "system" + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleTool Role = "tool" +) + +// Message is one entry in a chat history. +// +// Assistant messages may carry [Message.ToolCalls]. Tool messages carry +// [Message.ToolCallID] referring to the assistant turn that requested +// the call. This shape mirrors the OpenAI Chat Completions schema and +// is the lowest-friction representation across the runtimes we target. +type Message struct { + Role Role + Content string + + ToolCalls []ToolCall + + ToolCallID string +} + +// ToolCall is a tool invocation embedded in an assistant message. Used +// when persisting or replaying a session that already has tool calls in +// the history. +type ToolCall struct { + ID string + Name string + Arguments json.RawMessage +} + +// Tool is a callable function the model may invoke. Parameters is a +// JSON Schema describing the expected arguments. +type Tool struct { + Name string + Description string + Parameters json.RawMessage +} diff --git a/core/provider/testing.go b/core/provider/testing.go new file mode 100644 index 0000000..92a0096 --- /dev/null +++ b/core/provider/testing.go @@ -0,0 +1,102 @@ +package provider + +import ( + "context" + "errors" + "sync" + + "github.com/plexara/plexara-agents/core/event" +) + +// FakeScript is one scripted Stream call's behavior. +// +// If InitError is non-nil, Stream returns that error immediately. Otherwise +// the events are emitted in order on the returned channel and the channel +// is closed. +type FakeScript struct { + Events []event.Event + InitError error +} + +// FakeProvider is an in-memory [Provider] for tests. +// +// It plays a sequence of [FakeScript] entries: the first call to Stream +// uses the first entry, the second call uses the second, and so on. +// After the scripts are exhausted, Stream returns [ErrFakeExhausted]. +// +// Captured calls are available via [FakeProvider.Calls] for assertions. +type FakeProvider struct { + nameStr string + + mu sync.Mutex + scripts []FakeScript + calls []Request + idx int +} + +// FakeOption configures a [FakeProvider]. +type FakeOption func(*FakeProvider) + +// WithFakeName overrides the provider name reported by [FakeProvider.Name]. +func WithFakeName(name string) FakeOption { + return func(f *FakeProvider) { f.nameStr = name } +} + +// ErrFakeExhausted is returned by [FakeProvider.Stream] after all +// scripted entries have been consumed. +var ErrFakeExhausted = errors.New("provider: fake script exhausted") + +// NewFake constructs a [FakeProvider] that plays scripts in order. +func NewFake(scripts []FakeScript, opts ...FakeOption) *FakeProvider { + f := &FakeProvider{ + nameStr: "fake", + scripts: scripts, + } + for _, o := range opts { + o(f) + } + return f +} + +// Name returns the provider's name for logging. +func (f *FakeProvider) Name() string { return f.nameStr } + +// Calls returns a copy of the requests received so far, in order. +func (f *FakeProvider) Calls() []Request { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]Request, len(f.calls)) + copy(out, f.calls) + return out +} + +// Stream replays the next script in turn. If no scripts remain, it +// returns [ErrFakeExhausted]. +func (f *FakeProvider) Stream(ctx context.Context, req Request) (<-chan event.Event, error) { + f.mu.Lock() + if f.idx >= len(f.scripts) { + f.mu.Unlock() + return nil, ErrFakeExhausted + } + script := f.scripts[f.idx] + f.idx++ + f.calls = append(f.calls, req) + f.mu.Unlock() + + if script.InitError != nil { + return nil, script.InitError + } + + out := make(chan event.Event, len(script.Events)) + go func() { + defer close(out) + for _, e := range script.Events { + select { + case out <- e: + case <-ctx.Done(): + return + } + } + }() + return out, nil +} diff --git a/core/provider/testing_test.go b/core/provider/testing_test.go new file mode 100644 index 0000000..8f640d3 --- /dev/null +++ b/core/provider/testing_test.go @@ -0,0 +1,133 @@ +package provider_test + +import ( + "context" + "errors" + "testing" + + "github.com/plexara/plexara-agents/core/event" + "github.com/plexara/plexara-agents/core/provider" +) + +func TestFakeProvider_RepliesInOrder(t *testing.T) { + t.Parallel() + + want1 := []event.Event{ + event.TextDelta{Text: "hi"}, + event.Finish{Reason: event.FinishReasonStop}, + } + want2 := []event.Event{ + event.TextDelta{Text: "again"}, + event.Finish{Reason: event.FinishReasonStop}, + } + + p := provider.NewFake([]provider.FakeScript{ + {Events: want1}, + {Events: want2}, + }) + + got1 := drain(t, p, provider.Request{Model: "m1"}) + got2 := drain(t, p, provider.Request{Model: "m2"}) + + if !sameEvents(got1, want1) { + t.Errorf("first call events mismatch: got %v want %v", got1, want1) + } + if !sameEvents(got2, want2) { + t.Errorf("second call events mismatch: got %v want %v", got2, want2) + } + + calls := p.Calls() + if len(calls) != 2 || calls[0].Model != "m1" || calls[1].Model != "m2" { + t.Errorf("Calls() = %#v; want [{m1}, {m2}]", calls) + } +} + +func TestFakeProvider_InitError(t *testing.T) { + t.Parallel() + + want := errors.New("nope") + p := provider.NewFake([]provider.FakeScript{{InitError: want}}) + + _, err := p.Stream(t.Context(), provider.Request{}) + if !errors.Is(err, want) { + t.Errorf("Stream err = %v; want %v", err, want) + } +} + +func TestFakeProvider_Exhausted(t *testing.T) { + t.Parallel() + + p := provider.NewFake([]provider.FakeScript{ + {Events: []event.Event{event.Finish{}}}, + }) + + _ = drain(t, p, provider.Request{}) + _, err := p.Stream(t.Context(), provider.Request{}) + if !errors.Is(err, provider.ErrFakeExhausted) { + t.Errorf("Stream err = %v; want ErrFakeExhausted", err) + } +} + +func TestFakeProvider_WithName(t *testing.T) { + t.Parallel() + + p := provider.NewFake(nil, provider.WithFakeName("custom")) + if got := p.Name(); got != "custom" { + t.Errorf("Name() = %q; want %q", got, "custom") + } +} + +func TestFakeProvider_DefaultName(t *testing.T) { + t.Parallel() + + if got := provider.NewFake(nil).Name(); got != "fake" { + t.Errorf("Name() default = %q; want %q", got, "fake") + } +} + +func TestFakeProvider_ContextCancelled(t *testing.T) { + t.Parallel() + + p := provider.NewFake([]provider.FakeScript{ + {Events: []event.Event{event.TextDelta{Text: "a"}, event.TextDelta{Text: "b"}}}, + }) + + ctx, cancel := context.WithCancel(t.Context()) + ch, err := p.Stream(ctx, provider.Request{}) + if err != nil { + t.Fatalf("Stream: %v", err) + } + cancel() + + // Drain — channel must close (possibly without delivering all events). + for range ch { //nolint:revive // Intentional empty drain. + } +} + +// drain reads all events from a Stream call until the channel closes. +func drain(t *testing.T, p provider.Provider, req provider.Request) []event.Event { + t.Helper() + ch, err := p.Stream(t.Context(), req) + if err != nil { + t.Fatalf("Stream: %v", err) + } + var out []event.Event + for e := range ch { + out = append(out, e) + } + return out +} + +func sameEvents(a, b []event.Event) bool { + if len(a) != len(b) { + return false + } + for i := range a { + // Concrete-type comparison is fine for the simple values in + // these tests (no slices, no pointers). + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/docs/ci.md b/docs/ci.md index 43b86fa..f98ddc8 100644 --- a/docs/ci.md +++ b/docs/ci.md @@ -15,6 +15,29 @@ Operator's runbook for the GitHub Actions workflows under `.github/workflows/` a | `fuzz.yml` | nightly cron + manual | Extended fuzz cycles, one matrix entry per `Fuzz*` target | | `release.yml` | tag `v*.*.*` | GoReleaser, cosign keyless signing, SBOMs, SLSA build provenance | +## Local verification — `make verify` + +Before pushing, run `make verify`. It executes every check CI runs that can run locally (lint, vet, mod tidy, race tests with coverage, the full cross-compile matrix, gosec, govulncheck, **Semgrep in the same Docker image CI uses**, plus a brief fuzz pass). If `make verify` is green, CI is very likely to be green. + +The repo ships a pre-push hook at `.githooks/pre-push` that runs `make verify` automatically. Enable it with: + +```sh +git config core.hooksPath .githooks +``` + +`SKIP_VERIFY=1 git push` bypasses for emergencies. + +`CONTRIBUTING.md` is the contributor-facing reference; this section is the operator's reference. + +## Codecov registration + +The repo is configured to upload coverage to Codecov from `ci.yml`. Until the repo is registered with Codecov: + +1. The Codecov upload step runs but emits "Repository not found" and exits non-zero. +2. `fail_ci_if_error: false` is set on the upload step so this does not fail CI. + +Once registered (`gh secret list` should show `CODECOV_TOKEN`; visit and complete onboarding), flip `fail_ci_if_error: true` in `ci.yml` so a real upload regression starts blocking merges again. + ## Required and optional secrets | Secret | Required for | Notes |