diff --git a/app/integration.go b/app/integration.go index d3c6234..0487584 100644 --- a/app/integration.go +++ b/app/integration.go @@ -213,24 +213,16 @@ func prepareIntegration(i *Integration) error { i.proxy = httputil.NewSingleHostReverseProxy(u) oldDirector := i.proxy.Director - if hasWildcard { - i.proxy.Director = func(req *http.Request) { - dest, ok := resolvedDestinationFromContext(req.Context()) - if !ok { - oldDirector(req) - req.Host = u.Host - return - } - if resolvedDestinationApplied(req.Context()) { - return - } - applyResolvedDestination(req, dest) + i.proxy.Director = func(req *http.Request) { + if resolvedDestinationApplied(req.Context()) { + return } - } else { - i.proxy.Director = func(req *http.Request) { - oldDirector(req) - req.Host = u.Host + if dest, ok := resolvedDestinationFromContext(req.Context()); ok { + applyResolvedDestination(req, dest) + return } + oldDirector(req) + req.Host = u.Host } i.proxy.ModifyResponse = func(resp *http.Response) error { diff --git a/app/main.go b/app/main.go index 2cf5808..538980b 100644 --- a/app/main.go +++ b/app/main.go @@ -1418,6 +1418,14 @@ func proxyHandler(w http.ResponseWriter, r *http.Request) { } } + metricsHost := r.Host + metricsRequestURI := r.RequestURI + var metricsURL *url.URL + if r.URL != nil { + u := *r.URL + metricsURL = &u + } + resolvedDest, err := integ.resolveRequestDestination(r) if err != nil { logger.Warn("invalid destination header", "integration", integ.Name, "error", err) @@ -1429,8 +1437,8 @@ func proxyHandler(w http.ResponseWriter, r *http.Request) { return } r = r.WithContext(contextWithResolvedDestination(r.Context(), resolvedDest)) + applyResolvedDestination(r, resolvedDest) if integ.requiresDestinationHeader { - applyResolvedDestination(r, resolvedDest) r.Header.Del("X-AT-Destination") } @@ -1459,7 +1467,18 @@ func proxyHandler(w http.ResponseWriter, r *http.Request) { return } + proxyHost := r.Host + proxyRequestURI := r.RequestURI + proxyURL := r.URL + // Metrics hooks observe the proxy-facing route, but must use the live + // request so body reads are restored for the upstream proxy. + r.Host = metricsHost + r.RequestURI = metricsRequestURI + r.URL = metricsURL metrics.OnRequest(integ.Name, r) + r.Host = proxyHost + r.RequestURI = proxyRequestURI + r.URL = proxyURL handoffStart := time.Now() r = r.WithContext(metrics.WithUpstreamRoundtripStart(r.Context(), handoffStart)) metrics.RecordPreProxyDuration(integ.Name, handoffStart.Sub(start)) diff --git a/app/proxy_test.go b/app/proxy_test.go index c9f38e3..0c80c3e 100644 --- a/app/proxy_test.go +++ b/app/proxy_test.go @@ -122,6 +122,34 @@ func (*delayedMetricsPlugin) WriteProm(http.ResponseWriter) {} var delayedMetricsPluginOnce sync.Once +type bodyReadingMetricsPlugin struct { + matchHost string + err error + body string + host string + path string + rawQuery string +} + +func (p *bodyReadingMetricsPlugin) OnRequest(_ string, r *http.Request) { + if r.Host != p.matchHost { + return + } + p.host = r.Host + p.path = r.URL.Path + p.rawQuery = r.URL.RawQuery + body, err := authplugins.GetBody(r) + if err != nil { + p.err = err + return + } + p.body = string(body) +} + +func (*bodyReadingMetricsPlugin) OnResponse(string, string, *http.Request, *http.Response) {} + +func (*bodyReadingMetricsPlugin) WriteProm(http.ResponseWriter) {} + func promCounterValue(t *testing.T, prefix string) float64 { t.Helper() @@ -332,6 +360,63 @@ func TestProxyHandlerPreProxyIncludesMetricsRequestHooks(t *testing.T) { } } +func TestProxyHandlerMetricsBodyReadPreservesUpstreamBody(t *testing.T) { + denylists.Lock() + denylists.m = make(map[string]map[string][]CallRule) + denylists.Unlock() + + plugin := &bodyReadingMetricsPlugin{matchHost: "body-metrics"} + metrics.Register(plugin) + + var upstreamBody, upstreamPath, upstreamQuery string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read upstream body: %v", err) + } + upstreamBody = string(body) + upstreamPath = r.URL.Path + upstreamQuery = r.URL.RawQuery + w.WriteHeader(http.StatusAccepted) + })) + defer srv.Close() + + integ := Integration{Name: "body-metrics", Destination: srv.URL + "/base?static=1", InRateLimit: 1, OutRateLimit: 1} + if err := AddIntegration(&integ); err != nil { + t.Fatalf("failed to add integration: %v", err) + } + t.Cleanup(func() { + integ.inLimiter.Stop() + integ.outLimiter.Stop() + DeleteIntegration(integ.Name) + }) + + req := httptest.NewRequest(http.MethodPost, "http://body-metrics/submit?foo=bar", strings.NewReader("payload")) + req.Host = "body-metrics" + rr := httptest.NewRecorder() + + proxyHandler(rr, req) + + if rr.Code != http.StatusAccepted { + t.Fatalf("expected 202, got %d", rr.Code) + } + if plugin.err != nil { + t.Fatalf("metrics body read failed: %v", plugin.err) + } + if plugin.host != "body-metrics" || plugin.path != "/submit" || plugin.rawQuery != "foo=bar" { + t.Fatalf("metrics saw host/path/query %q %q %q", plugin.host, plugin.path, plugin.rawQuery) + } + if plugin.body != "payload" { + t.Fatalf("metrics saw body %q", plugin.body) + } + if upstreamBody != "payload" { + t.Fatalf("upstream saw body %q", upstreamBody) + } + if upstreamPath != "/base/submit" || upstreamQuery != "static=1&foo=bar" { + t.Fatalf("upstream saw path/query %q %q", upstreamPath, upstreamQuery) + } +} + func TestProxyHandlerHostCaseInsensitive(t *testing.T) { denylists.Lock() denylists.m = make(map[string]map[string][]CallRule) @@ -1133,6 +1218,82 @@ func TestProxyHandlerWildcardAddAuthSeesResolvedDestination(t *testing.T) { } } +func TestProxyHandlerStaticAddAuthSeesResolvedDestination(t *testing.T) { + denylists.Lock() + denylists.m = make(map[string]map[string][]CallRule) + denylists.Unlock() + + captureAddAuthCount = 0 + captureLastURL = "" + + upstream, err := url.Parse("http://backend.example.com/base?static=1") + if err != nil { + t.Fatalf("parse upstream: %v", err) + } + + integ := Integration{ + Name: "static-auth-url", + Destination: upstream.String(), + InRateLimit: 1, + OutRateLimit: 1, + OutgoingAuth: []AuthPluginConfig{{ + Type: "test_capture", + Params: map[string]interface{}{"expect_host": upstream.Host}, + }}, + } + if err := AddIntegration(&integ); err != nil { + t.Fatalf("failed to add integration: %v", err) + } + t.Cleanup(func() { + integ.inLimiter.Stop() + integ.outLimiter.Stop() + DeleteIntegration("static-auth-url") + }) + + called := false + integ.proxy.Transport = roundTripFunc(func(req *http.Request) (*http.Response, error) { + called = true + if req.URL.Host != upstream.Host { + t.Fatalf("unexpected upstream host: %s", req.URL.Host) + } + if req.Host != upstream.Host { + t.Fatalf("unexpected request host: %s", req.Host) + } + if req.URL.Path != "/base/test" { + t.Fatalf("unexpected upstream path: %s", req.URL.Path) + } + if req.URL.RawQuery != "static=1&foo=bar" { + t.Fatalf("unexpected query: %s", req.URL.RawQuery) + } + resp := &http.Response{ + StatusCode: http.StatusNoContent, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("")), + Request: req, + } + return resp, nil + }) + + req := httptest.NewRequest(http.MethodGet, "http://static-auth-url/test?foo=bar", nil) + req.Host = "static-auth-url" + rr := httptest.NewRecorder() + + proxyHandler(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("expected 204, got %d", rr.Code) + } + if !called { + t.Fatal("expected transport to be invoked") + } + if captureAddAuthCount != 1 { + t.Fatalf("expected AddAuth to be called once, got %d", captureAddAuthCount) + } + if captureLastURL != "http://backend.example.com/base/test?static=1&foo=bar" { + t.Fatalf("unexpected URL seen by AddAuth: %s", captureLastURL) + } +} + func TestProxyHandlerOutgoingAuthError(t *testing.T) { denylists.Lock() denylists.m = make(map[string]map[string][]CallRule)