Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/health/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func (c *Checker) makeRPCCall(ctx context.Context, url, method, chain, endpointI
Int("status_code", resp.StatusCode).
Str("body", string(bodyBytes[:n])).
Msg("RPC call failed: endpoint returned non-2xx status")
return nil, err
return nil, errors.New("HTTP " + strconv.Itoa(resp.StatusCode) + ": " + resp.Status)
}

// Read response
Expand Down
143 changes: 124 additions & 19 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ func (e *RateLimitError) Error() string {
return e.Message
}

// BadRequestError represents a 400 Bad Request error that may need special handling
type BadRequestError struct {
StatusCode int
Message string
Body []byte
Headers http.Header
}

func (e *BadRequestError) Error() string {
return e.Message
}

// HealthCheckerIface defines the interface for health checker operations needed by the server
type HealthCheckerIface interface {
IsReady() bool
Expand Down Expand Up @@ -408,6 +420,8 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc {
var triedEndpoints []string
retryCount := 0
publicAttemptCount := 0
var first400Error *BadRequestError
var first400EndpointID string

for retryCount < s.maxRetries && len(allEndpoints) > 0 {
select {
Expand Down Expand Up @@ -447,6 +461,33 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc {
tryCancel() // Always cancel the per-try context

if err != nil {
// Check if this is a 400 Bad Request error
if badReqErr, ok := err.(*BadRequestError); ok {
if first400Error == nil {
// First 400 response. Cache it and retry with next endpoint.
first400Error = badReqErr
first400EndpointID = endpoint.ID
log.Debug().Str("endpoint", endpoint.ID).Msg("First endpoint returned 400, will retry with next endpoint")
} else {
// Second endpoint also returned 400. This is the user's fault, pass it through.
log.Debug().Str("first_endpoint", first400EndpointID).Str("second_endpoint", endpoint.ID).Msg("Both endpoints returned 400, passing through to the user.")

// Copy response headers from the error
for key, values := range badReqErr.Headers {
// Skip CORS headers to avoid duplication (we set our own)
if strings.HasPrefix(key, "Access-Control-") {
continue
}
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(badReqErr.StatusCode)
w.Write(badReqErr.Body)
return
}
}

log.Debug().Str("error", helpers.RedactAPIKey(err.Error())).Str("endpoint", endpoint.ID).Str("endpoint_url", helpers.RedactAPIKey(endpoint.Endpoint.HTTPURL)).Int("retry", retryCount).Msg("HTTP request failed, will retry with different endpoint")
triedEndpoints = append(triedEndpoints, endpoint.ID)

Expand All @@ -460,12 +501,26 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc {
allEndpoints = remainingEndpoints
retryCount++

// If we got a 400 from first endpoint, continue retrying
if first400Error != nil && len(allEndpoints) > 0 && retryCount < s.maxRetries {
log.Debug().Str("chain", chain).Str("failed_endpoint", endpoint.ID).Int("public_attempt_count", publicAttemptCount).Int("remaining_endpoints", len(allEndpoints)).Int("retry", retryCount).Msg("Retrying HTTP request with different endpoint after 400")
continue
}

if len(allEndpoints) > 0 && retryCount < s.maxRetries {
log.Debug().Str("chain", chain).Str("failed_endpoint", endpoint.ID).Int("public_attempt_count", publicAttemptCount).Int("remaining_endpoints", len(allEndpoints)).Int("retry", retryCount).Msg("Retrying HTTP request with different endpoint")
continue
}
} else {
// Success. Increment the request count and track success for debouncing.
// Success. If we had a cached 400 error, mark that endpoint as unhealthy (confirmed it's actually unhealthy)
if first400Error != nil {
log.Debug().Str("endpoint", first400EndpointID).Msg("Second endpoint succeeded, marking first endpoint (that returned 400) as unhealthy")
s.markEndpointUnhealthyProtocol(chain, first400EndpointID, "http")
first400Error = nil
first400EndpointID = ""
}

// Increment the request count and track success for debouncing
log.Debug().Str("chain", chain).Str("endpoint", endpoint.ID).Str("endpoint_url", helpers.RedactAPIKey(endpoint.Endpoint.HTTPURL)).Int("retry", retryCount).Msg("HTTP request succeeded")
if err := s.valkeyClient.IncrementRequestCount(ctx, chain, endpoint.ID, "proxy_requests"); err != nil {
log.Error().Err(err).Str("endpoint", endpoint.ID).Msg("Failed to increment request count")
Expand All @@ -476,6 +531,24 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc {
}
}

// If we have a cached 400 error and ran out of endpoints, pass it through
if first400Error != nil {
log.Debug().Str("endpoint", first400EndpointID).Msg("Ran out of endpoints after first 400, passing through to user")
// Copy response headers from the error
for key, values := range first400Error.Headers {
// Skip CORS headers to avoid duplication (we set our own)
if strings.HasPrefix(key, "Access-Control-") {
continue
}
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(first400Error.StatusCode)
w.Write(first400Error.Body)
return
}

// If we get here, all retries failed
if retryCount >= s.maxRetries {
log.Error().Str("chain", chain).Strs("tried_endpoints", triedEndpoints).Int("max_retries", s.maxRetries).Msg("Max retries reached")
Expand Down Expand Up @@ -983,26 +1056,44 @@ func (s *Server) defaultForwardRequestWithBodyFunc(w http.ResponseWriter, ctx co
}
defer resp.Body.Close()

// Check for HTTP status codes that should trigger retries
if s.shouldRetry(resp.StatusCode) {
if chain, endpointID, found := s.findChainAndEndpointByURL(targetURL); found {
// Check for non-2xx status codes, all of them should trigger retries
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
chain, endpointID, found := s.findChainAndEndpointByURL(targetURL)

// Read response body for logging and potential passing through
bodyBytes, readErr := io.ReadAll(resp.Body)
if readErr != nil {
bodyBytes = []byte{}
}

// Mark endpoint as unhealthy for any non-2xx response
if found {
if resp.StatusCode == 429 {
// For 429 (Too Many Requests), use the rate limit handler
s.markEndpointUnhealthyProtocol(chain, endpointID, "http")
s.handleRateLimit(chain, endpointID, "http")
log.Debug().Str("url", helpers.RedactAPIKey(targetURL)).Int("status_code", resp.StatusCode).Msg("Endpoint returned 429 (Too Many Requests), handling rate limit")
} else {
// For 5xx errors, mark as unhealthy
s.markEndpointUnhealthyProtocol(chain, endpointID, "http")
log.Debug().Str("url", helpers.RedactAPIKey(targetURL)).Int("status_code", resp.StatusCode).Msg("Endpoint returned server error, marked unhealthy")
log.Debug().Str("url", helpers.RedactAPIKey(targetURL)).Int("status_code", resp.StatusCode).Msg("Endpoint returned non-2xx status, marked unhealthy")
}
}
// Drain to enable connection reuse
io.Copy(io.Discard, resp.Body)

// Special handling for 400 Bad Request
if resp.StatusCode == 400 {
return &BadRequestError{
StatusCode: resp.StatusCode,
Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, resp.Status),
Body: bodyBytes,
Headers: resp.Header,
}
}

// For all other non-2xx errors, return a generic error
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
}

// Copy response headers, but skip CORS headers since we set our own
// Copy response headers
for key, values := range resp.Header {
// Skip CORS headers to avoid duplication
if strings.HasPrefix(key, "Access-Control-") {
Expand All @@ -1023,8 +1114,8 @@ func (s *Server) defaultForwardRequestWithBodyFunc(w http.ResponseWriter, ctx co

// shouldRetry returns true if the HTTP status code should trigger a retry
func (s *Server) shouldRetry(statusCode int) bool {
// Retry on 5xx server errors and 429 Too Many Requests
return (statusCode >= 500 && statusCode < 600) || statusCode == 429
// Retry on all non-2xx status codes
return statusCode < 200 || statusCode >= 300
}

// proxyWebSocketCopy copies messages from src to dst
Expand Down Expand Up @@ -1056,16 +1147,30 @@ func (s *Server) defaultProxyWebSocket(w http.ResponseWriter, r *http.Request, b
// Connect to the backend WebSocket
backendConn, resp, err := websocket.DefaultDialer.Dial(backendURL, nil)
if err != nil {
// Check if this is a 429 rate limit response during handshake
if resp != nil && resp.StatusCode == 429 {
log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake rate limited")
return &RateLimitError{
StatusCode: resp.StatusCode,
Message: fmt.Sprintf("WebSocket handshake was rate-limited: HTTP %d", resp.StatusCode),
chain, endpointID, found := s.findChainAndEndpointByURL(backendURL)

// Check for non-2xx status codes during handshake
if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
if resp.StatusCode == 429 {
// For 429 (Too Many Requests), use the rate limit handler
if found {
s.markEndpointUnhealthyProtocol(chain, endpointID, "ws")
s.handleRateLimit(chain, endpointID, "ws")
}
log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake rate limited")
return &RateLimitError{
StatusCode: resp.StatusCode,
Message: fmt.Sprintf("WebSocket handshake was rate-limited: HTTP %d", resp.StatusCode),
}
}
}

if chain, endpointID, found := s.findChainAndEndpointByURL(backendURL); found {
// Mark endpoint as unhealthy for any other non-2xx status code
if found {
s.markEndpointUnhealthyProtocol(chain, endpointID, "ws")
log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake returned non-2xx status, marked unhealthy")
}
} else if found {
// Network/connection error - mark as unhealthy
s.markEndpointUnhealthyProtocol(chain, endpointID, "ws")
} else {
log.Warn().Str("url", helpers.RedactAPIKey(backendURL)).Msg("Failed to find chain and endpoint for failed WS endpoint URL, cannot mark it as unhealthy.")
Expand Down