Skip to content

Commit 4124b20

Browse files
authored
[agent] Cleanup unused, shimmed websocket connections after a configurable timeout.
This PR adds a configurable value for how long shimmed websocket connections can remain idle before they are garbage collected. Previously, the underlying connection from the proxy agent to the backend server was only cleaned up when there was either an error while sending or receiving messages, or if the agent received an explicit close message. When a client of shim protocol stopped communicating without first sending a close message, this caused the underlying connection between the proxy agent and backend server to be retained indefinitely. That, in turn, can result in a resource leak within the agent (and possibly in the backend server as well). To account for that scenario, we needed to find some way to identify that a client of the shim protocol had stopped communicating, so that we could clean up those lost connections. The shim protocol does not include any sort of ping or pong messages, but it does have polling requests sent by the client which can be used as a dead-man switch; once a sufficient amount of time has elapsed since the last such request, we can assume that the client is gone and the underlying connection can be removed. This PR adds a configuration option to the agent to specify what that amount of time should be (with a default of 1 hour), and implements the agent-side cleanup logic to identify and close expired connections.
2 parents d9a3e67 + 5588f04 commit 4124b20

6 files changed

Lines changed: 182 additions & 27 deletions

File tree

agent/agent.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ var (
7373
injectBanner = flag.String("inject-banner", "", "HTML snippet to inject in served webpages")
7474
bannerHeight = flag.String("banner-height", "40px", "Height of the injected banner. This is ignored if no banner is set.")
7575
shimWebsockets = flag.Bool("shim-websockets", false, "Whether or not to replace websockets with a shim")
76+
websocketShimTimeout = flag.Duration("websocket-shim-timeout", 60*time.Minute, "Timeout for websocket shim connections to expire due to inactivity.")
7677
shimPath = flag.String("shim-path", "", "Path under which to handle websocket shim requests")
7778
healthCheckPath = flag.String("health-check-path", "/", "Path on backend host to issue health checks against. Defaults to the root.")
7879
healthCheckFreq = flag.Int("health-check-interval-seconds", 0, "Wait time in seconds between health checks. Set to zero to disable health checks. Checks disabled by default.")
@@ -126,7 +127,8 @@ func hostProxy(ctx context.Context, host, shimPath string, injectShimCode, force
126127
// restricted to a path prefix not equal to "/" will fail for websocket open requests. Passing in the
127128
// sessionHandler twice allows the websocket handler to ensure that cookies are applied based on the
128129
// correct, restored path.
129-
h, err = websockets.Proxy(ctx, h, host, shimPath, *rewriteWebsocketHost, *enableWebsocketsInjection, sessionLRU.SessionHandler, metricHandler)
130+
h, err = websockets.Proxy(ctx, h, host, shimPath, *rewriteWebsocketHost, *enableWebsocketsInjection, sessionLRU.SessionHandler,
131+
metricHandler, *websocketShimTimeout)
130132
if injectShimCode {
131133
shimFunc, err := websockets.ShimBody(shimPath)
132134
if err != nil {

agent/websockets/connection.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ limitations under the License.
1717
package websockets
1818

1919
import (
20+
"context"
2021
"encoding/base64"
2122
"encoding/json"
2223
"errors"
2324
"fmt"
2425
"log"
2526
"net/http"
27+
"sync"
2628
"time"
2729

28-
"context"
29-
3030
"github.com/gorilla/websocket"
3131
)
3232

@@ -57,12 +57,14 @@ func (m *message) Serialize(version int) interface{} {
5757
// and encapsulates it in an API that is a little more amenable to how the server side
5858
// of our websocket shim is implemented.
5959
type Connection struct {
60-
done func() <-chan struct{}
61-
cancel context.CancelFunc
62-
clientMessages chan *message
63-
serverMessages chan *message
64-
protocolVersion int
65-
subprotocol string
60+
done func() <-chan struct{}
61+
cancel context.CancelFunc
62+
clientMessages chan *message
63+
serverMessages chan *message
64+
protocolVersion int
65+
subprotocol string
66+
mu sync.Mutex
67+
lastActivityTime time.Time
6668
}
6769

6870
// This map defines the set of headers that should be stripped from the WS request, as they
@@ -87,6 +89,20 @@ func stripWSHeader(header http.Header) http.Header {
8789
return result
8890
}
8991

92+
// updateActivity updates the last activity timestamp.
93+
func (conn *Connection) updateActivity() {
94+
conn.mu.Lock()
95+
defer conn.mu.Unlock()
96+
conn.lastActivityTime = time.Now()
97+
}
98+
99+
// lastActivity returns the last activity timestamp.
100+
func (conn *Connection) lastActivity() time.Time {
101+
conn.mu.Lock()
102+
defer conn.mu.Unlock()
103+
return conn.lastActivityTime
104+
}
105+
90106
// NewConnection creates and returns a new Connection.
91107
func NewConnection(ctx context.Context, targetURL string, header http.Header, errCallback func(err error)) (*Connection, error) {
92108
ctx, cancel := context.WithCancel(ctx)
@@ -162,11 +178,12 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er
162178
}
163179
}()
164180
return &Connection{
165-
done: ctx.Done,
166-
cancel: cancel,
167-
clientMessages: clientMessages,
168-
serverMessages: serverMessages,
169-
subprotocol: serverConn.Subprotocol(),
181+
done: ctx.Done,
182+
cancel: cancel,
183+
clientMessages: clientMessages,
184+
serverMessages: serverMessages,
185+
subprotocol: serverConn.Subprotocol(),
186+
lastActivityTime: time.Now(),
170187
}, nil
171188
}
172189

@@ -184,6 +201,7 @@ func (conn *Connection) Close() {
184201
//
185202
// The returned error value is non-nill if the connection has been closed.
186203
func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool, injectedHeaders map[string]string) error {
204+
conn.updateActivity()
187205
var clientMessage *message
188206
if textMsg, ok := msg.(string); ok {
189207
clientMessage = &message{
@@ -236,7 +254,9 @@ func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool
236254
//
237255
// The returned []string value is nil if the error is non-nil, or if the method
238256
// times out while waiting for a server message.
239-
func (conn *Connection) ReadServerMessages() ([]interface{}, error) {
257+
func (conn *Connection) ReadServerMessages(readTimeout time.Duration) ([]interface{}, error) {
258+
conn.updateActivity()
259+
defer conn.updateActivity()
240260
var msgs []interface{}
241261
select {
242262
case serverMsg, ok := <-conn.serverMessages:
@@ -257,7 +277,7 @@ func (conn *Connection) ReadServerMessages() ([]interface{}, error) {
257277
return msgs, nil
258278
}
259279
}
260-
case <-time.After(time.Second * 20):
280+
case <-time.After(readTimeout):
261281
return nil, nil
262282
}
263283
}

agent/websockets/shim.go

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package websockets
1818

1919
import (
2020
"bytes"
21+
"context"
2122
"encoding/json"
2223
"fmt"
2324
"io"
@@ -31,8 +32,8 @@ import (
3132
"sync"
3233
"sync/atomic"
3334
"text/template"
35+
"time"
3436

35-
"context"
3637
"github.com/google/inverting-proxy/agent/metrics"
3738
)
3839

@@ -320,9 +321,33 @@ func (c *connectionErrorHandler) ReportError(err error) {
320321
}
321322
}
322323

323-
func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler) http.Handler {
324+
func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler, timeout time.Duration) http.Handler {
324325
var connections sync.Map
325326
var sessionCount uint64
327+
328+
// Background goroutine to clean up inactive websocket shim connections.
329+
go func() {
330+
ticker := time.NewTicker(min(timeout, 30*time.Second))
331+
defer ticker.Stop()
332+
for {
333+
select {
334+
case <-ctx.Done():
335+
return
336+
case <-ticker.C:
337+
connections.Range(func(key, value any) bool {
338+
sessionID := key.(string)
339+
conn := value.(*Connection)
340+
if time.Since(conn.lastActivity()) > timeout {
341+
log.Printf("Closing inactive websocket shim session %q after timeout", sessionID)
342+
conn.Close()
343+
connections.Delete(sessionID)
344+
}
345+
return true // Continue iteration
346+
})
347+
}
348+
}
349+
}()
350+
326351
mux := http.NewServeMux()
327352
errorHandler := &connectionErrorHandler{}
328353
openWebsocketHandler := openWebsocketWrapper(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -351,9 +376,9 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
351376
}
352377
}
353378
resp := &sessionMessage{
354-
ID: sessionID,
355-
Message: targetURL.String(),
356-
Version: conn.protocolVersion,
379+
ID: sessionID,
380+
Message: targetURL.String(),
381+
Version: conn.protocolVersion,
357382
Subprotocol: conn.Subprotocol(),
358383
}
359384
respBytes, err := json.Marshal(resp)
@@ -512,7 +537,7 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
512537
metricHandler.WriteResponseCodeMetric(statusCode)
513538
return
514539
}
515-
serverMsgs, err := conn.ReadServerMessages()
540+
serverMsgs, err := conn.ReadServerMessages(min(20*time.Second, timeout/2))
516541
if err != nil {
517542
statusCode := http.StatusBadRequest
518543
errorMessage := fmt.Sprintf("attempt to read data from a closed session: %q", msg.ID)
@@ -548,11 +573,11 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
548573
// openWebsocketWrapper is a http.Handler wrapper function that is invoked on websocket open requests after the original
549574
// targetURL of the request is restored. It must call the wrapped http.Handler with which it is created after it
550575
// is finished processing the request.
551-
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler) (http.Handler, error) {
576+
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler, timeout time.Duration) (http.Handler, error) {
552577
mux := http.NewServeMux()
553578
if shimPath != "" {
554579
shimPath = path.Clean("/"+shimPath) + "/"
555-
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler)
580+
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler, timeout)
556581
mux.Handle(shimPath, shimServer)
557582
}
558583
mux.Handle("/", wrapped)

agent/websockets/websockets_test.go

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ import (
2323
"errors"
2424
"fmt"
2525
"io"
26+
"io/ioutil"
2627
"net/http"
2728
"net/http/httptest"
2829
"net/url"
2930
"path"
3031
"strings"
3132
"sync"
3233
"testing"
34+
"time"
3335

3436
"github.com/google/go-cmp/cmp"
3537
"github.com/google/go-cmp/cmp/cmpopts"
@@ -239,7 +241,7 @@ func TestShimHandlers(t *testing.T) {
239241
openWrapper := func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler {
240242
return h
241243
}
242-
p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil)
244+
p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil, 60*time.Second)
243245
if err != nil {
244246
t.Fatalf("Failure creating the websocket shim proxy: %+v", err)
245247
}
@@ -354,3 +356,107 @@ func TestShimHandlers(t *testing.T) {
354356
}
355357
}
356358
}
359+
360+
func TestShimPolling(t *testing.T) {
361+
ctx, cancel := context.WithCancel(context.Background())
362+
defer cancel()
363+
// Setup a fake backend that accepts websocket connections but sends no messages.
364+
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
365+
upgrader := websocket.Upgrader{}
366+
conn, err := upgrader.Upgrade(w, r, nil)
367+
if err != nil {
368+
t.Logf("Failed to upgrade websocket: %v", err)
369+
return
370+
}
371+
defer conn.Close()
372+
// Keep connection open until client closes it.
373+
for {
374+
if _, _, err := conn.NextReader(); err != nil {
375+
break
376+
}
377+
}
378+
}))
379+
defer backend.Close()
380+
381+
backendURL, err := url.Parse(backend.URL)
382+
if err != nil {
383+
t.Fatalf("Failed to parse backend URL: %v", err)
384+
}
385+
386+
shimPath := "/shim/"
387+
idleTimeout := 3 * time.Second
388+
shim := createShimChannel(
389+
ctx,
390+
backendURL.Host,
391+
shimPath,
392+
false,
393+
func(h http.Handler, m *metrics.MetricHandler) http.Handler { return h },
394+
false,
395+
nil,
396+
idleTimeout,
397+
)
398+
shimServer := httptest.NewServer(shim)
399+
defer shimServer.Close()
400+
401+
// 1. Open a websocket connection via the shim.
402+
openURL := shimServer.URL + shimPath + "open"
403+
resp, err := http.Post(openURL, "text/plain", strings.NewReader(backendURL.String()))
404+
if err != nil {
405+
t.Fatalf("Failed to open shim connection: %v", err)
406+
}
407+
if resp.StatusCode != http.StatusOK {
408+
t.Fatalf("Failed to open shim connection, status: %d", resp.StatusCode)
409+
}
410+
body, err := ioutil.ReadAll(resp.Body)
411+
if err != nil {
412+
t.Fatalf("Failed to read open response body: %v", err)
413+
}
414+
resp.Body.Close()
415+
var openResp sessionMessage
416+
if err := json.Unmarshal(body, &openResp); err != nil {
417+
t.Fatalf("Failed to unmarshal open response: %v", err)
418+
}
419+
sessionID := openResp.ID
420+
if sessionID == "" {
421+
t.Fatal("No sessionID in open response")
422+
}
423+
424+
// 2. Poll repeatedly without any messages being sent.
425+
pollURL := shimServer.URL + shimPath + "poll"
426+
pollReq := fmt.Sprintf(`{"id": %q}`, sessionID)
427+
timeout := time.Now().Add(idleTimeout + 1*time.Second)
428+
for time.Now().Before(timeout) {
429+
resp, err := http.Post(pollURL, "application/json", strings.NewReader(pollReq))
430+
if err != nil {
431+
t.Fatalf("Failed to poll shim connection: %v", err)
432+
}
433+
resp.Body.Close()
434+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusRequestTimeout {
435+
t.Fatalf("Unexpected status code during polling: %d", resp.StatusCode)
436+
}
437+
// Sleep for half of read timeout to simulate polling faster than idle timeout.
438+
time.Sleep(500 * time.Millisecond)
439+
}
440+
441+
// 3. After idleTimeout + 1s, one more poll should succeed.
442+
resp, err = http.Post(pollURL, "application/json", strings.NewReader(pollReq))
443+
if err != nil {
444+
t.Fatalf("Failed to poll shim connection: %v", err)
445+
}
446+
defer resp.Body.Close()
447+
if resp.StatusCode != http.StatusRequestTimeout {
448+
t.Errorf("Polling after idle timeout got status %d, want %d", resp.StatusCode, http.StatusRequestTimeout)
449+
}
450+
451+
// 4. Close connection.
452+
closeURL := shimServer.URL + shimPath + "close"
453+
closeReq := fmt.Sprintf(`{"id": %q}`, sessionID)
454+
resp, err = http.Post(closeURL, "application/json", strings.NewReader(closeReq))
455+
if err != nil {
456+
t.Fatalf("Failed to close shim connection: %v", err)
457+
}
458+
defer resp.Body.Close()
459+
if resp.StatusCode != http.StatusOK {
460+
t.Errorf("Close shim connection got status %d, want %d", resp.StatusCode, http.StatusOK)
461+
}
462+
}

server/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ func (p *proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
220220
p.Lock()
221221
p.requests[id] = pending
222222
p.Unlock()
223-
defer func(){
223+
defer func() {
224224
p.Lock()
225225
delete(p.requests, id)
226226
p.Unlock()

testing/websockets/main.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"net/http"
3232
"net/http/httputil"
3333
"net/url"
34+
"time"
3435

3536
"github.com/google/inverting-proxy/agent/metrics"
3637
"github.com/google/inverting-proxy/agent/websockets"
@@ -48,6 +49,7 @@ var (
4849
monitoringEndpoint = flag.String("monitoring-endpoint", "staging-monitoring.sandbox.googleapis.com:443", "The endpoint to which to write metrics. Eg: monitoring.googleapis.com corresponds to Cloud Monarch.")
4950
monitoringResourceType = flag.String("monitoring-resource-type", "gce_instance", "The monitoring resource type. Eg: gce_instance")
5051
monitoringResourceLabels = flag.String("monitoring-resource-labels", "instance-id=fake-instance-id,instance-zone=us-west1-a", "Comma separated key value pairs for the purpose of monitoring configuration. Eg: 'instance-id=my-instance-id,instance-zone=us-west1-a")
52+
websocketShimTimeout = flag.Duration("websocket-shim-timeout", 60*time.Minute, "Timeout for websocket shim connections to expire due to inactivity.")
5153
)
5254

5355
func main() {
@@ -69,7 +71,7 @@ func main() {
6971
}
7072

7173
backendProxy := httputil.NewSingleHostReverseProxy(backendURL)
72-
shimmingProxy, err := websockets.Proxy(context.Background(), backendProxy, backendURL.Host, *shimPath, true, *enableWebsocketInjection, func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler { return h }, metricHandler)
74+
shimmingProxy, err := websockets.Proxy(context.Background(), backendProxy, backendURL.Host, *shimPath, true, *enableWebsocketInjection, func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler { return h }, metricHandler, *websocketShimTimeout)
7375
if err != nil {
7476
log.Fatalf("Failure starting the websocket-shimming proxy: %v", err)
7577
}

0 commit comments

Comments
 (0)