Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions .github/workflows/build-and-push-images.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,6 @@ jobs:
tags: ${{ steps.parse_tags.outputs.tags }}
has_custom_tags: ${{ steps.parse_tags.outputs.has_custom_tags }}
steps:
- name: Debug - Show received tags
run: |
echo "Event name: ${{ github.event_name }}"
echo "Received tag input: '${{ github.event.inputs.tag }}'"
if [ -n "${{ github.event.inputs.tag }}" ]; then
echo "Tags breakdown:"
TAGS="${{ github.event.inputs.tag }}"
IFS=',' read -ra TAG_ARRAY <<< "$TAGS"
for tag in "${TAG_ARRAY[@]}"; do
echo " - $tag"
done
fi

- name: Parse tags
id: parse_tags
shell: bash
Expand All @@ -97,6 +84,7 @@ jobs:
echo "has_custom_tags=true" >> $GITHUB_OUTPUT
echo "Parsed tags successfully"
else
echo "tags=" >> $GITHUB_OUTPUT
echo "has_custom_tags=false" >> $GITHUB_OUTPUT
echo "No custom tags provided (event: ${{ github.event_name }}, has tag input: $([ -n '${{ github.event.inputs.tag }}' ] && echo 'yes' || echo 'no'))"
fi
Expand Down Expand Up @@ -130,16 +118,6 @@ jobs:
username: ${{ github.repository_owner }}
password: ${{ secrets.AETHERLAY_GITHUB_TOKEN }}

- name: Debug - Show tags for health-checker
run: |
echo "Tags from parse_tags job:"
echo "${{ needs.parse_tags.outputs.tags }}"
echo ""
echo "Has custom tags: ${{ needs.parse_tags.outputs.has_custom_tags }}"
echo "Event name: ${{ github.event_name }}"
echo "Ref name: ${{ github.ref_name }}"
echo "Default branch: ${{ github.event.repository.default_branch }}"

- name: Docker meta for health-checker
id: meta_hc
uses: docker/metadata-action@v5
Expand Down Expand Up @@ -198,16 +176,6 @@ jobs:
username: ${{ github.repository_owner }}
password: ${{ secrets.AETHERLAY_GITHUB_TOKEN }}

- name: Debug - Show tags for load-balancer
run: |
echo "Tags from parse_tags job:"
echo "${{ needs.parse_tags.outputs.tags }}"
echo ""
echo "Has custom tags: ${{ needs.parse_tags.outputs.has_custom_tags }}"
echo "Event name: ${{ github.event_name }}"
echo "Ref name: ${{ github.ref_name }}"
echo "Default branch: ${{ github.event.repository.default_branch }}"

- name: Docker meta for load-balancer
id: meta_lb
uses: docker/metadata-action@v5
Expand Down
2 changes: 1 addition & 1 deletion internal/health/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func (c *Checker) makeRPCCall(ctx context.Context, url, method, chain, endpointI
Int("status_code", resp.StatusCode).
Str("body", string(bodyBytes[:n])).
Msg("RPC call failed: endpoint returned non-2xx status")
return nil, err
return nil, errors.New("HTTP " + strconv.Itoa(resp.StatusCode) + ": " + resp.Status)
}

// Read response
Expand Down
210 changes: 186 additions & 24 deletions internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ func (e *RateLimitError) Error() string {
return e.Message
}

// BadRequestError represents a 400 Bad Request error that may need special handling
type BadRequestError struct {
StatusCode int
Message string
Body []byte
Headers http.Header
}

func (e *BadRequestError) Error() string {
return e.Message
}

// HealthCheckerIface defines the interface for health checker operations needed by the server
type HealthCheckerIface interface {
IsReady() bool
Expand Down Expand Up @@ -408,6 +420,8 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc {
var triedEndpoints []string
retryCount := 0
publicAttemptCount := 0
var first400Error *BadRequestError
var first400EndpointID string

for retryCount < s.maxRetries && len(allEndpoints) > 0 {
select {
Expand Down Expand Up @@ -447,6 +461,35 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc {
tryCancel() // Always cancel the per-try context

if err != nil {
// Check if this is a 400 Bad Request error
if badReqErr, ok := err.(*BadRequestError); ok {
if first400Error == nil {
// First 400 response. Cache it and retry with next endpoint.
first400Error = badReqErr
first400EndpointID = endpoint.ID
log.Debug().Str("endpoint", endpoint.ID).Msg("First endpoint returned 400, will retry with next endpoint")
} else {
// Second endpoint also returned 400. This is the user's fault, pass it through.
log.Debug().Str("first_endpoint", first400EndpointID).Str("second_endpoint", endpoint.ID).Msg("Both endpoints returned 400, passing through to the user.")

// Copy response headers from the error
for key, values := range badReqErr.Headers {
// Skip CORS headers to avoid duplication (we set our own)
if strings.HasPrefix(key, "Access-Control-") {
continue
}
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(badReqErr.StatusCode)
if _, err := w.Write(badReqErr.Body); err != nil {
log.Debug().Err(err).Msg("Failed to write 400 response body to client")
}
return
}
}

log.Debug().Str("error", helpers.RedactAPIKey(err.Error())).Str("endpoint", endpoint.ID).Str("endpoint_url", helpers.RedactAPIKey(endpoint.Endpoint.HTTPURL)).Int("retry", retryCount).Msg("HTTP request failed, will retry with different endpoint")
triedEndpoints = append(triedEndpoints, endpoint.ID)

Expand All @@ -460,12 +503,26 @@ func (s *Server) handleRequestHTTP(chain string) http.HandlerFunc {
allEndpoints = remainingEndpoints
retryCount++

// If we got a 400 from first endpoint, continue retrying
if first400Error != nil && len(allEndpoints) > 0 && retryCount < s.maxRetries {
log.Debug().Str("chain", chain).Str("failed_endpoint", endpoint.ID).Int("public_attempt_count", publicAttemptCount).Int("remaining_endpoints", len(allEndpoints)).Int("retry", retryCount).Msg("Retrying HTTP request with different endpoint after 400")
continue
}

if len(allEndpoints) > 0 && retryCount < s.maxRetries {
log.Debug().Str("chain", chain).Str("failed_endpoint", endpoint.ID).Int("public_attempt_count", publicAttemptCount).Int("remaining_endpoints", len(allEndpoints)).Int("retry", retryCount).Msg("Retrying HTTP request with different endpoint")
continue
}
} else {
// Success. Increment the request count and track success for debouncing.
// Success. If we had a cached 400 error, mark that endpoint as unhealthy (confirmed it's actually unhealthy)
if first400Error != nil {
log.Debug().Str("endpoint", first400EndpointID).Msg("Second endpoint succeeded, marking first endpoint (that returned 400) as unhealthy")
s.markEndpointUnhealthyProtocol(chain, first400EndpointID, "http")
first400Error = nil
first400EndpointID = ""
}

// Increment the request count and track success for debouncing
log.Debug().Str("chain", chain).Str("endpoint", endpoint.ID).Str("endpoint_url", helpers.RedactAPIKey(endpoint.Endpoint.HTTPURL)).Int("retry", retryCount).Msg("HTTP request succeeded")
if err := s.valkeyClient.IncrementRequestCount(ctx, chain, endpoint.ID, "proxy_requests"); err != nil {
log.Error().Err(err).Str("endpoint", endpoint.ID).Msg("Failed to increment request count")
Expand Down Expand Up @@ -516,6 +573,8 @@ func (s *Server) handleRequestWS(chain string) http.HandlerFunc {
var triedEndpoints []string
retryCount := 0
publicAttemptCount := 0
var first400Error *BadRequestError
var first400EndpointID string
Comment on lines +576 to +577
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

wc -l internal/server/server.go

Repository: project-aethermesh/aetherlay

Length of output: 102


🏁 Script executed:

sed -n '570,590p' internal/server/server.go

Repository: project-aethermesh/aetherlay

Length of output: 690


🏁 Script executed:

sed -n '1175,1235p' internal/server/server.go

Repository: project-aethermesh/aetherlay

Length of output: 2855


🏁 Script executed:

sed -n '617,670p' internal/server/server.go

Repository: project-aethermesh/aetherlay

Length of output: 2536


🏁 Script executed:

sed -n '720,740p' internal/server/server.go

Repository: project-aethermesh/aetherlay

Length of output: 1451


🌐 Web query:

Go http.ResponseWriter after websocket.Upgrader.Upgrade - is connection hijacked, can you call WriteHeader or Write after upgrade?

💡 Result:

Short answer: Yes — Upgrade hijacks the underlying connection, so you must not call ResponseWriter.WriteHeader or Write after Upgrade; use the returned *websocket.Conn instead. If you do call Write/WriteHeader after the hijack you’ll get errors like "http: response.Write on hijacked connection". Also note hijack is only supported for HTTP/1.x (HTTP/2 ResponseWriters don’t implement Hijacker). [1][2]

Sources:

  • gorilla/websocket Upgrade docs (Upgrade calls hijack as part of the handshake). [1]
  • Explanation and common error ("response.Write on hijacked connection") on Stack Overflow. [2]

🌐 Web query:

gorilla/websocket Dial error - do I need to close resp.Body when Dial fails?

💡 Result:

No — you do not need to close resp.Body when gorilla/websocket Dial returns an error. The Dialer returns a non-nil *http.Response on handshake failure for caller inspection, but the response body may be incomplete and “does not need to be closed by the application.” [1][2]

Sources:

  • gorilla/websocket package docs. [1]
  • gorilla/websocket client.go (source). [2]

🏁 Script executed:

sed -n '560,580p' internal/server/server.go

Repository: project-aethermesh/aetherlay

Length of output: 877


🏁 Script executed:

sed -n '1170,1185p' internal/server/server.go

Repository: project-aethermesh/aetherlay

Length of output: 617


🏁 Script executed:

# Get the full function signature and structure to understand flow
grep -B 30 "var first400Error" internal/server/server.go | head -40

Repository: project-aethermesh/aetherlay

Length of output: 1515


🏁 Script executed:

# Understand the exact function this code is in
sed -n '540,580p' internal/server/server.go

Repository: project-aethermesh/aetherlay

Length of output: 1670


WebSocket handshake error handling: Upgrade happens too early in the retry loop, making HTTP response writes invalid.

The WS path attempts to mirror HTTP 400 logic, but there's a critical architectural issue:

  1. Premature Upgrade prevents HTTP error responses in the retry loop.

    • defaultProxyWebSocket upgrades the client connection immediately via upgrader.Upgrade(w, r, nil) (Line 1178). The connection is hijacked and http.ResponseWriter becomes invalid for further HTTP operations.
    • However, handleRequestWS calls defaultProxyWebSocket on each retry iteration (lines 576+). If the backend handshake fails with a 400 error on the first endpoint, the function returns a BadRequestError to handleRequestWS.
    • handleRequestWS then attempts to write HTTP headers and body via w.Header().Add, w.WriteHeader, and w.Write (lines 638–641) to pass through the error to the client. But since the connection was already hijacked by the Upgrade call in the previous defaultProxyWebSocket invocation, these writes will fail with "response.Write on hijacked connection" errors.
    • To fix this, the backend handshake should be completed (and its status code checked) before calling Upgrade. Only call Upgrade once you confirm the backend will accept the connection, or restructure so defaultProxyWebSocket does not upgrade on retry attempts that fail.
  2. Response body closure on handshake error (lines 1203–1207) is not required.

    • Per gorilla/websocket documentation, resp.Body does not need to be closed by the application when Dial returns an error. The current code reads the body without closing it, which is acceptable per the library's contract and does not cause resource leaks.

Given the hijack issue, the WS 400 pass-through path will not work as written. Consider redesigning the flow to probe or perform the backend handshake before upgrading the client connection.

🤖 Prompt for AI Agents
internal/server/server.go around lines 576-577: the WebSocket upgrade is being
performed inside defaultProxyWebSocket on each retry which hijacks the client
connection and makes subsequent HTTP error writes invalid; change the flow so
the backend handshake (dialing the backend and checking its HTTP status) is
completed and its status code examined before calling upgrader.Upgrade on the
client side—either by refactoring defaultProxyWebSocket to separate "probe
backend handshake without upgrading client" and only upgrade after a successful
backend response, or by moving the upgrade out of the retry loop so retries only
probe the backend; also remove or leave unchanged any explicit resp.Body.Close
on dial errors per gorilla/websocket contract (no body close required) and
ensure error responses are written to the client only when the connection has
not been hijacked.


for retryCount < s.maxRetries && len(allEndpoints) > 0 {
select {
Expand Down Expand Up @@ -555,6 +614,53 @@ func (s *Server) handleRequestWS(chain string) http.HandlerFunc {
tryCancel() // Always cancel the per-try context

if err != nil {
// Check if this is a 400 Bad Request error
if badReqErr, ok := err.(*BadRequestError); ok {
if first400Error == nil {
// First 400 response. Cache it and retry with next endpoint.
first400Error = badReqErr
first400EndpointID = endpoint.ID
log.Debug().Str("endpoint", endpoint.ID).Msg("First WebSocket endpoint returned 400, will retry with next endpoint")
} else {
// Second endpoint also returned 400. This is the user's fault, pass it through.
log.Debug().Str("first_endpoint", first400EndpointID).Str("second_endpoint", endpoint.ID).Msg("Both WebSocket endpoints returned 400, passing through to the user.")

// Copy response headers from the error
for key, values := range badReqErr.Headers {
// Skip CORS headers to avoid duplication (we set our own)
if strings.HasPrefix(key, "Access-Control-") {
continue
}
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(badReqErr.StatusCode)
if _, err := w.Write(badReqErr.Body); err != nil {
log.Debug().Err(err).Msg("Failed to write 400 response body to client")
}
return
}

// Remove the failed endpoint from the list
var remainingEndpoints []EndpointWithID
for _, ep := range allEndpoints {
if ep.ID != endpoint.ID {
remainingEndpoints = append(remainingEndpoints, ep)
}
}
allEndpoints = remainingEndpoints
retryCount++

// If we got a 400 from first endpoint, continue retrying
if len(allEndpoints) > 0 && retryCount < s.maxRetries {
log.Debug().Str("chain", chain).Str("failed_endpoint", endpoint.ID).Int("public_attempt_count", publicAttemptCount).Int("remaining_endpoints", len(allEndpoints)).Int("retry", retryCount).Msg("Retrying WebSocket with different endpoint after 400")
continue
}
// If no more endpoints, break and handle cached 400 error
break
}

// Check if this is a 429 rate limiting error during handshake
if _, ok := err.(*RateLimitError); ok {
log.Debug().Str("chain", chain).Str("endpoint", endpoint.ID).Int("retry", retryCount).Msg("WebSocket handshake rate limited")
Expand Down Expand Up @@ -617,7 +723,15 @@ func (s *Server) handleRequestWS(chain string) http.HandlerFunc {
continue
}
} else {
// Success. Increment the request count and track success for debouncing.
// Success. If we had a cached 400 error, mark that endpoint as unhealthy (confirmed it's actually unhealthy)
if first400Error != nil {
log.Debug().Str("endpoint", first400EndpointID).Msg("Second WebSocket endpoint succeeded, marking first endpoint (that returned 400) as unhealthy")
s.markEndpointUnhealthyProtocol(chain, first400EndpointID, "ws")
first400Error = nil
first400EndpointID = ""
}

// Increment the request count and track success for debouncing.
log.Debug().Str("chain", chain).Str("endpoint", endpoint.ID).Str("endpoint_url", helpers.RedactAPIKey(endpoint.Endpoint.WSURL)).Int("retry", retryCount).Msg("WebSocket connection succeeded")
if err := s.valkeyClient.IncrementRequestCount(ctx, chain, endpoint.ID, "proxy_requests"); err != nil {
log.Error().Err(err).Str("endpoint", endpoint.ID).Msg("Failed to increment WebSocket request count")
Expand Down Expand Up @@ -983,26 +1097,47 @@ func (s *Server) defaultForwardRequestWithBodyFunc(w http.ResponseWriter, ctx co
}
defer resp.Body.Close()

// Check for HTTP status codes that should trigger retries
if s.shouldRetry(resp.StatusCode) {
if chain, endpointID, found := s.findChainAndEndpointByURL(targetURL); found {
// Check for non-2xx status codes, all of them should trigger retries
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
chain, endpointID, found := s.findChainAndEndpointByURL(targetURL)

// Special handling for 400 Bad Request - defer health decision to caller's retry logic.
// The HTTP handler caches the first 400 and retries with another endpoint. Only if the
// second endpoint succeeds will it mark the first endpoint as unhealthy. This prevents
// marking endpoints unhealthy due to client errors (bad requests) rather than endpoint failures.
// We return early here to avoid any health marking, allowing the caller to make the decision.
if resp.StatusCode == 400 {
// Read response body for logging and passing through
respBodyBytes, readErr := io.ReadAll(resp.Body)
if readErr != nil {
respBodyBytes = []byte{}
}
return &BadRequestError{
StatusCode: resp.StatusCode,
Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, resp.Status),
Body: respBodyBytes,
Headers: resp.Header,
}
}

// For all other non-2xx responses (400 already handled above), mark endpoint as unhealthy
if found {
if resp.StatusCode == 429 {
// For 429 (Too Many Requests), use the rate limit handler
s.markEndpointUnhealthyProtocol(chain, endpointID, "http")
s.handleRateLimit(chain, endpointID, "http")
log.Debug().Str("url", helpers.RedactAPIKey(targetURL)).Int("status_code", resp.StatusCode).Msg("Endpoint returned 429 (Too Many Requests), handling rate limit")
} else {
// For 5xx errors, mark as unhealthy
s.markEndpointUnhealthyProtocol(chain, endpointID, "http")
log.Debug().Str("url", helpers.RedactAPIKey(targetURL)).Int("status_code", resp.StatusCode).Msg("Endpoint returned server error, marked unhealthy")
log.Debug().Str("url", helpers.RedactAPIKey(targetURL)).Int("status_code", resp.StatusCode).Msg("Endpoint returned non-2xx status, marked unhealthy")
}
}
// Drain to enable connection reuse
io.Copy(io.Discard, resp.Body)

// Return error for all non-2xx responses (400 already handled above)
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
}

// Copy response headers, but skip CORS headers since we set our own
// Copy response headers
for key, values := range resp.Header {
// Skip CORS headers to avoid duplication
if strings.HasPrefix(key, "Access-Control-") {
Expand All @@ -1021,12 +1156,6 @@ func (s *Server) defaultForwardRequestWithBodyFunc(w http.ResponseWriter, ctx co
return err
}

// shouldRetry returns true if the HTTP status code should trigger a retry
func (s *Server) shouldRetry(statusCode int) bool {
// Retry on 5xx server errors and 429 Too Many Requests
return (statusCode >= 500 && statusCode < 600) || statusCode == 429
}

// proxyWebSocketCopy copies messages from src to dst
func proxyWebSocketCopy(src, dst *websocket.Conn) error {
for {
Expand Down Expand Up @@ -1056,16 +1185,49 @@ func (s *Server) defaultProxyWebSocket(w http.ResponseWriter, r *http.Request, b
// Connect to the backend WebSocket
backendConn, resp, err := websocket.DefaultDialer.Dial(backendURL, nil)
if err != nil {
// Check if this is a 429 rate limit response during handshake
if resp != nil && resp.StatusCode == 429 {
log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake rate limited")
return &RateLimitError{
StatusCode: resp.StatusCode,
Message: fmt.Sprintf("WebSocket handshake was rate-limited: HTTP %d", resp.StatusCode),
chain, endpointID, found := s.findChainAndEndpointByURL(backendURL)

// Check for non-2xx status codes during handshake
if resp != nil && (resp.StatusCode < 200 || resp.StatusCode >= 300) {
if resp.StatusCode == 429 {
// For 429 (Too Many Requests), mark unhealthy and return RateLimitError as signal
if found {
s.markEndpointUnhealthyProtocol(chain, endpointID, "ws")
}
log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake rate limited")
return &RateLimitError{
StatusCode: resp.StatusCode,
Message: fmt.Sprintf("WebSocket handshake was rate-limited: HTTP %d", resp.StatusCode),
}
}
}

if chain, endpointID, found := s.findChainAndEndpointByURL(backendURL); found {
// Special handling for 400 Bad Request - defer health decision to caller's retry logic.
// The WebSocket handler caches the first 400 and retries with another endpoint. Only if the
// second endpoint succeeds will it mark the first endpoint as unhealthy. This prevents
// marking endpoints unhealthy due to client errors (bad requests) rather than endpoint failures.
// We return early here to avoid any health marking, allowing the caller to make the decision.
if resp.StatusCode == 400 {
// Read response body for logging and passing through
var respBodyBytes []byte
if resp.Body != nil {
respBodyBytes, _ = io.ReadAll(resp.Body)
}
log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake returned 400 Bad Request")
return &BadRequestError{
StatusCode: resp.StatusCode,
Message: fmt.Sprintf("WebSocket handshake was rejected: HTTP %d", resp.StatusCode),
Body: respBodyBytes,
Headers: resp.Header,
}
}

// Mark endpoint as unhealthy for any other non-2xx status code (skip 400)
if found {
s.markEndpointUnhealthyProtocol(chain, endpointID, "ws")
log.Debug().Str("url", helpers.RedactAPIKey(backendURL)).Int("status_code", resp.StatusCode).Msg("WebSocket handshake returned non-2xx status, marked unhealthy")
}
} else if found {
// Network/connection error - mark as unhealthy
s.markEndpointUnhealthyProtocol(chain, endpointID, "ws")
} else {
log.Warn().Str("url", helpers.RedactAPIKey(backendURL)).Msg("Failed to find chain and endpoint for failed WS endpoint URL, cannot mark it as unhealthy.")
Expand Down
Loading