@@ -9,10 +9,12 @@ import (
99 "context"
1010 "crypto/tls"
1111 "crypto/x509"
12+ "errors"
1213 "fmt"
1314 "io"
1415 "io/ioutil"
1516 "net"
17+ "os"
1618 "testing"
1719 "time"
1820)
@@ -29,7 +31,7 @@ func TestPassthrough(t *testing.T) {
2931 conn , err := net .Dial ("tcp" , pl .Addr ().String ())
3032 if err != nil {
3133 t .Fatalf ("err: %v" , err )
32- }
34+ }
3335 defer conn .Close ()
3436
3537 conn .Write ([]byte ("ping" ))
@@ -71,7 +73,7 @@ func TestReadHeaderTimeout(t *testing.T) {
7173
7274 pl := & Listener {
7375 Listener : l ,
74- ReadHeaderTimeout : 1 * time .Millisecond ,
76+ ReadHeaderTimeout : time .Millisecond * 250 ,
7577 }
7678
7779 ctx , cancel := context .WithCancel (context .Background ())
@@ -97,6 +99,93 @@ func TestReadHeaderTimeout(t *testing.T) {
9799 recv := make ([]byte , 4 )
98100 _ , err = conn .Read (recv )
99101
102+ if err != nil && ! errors .Is (err , os .ErrDeadlineExceeded ){
103+ t .Fatal ("should timeout" )
104+ }
105+ }
106+
107+ func TestReadHeaderTimeoutIsReset (t * testing.T ) {
108+ const timeout = time .Millisecond * 250
109+
110+ l , err := net .Listen ("tcp" , "127.0.0.1:0" )
111+ if err != nil {
112+ t .Fatalf ("err: %v" , err )
113+ }
114+
115+ pl := & Listener {
116+ Listener : l ,
117+ ReadHeaderTimeout : timeout ,
118+ }
119+
120+ header := & Header {
121+ Version : 2 ,
122+ Command : PROXY ,
123+ TransportProtocol : TCPv4 ,
124+ SourceAddr : & net.TCPAddr {
125+ IP : net .ParseIP ("10.1.1.1" ),
126+ Port : 1000 ,
127+ },
128+ DestinationAddr : & net.TCPAddr {
129+ IP : net .ParseIP ("20.2.2.2" ),
130+ Port : 2000 ,
131+ },
132+ }
133+ go func () {
134+ conn , err := net .Dial ("tcp" , pl .Addr ().String ())
135+ if err != nil {
136+ t .Fatalf ("err: %v" , err )
137+ }
138+ defer conn .Close ()
139+
140+ // Write out the header!
141+ header .WriteTo (conn )
142+
143+ // Sleep here longer than the configured timeout.
144+ time .Sleep (timeout * 2 )
145+
146+ conn .Write ([]byte ("ping" ))
147+ recv := make ([]byte , 4 )
148+ _ , err = conn .Read (recv )
149+ if err != nil {
150+ t .Fatalf ("err: %v" , err )
151+ }
152+ if ! bytes .Equal (recv , []byte ("pong" )) {
153+ t .Fatalf ("bad: %v" , recv )
154+ }
155+ }()
156+
157+ conn , err := pl .Accept ()
158+ if err != nil {
159+ t .Fatalf ("err: %v" , err )
160+ }
161+ defer conn .Close ()
162+
163+ recv := make ([]byte , 4 )
164+ _ , err = conn .Read (recv )
165+ if err != nil {
166+ t .Fatalf ("err: %v" , err )
167+ }
168+ if ! bytes .Equal (recv , []byte ("ping" )) {
169+ t .Fatalf ("bad: %v" , recv )
170+ }
171+
172+ if _ , err := conn .Write ([]byte ("pong" )); err != nil {
173+ t .Fatalf ("err: %v" , err )
174+ }
175+
176+ // Check the remote addr
177+ addr := conn .RemoteAddr ().(* net.TCPAddr )
178+ if addr .IP .String () != "10.1.1.1" {
179+ t .Fatalf ("bad: %v" , addr )
180+ }
181+ if addr .Port != 1000 {
182+ t .Fatalf ("bad: %v" , addr )
183+ }
184+
185+ h := conn .(* Conn ).ProxyHeader ()
186+ if ! h .EqualsTo (header ) {
187+ t .Errorf ("bad: %v" , h )
188+ }
100189}
101190
102191func TestParse_ipv4 (t * testing.T ) {
0 commit comments