diff --git a/connect.go b/connect.go index 274a41ee..2bf3cede 100644 --- a/connect.go +++ b/connect.go @@ -30,6 +30,7 @@ import ( "io" "net/http" "net/url" + "time" ) // Version is the semantic version of the connect module. @@ -319,6 +320,8 @@ type Spec struct { Procedure string // for example, "/acme.foo.v1.FooService/Bar" IsClient bool // otherwise we're in a handler IdempotencyLevel IdempotencyLevel + ReadTimeout time.Duration + WriteTimeout time.Duration } // Peer describes the other party to an RPC. diff --git a/handler.go b/handler.go index e33934ea..4c883adc 100644 --- a/handler.go +++ b/handler.go @@ -17,6 +17,7 @@ package connect import ( "context" "net/http" + "time" ) // A Handler is the server-side implementation of a single RPC defined by a @@ -253,6 +254,12 @@ func NewBidiStreamHandler[Req, Res any]( // ServeHTTP implements [http.Handler]. func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { + if h.spec.ReadTimeout != 0 { + rc := http.NewResponseController(responseWriter) + rc.SetReadDeadline(time.Now().Add(h.spec.ReadTimeout)) + rc.SetWriteDeadline(time.Now().Add(h.spec.WriteTimeout)) + } + // We don't need to defer functions to close the request body or read to // EOF: the stream we construct later on already does that, and we only // return early when dealing with misbehaving clients. In those cases, it's @@ -348,6 +355,8 @@ type handlerConfig struct { ReadMaxBytes int SendMaxBytes int StreamType StreamType + ReadTimeout time.Duration + WriteTimeout time.Duration } func newHandlerConfig(procedure string, streamType StreamType, options []HandlerOption) *handlerConfig { @@ -374,6 +383,8 @@ func (c *handlerConfig) newSpec() Spec { Schema: c.Schema, StreamType: c.StreamType, IdempotencyLevel: c.IdempotencyLevel, + ReadTimeout: c.ReadTimeout, + WriteTimeout: c.WriteTimeout, } } diff --git a/option.go b/option.go index fe0a2cd9..d4831bc1 100644 --- a/option.go +++ b/option.go @@ -19,6 +19,7 @@ import ( "context" "io" "net/http" + "time" ) // A ClientOption configures a [Client]. @@ -351,6 +352,34 @@ func WithInterceptors(interceptors ...Interceptor) Option { return &interceptorsOption{interceptors} } +// WithReadTimeout option specifies the maximum amount of time that a service +// handler is allowed to take when reading a message in a stream. +// If the total time exceeds WithReadTimeout, then that particular stream is +// closed. +// This enables the user to close only that particular stream instead of the +// entire connection. +// This prevents malicious or slow clients from using up resources. +// This option is passed to the handler config and then to the spec. +// Finally, ServeHTTP function of the handler reads the timeout values from +// the spec and enforces them using ResponseController. +func WithReadTimeout(value time.Duration) HandlerOption { + return &readTimeoutOption{value: value} +} + +// WithWriteTimeout option specifies the maximum amount of time that a service +// handler is allowed to take when writing a message to a stream. +// If the total time exceeds WithReadTimeout, then that particular stream is +// closed. +// This enables the user to close only that particular stream instead of the +// entire connection. +// This prevents malicious or slow clients from using up resources. +// This option is passed to the handler config and then to the spec. +// Finally, ServeHTTP function of the handler reads the timeout values from +// the spec and enforces them using ResponseController. +func WithWriteTimeout(value time.Duration) HandlerOption { + return &writeTimeoutOption{value: value} +} + // WithOptions composes multiple Options into one. func WithOptions(options ...Option) Option { return &optionsOption{options} @@ -645,3 +674,15 @@ func (o *conditionalHandlerOptions) applyToHandler(config *handlerConfig) { option.applyToHandler(config) } } + +type readTimeoutOption struct{ value time.Duration } + +func (o *readTimeoutOption) applyToHandler(config *handlerConfig) { + config.ReadTimeout = o.value +} + +type writeTimeoutOption struct{ value time.Duration } + +func (o *writeTimeoutOption) applyToHandler(config *handlerConfig) { + config.WriteTimeout = o.value +}