diff --git a/.github/workflows/build-and-push-images.yaml b/.github/workflows/build-and-push-images.yaml index a4bfcc9..1aaaf3e 100644 --- a/.github/workflows/build-and-push-images.yaml +++ b/.github/workflows/build-and-push-images.yaml @@ -62,19 +62,6 @@ jobs: tags: ${{ steps.parse_tags.outputs.tags }} has_custom_tags: ${{ steps.parse_tags.outputs.has_custom_tags }} steps: - - name: Debug - Show received tags - run: | - echo "Event name: ${{ github.event_name }}" - echo "Received tag input: '${{ github.event.inputs.tag }}'" - if [ -n "${{ github.event.inputs.tag }}" ]; then - echo "Tags breakdown:" - TAGS="${{ github.event.inputs.tag }}" - IFS=',' read -ra TAG_ARRAY <<< "$TAGS" - for tag in "${TAG_ARRAY[@]}"; do - echo " - $tag" - done - fi - - name: Parse tags id: parse_tags shell: bash @@ -97,6 +84,7 @@ jobs: echo "has_custom_tags=true" >> $GITHUB_OUTPUT echo "Parsed tags successfully" else + echo "tags=" >> $GITHUB_OUTPUT echo "has_custom_tags=false" >> $GITHUB_OUTPUT echo "No custom tags provided (event: ${{ github.event_name }}, has tag input: $([ -n '${{ github.event.inputs.tag }}' ] && echo 'yes' || echo 'no'))" fi @@ -130,16 +118,6 @@ jobs: username: ${{ github.repository_owner }} password: ${{ secrets.AETHERLAY_GITHUB_TOKEN }} - - name: Debug - Show tags for health-checker - run: | - echo "Tags from parse_tags job:" - echo "${{ needs.parse_tags.outputs.tags }}" - echo "" - echo "Has custom tags: ${{ needs.parse_tags.outputs.has_custom_tags }}" - echo "Event name: ${{ github.event_name }}" - echo "Ref name: ${{ github.ref_name }}" - echo "Default branch: ${{ github.event.repository.default_branch }}" - - name: Docker meta for health-checker id: meta_hc uses: docker/metadata-action@v5 @@ -198,16 +176,6 @@ jobs: username: ${{ github.repository_owner }} password: ${{ secrets.AETHERLAY_GITHUB_TOKEN }} - - name: Debug - Show tags for load-balancer - run: | - echo "Tags from parse_tags job:" - echo "${{ needs.parse_tags.outputs.tags }}" - echo "" - echo "Has custom tags: ${{ needs.parse_tags.outputs.has_custom_tags }}" - echo "Event name: ${{ github.event_name }}" - echo "Ref name: ${{ github.ref_name }}" - echo "Default branch: ${{ github.event.repository.default_branch }}" - - name: Docker meta for load-balancer id: meta_lb uses: docker/metadata-action@v5 diff --git a/internal/health/checker.go b/internal/health/checker.go index 5edd672..3a3dc90 100644 --- a/internal/health/checker.go +++ b/internal/health/checker.go @@ -422,7 +422,7 @@ func (c *Checker) makeRPCCall(ctx context.Context, url, method, chain, endpointI Int("status_code", resp.StatusCode). Str("body", string(bodyBytes[:n])). Msg("RPC call failed: endpoint returned non-2xx status") - return nil, err + return nil, errors.New("HTTP " + strconv.Itoa(resp.StatusCode) + ": " + resp.Status) } // Read response diff --git a/internal/server/server.go b/internal/server/server.go index 7672622..82981f9 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -34,6 +34,18 @@ func (e *RateLimitError) Error() string { return e.Message } +// BadRequestError represents a 400 Bad Request error that may need special handling +type BadRequestError struct { + StatusCode int + Message string + Body []byte + Headers http.Header +} + +func (e *BadRequestError) Error() string { + return e.Message +} + // HealthCheckerIface defines the interface for health checker operations needed by the server type HealthCheckerIface interface { IsReady() bool @@ -408,6 +420,8 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc { var triedEndpoints []string retryCount := 0 publicAttemptCount := 0 + var first400Error *BadRequestError + var first400EndpointID string for retryCount < s.maxRetries && len(allEndpoints) > 0 { select { @@ -447,6 +461,35 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc { tryCancel() // Always cancel the per-try context if err != nil { + // Check if this is a 400 Bad Request error + if badReqErr, ok := err.(*BadRequestError); ok { + if first400Error == nil { + // First 400 response. Cache it and retry with next endpoint. + first400Error = badReqErr + first400EndpointID = endpoint.ID + log.Debug().Str("endpoint", endpoint.ID).Msg("First endpoint returned 400, will retry with next endpoint") + } else { + // Second endpoint also returned 400. This is the user's fault, pass it through. + log.Debug().Str("first_endpoint", first400EndpointID).Str("second_endpoint", endpoint.ID).Msg("Both endpoints returned 400, passing through to the user.") + + // Copy response headers from the error + for key, values := range badReqErr.Headers { + // Skip CORS headers to avoid duplication (we set our own) + if strings.HasPrefix(key, "Access-Control-") { + continue + } + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(badReqErr.StatusCode) + if _, err := w.Write(badReqErr.Body); err != nil { + log.Debug().Err(err).Msg("Failed to write 400 response body to client") + } + return + } + } + log.Debug().Str("error", helpers.RedactAPIKey(err.Error())).Str("endpoint", endpoint.ID).Str("endpoint_url", helpers.RedactAPIKey(endpoint.Endpoint.HTTPURL)).Int("retry", retryCount).Msg("HTTP request failed, will retry with different endpoint") triedEndpoints = append(triedEndpoints, endpoint.ID) @@ -460,12 +503,26 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc { allEndpoints = remainingEndpoints retryCount++ + // If we got a 400 from first endpoint, continue retrying + if first400Error != nil && len(allEndpoints) > 0 && retryCount < s.maxRetries { + log.Debug().Str("chain", chain).Str("failed_endpoint", endpoint.ID).Int("public_attempt_count", publicAttemptCount).Int("remaining_endpoints", len(allEndpoints)).Int("retry", retryCount).Msg("Retrying HTTP request with different endpoint after 400") + continue + } + if len(allEndpoints) > 0 && retryCount < s.maxRetries { log.Debug().Str("chain", chain).Str("failed_endpoint", endpoint.ID).Int("public_attempt_count", publicAttemptCount).Int("remaining_endpoints", len(allEndpoints)).Int("retry", retryCount).Msg("Retrying HTTP request with different endpoint") continue } } else { - // Success. Increment the request count and track success for debouncing. + // Success. If we had a cached 400 error, mark that endpoint as unhealthy (confirmed it's actually unhealthy) + if first400Error != nil { + log.Debug().Str("endpoint", first400EndpointID).Msg("Second endpoint succeeded, marking first endpoint (that returned 400) as unhealthy") + s.markEndpointUnhealthyProtocol(chain, first400EndpointID, "http") + first400Error = nil + first400EndpointID = "" + } + + // Increment the request count and track success for debouncing log.Debug().Str("chain", chain).Str("endpoint", endpoint.ID).Str("endpoint_url", helpers.RedactAPIKey(endpoint.Endpoint.HTTPURL)).Int("retry", retryCount).Msg("HTTP request succeeded") if err := s.valkeyClient.IncrementRequestCount(ctx, chain, endpoint.ID, "proxy_requests"); err != nil { log.Error().Err(err).Str("endpoint", endpoint.ID).Msg("Failed to increment request count") @@ -516,6 +573,8 @@ func (s *Server) handleRequestWS(chain string) http.HandlerFunc { var triedEndpoints []string retryCount := 0 publicAttemptCount := 0 + var first400Error *BadRequestError + var first400EndpointID string for retryCount < s.maxRetries && len(allEndpoints) > 0 { select { @@ -555,6 +614,53 @@ func (s *Server) handleRequestWS(chain string) http.HandlerFunc { tryCancel() // Always cancel the per-try context if err != nil { + // Check if this is a 400 Bad Request error + if badReqErr, ok := err.(*BadRequestError); ok { + if first400Error == nil { + // First 400 response. Cache it and retry with next endpoint. + first400Error = badReqErr + first400EndpointID = endpoint.ID + log.Debug().Str("endpoint", endpoint.ID).Msg("First WebSocket endpoint returned 400, will retry with next endpoint") + } else { + // Second endpoint also returned 400. This is the user's fault, pass it through. + log.Debug().Str("first_endpoint", first400EndpointID).Str("second_endpoint", endpoint.ID).Msg("Both WebSocket endpoints returned 400, passing through to the user.") + + // Copy response headers from the error + for key, values := range badReqErr.Headers { + // Skip CORS headers to avoid duplication (we set our own) + if strings.HasPrefix(key, "Access-Control-") { + continue + } + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(badReqErr.StatusCode) + if _, err := w.Write(badReqErr.Body); err != nil { + log.Debug().Err(err).Msg("Failed to write 400 response body to client") + } + return + } + + // Remove the failed endpoint from the list + var remainingEndpoints []EndpointWithID + for _, ep := range allEndpoints { + if ep.ID != endpoint.ID { + remainingEndpoints = append(remainingEndpoints, ep) + } + } + allEndpoints = remainingEndpoints + retryCount++ + + // If we got a 400 from first endpoint, continue retrying + if len(allEndpoints) > 0 && retryCount < s.maxRetries { + log.Debug().Str("chain", chain).Str("failed_endpoint", endpoint.ID).Int("public_attempt_count", publicAttemptCount).Int("remaining_endpoints", len(allEndpoints)).Int("retry", retryCount).Msg("Retrying WebSocket with different endpoint after 400") + continue + } + // If no more endpoints, break and handle cached 400 error + break + } + // Check if this is a 429 rate limiting error during handshake if _, ok := err.(*RateLimitError); ok { log.Debug().Str("chain", chain).Str("endpoint", endpoint.ID).Int("retry", retryCount).Msg("WebSocket handshake rate limited") @@ -617,7 +723,15 @@ func (s *Server) handleRequestWS(chain string) http.HandlerFunc { continue } } else { - // Success. Increment the request count and track success for debouncing. + // Success. If we had a cached 400 error, mark that endpoint as unhealthy (confirmed it's actually unhealthy) + if first400Error != nil { + log.Debug().Str("endpoint", first400EndpointID).Msg("Second WebSocket endpoint succeeded, marking first endpoint (that returned 400) as unhealthy") + s.markEndpointUnhealthyProtocol(chain, first400EndpointID, "ws") + first400Error = nil + first400EndpointID = "" + } + + // Increment the request count and track success for debouncing. log.Debug().Str("chain", chain).Str("endpoint", endpoint.ID).Str("endpoint_url", helpers.RedactAPIKey(endpoint.Endpoint.WSURL)).Int("retry", retryCount).Msg("WebSocket connection succeeded") if err := s.valkeyClient.IncrementRequestCount(ctx, chain, endpoint.ID, "proxy_requests"); err != nil { log.Error().Err(err).Str("endpoint", endpoint.ID).Msg("Failed to increment WebSocket request count") @@ -983,26 +1097,47 @@ func (s *Server) defaultForwardRequestWithBodyFunc(w http.ResponseWriter, ctx co } defer resp.Body.Close() - // Check for HTTP status codes that should trigger retries - if s.shouldRetry(resp.StatusCode) { - if chain, endpointID, found := s.findChainAndEndpointByURL(targetURL); found { + // Check for non-2xx status codes, all of them should trigger retries + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + chain, endpointID, found := s.findChainAndEndpointByURL(targetURL) + + // Special handling for 400 Bad Request - defer health decision to caller's retry logic. + // The HTTP handler caches the first 400 and retries with another endpoint. Only if the + // second endpoint succeeds will it mark the first endpoint as unhealthy. This prevents + // marking endpoints unhealthy due to client errors (bad requests) rather than endpoint failures. + // We return early here to avoid any health marking, allowing the caller to make the decision. + if resp.StatusCode == 400 { + // Read response body for logging and passing through + respBodyBytes, readErr := io.ReadAll(resp.Body) + if readErr != nil { + respBodyBytes = []byte{} + } + return &BadRequestError{ + StatusCode: resp.StatusCode, + Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, resp.Status), + Body: respBodyBytes, + Headers: resp.Header, + } + } + + // For all other non-2xx responses (400 already handled above), mark endpoint as unhealthy + if found { if resp.StatusCode == 429 { // For 429 (Too Many Requests), use the rate limit handler s.markEndpointUnhealthyProtocol(chain, endpointID, "http") s.handleRateLimit(chain, endpointID, "http") log.Debug().Str("url", helpers.RedactAPIKey(targetURL)).Int("status_code", resp.StatusCode).Msg("Endpoint returned 429 (Too Many Requests), handling rate limit") } else { - // For 5xx errors, mark as unhealthy s.markEndpointUnhealthyProtocol(chain, endpointID, "http") - log.Debug().Str("url", helpers.RedactAPIKey(targetURL)).Int("status_code", resp.StatusCode).Msg("Endpoint returned server error, marked unhealthy") + log.Debug().Str("url", helpers.RedactAPIKey(targetURL)).Int("status_code", resp.StatusCode).Msg("Endpoint returned non-2xx status, marked unhealthy") } } - // Drain to enable connection reuse - io.Copy(io.Discard, resp.Body) + + // Return error for all non-2xx responses (400 already handled above) return fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status) } - // Copy response headers, but skip CORS headers since we set our own + // Copy response headers for key, values := range resp.Header { // Skip CORS headers to avoid duplication if strings.HasPrefix(key, "Access-Control-") { @@ -1021,12 +1156,6 @@ func (s *Server) defaultForwardRequestWithBodyFunc(w http.ResponseWriter, ctx co return err } -// shouldRetry returns true if the HTTP status code should trigger a retry -func (s *Server) shouldRetry(statusCode int) bool { - // Retry on 5xx server errors and 429 Too Many Requests - return (statusCode >= 500 && statusCode < 600) || statusCode == 429 -} - // proxyWebSocketCopy copies messages from src to dst func proxyWebSocketCopy(src, dst *websocket.Conn) error { for { @@ -1056,16 +1185,49 @@ func (s *Server) defaultProxyWebSocket(w http.ResponseWriter, r *http.Request, b // Connect to the backend WebSocket backendConn, resp, err := websocket.DefaultDialer.Dial(backendURL, nil) if err != nil { - // Check if this is a 429 rate limit response during handshake - if resp != nil && resp.StatusCode == 429 { - log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake rate limited") - return &RateLimitError{ - StatusCode: resp.StatusCode, - Message: fmt.Sprintf("WebSocket handshake was rate-limited: HTTP %d", resp.StatusCode), + chain, endpointID, found := s.findChainAndEndpointByURL(backendURL) + + // Check for non-2xx status codes during handshake + if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) { + if resp.StatusCode == 429 { + // For 429 (Too Many Requests), mark unhealthy and return RateLimitError as signal + if found { + s.markEndpointUnhealthyProtocol(chain, endpointID, "ws") + } + log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake rate limited") + return &RateLimitError{ + StatusCode: resp.StatusCode, + Message: fmt.Sprintf("WebSocket handshake was rate-limited: HTTP %d", resp.StatusCode), + } } - } - if chain, endpointID, found := s.findChainAndEndpointByURL(backendURL); found { + // Special handling for 400 Bad Request - defer health decision to caller's retry logic. + // The WebSocket handler caches the first 400 and retries with another endpoint. Only if the + // second endpoint succeeds will it mark the first endpoint as unhealthy. This prevents + // marking endpoints unhealthy due to client errors (bad requests) rather than endpoint failures. + // We return early here to avoid any health marking, allowing the caller to make the decision. + if resp.StatusCode == 400 { + // Read response body for logging and passing through + var respBodyBytes []byte + if resp.Body != nil { + respBodyBytes, _ = io.ReadAll(resp.Body) + } + log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake returned 400 Bad Request") + return &BadRequestError{ + StatusCode: resp.StatusCode, + Message: fmt.Sprintf("WebSocket handshake was rejected: HTTP %d", resp.StatusCode), + Body: respBodyBytes, + Headers: resp.Header, + } + } + + // Mark endpoint as unhealthy for any other non-2xx status code (skip 400) + if found { + s.markEndpointUnhealthyProtocol(chain, endpointID, "ws") + log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake returned non-2xx status, marked unhealthy") + } + } else if found { + // Network/connection error - mark as unhealthy s.markEndpointUnhealthyProtocol(chain, endpointID, "ws") } else { log.Warn().Str("url", helpers.RedactAPIKey(backendURL)).Msg("Failed to find chain and endpoint for failed WS endpoint URL, cannot mark it as unhealthy.") diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 12ff1ea..6fdb687 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -561,38 +561,6 @@ func TestMarkEndpointUnhealthy_WS(t *testing.T) { } } -// TestShouldRetry tests the HTTP status code retry logic. -func TestShouldRetry(t *testing.T) { - cfg := &config.Config{} - valkeyClient := store.NewMockValkeyClient() - server := NewServer(cfg, valkeyClient, createTestConfig()) - - tests := []struct { - statusCode int - shouldRetry bool - description string - }{ - {200, false, "2xx success should NOT retry"}, - {201, false, "2xx success should NOT retry"}, - {400, false, "4xx client error should NOT retry"}, - {404, false, "4xx client error should NOT retry"}, - {429, true, "429 Too Many Requests should retry"}, - {500, true, "5xx server error should retry"}, - {504, true, "5xx server error should retry"}, - {599, true, "5xx server error should retry"}, - {600, false, "6xx should NOT retry"}, - } - - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - result := server.shouldRetry(test.statusCode) - if result != test.shouldRetry { - t.Errorf("shouldRetry(%d) = %v, expected %v", test.statusCode, result, test.shouldRetry) - } - }) - } -} - func TestHandleRateLimit(t *testing.T) { // Create a test config cfg := &config.Config{