diff --git a/internal/identity/identity.go b/internal/identity/identity.go index 80e9bd4..1a04e60 100644 --- a/internal/identity/identity.go +++ b/internal/identity/identity.go @@ -13,10 +13,29 @@ const ( HeaderGroupID = "X-Saturn-Group-Id" HeaderResourceID = "X-Saturn-Resource-Id" HeaderResourceType = "X-Saturn-Resource-Type" + + // HeaderAuthID carries the token / API-key identity — the JWT `sub` claim, + // which in Atlas is the IdentityAuth.id (the same value for both browser- + // session and API-key tokens; they share one token mechanism). This is the + // stable key to attribute consumption to a specific API key. Org / user / + // group are resolved DOWNSTREAM (out of band, at rating time) from this id + // via the IdentityAuth record, so the hot path never has to resolve the + // active-org context (which, for a user in multiple orgs, isn't in the + // token). + // + // NOTE: auth-server does not inject this header yet — it currently emits + // only User/Group/Resource. Wiring `sub` → this header in auth-server, and + // adding it to Traefik's authResponseHeaders allowlist, is a separate + // (small) change. Phoebe reads it defensively: absent = empty string. + HeaderAuthID = "X-Saturn-Auth-Id" ) -// Identity is the trusted, pre-resolved caller identity for a request. +// Identity is the trusted, pre-resolved caller identity for a request. Phoebe +// captures everything atlas-auth gives it and attributes downstream; it does +// not decide which field is "the tenant" on the hot path. type Identity struct { + // AuthID is the token / API-key identity (JWT sub). Primary attribution key. + AuthID string UserID string GroupID string ResourceID string @@ -27,6 +46,7 @@ type Identity struct { // validation beyond reading the values; authorization happened at the edge. func FromRequest(r *http.Request) Identity { return Identity{ + AuthID: r.Header.Get(HeaderAuthID), UserID: r.Header.Get(HeaderUserID), GroupID: r.Header.Get(HeaderGroupID), ResourceID: r.Header.Get(HeaderResourceID), diff --git a/internal/identity/identity_test.go b/internal/identity/identity_test.go new file mode 100644 index 0000000..f4abecd --- /dev/null +++ b/internal/identity/identity_test.go @@ -0,0 +1,42 @@ +package identity + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestFromRequestCapturesAllHeaders(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + r.Header.Set(HeaderAuthID, "auth-abc") + r.Header.Set(HeaderUserID, "user-1") + r.Header.Set(HeaderGroupID, "group-2") + r.Header.Set(HeaderResourceID, "model-x") + r.Header.Set(HeaderResourceType, "deployment") + + id := FromRequest(r) + + if id.AuthID != "auth-abc" { + t.Errorf("AuthID = %q, want auth-abc", id.AuthID) + } + if id.UserID != "user-1" { + t.Errorf("UserID = %q, want user-1", id.UserID) + } + if id.GroupID != "group-2" { + t.Errorf("GroupID = %q, want group-2", id.GroupID) + } + if id.ResourceID != "model-x" { + t.Errorf("ResourceID = %q, want model-x", id.ResourceID) + } + if id.ResourceType != "deployment" { + t.Errorf("ResourceType = %q, want deployment", id.ResourceType) + } +} + +func TestFromRequestMissingHeadersAreEmpty(t *testing.T) { + r := httptest.NewRequest(http.MethodPost, "/", nil) + id := FromRequest(r) + if id.AuthID != "" || id.UserID != "" || id.GroupID != "" || id.ResourceID != "" || id.ResourceType != "" { + t.Fatalf("expected all-empty identity, got %+v", id) + } +} diff --git a/internal/metering/metering.go b/internal/metering/metering.go index 3cbe5e5..5d91122 100644 --- a/internal/metering/metering.go +++ b/internal/metering/metering.go @@ -36,18 +36,36 @@ func (u Usage) CachedTokens() int { // Event is one immutable, idempotent metering record per request. It is keyed // by RequestID for downstream dedup (at-least-once delivery). +// +// Phoebe records the RAW identity it was given (every X-Saturn-* header) plus +// raw token counts. It does NOT resolve org/tenant on the hot path: AuthID (the +// token / API-key id) is the stable attribution key, and rating resolves +// auth_id → IdentityAuth → user/group/org out of band. UserID, GroupID, +// ResourceID, ResourceType are captured verbatim so no information the edge +// gave us is lost. type Event struct { - RequestID string `json:"request_id"` - GroupID string `json:"group_id"` // tenant / org - UserID string `json:"user_id"` - Model string `json:"model"` - Adapter string `json:"adapter,omitempty"` - PromptTokens int `json:"prompt_tokens"` - CachedTokens int `json:"cached_tokens"` - CompletionTokens int `json:"completion_tokens"` - FinishReason string `json:"finish_reason,omitempty"` - GPUType string `json:"gpu_type,omitempty"` // for margin; echoed by router/engine - Aborted bool `json:"aborted,omitempty"` + RequestID string `json:"request_id"` + + // Identity, captured verbatim from atlas-auth headers. + AuthID string `json:"auth_id,omitempty"` // token / API-key id (JWT sub) — primary key + UserID string `json:"user_id,omitempty"` // present on user tokens + GroupID string `json:"group_id,omitempty"` // present on group tokens + ResourceID string `json:"resource_id,omitempty"` // model / deployment id + ResourceType string `json:"resource_type,omitempty"` // e.g. workspace, deployment + + // Workload. + Model string `json:"model,omitempty"` + Adapter string `json:"adapter,omitempty"` + + // Token counts (the engine's own usage block; never re-tokenized). + PromptTokens int `json:"prompt_tokens"` + CachedTokens int `json:"cached_tokens"` + CompletionTokens int `json:"completion_tokens"` + + FinishReason string `json:"finish_reason,omitempty"` + GPUType string `json:"gpu_type,omitempty"` // for margin; echoed by router/engine + Aborted bool `json:"aborted,omitempty"` + // TimestampUnixMs is stamped by the emitter, not in the hot path here. TimestampUnixMs int64 `json:"timestamp_unix_ms"` } @@ -66,7 +84,7 @@ type LogEmitter struct { } func (l *LogEmitter) Emit(_ context.Context, e Event) { - l.Log.Info.Printf("metering event: request_id=%s group=%s user=%s model=%s prompt=%d cached=%d completion=%d finish=%s aborted=%t", - e.RequestID, e.GroupID, e.UserID, e.Model, + l.Log.Info.Printf("metering event: request_id=%s auth_id=%s group=%s user=%s resource=%s/%s model=%s prompt=%d cached=%d completion=%d finish=%s aborted=%t", + e.RequestID, e.AuthID, e.GroupID, e.UserID, e.ResourceType, e.ResourceID, e.Model, e.PromptTokens, e.CachedTokens, e.CompletionTokens, e.FinishReason, e.Aborted) } diff --git a/internal/proxy/abort_test.go b/internal/proxy/abort_test.go index 5e1a3de..f1739ba 100644 --- a/internal/proxy/abort_test.go +++ b/internal/proxy/abort_test.go @@ -83,6 +83,7 @@ func doAbortRequest(t *testing.T, srv *Server, delayBeforeCancel time.Duration) if err != nil { t.Fatalf("new request: %v", err) } + req.Header.Set(identity.HeaderAuthID, "auth-1") req.Header.Set(identity.HeaderResourceID, "model-abc") req.Header.Set(identity.HeaderGroupID, "org-1") req.Header.Set(identity.HeaderUserID, "user-1") @@ -250,6 +251,7 @@ func TestNormalCompletionNotAffectedByAbortWatcher(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"m","stream":true,"messages":[]}`)) + req.Header.Set(identity.HeaderAuthID, "auth-1") req.Header.Set(identity.HeaderResourceID, "model-abc") req.Header.Set(identity.HeaderGroupID, "org-1") req.Header.Set(identity.HeaderUserID, "user-1") @@ -307,6 +309,7 @@ func TestAbortRaceStress(t *testing.T) { }() req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"m","stream":true,"messages":[]}`)) + req.Header.Set(identity.HeaderAuthID, "auth-1") req.Header.Set(identity.HeaderResourceID, "model-abc") req.Header.Set(identity.HeaderGroupID, "org-1") req.Header.Set(identity.HeaderUserID, "user-1") @@ -373,6 +376,7 @@ func TestLongStreamNoDeadlineSever(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"m","stream":true,"messages":[]}`)) + req.Header.Set(identity.HeaderAuthID, "auth-1") req.Header.Set(identity.HeaderResourceID, "model-abc") req.Header.Set(identity.HeaderGroupID, "org-1") req.Header.Set(identity.HeaderUserID, "user-1") diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 660a6bc..9944e0b 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -79,8 +79,16 @@ func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) { func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) { id := identity.FromRequest(r) - if id.ResourceID == "" { - http.Error(w, "missing "+identity.HeaderResourceID, http.StatusBadRequest) + // Billing-identity gate: fail closed if we lack what we need to attribute + // consumption. A billing product must not serve traffic it can't bill — a + // missing identity header means the edge contract is broken (auth-server + // not emitting it, or Traefik not allowlisting it), not a normal request. + // Report every missing field at once so the misconfiguration is obvious. + if missing := missingBillingFields(id); len(missing) > 0 { + s.log.Warn.Printf("rejecting unbillable request: missing %s (request_id=%s)", + strings.Join(missing, ", "), r.Header.Get(requestIDHeader)) + http.Error(w, "missing required billing identity: "+strings.Join(missing, ", "), + http.StatusBadRequest) return } @@ -197,10 +205,15 @@ func (s *Server) emit(ctx context.Context, id identity.Identity, requestID strin } e := metering.Event{ - RequestID: requestID, - GroupID: id.GroupID, - UserID: id.UserID, - Model: id.ResourceID, + RequestID: requestID, + // Identity captured verbatim — attribution resolved downstream. + AuthID: id.AuthID, + UserID: id.UserID, + GroupID: id.GroupID, + ResourceID: id.ResourceID, + ResourceType: id.ResourceType, + Model: id.ResourceID, + PromptTokens: res.Usage.PromptTokens, CachedTokens: res.Usage.CachedTokens(), CompletionTokens: res.Usage.CompletionTokens, @@ -210,6 +223,23 @@ func (s *Server) emit(ctx context.Context, id identity.Identity, requestID strin s.emitter.Emit(ctx, e) } +// missingBillingFields returns the names of the identity headers required to +// bill a request that are absent. Empty result means the request is billable. +// +// AuthID (token / API-key id) is the attribution key; ResourceID identifies +// the model being billed. Both are mandatory. UserID/GroupID are resolved +// downstream from AuthID and are NOT required here. +func missingBillingFields(id identity.Identity) []string { + var missing []string + if id.AuthID == "" { + missing = append(missing, identity.HeaderAuthID) + } + if id.ResourceID == "" { + missing = append(missing, identity.HeaderResourceID) + } + return missing +} + // isEventStream reports whether the response is an SSE stream. func isEventStream(resp *http.Response) bool { ct := resp.Header.Get("Content-Type") diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2df6c8d..82dfa6e 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -59,13 +59,53 @@ func TestHealthz(t *testing.T) { } } -func TestProxyMissingResourceID(t *testing.T) { - srv := newTestServer(t, &url.URL{Scheme: "http", Host: "localhost:1"}) - rr := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) - srv.Handler().ServeHTTP(rr, req) - if rr.Code != http.StatusBadRequest { - t.Fatalf("missing resource id: got %d, want 400", rr.Code) +// TestProxyBillingGate verifies the fail-closed billing-identity gate: a +// request missing the auth-id and/or resource-id headers is rejected with 400 +// (we never serve traffic we can't attribute), and the error names what's +// missing. An emitter is checked to ensure nothing is billed for a reject. +func TestProxyBillingGate(t *testing.T) { + tests := []struct { + name string + authID string + resourceID string + wantStatus int + wantInBody string + }{ + {"missing both", "", "", http.StatusBadRequest, identity.HeaderAuthID}, + {"missing auth-id", "", "model-abc", http.StatusBadRequest, identity.HeaderAuthID}, + {"missing resource-id", "auth-1", "", http.StatusBadRequest, identity.HeaderResourceID}, + {"both present", "auth-1", "model-abc", http.StatusOK, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer backend.Close() + upstream, _ := url.Parse(backend.URL) + em := &recordingEmitter{} + srv := newTestServerE(t, upstream, em) + + rr := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + if tt.authID != "" { + req.Header.Set(identity.HeaderAuthID, tt.authID) + } + if tt.resourceID != "" { + req.Header.Set(identity.HeaderResourceID, tt.resourceID) + } + srv.Handler().ServeHTTP(rr, req) + + if rr.Code != tt.wantStatus { + t.Fatalf("status = %d, want %d", rr.Code, tt.wantStatus) + } + if tt.wantInBody != "" && !strings.Contains(rr.Body.String(), tt.wantInBody) { + t.Fatalf("body %q does not name missing field %q", rr.Body.String(), tt.wantInBody) + } + if tt.wantStatus == http.StatusBadRequest && len(em.all()) != 0 { + t.Fatalf("rejected request should emit no billing event, got %d", len(em.all())) + } + }) } } @@ -80,6 +120,7 @@ func TestProxyForwardsToUpstream(t *testing.T) { rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + req.Header.Set(identity.HeaderAuthID, "auth-key-7") req.Header.Set(identity.HeaderResourceID, "model-abc") srv.Handler().ServeHTTP(rr, req) @@ -100,6 +141,7 @@ func TestProxyNotFound(t *testing.T) { rr := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + req.Header.Set(identity.HeaderAuthID, "auth-1") req.Header.Set(identity.HeaderResourceID, "gone") srv.Handler().ServeHTTP(rr, req) @@ -142,8 +184,10 @@ func TestProxyStreamingEndToEnd(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"model":"m","stream":true,"messages":[]}`)) req.Header.Set(identity.HeaderResourceID, "model-abc") + req.Header.Set(identity.HeaderResourceType, "deployment") req.Header.Set(identity.HeaderGroupID, "org-1") req.Header.Set(identity.HeaderUserID, "user-1") + req.Header.Set(identity.HeaderAuthID, "auth-key-7") req.Header.Set("X-Request-Id", "req-123") srv.Handler().ServeHTTP(rr, req) @@ -166,6 +210,12 @@ func TestProxyStreamingEndToEnd(t *testing.T) { if e.RequestID != "req-123" || e.GroupID != "org-1" || e.UserID != "user-1" || e.Model != "model-abc" { t.Fatalf("event identity wrong: %+v", e) } + if e.AuthID != "auth-key-7" { + t.Fatalf("event AuthID = %q, want auth-key-7", e.AuthID) + } + if e.ResourceID != "model-abc" || e.ResourceType != "deployment" { + t.Fatalf("event resource fields wrong: id=%q type=%q", e.ResourceID, e.ResourceType) + } if e.PromptTokens != 2006 || e.CompletionTokens != 300 || e.CachedTokens != 1920 { t.Fatalf("event token counts wrong: %+v", e) }