Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion internal/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,29 @@ const (
HeaderGroupID = "X-Saturn-Group-Id"
HeaderResourceID = "X-Saturn-Resource-Id"
HeaderResourceType = "X-Saturn-Resource-Type"

// HeaderAuthID carries the token / API-key identity — the JWT `sub` claim,
// which in Atlas is the IdentityAuth.id (the same value for both browser-
// session and API-key tokens; they share one token mechanism). This is the
// stable key to attribute consumption to a specific API key. Org / user /
// group are resolved DOWNSTREAM (out of band, at rating time) from this id
// via the IdentityAuth record, so the hot path never has to resolve the
// active-org context (which, for a user in multiple orgs, isn't in the
// token).
//
// NOTE: auth-server does not inject this header yet — it currently emits
// only User/Group/Resource. Wiring `sub` → this header in auth-server, and
// adding it to Traefik's authResponseHeaders allowlist, is a separate
// (small) change. Phoebe reads it defensively: absent = empty string.
HeaderAuthID = "X-Saturn-Auth-Id"
)

// Identity is the trusted, pre-resolved caller identity for a request.
// Identity is the trusted, pre-resolved caller identity for a request. Phoebe
// captures everything atlas-auth gives it and attributes downstream; it does
// not decide which field is "the tenant" on the hot path.
type Identity struct {
// AuthID is the token / API-key identity (JWT sub). Primary attribution key.
AuthID string
UserID string
GroupID string
ResourceID string
Expand All @@ -27,6 +46,7 @@ type Identity struct {
// validation beyond reading the values; authorization happened at the edge.
func FromRequest(r *http.Request) Identity {
return Identity{
AuthID: r.Header.Get(HeaderAuthID),
UserID: r.Header.Get(HeaderUserID),
GroupID: r.Header.Get(HeaderGroupID),
ResourceID: r.Header.Get(HeaderResourceID),
Expand Down
42 changes: 42 additions & 0 deletions internal/identity/identity_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package identity

import (
"net/http"
"net/http/httptest"
"testing"
)

func TestFromRequestCapturesAllHeaders(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
r.Header.Set(HeaderAuthID, "auth-abc")
r.Header.Set(HeaderUserID, "user-1")
r.Header.Set(HeaderGroupID, "group-2")
r.Header.Set(HeaderResourceID, "model-x")
r.Header.Set(HeaderResourceType, "deployment")

id := FromRequest(r)

if id.AuthID != "auth-abc" {
t.Errorf("AuthID = %q, want auth-abc", id.AuthID)
}
if id.UserID != "user-1" {
t.Errorf("UserID = %q, want user-1", id.UserID)
}
if id.GroupID != "group-2" {
t.Errorf("GroupID = %q, want group-2", id.GroupID)
}
if id.ResourceID != "model-x" {
t.Errorf("ResourceID = %q, want model-x", id.ResourceID)
}
if id.ResourceType != "deployment" {
t.Errorf("ResourceType = %q, want deployment", id.ResourceType)
}
}

func TestFromRequestMissingHeadersAreEmpty(t *testing.T) {
r := httptest.NewRequest(http.MethodPost, "/", nil)
id := FromRequest(r)
if id.AuthID != "" || id.UserID != "" || id.GroupID != "" || id.ResourceID != "" || id.ResourceType != "" {
t.Fatalf("expected all-empty identity, got %+v", id)
}
}
44 changes: 31 additions & 13 deletions internal/metering/metering.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,36 @@ func (u Usage) CachedTokens() int {

// Event is one immutable, idempotent metering record per request. It is keyed
// by RequestID for downstream dedup (at-least-once delivery).
//
// Phoebe records the RAW identity it was given (every X-Saturn-* header) plus
// raw token counts. It does NOT resolve org/tenant on the hot path: AuthID (the
// token / API-key id) is the stable attribution key, and rating resolves
// auth_id → IdentityAuth → user/group/org out of band. UserID, GroupID,
// ResourceID, ResourceType are captured verbatim so no information the edge
// gave us is lost.
type Event struct {
RequestID string `json:"request_id"`
GroupID string `json:"group_id"` // tenant / org
UserID string `json:"user_id"`
Model string `json:"model"`
Adapter string `json:"adapter,omitempty"`
PromptTokens int `json:"prompt_tokens"`
CachedTokens int `json:"cached_tokens"`
CompletionTokens int `json:"completion_tokens"`
FinishReason string `json:"finish_reason,omitempty"`
GPUType string `json:"gpu_type,omitempty"` // for margin; echoed by router/engine
Aborted bool `json:"aborted,omitempty"`
RequestID string `json:"request_id"`

// Identity, captured verbatim from atlas-auth headers.
AuthID string `json:"auth_id,omitempty"` // token / API-key id (JWT sub) — primary key
UserID string `json:"user_id,omitempty"` // present on user tokens
GroupID string `json:"group_id,omitempty"` // present on group tokens
ResourceID string `json:"resource_id,omitempty"` // model / deployment id
ResourceType string `json:"resource_type,omitempty"` // e.g. workspace, deployment

// Workload.
Model string `json:"model,omitempty"`
Adapter string `json:"adapter,omitempty"`

// Token counts (the engine's own usage block; never re-tokenized).
PromptTokens int `json:"prompt_tokens"`
CachedTokens int `json:"cached_tokens"`
CompletionTokens int `json:"completion_tokens"`

FinishReason string `json:"finish_reason,omitempty"`
GPUType string `json:"gpu_type,omitempty"` // for margin; echoed by router/engine
Aborted bool `json:"aborted,omitempty"`

// TimestampUnixMs is stamped by the emitter, not in the hot path here.
TimestampUnixMs int64 `json:"timestamp_unix_ms"`
}
Expand All @@ -66,7 +84,7 @@ type LogEmitter struct {
}

func (l *LogEmitter) Emit(_ context.Context, e Event) {
l.Log.Info.Printf("metering event: request_id=%s group=%s user=%s model=%s prompt=%d cached=%d completion=%d finish=%s aborted=%t",
e.RequestID, e.GroupID, e.UserID, e.Model,
l.Log.Info.Printf("metering event: request_id=%s auth_id=%s group=%s user=%s resource=%s/%s model=%s prompt=%d cached=%d completion=%d finish=%s aborted=%t",
e.RequestID, e.AuthID, e.GroupID, e.UserID, e.ResourceType, e.ResourceID, e.Model,
e.PromptTokens, e.CachedTokens, e.CompletionTokens, e.FinishReason, e.Aborted)
}
4 changes: 4 additions & 0 deletions internal/proxy/abort_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func doAbortRequest(t *testing.T, srv *Server, delayBeforeCancel time.Duration)
if err != nil {
t.Fatalf("new request: %v", err)
}
req.Header.Set(identity.HeaderAuthID, "auth-1")
req.Header.Set(identity.HeaderResourceID, "model-abc")
req.Header.Set(identity.HeaderGroupID, "org-1")
req.Header.Set(identity.HeaderUserID, "user-1")
Expand Down Expand Up @@ -250,6 +251,7 @@ func TestNormalCompletionNotAffectedByAbortWatcher(t *testing.T) {

req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions",
strings.NewReader(`{"model":"m","stream":true,"messages":[]}`))
req.Header.Set(identity.HeaderAuthID, "auth-1")
req.Header.Set(identity.HeaderResourceID, "model-abc")
req.Header.Set(identity.HeaderGroupID, "org-1")
req.Header.Set(identity.HeaderUserID, "user-1")
Expand Down Expand Up @@ -307,6 +309,7 @@ func TestAbortRaceStress(t *testing.T) {
}()
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "/v1/chat/completions",
strings.NewReader(`{"model":"m","stream":true,"messages":[]}`))
req.Header.Set(identity.HeaderAuthID, "auth-1")
req.Header.Set(identity.HeaderResourceID, "model-abc")
req.Header.Set(identity.HeaderGroupID, "org-1")
req.Header.Set(identity.HeaderUserID, "user-1")
Expand Down Expand Up @@ -373,6 +376,7 @@ func TestLongStreamNoDeadlineSever(t *testing.T) {

req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions",
strings.NewReader(`{"model":"m","stream":true,"messages":[]}`))
req.Header.Set(identity.HeaderAuthID, "auth-1")
req.Header.Set(identity.HeaderResourceID, "model-abc")
req.Header.Set(identity.HeaderGroupID, "org-1")
req.Header.Set(identity.HeaderUserID, "user-1")
Expand Down
42 changes: 36 additions & 6 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,16 @@ func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) {
func (s *Server) handleProxy(w http.ResponseWriter, r *http.Request) {
id := identity.FromRequest(r)

if id.ResourceID == "" {
http.Error(w, "missing "+identity.HeaderResourceID, http.StatusBadRequest)
// Billing-identity gate: fail closed if we lack what we need to attribute
// consumption. A billing product must not serve traffic it can't bill — a
// missing identity header means the edge contract is broken (auth-server
// not emitting it, or Traefik not allowlisting it), not a normal request.
// Report every missing field at once so the misconfiguration is obvious.
if missing := missingBillingFields(id); len(missing) > 0 {
s.log.Warn.Printf("rejecting unbillable request: missing %s (request_id=%s)",
strings.Join(missing, ", "), r.Header.Get(requestIDHeader))
http.Error(w, "missing required billing identity: "+strings.Join(missing, ", "),
http.StatusBadRequest)
return
}

Expand Down Expand Up @@ -197,10 +205,15 @@ func (s *Server) emit(ctx context.Context, id identity.Identity, requestID strin
}

e := metering.Event{
RequestID: requestID,
GroupID: id.GroupID,
UserID: id.UserID,
Model: id.ResourceID,
RequestID: requestID,
// Identity captured verbatim — attribution resolved downstream.
AuthID: id.AuthID,
UserID: id.UserID,
GroupID: id.GroupID,
ResourceID: id.ResourceID,
ResourceType: id.ResourceType,
Model: id.ResourceID,

PromptTokens: res.Usage.PromptTokens,
CachedTokens: res.Usage.CachedTokens(),
CompletionTokens: res.Usage.CompletionTokens,
Expand All @@ -210,6 +223,23 @@ func (s *Server) emit(ctx context.Context, id identity.Identity, requestID strin
s.emitter.Emit(ctx, e)
}

// missingBillingFields returns the names of the identity headers required to
// bill a request that are absent. Empty result means the request is billable.
//
// AuthID (token / API-key id) is the attribution key; ResourceID identifies
// the model being billed. Both are mandatory. UserID/GroupID are resolved
// downstream from AuthID and are NOT required here.
func missingBillingFields(id identity.Identity) []string {
var missing []string
if id.AuthID == "" {
missing = append(missing, identity.HeaderAuthID)
}
if id.ResourceID == "" {
missing = append(missing, identity.HeaderResourceID)
}
return missing
}

// isEventStream reports whether the response is an SSE stream.
func isEventStream(resp *http.Response) bool {
ct := resp.Header.Get("Content-Type")
Expand Down
64 changes: 57 additions & 7 deletions internal/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,53 @@ func TestHealthz(t *testing.T) {
}
}

func TestProxyMissingResourceID(t *testing.T) {
srv := newTestServer(t, &url.URL{Scheme: "http", Host: "localhost:1"})
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
srv.Handler().ServeHTTP(rr, req)
if rr.Code != http.StatusBadRequest {
t.Fatalf("missing resource id: got %d, want 400", rr.Code)
// TestProxyBillingGate verifies the fail-closed billing-identity gate: a
// request missing the auth-id and/or resource-id headers is rejected with 400
// (we never serve traffic we can't attribute), and the error names what's
// missing. An emitter is checked to ensure nothing is billed for a reject.
func TestProxyBillingGate(t *testing.T) {
tests := []struct {
name string
authID string
resourceID string
wantStatus int
wantInBody string
}{
{"missing both", "", "", http.StatusBadRequest, identity.HeaderAuthID},
{"missing auth-id", "", "model-abc", http.StatusBadRequest, identity.HeaderAuthID},
{"missing resource-id", "auth-1", "", http.StatusBadRequest, identity.HeaderResourceID},
{"both present", "auth-1", "model-abc", http.StatusOK, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer backend.Close()
upstream, _ := url.Parse(backend.URL)
em := &recordingEmitter{}
srv := newTestServerE(t, upstream, em)

rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
if tt.authID != "" {
req.Header.Set(identity.HeaderAuthID, tt.authID)
}
if tt.resourceID != "" {
req.Header.Set(identity.HeaderResourceID, tt.resourceID)
}
srv.Handler().ServeHTTP(rr, req)

if rr.Code != tt.wantStatus {
t.Fatalf("status = %d, want %d", rr.Code, tt.wantStatus)
}
if tt.wantInBody != "" && !strings.Contains(rr.Body.String(), tt.wantInBody) {
t.Fatalf("body %q does not name missing field %q", rr.Body.String(), tt.wantInBody)
}
if tt.wantStatus == http.StatusBadRequest && len(em.all()) != 0 {
t.Fatalf("rejected request should emit no billing event, got %d", len(em.all()))
}
})
}
}

Expand All @@ -80,6 +120,7 @@ func TestProxyForwardsToUpstream(t *testing.T) {

rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
req.Header.Set(identity.HeaderAuthID, "auth-key-7")
req.Header.Set(identity.HeaderResourceID, "model-abc")
srv.Handler().ServeHTTP(rr, req)

Expand All @@ -100,6 +141,7 @@ func TestProxyNotFound(t *testing.T) {

rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
req.Header.Set(identity.HeaderAuthID, "auth-1")
req.Header.Set(identity.HeaderResourceID, "gone")
srv.Handler().ServeHTTP(rr, req)

Expand Down Expand Up @@ -142,8 +184,10 @@ func TestProxyStreamingEndToEnd(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions",
strings.NewReader(`{"model":"m","stream":true,"messages":[]}`))
req.Header.Set(identity.HeaderResourceID, "model-abc")
req.Header.Set(identity.HeaderResourceType, "deployment")
req.Header.Set(identity.HeaderGroupID, "org-1")
req.Header.Set(identity.HeaderUserID, "user-1")
req.Header.Set(identity.HeaderAuthID, "auth-key-7")
req.Header.Set("X-Request-Id", "req-123")

srv.Handler().ServeHTTP(rr, req)
Expand All @@ -166,6 +210,12 @@ func TestProxyStreamingEndToEnd(t *testing.T) {
if e.RequestID != "req-123" || e.GroupID != "org-1" || e.UserID != "user-1" || e.Model != "model-abc" {
t.Fatalf("event identity wrong: %+v", e)
}
if e.AuthID != "auth-key-7" {
t.Fatalf("event AuthID = %q, want auth-key-7", e.AuthID)
}
if e.ResourceID != "model-abc" || e.ResourceType != "deployment" {
t.Fatalf("event resource fields wrong: id=%q type=%q", e.ResourceID, e.ResourceType)
}
if e.PromptTokens != 2006 || e.CompletionTokens != 300 || e.CachedTokens != 1920 {
t.Fatalf("event token counts wrong: %+v", e)
}
Expand Down
Loading