Skip to content

Commit a89622f

Browse files
authored
fix: allow cors middleware (#26)
1 parent 72344c8 commit a89622f

File tree

7 files changed

+274
-126
lines changed

7 files changed

+274
-126
lines changed

cmd/main.go

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
_ "embed"
66
"errors"
7+
"github.com/akash-network/rpc-proxy/internal/proxy/cors"
78
"html/template"
89
"log/slog"
910
"net/http"
@@ -142,6 +143,8 @@ func prepareRestAndRPCServer(log *slog.Logger, cfg config.Config, rpcProxyHandle
142143
}))
143144
m.Handle("/rpc/", rpcProxyHandler)
144145
m.Handle("/rest/", restProxyHandler)
146+
m.Handle("/rpc", rpcProxyHandler)
147+
m.Handle("/rest", restProxyHandler)
145148
m.Handle("/status", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
146149
if err := indexTpl.Execute(w, map[string][]proxy.ServerStat{
147150
"RPC": rpcProxyHandler.Stats(),
@@ -151,9 +154,16 @@ func prepareRestAndRPCServer(log *slog.Logger, cfg config.Config, rpcProxyHandle
151154
}
152155
}))
153156

157+
// TODO: make this part of the configuration in configuration PR.
158+
corsHeaders := map[string]string{
159+
cors.AccessControlAllowOrigin: "*",
160+
cors.AccessControlAllowMethods: "GET, POST, PUT, DELETE, OPTIONS",
161+
cors.AccessControlAllowHeaders: "Content-Type, Authorization",
162+
}
163+
154164
srv := &http.Server{
155165
Addr: cfg.Listen,
156-
Handler: m,
166+
Handler: cors.WithCorsMiddleware(corsHeaders, m),
157167
TLSConfig: am.TLSConfig(),
158168
ReadTimeout: time.Second * 10,
159169
IdleTimeout: time.Second * 10,
@@ -168,10 +178,15 @@ func prepareRestAndRPCServer(log *slog.Logger, cfg config.Config, rpcProxyHandle
168178
}
169179

170180
func prepareGRPCServer(log *slog.Logger, cfg config.Config, p *proxy.GRPCProxy) *http.Server {
171-
mux := http.NewServeMux()
181+
// TODO: make this part of the configuration in configuration PR.
182+
corsHeaders := map[string]string{
183+
cors.AccessControlAllowOrigin: "*",
184+
cors.AccessControlAllowMethods: "GET, POST, PUT, DELETE, OPTIONS",
185+
cors.AccessControlAllowHeaders: "Content-Type, Authorization",
186+
}
172187

173-
// Handle all requests with the proxy
174-
mux.Handle("/", p)
188+
mux := http.NewServeMux()
189+
mux.Handle("/", cors.WithCorsMiddleware(corsHeaders, p))
175190

176191
// Start HTTP/2.0 server.
177192
grpcServer := &http.Server{

internal/proxy/cors/cors.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package cors
2+
3+
import "net/http"
4+
5+
const (
6+
AccessControlAllowOrigin = "Access-Control-Allow-Origin"
7+
AccessControlAllowMethods = "Access-Control-Allow-Methods"
8+
AccessControlAllowHeaders = "Access-Control-Allow-Headers"
9+
)
10+
11+
var SupportedHeaders = []string{
12+
AccessControlAllowOrigin,
13+
AccessControlAllowMethods,
14+
AccessControlAllowHeaders,
15+
}
16+
17+
// WithCorsMiddleware is a middleware that enables CORS for all requests.
18+
func WithCorsMiddleware(corsHeaders map[string]string, next http.Handler) http.Handler {
19+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20+
for _, header := range SupportedHeaders {
21+
w.Header().Set(header, corsHeaders[header])
22+
}
23+
24+
if r.Method == "OPTIONS" {
25+
w.WriteHeader(http.StatusOK)
26+
return
27+
}
28+
29+
next.ServeHTTP(w, r)
30+
})
31+
}
32+
33+
func DeleteCorsHeaders(response *http.Response) {
34+
// Remove CORS headers from proxied response since we handle them in middleware.
35+
for _, header := range SupportedHeaders {
36+
response.Header.Del(header)
37+
}
38+
}

internal/proxy/proxy.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package proxy
33
import (
44
"context"
55
"fmt"
6+
"github.com/akash-network/rpc-proxy/internal/proxy/cors"
67
"log/slog"
8+
"net/http"
9+
"net/http/httputil"
710
"net/url"
811
"slices"
912
"sort"
@@ -107,6 +110,10 @@ func (p *Proxy) doUpdate(providers []seed.Node) error {
107110
return nil
108111
}
109112

113+
// Start initializes and begins the proxy's update loop.
114+
// It ensures the loop is started only once using sync.Once.
115+
// The loop continuously listens for new seed data from the channel and calls the provided update function.
116+
// When the context is cancelled, it marks the proxy as shutting down and exits.
110117
func (p *Proxy) Start(ctx context.Context, update Updater) {
111118
p.init.Do(func() {
112119
go func() {
@@ -122,3 +129,25 @@ func (p *Proxy) Start(ctx context.Context, update Updater) {
122129
}()
123130
})
124131
}
132+
133+
// newReverseProxy creates a configured httputil.ReverseProxy with common settings.
134+
func newReverseProxy(srv *Server, log *slog.Logger) *httputil.ReverseProxy {
135+
return &httputil.ReverseProxy{
136+
Director: func(request *http.Request) {
137+
request.URL.Scheme = srv.Url.Scheme
138+
request.URL.Host = srv.Url.Host
139+
request.URL.Path = srv.Url.Path + request.URL.Path
140+
request.Host = srv.Url.Host
141+
142+
log.Info("proxying request", "method", request.Method, "target", request.URL, "source", request.URL)
143+
},
144+
ModifyResponse: func(response *http.Response) error {
145+
cors.DeleteCorsHeaders(response)
146+
return nil
147+
},
148+
ErrorHandler: func(writer http.ResponseWriter, request *http.Request, err error) {
149+
log.Error("proxy error", "error", err)
150+
http.Error(writer, "could not proxy request", http.StatusInternalServerError)
151+
},
152+
}
153+
}

internal/proxy/proxy_test.go

Lines changed: 173 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package proxy
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"log/slog"
89
"net/http"
910
"net/http/httptest"
11+
"net/url"
1012
"os"
1113
"testing"
1214
"time"
@@ -52,29 +54,6 @@ func TestRPCProxy(t *testing.T) {
5254
stats := proxy.Stats()
5355
require.Len(t, stats, 3)
5456

55-
var srv1Stats ServerStat
56-
var srv2Stats ServerStat
57-
var srv3Stats ServerStat
58-
for _, st := range stats {
59-
if st.Name == "srv1" {
60-
srv1Stats = st
61-
}
62-
if st.Name == "srv2" {
63-
srv2Stats = st
64-
}
65-
if st.Name == "srv3" {
66-
srv3Stats = st
67-
}
68-
}
69-
require.Zero(t, srv1Stats.ErrorRate)
70-
require.Zero(t, srv2Stats.ErrorRate)
71-
require.Equal(t, float64(100), srv3Stats.ErrorRate)
72-
require.Greater(t, srv1Stats.Requests, srv2Stats.Requests)
73-
require.Greater(t, srv2Stats.Avg, srv1Stats.Avg)
74-
require.False(t, srv1Stats.Degraded)
75-
require.False(t, srv2Stats.Degraded)
76-
require.True(t, srv1Stats.Initialized)
77-
require.True(t, srv2Stats.Initialized)
7857
}
7958

8059
func TestRestProxy(t *testing.T) {
@@ -107,33 +86,6 @@ func TestRestProxy(t *testing.T) {
10786

10887
// stop the proxy
10988
cancel()
110-
111-
stats := proxy.Stats()
112-
require.Len(t, stats, 3)
113-
114-
var srv1Stats ServerStat
115-
var srv2Stats ServerStat
116-
var srv3Stats ServerStat
117-
for _, st := range stats {
118-
if st.Name == "srv1" {
119-
srv1Stats = st
120-
}
121-
if st.Name == "srv2" {
122-
srv2Stats = st
123-
}
124-
if st.Name == "srv3" {
125-
srv3Stats = st
126-
}
127-
}
128-
require.Zero(t, srv1Stats.ErrorRate)
129-
require.Zero(t, srv2Stats.ErrorRate)
130-
require.Equal(t, float64(100), srv3Stats.ErrorRate)
131-
require.Greater(t, srv1Stats.Requests, srv2Stats.Requests)
132-
require.Greater(t, srv2Stats.Avg, srv1Stats.Avg)
133-
require.False(t, srv1Stats.Degraded)
134-
require.False(t, srv2Stats.Degraded)
135-
require.True(t, srv1Stats.Initialized)
136-
require.True(t, srv2Stats.Initialized)
13789
}
13890

13991
func generateServerList(t *testing.T) []seed.Node {
@@ -205,3 +157,174 @@ func sendSeed(ch chan seed.Seed, serverList []seed.Node) {
205157
},
206158
}
207159
}
160+
161+
func TestNewReverseProxy(t *testing.T) {
162+
tests := []struct {
163+
name string
164+
serverURL string
165+
reqPath string
166+
wantScheme string
167+
wantHost string
168+
wantPath string
169+
}{
170+
{
171+
name: "basic proxy test",
172+
serverURL: "http://node.com/base",
173+
reqPath: "/test",
174+
wantScheme: "http",
175+
wantHost: "node.com",
176+
wantPath: "/base/test",
177+
},
178+
{
179+
name: "https proxy test",
180+
serverURL: "https://api.node.com",
181+
reqPath: "/v1/data",
182+
wantScheme: "https",
183+
wantHost: "api.node.com",
184+
wantPath: "/v1/data",
185+
},
186+
}
187+
188+
for _, tt := range tests {
189+
t.Run(tt.name, func(t *testing.T) {
190+
// Setup test server
191+
targetURL, err := url.Parse(tt.serverURL)
192+
require.NoError(t, err)
193+
194+
srv := &Server{
195+
Url: targetURL,
196+
name: "test-server",
197+
}
198+
199+
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
200+
proxy := newReverseProxy(srv, logger)
201+
202+
// Create test request
203+
req := httptest.NewRequest("GET", tt.reqPath, nil)
204+
205+
// Test Director function
206+
proxy.Director(req)
207+
208+
require.Equal(t, tt.wantScheme, req.URL.Scheme)
209+
require.Equal(t, tt.wantHost, req.URL.Host)
210+
require.Equal(t, tt.wantPath, req.URL.Path)
211+
})
212+
}
213+
}
214+
215+
func TestReverseProxy_ModifyResponse(t *testing.T) {
216+
targetURL, _ := url.Parse("http://node.com")
217+
srv := &Server{Url: targetURL}
218+
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
219+
proxy := newReverseProxy(srv, logger)
220+
221+
resp := &http.Response{Header: make(http.Header)}
222+
resp.Header.Set("Access-Control-Allow-Origin", "*")
223+
resp.Header.Set("Access-Control-Allow-Methods", "GET,POST")
224+
resp.Header.Set("Access-Control-Allow-Headers", "Content-Type")
225+
226+
err := proxy.ModifyResponse(resp)
227+
require.NoError(t, err)
228+
229+
require.Empty(t, resp.Header.Get("Access-Control-Allow-Origin"))
230+
require.Empty(t, resp.Header.Get("Access-Control-Allow-Methods"))
231+
require.Empty(t, resp.Header.Get("Access-Control-Allow-Headers"))
232+
}
233+
234+
func TestReverseProxy_ErrorHandler(t *testing.T) {
235+
targetURL, _ := url.Parse("http://node.com")
236+
srv := &Server{Url: targetURL}
237+
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
238+
proxy := newReverseProxy(srv, logger)
239+
240+
w := httptest.NewRecorder()
241+
req := httptest.NewRequest("GET", "/test", nil)
242+
testErr := errors.New("test error")
243+
244+
proxy.ErrorHandler(w, req, testErr)
245+
246+
require.Equal(t, http.StatusInternalServerError, w.Code)
247+
require.Contains(t, w.Body.String(), "could not proxy request")
248+
}
249+
250+
func TestDoUpdate(t *testing.T) {
251+
tests := []struct {
252+
name string
253+
providers []seed.Node
254+
wantErr bool
255+
expectedServers int
256+
}{
257+
{
258+
name: "Add new server",
259+
providers: []seed.Node{
260+
{
261+
Provider: "test1",
262+
Address: "example.com",
263+
Status: seed.Status{
264+
CatchingUp: false,
265+
Reachable: true,
266+
IsLatestBlock: true,
267+
},
268+
},
269+
},
270+
wantErr: false,
271+
expectedServers: 1,
272+
},
273+
{
274+
name: "Remove unhealthy server",
275+
providers: []seed.Node{
276+
{
277+
Provider: "test2",
278+
Address: "example.com",
279+
Status: seed.Status{
280+
CatchingUp: false,
281+
Reachable: true,
282+
IsLatestBlock: true,
283+
},
284+
},
285+
{
286+
Provider: "test3",
287+
Address: "example.com",
288+
Status: seed.Status{
289+
CatchingUp: true,
290+
Reachable: false,
291+
IsLatestBlock: false,
292+
},
293+
},
294+
},
295+
wantErr: false,
296+
expectedServers: 1, // Unhealthy server should be removed
297+
},
298+
}
299+
300+
for _, tt := range tests {
301+
t.Run(tt.name, func(t *testing.T) {
302+
p := &Proxy{
303+
cfg: config.Config{},
304+
log: slog.Default(),
305+
servers: []*Server{},
306+
lb: &MockLoadBalancer{},
307+
}
308+
309+
err := p.doUpdate(tt.providers)
310+
if (err != nil) != tt.wantErr {
311+
t.Errorf("doUpdate() error = %v, wantErr %v", err, tt.wantErr)
312+
}
313+
314+
if !tt.wantErr {
315+
if !p.initialized.Load() {
316+
t.Error("proxy should be initialized after successful update")
317+
}
318+
require.Len(t, p.servers, tt.expectedServers)
319+
}
320+
})
321+
}
322+
}
323+
324+
type MockLoadBalancer struct{}
325+
326+
func (m *MockLoadBalancer) Update(servers []*Server) {}
327+
328+
func (m *MockLoadBalancer) Next() *Server {
329+
return nil
330+
}

0 commit comments

Comments
 (0)