Skip to content

Commit 0e1579f

Browse files
authored
fix(writer): buffer race condition (#578)
* fix(writer): buffer race condition * fix: handle all buffers write ops
1 parent e03395b commit 0e1579f

File tree

6 files changed

+160
-59
lines changed

6 files changed

+160
-59
lines changed

pkg/middleware/middleware.go

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,9 @@ func (s *SouinBaseHandler) Upstream(
467467
}
468468

469469
err := s.Store(customWriter, rq, requestCc, cachedKey, uri)
470-
defer customWriter.Buf.Reset()
470+
defer customWriter.handleBuffer(func(b *bytes.Buffer) {
471+
b.Reset()
472+
})
471473

472474
return singleflightValue{
473475
body: customWriter.Buf.Bytes(),
@@ -521,7 +523,9 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF
521523
statusCode := customWriter.GetStatusCode()
522524
if err == nil {
523525
if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified {
524-
customWriter.Buf.Reset()
526+
customWriter.handleBuffer(func(b *bytes.Buffer) {
527+
b.Reset()
528+
})
525529
customWriter.Rw.WriteHeader(http.StatusPreconditionFailed)
526530

527531
return nil, errors.New("")
@@ -542,7 +546,9 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF
542546
),
543547
)
544548

545-
defer customWriter.Buf.Reset()
549+
defer customWriter.handleBuffer(func(b *bytes.Buffer) {
550+
b.Reset()
551+
})
546552
return singleflightValue{
547553
body: customWriter.Buf.Bytes(),
548554
headers: customWriter.Header().Clone(),
@@ -598,6 +604,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
598604

599605
req := s.context.SetBaseContext(rq)
600606
cacheName := req.Context().Value(context.CacheName).(string)
607+
601608
if rq.Header.Get("Upgrade") == "websocket" || rq.Header.Get("Accept") == "text/event-stream" || (s.ExcludeRegex != nil && s.ExcludeRegex.MatchString(rq.RequestURI)) {
602609
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=EXCLUDED-REQUEST-URI")
603610
return next(rw, req)
@@ -689,14 +696,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
689696
}
690697
if validator.NotModified {
691698
customWriter.WriteHeader(http.StatusNotModified)
692-
customWriter.Buf.Reset()
699+
customWriter.handleBuffer(func(b *bytes.Buffer) {
700+
b.Reset()
701+
})
693702
_, _ = customWriter.Send()
694703

695704
return nil
696705
}
697706

698707
customWriter.WriteHeader(response.StatusCode)
699-
_, _ = io.Copy(customWriter.Buf, response.Body)
708+
customWriter.handleBuffer(func(b *bytes.Buffer) {
709+
_, _ = io.Copy(b, response.Body)
710+
})
700711
_, _ = customWriter.Send()
701712

702713
return nil
@@ -722,7 +733,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
722733
}
723734
customWriter.WriteHeader(response.StatusCode)
724735
s.Configuration.GetLogger().Debugf("Serve from cache %+v", req)
725-
_, _ = io.Copy(customWriter.Buf, response.Body)
736+
customWriter.handleBuffer(func(b *bytes.Buffer) {
737+
_, _ = io.Copy(b, response.Body)
738+
})
726739
_, err := customWriter.Send()
727740
prometheus.Increment(prometheus.CachedResponseCounter)
728741

@@ -742,7 +755,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
742755
}
743756
customWriter.WriteHeader(response.StatusCode)
744757
rfc.HitStaleCache(&response.Header)
745-
_, _ = io.Copy(customWriter.Buf, response.Body)
758+
customWriter.handleBuffer(func(b *bytes.Buffer) {
759+
_, _ = io.Copy(b, response.Body)
760+
})
746761
_, err := customWriter.Send()
747762
customWriter = NewCustomWriter(req, rw, bufPool)
748763
go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) {
@@ -766,14 +781,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
766781
response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code)
767782
maps.Copy(customWriter.Header(), response.Header)
768783
customWriter.WriteHeader(response.StatusCode)
769-
customWriter.Buf.Reset()
770-
_, _ = io.Copy(customWriter.Buf, response.Body)
784+
customWriter.handleBuffer(func(b *bytes.Buffer) {
785+
b.Reset()
786+
_, _ = io.Copy(b, response.Body)
787+
})
771788
_, err := customWriter.Send()
772789

773790
return err
774791
}
775792
rw.WriteHeader(http.StatusGatewayTimeout)
776-
customWriter.Buf.Reset()
793+
customWriter.handleBuffer(func(b *bytes.Buffer) {
794+
b.Reset()
795+
})
777796
_, err := customWriter.Send()
778797

779798
return err
@@ -784,7 +803,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
784803
rfc.SetCacheStatusHeader(response, storerName)
785804
customWriter.WriteHeader(response.StatusCode)
786805
maps.Copy(customWriter.Header(), response.Header)
787-
_, _ = io.Copy(customWriter.Buf, response.Body)
806+
customWriter.handleBuffer(func(b *bytes.Buffer) {
807+
_, _ = io.Copy(b, response.Body)
808+
})
788809
_, _ = customWriter.Send()
789810

790811
return err
@@ -793,7 +814,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
793814

794815
if statusCode != http.StatusNotModified && validator.Matched {
795816
customWriter.WriteHeader(http.StatusNotModified)
796-
customWriter.Buf.Reset()
817+
customWriter.handleBuffer(func(b *bytes.Buffer) {
818+
b.Reset()
819+
})
797820
_, _ = customWriter.Send()
798821

799822
return err
@@ -808,7 +831,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
808831
customWriter.WriteHeader(response.StatusCode)
809832
rfc.HitStaleCache(&response.Header)
810833
maps.Copy(customWriter.Header(), response.Header)
811-
_, _ = io.Copy(customWriter.Buf, response.Body)
834+
customWriter.handleBuffer(func(b *bytes.Buffer) {
835+
_, _ = io.Copy(b, response.Body)
836+
})
812837
_, err := customWriter.Send()
813838

814839
return err
@@ -822,7 +847,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
822847
customWriter.WriteHeader(response.StatusCode)
823848
rfc.HitStaleCache(&response.Header)
824849
maps.Copy(customWriter.Header(), response.Header)
825-
_, _ = io.Copy(customWriter.Buf, response.Body)
850+
customWriter.handleBuffer(func(b *bytes.Buffer) {
851+
_, _ = io.Copy(b, response.Body)
852+
})
826853
_, err := customWriter.Send()
827854

828855
return err
@@ -846,8 +873,10 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
846873
response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code)
847874
maps.Copy(customWriter.Header(), response.Header)
848875
customWriter.WriteHeader(response.StatusCode)
849-
customWriter.Buf.Reset()
850-
_, _ = io.Copy(customWriter.Buf, response.Body)
876+
customWriter.handleBuffer(func(b *bytes.Buffer) {
877+
b.Reset()
878+
_, _ = io.Copy(b, response.Body)
879+
})
851880
_, err := customWriter.Send()
852881

853882
return err

pkg/middleware/writer.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ type CustomWriter struct {
3939
statusCode int
4040
}
4141

42+
func (r *CustomWriter) handleBuffer(callback func(*bytes.Buffer)) {
43+
r.mutex.Lock()
44+
callback(r.Buf)
45+
r.mutex.Unlock()
46+
}
47+
4248
// Header will write the response headers
4349
func (r *CustomWriter) Header() http.Header {
4450
r.mutex.Lock()
@@ -71,17 +77,19 @@ func (r *CustomWriter) WriteHeader(code int) {
7177

7278
// Write will write the response body
7379
func (r *CustomWriter) Write(b []byte) (int, error) {
74-
r.mutex.Lock()
75-
defer r.mutex.Unlock()
76-
r.Buf.Grow(len(b))
77-
_, _ = r.Buf.Write(b)
80+
r.handleBuffer(func(actual *bytes.Buffer) {
81+
actual.Grow(len(b))
82+
_, _ = actual.Write(b)
83+
})
7884

7985
return len(b), nil
8086
}
8187

8288
// Send delays the response to handle Cache-Status
8389
func (r *CustomWriter) Send() (int, error) {
84-
defer r.Buf.Reset()
90+
defer r.handleBuffer(func(b *bytes.Buffer) {
91+
b.Reset()
92+
})
8593
storedLength := r.Header().Get(rfc.StoredLengthHeader)
8694
if storedLength != "" {
8795
r.Header().Set("Content-Length", storedLength)

plugins/traefik/override/middleware/middleware.go

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ func (s *SouinBaseHandler) Store(
189189
ma = time.Duration(responseCc.SMaxAge) * time.Second
190190
} else if responseCc.MaxAge >= 0 {
191191
ma = time.Duration(responseCc.MaxAge) * time.Second
192-
} else if customWriter.Header().Get("Expires") != "" {
192+
} else if !modeContext.Bypass_response && customWriter.Header().Get("Expires") != "" {
193193
exp, err := time.Parse(time.RFC1123, customWriter.Header().Get("Expires"))
194194
if err != nil {
195195
return nil
@@ -249,7 +249,7 @@ func (s *SouinBaseHandler) Store(
249249
}
250250
res.Header.Set(rfc.StoredLengthHeader, res.Header.Get("Content-Length"))
251251
response, err := httputil.DumpResponse(&res, true)
252-
if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode)) {
252+
if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode) || s.hasAllowedAdditionalStatusCodesToCache(statusCode)) {
253253
variedHeaders, isVaryStar := rfc.VariedHeaderAllCommaSepValues(res.Header)
254254
if isVaryStar {
255255
// "Implies that the response is uncacheable"
@@ -372,7 +372,9 @@ func (s *SouinBaseHandler) Upstream(
372372
}
373373

374374
err := s.Store(customWriter, rq, requestCc, cachedKey)
375-
defer customWriter.Buf.Reset()
375+
defer customWriter.handleBuffer(func(b *bytes.Buffer) {
376+
b.Reset()
377+
})
376378

377379
return singleflightValue{
378380
body: customWriter.Buf.Bytes(),
@@ -423,7 +425,9 @@ func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handler
423425
statusCode := customWriter.GetStatusCode()
424426
if err == nil {
425427
if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified {
426-
customWriter.Buf.Reset()
428+
customWriter.handleBuffer(func(b *bytes.Buffer) {
429+
b.Reset()
430+
})
427431
customWriter.Rw.WriteHeader(http.StatusPreconditionFailed)
428432

429433
return nil, errors.New("")
@@ -444,7 +448,9 @@ func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handler
444448
),
445449
)
446450

447-
defer customWriter.Buf.Reset()
451+
defer customWriter.handleBuffer(func(b *bytes.Buffer) {
452+
b.Reset()
453+
})
448454
return singleflightValue{
449455
body: customWriter.Buf.Bytes(),
450456
headers: customWriter.Header().Clone(),
@@ -493,6 +499,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
493499
handler(rw, rq)
494500
return nil
495501
}
502+
496503
req := s.context.SetBaseContext(rq)
497504
cacheName := req.Context().Value(context.CacheName).(string)
498505

@@ -526,7 +533,6 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
526533
requestCc, coErr := cacheobject.ParseRequestCacheControl(rfc.HeaderAllCommaSepValuesString(req.Header, "Cache-Control"))
527534

528535
modeContext := req.Context().Value(context.Mode).(*context.ModeContext)
529-
530536
if !modeContext.Bypass_request && (coErr != nil || requestCc == nil) {
531537
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=CACHE-CONTROL-EXTRACTION-ERROR")
532538

@@ -593,14 +599,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
593599
}
594600
if validator.NotModified {
595601
customWriter.WriteHeader(http.StatusNotModified)
596-
customWriter.Buf.Reset()
602+
customWriter.handleBuffer(func(b *bytes.Buffer) {
603+
b.Reset()
604+
})
597605
_, _ = customWriter.Send()
598606

599607
return nil
600608
}
601609

602610
customWriter.WriteHeader(response.StatusCode)
603-
_, _ = io.Copy(customWriter.Buf, response.Body)
611+
customWriter.handleBuffer(func(b *bytes.Buffer) {
612+
_, _ = io.Copy(b, response.Body)
613+
})
604614
_, _ = customWriter.Send()
605615

606616
return nil
@@ -624,7 +634,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
624634
customWriter.Header()[h] = v
625635
}
626636
customWriter.WriteHeader(response.StatusCode)
627-
_, _ = io.Copy(customWriter.Buf, response.Body)
637+
customWriter.handleBuffer(func(b *bytes.Buffer) {
638+
_, _ = io.Copy(b, response.Body)
639+
})
628640
_, err := customWriter.Send()
629641

630642
return err
@@ -643,7 +655,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
643655
}
644656
customWriter.WriteHeader(response.StatusCode)
645657
rfc.HitStaleCache(&response.Header)
646-
_, _ = io.Copy(customWriter.Buf, response.Body)
658+
customWriter.handleBuffer(func(b *bytes.Buffer) {
659+
_, _ = io.Copy(b, response.Body)
660+
})
647661
_, err := customWriter.Send()
648662
customWriter = NewCustomWriter(req, rw, bufPool)
649663
go func(v *types.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) {
@@ -656,7 +670,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
656670
return err
657671
}
658672

659-
if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation {
673+
if modeContext.Bypass_response || responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation {
660674
req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag)
661675
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri)
662676
statusCode := customWriter.GetStatusCode()
@@ -670,14 +684,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
670684
customWriter.Header().Set(k, response.Header.Get(k))
671685
}
672686
customWriter.WriteHeader(response.StatusCode)
673-
customWriter.Buf.Reset()
674-
_, _ = io.Copy(customWriter.Buf, response.Body)
687+
customWriter.handleBuffer(func(b *bytes.Buffer) {
688+
b.Reset()
689+
_, _ = io.Copy(b, response.Body)
690+
})
675691
_, err := customWriter.Send()
676692

677693
return err
678694
}
679695
rw.WriteHeader(http.StatusGatewayTimeout)
680-
customWriter.Buf.Reset()
696+
customWriter.handleBuffer(func(b *bytes.Buffer) {
697+
b.Reset()
698+
})
681699
_, err := customWriter.Send()
682700

683701
return err
@@ -691,7 +709,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
691709
for k := range response.Header {
692710
customWriter.Header().Set(k, response.Header.Get(k))
693711
}
694-
_, _ = io.Copy(customWriter.Buf, response.Body)
712+
customWriter.handleBuffer(func(b *bytes.Buffer) {
713+
_, _ = io.Copy(b, response.Body)
714+
})
695715
_, _ = customWriter.Send()
696716

697717
return err
@@ -700,7 +720,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
700720

701721
if statusCode != http.StatusNotModified && validator.Matched {
702722
customWriter.WriteHeader(http.StatusNotModified)
703-
customWriter.Buf.Reset()
723+
customWriter.handleBuffer(func(b *bytes.Buffer) {
724+
b.Reset()
725+
})
704726
_, _ = customWriter.Send()
705727

706728
return err
@@ -718,7 +740,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
718740
for k := range response.Header {
719741
customWriter.Header().Set(k, response.Header.Get(k))
720742
}
721-
_, _ = io.Copy(customWriter.Buf, response.Body)
743+
customWriter.handleBuffer(func(b *bytes.Buffer) {
744+
_, _ = io.Copy(b, response.Body)
745+
})
722746
_, err := customWriter.Send()
723747

724748
return err
@@ -747,8 +771,10 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
747771
customWriter.Header().Set(k, response.Header.Get(k))
748772
}
749773
customWriter.WriteHeader(response.StatusCode)
750-
customWriter.Buf.Reset()
751-
_, _ = io.Copy(customWriter.Buf, response.Body)
774+
customWriter.handleBuffer(func(b *bytes.Buffer) {
775+
b.Reset()
776+
_, _ = io.Copy(b, response.Body)
777+
})
752778
_, err := customWriter.Send()
753779

754780
return err

0 commit comments

Comments
 (0)