Skip to content

Commit 67d28b3

Browse files
committed
protocol: prove erroneous timeout after timeout is set
As reported in #75 Signed-off-by: Pires <[email protected]>
1 parent a55009f commit 67d28b3

File tree

1 file changed

+91
-2
lines changed

1 file changed

+91
-2
lines changed

protocol_test.go

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ import (
99
"context"
1010
"crypto/tls"
1111
"crypto/x509"
12+
"errors"
1213
"fmt"
1314
"io"
1415
"io/ioutil"
1516
"net"
17+
"os"
1618
"testing"
1719
"time"
1820
)
@@ -29,7 +31,7 @@ func TestPassthrough(t *testing.T) {
2931
conn, err := net.Dial("tcp", pl.Addr().String())
3032
if err != nil {
3133
t.Fatalf("err: %v", err)
32-
}
34+
}
3335
defer conn.Close()
3436

3537
conn.Write([]byte("ping"))
@@ -71,7 +73,7 @@ func TestReadHeaderTimeout(t *testing.T) {
7173

7274
pl := &Listener{
7375
Listener: l,
74-
ReadHeaderTimeout: 1 * time.Millisecond,
76+
ReadHeaderTimeout: time.Millisecond * 250,
7577
}
7678

7779
ctx, cancel := context.WithCancel(context.Background())
@@ -97,6 +99,93 @@ func TestReadHeaderTimeout(t *testing.T) {
9799
recv := make([]byte, 4)
98100
_, err = conn.Read(recv)
99101

102+
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded){
103+
t.Fatal("should timeout")
104+
}
105+
}
106+
107+
func TestReadHeaderTimeoutIsReset(t *testing.T) {
108+
const timeout = time.Millisecond * 250
109+
110+
l, err := net.Listen("tcp", "127.0.0.1:0")
111+
if err != nil {
112+
t.Fatalf("err: %v", err)
113+
}
114+
115+
pl := &Listener{
116+
Listener: l,
117+
ReadHeaderTimeout: timeout,
118+
}
119+
120+
header := &Header{
121+
Version: 2,
122+
Command: PROXY,
123+
TransportProtocol: TCPv4,
124+
SourceAddr: &net.TCPAddr{
125+
IP: net.ParseIP("10.1.1.1"),
126+
Port: 1000,
127+
},
128+
DestinationAddr: &net.TCPAddr{
129+
IP: net.ParseIP("20.2.2.2"),
130+
Port: 2000,
131+
},
132+
}
133+
go func() {
134+
conn, err := net.Dial("tcp", pl.Addr().String())
135+
if err != nil {
136+
t.Fatalf("err: %v", err)
137+
}
138+
defer conn.Close()
139+
140+
// Write out the header!
141+
header.WriteTo(conn)
142+
143+
// Sleep here longer than the configured timeout.
144+
time.Sleep(timeout * 2)
145+
146+
conn.Write([]byte("ping"))
147+
recv := make([]byte, 4)
148+
_, err = conn.Read(recv)
149+
if err != nil {
150+
t.Fatalf("err: %v", err)
151+
}
152+
if !bytes.Equal(recv, []byte("pong")) {
153+
t.Fatalf("bad: %v", recv)
154+
}
155+
}()
156+
157+
conn, err := pl.Accept()
158+
if err != nil {
159+
t.Fatalf("err: %v", err)
160+
}
161+
defer conn.Close()
162+
163+
recv := make([]byte, 4)
164+
_, err = conn.Read(recv)
165+
if err != nil {
166+
t.Fatalf("err: %v", err)
167+
}
168+
if !bytes.Equal(recv, []byte("ping")) {
169+
t.Fatalf("bad: %v", recv)
170+
}
171+
172+
if _, err := conn.Write([]byte("pong")); err != nil {
173+
t.Fatalf("err: %v", err)
174+
}
175+
176+
// Check the remote addr
177+
addr := conn.RemoteAddr().(*net.TCPAddr)
178+
if addr.IP.String() != "10.1.1.1" {
179+
t.Fatalf("bad: %v", addr)
180+
}
181+
if addr.Port != 1000 {
182+
t.Fatalf("bad: %v", addr)
183+
}
184+
185+
h := conn.(*Conn).ProxyHeader()
186+
if !h.EqualsTo(header) {
187+
t.Errorf("bad: %v", h)
188+
}
100189
}
101190

102191
func TestParse_ipv4(t *testing.T) {

0 commit comments

Comments
 (0)