Skip to content

Commit f20d9b6

Browse files
committed
Add back ratelimit
1 parent 095b183 commit f20d9b6

File tree

3 files changed

+177
-0
lines changed

3 files changed

+177
-0
lines changed

config/confighttp/confighttp.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ type ServerConfig struct {
290290
// Auth for this receiver
291291
Auth *AuthConfig `mapstructure:"auth"`
292292

293+
// RateLimit for this receiver
294+
RateLimit *RateLimit `mapstructure:"rate_limit"`
295+
293296
// MaxRequestBodySize sets the maximum request body size in bytes. Default: 20MiB.
294297
MaxRequestBodySize int64 `mapstructure:"max_request_body_size"`
295298

@@ -441,6 +444,17 @@ func (hss *ServerConfig) ToServer(_ context.Context, host component.Host, settin
441444
handler = authInterceptor(handler, server, hss.Auth.RequestParameters)
442445
}
443446

447+
// The RateLimit interceptor should always be right after auth to ensure
448+
// the request rate is within an acceptable threshold.
449+
if hss.RateLimit != nil {
450+
limiter, err := hss.RateLimit.rateLimiter(host.GetExtensions())
451+
if err != nil {
452+
return nil, err
453+
}
454+
455+
handler = rateLimitInterceptor(handler, limiter)
456+
}
457+
444458
if hss.CORS != nil && len(hss.CORS.AllowedOrigins) > 0 {
445459
co := cors.Options{
446460
AllowedOrigins: hss.CORS.AllowedOrigins,

config/confighttp/ratelimit.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright The OpenTelemetry Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package confighttp // import "go.opentelemetry.io/collector/config/confighttp"
5+
6+
import (
7+
"context"
8+
"fmt"
9+
"net/http"
10+
11+
"go.opentelemetry.io/collector/component"
12+
"go.opentelemetry.io/collector/extension"
13+
)
14+
15+
// RateLimit defines rate limiter settings for the receiver.
16+
type RateLimit struct {
17+
// RateLimiterID specifies the name of the extension to use in order to rate limit the incoming data point.
18+
RateLimiterID component.ID `mapstructure:"rate_limiter"`
19+
}
20+
21+
type rateLimiter interface {
22+
extension.Extension
23+
24+
Take(context.Context, string, http.Header) error
25+
}
26+
27+
// rateLimiter attempts to select the appropriate rateLimiter from the list of extensions,
28+
// based on the component id of the extension. If a rateLimiter is not found, an error is returned.
29+
func (rl RateLimit) rateLimiter(extensions map[component.ID]component.Component) (rateLimiter, error) {
30+
if ext, found := extensions[rl.RateLimiterID]; found {
31+
if limiter, ok := ext.(rateLimiter); ok {
32+
return limiter, nil
33+
}
34+
return nil, fmt.Errorf("extension %q is not a rate limit", rl.RateLimiterID)
35+
}
36+
return nil, fmt.Errorf("rate limit %q not found", rl.RateLimiterID)
37+
}
38+
39+
// rateLimitInterceptor adds interceptor for rate limit check.
40+
// It returns TooManyRequests(429) status code if rate limiter rejects the request.
41+
func rateLimitInterceptor(next http.Handler, limiter rateLimiter) http.Handler {
42+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
43+
err := limiter.Take(r.Context(), r.URL.Path, r.Header)
44+
if err != nil {
45+
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
46+
return
47+
}
48+
49+
next.ServeHTTP(w, r)
50+
})
51+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Copyright The OpenTelemetry Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package confighttp
5+
6+
import (
7+
"context"
8+
"errors"
9+
"fmt"
10+
"net/http"
11+
"net/http/httptest"
12+
"testing"
13+
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
17+
"go.opentelemetry.io/collector/component"
18+
"go.opentelemetry.io/collector/component/componenttest"
19+
)
20+
21+
func TestServerRateLimit(t *testing.T) {
22+
// prepare
23+
hss := ServerConfig{
24+
Endpoint: "localhost:0",
25+
RateLimit: &RateLimit{
26+
RateLimiterID: component.NewID(component.MustNewType("mock")),
27+
},
28+
}
29+
30+
limiter := &mockRateLimiter{}
31+
32+
host := &mockHost{
33+
ext: map[component.ID]component.Component{
34+
component.NewID(component.MustNewType("mock")): limiter,
35+
},
36+
}
37+
38+
var handlerCalled bool
39+
handler := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
40+
handlerCalled = true
41+
})
42+
43+
srv, err := hss.ToServer(context.Background(), host, componenttest.NewNopTelemetrySettings(), handler)
44+
require.NoError(t, err)
45+
46+
// test
47+
srv.Handler.ServeHTTP(&httptest.ResponseRecorder{}, httptest.NewRequest("GET", "/", nil))
48+
49+
// verify
50+
assert.True(t, handlerCalled)
51+
assert.Equal(t, 1, limiter.calls)
52+
}
53+
54+
func TestInvalidServerRateLimit(t *testing.T) {
55+
hss := ServerConfig{
56+
RateLimit: &RateLimit{
57+
RateLimiterID: component.NewID(component.MustNewType("non_existing")),
58+
},
59+
}
60+
61+
srv, err := hss.ToServer(context.Background(), componenttest.NewNopHost(), componenttest.NewNopTelemetrySettings(), http.NewServeMux())
62+
require.Error(t, err)
63+
require.Nil(t, srv)
64+
}
65+
66+
func TestRejectedServerRateLimit(t *testing.T) {
67+
// prepare
68+
hss := ServerConfig{
69+
Endpoint: "localhost:0",
70+
RateLimit: &RateLimit{
71+
RateLimiterID: component.NewID(component.MustNewType("mock")),
72+
},
73+
}
74+
host := &mockHost{
75+
ext: map[component.ID]component.Component{
76+
component.NewID(component.MustNewType("mock")): &mockRateLimiter{
77+
err: errors.New("rate limited"),
78+
},
79+
},
80+
}
81+
82+
srv, err := hss.ToServer(context.Background(), host, componenttest.NewNopTelemetrySettings(), http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
83+
require.NoError(t, err)
84+
85+
// test
86+
response := &httptest.ResponseRecorder{}
87+
srv.Handler.ServeHTTP(response, httptest.NewRequest("GET", "/", nil))
88+
89+
// verify
90+
assert.Equal(t, response.Result().StatusCode, http.StatusTooManyRequests)
91+
assert.Equal(t, response.Result().Status, fmt.Sprintf("%v %s", http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)))
92+
}
93+
94+
// Mocks
95+
96+
type mockRateLimiter struct {
97+
calls int
98+
err error
99+
}
100+
101+
func (m *mockRateLimiter) Take(context.Context, string, http.Header) error {
102+
m.calls++
103+
return m.err
104+
}
105+
106+
func (m *mockRateLimiter) Start(_ context.Context, _ component.Host) error {
107+
return nil
108+
}
109+
110+
func (m *mockRateLimiter) Shutdown(_ context.Context) error {
111+
return nil
112+
}

0 commit comments

Comments
 (0)