Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 3 additions & 18 deletions internal/connmgr/connmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,9 @@ type Config struct {
// to. If nil, no new connections will be made automatically.
GetNewAddress func() (net.Addr, error)

// Dial connects to the address on the named network. Either Dial or
// DialAddr need to be specified (but not both).
// Dial connects to the address on the named network.
Dial func(ctx context.Context, network, addr string) (net.Conn, error)

// DialAddr is an alternative to Dial which receives a full net.Addr instead
// of just the protocol family and address. Either DialAddr or Dial need
// to be specified (but not both).
DialAddr func(context.Context, net.Addr) (net.Conn, error)

// DialTimeout specifies the amount of time to wait for a connection to
// complete before giving up.
DialTimeout time.Duration
Expand Down Expand Up @@ -388,12 +382,7 @@ func (cm *ConnManager) Connect(ctx context.Context, c *ConnReq) {
defer cancel()
}
var conn net.Conn
var err error
if cm.cfg.Dial != nil {
conn, err = cm.cfg.Dial(ctx, c.Addr.Network(), c.Addr.String())
} else {
conn, err = cm.cfg.DialAddr(ctx, c.Addr)
}
conn, err := cm.cfg.Dial(ctx, c.Addr.Network(), c.Addr.String())
if err != nil {
cm.connMtx.Lock()
cm.handleFailedPending(ctx, c, err)
Expand Down Expand Up @@ -648,13 +637,9 @@ func (cm *ConnManager) Run(ctx context.Context) {
//
// Use Run to start listening and/or connecting to the network.
func New(cfg *Config) (*ConnManager, error) {
if cfg.Dial == nil && cfg.DialAddr == nil {
if cfg.Dial == nil {
return nil, MakeError(ErrDialNil, "dial cannot be nil")
}
if cfg.Dial != nil && cfg.DialAddr != nil {
return nil, MakeError(ErrBothDialsFilled,
"cannot specify both Dial and DialAddr")
}
// Default to sane values
if cfg.RetryDuration <= 0 {
cfg.RetryDuration = defaultRetryDuration
Expand Down
91 changes: 9 additions & 82 deletions internal/connmgr/connmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,6 @@ func mockDialer(ctx context.Context, network, addr string) (net.Conn, error) {
return c, ctx.Err()
}

// mockDialer mocks the net.Dial interface by returning a mock connection to
// the given address.
func mockDialerAddr(ctx context.Context, addr net.Addr) (net.Conn, error) {
r, w := io.Pipe()
c := &mockConn{rAddr: addr}
c.Reader = r
c.Writer = w
return c, nil
}

// TestNewConfig tests that new ConnManager config is validated as expected.
func TestNewConfig(t *testing.T) {
_, err := New(&Config{})
Expand All @@ -108,21 +98,6 @@ func TestNewConfig(t *testing.T) {
if err != nil {
t.Fatalf("New unexpected error: %v", err)
}

_, err = New(&Config{
Dial: mockDialer,
DialAddr: mockDialerAddr,
})
if err == nil {
t.Fatal("New expected error: 'Dial and DialAddr can't be both nil', got nil")
}

_, err = New(&Config{
DialAddr: mockDialerAddr,
})
if err != nil {
t.Fatalf("New unexpected error: %v", err)
}
}

// assertConnReqID ensures the provided connection request has the given ID.
Expand Down Expand Up @@ -248,54 +223,6 @@ func TestTargetOutbound(t *testing.T) {
wg.Wait()
}

// TestPassAddrAlongDialAddr tests if when using the DialAddr config option,
// any address object returned by GetNewAddress will be correctly passed along
// to DialAddr to be used for connecting to a host.
func TestPassAddrAlongDialAddr(t *testing.T) {
dailedAddr := make(chan net.Addr)
detectDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) {
dailedAddr <- addr
return nil, errors.New("error")
}

// targetAddr will be the specific address we'll use to connect. It _could_
// be carrying more info than a standard (tcp/udp) network address, so it
// needs to be relayed to dialAddr.
targetAddr := mockAddr{
net: "invalid",
address: "unreachable",
}

cmgr, err := New(&Config{
TargetOutbound: 1,
DialAddr: detectDialer,
GetNewAddress: func() (net.Addr, error) {
return targetAddr, nil
},
})
if err != nil {
t.Fatalf("New error: %v", err)
}
_, shutdown, wg := runConnMgrAsync(context.Background(), cmgr)

select {
case addr := <-dailedAddr:
receivedMock, isMockAddr := addr.(mockAddr)
if !isMockAddr {
t.Fatal("connected to an address that was not a mockAddr")
}
if receivedMock != targetAddr {
t.Fatal("connected to an address different than the expected target")
}
case <-time.After(time.Millisecond * 20):
t.Fatal("did not get connection to target address before timeout")
}

// Ensure clean shutdown of connection manager.
shutdown()
wg.Wait()
}

// TestRetryPermanent tests that permanent connection requests are retried.
//
// We make a permanent connection request using Connect, disconnect it using
Expand Down Expand Up @@ -589,13 +516,13 @@ func TestRemovePendingConnection(t *testing.T) {
// succeed.
dialed := make(chan struct{})
wait := make(chan struct{})
indefiniteDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) {
indefiniteDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
close(dialed)
<-wait
return nil, errors.New("error")
}
cmgr, err := New(&Config{
DialAddr: indefiniteDialer,
Dial: indefiniteDialer,
})
if err != nil {
t.Fatalf("New error: %v", err)
Expand Down Expand Up @@ -647,10 +574,10 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) {
// connect chan is signaled. The dial attempt immediately after that
// will succeed in returning a connection.
connect := make(chan struct{})
failingDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) {
failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
select {
case <-connect:
return mockDialerAddr(ctx, addr)
return mockDialer(ctx, network, addr)
default:
}

Expand All @@ -659,7 +586,7 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) {

connected := make(chan *ConnReq)
cmgr, err := New(&Config{
DialAddr: failingDialer,
Dial: failingDialer,
RetryDuration: retryTimeout,
OnConnection: func(c *ConnReq, conn net.Conn) {
connected <- c
Expand Down Expand Up @@ -823,17 +750,17 @@ func TestForEachConnReq(t *testing.T) {
targetOutbound := uint32(5)
connected := make(chan *ConnReq)
pending := make(chan struct{})
delayDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) {
if addr.String() == "127.0.0.1:18557" {
delayDialer := func(ctx context.Context, network, addr string) (net.Conn, error) {
if addr == "127.0.0.1:18557" {
close(pending)
time.Sleep(time.Second)
return nil, errors.New("error")
}
return mockDialerAddr(ctx, addr)
return mockDialer(ctx, network, addr)
}
cmgr, err := New(&Config{
TargetOutbound: targetOutbound,
DialAddr: delayDialer,
Dial: delayDialer,
GetNewAddress: func() (net.Addr, error) {
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Expand Down
4 changes: 0 additions & 4 deletions internal/connmgr/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ const (
// the configuration.
ErrDialNil = ErrorKind("ErrDialNil")

// ErrBothDialsFilled is used to indicate that Dial and DialAddr
// cannot both be specified in the configuration.
ErrBothDialsFilled = ErrorKind("ErrBothDialsFilled")

// ErrNotFound indicates a specified connection ID or address is unknown to
// the connection manager.
ErrNotFound = ErrorKind("ErrNotFound")
Expand Down
31 changes: 15 additions & 16 deletions internal/connmgr/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ func TestErrorKindStringer(t *testing.T) {
want string
}{
{ErrDialNil, "ErrDialNil"},
{ErrBothDialsFilled, "ErrBothDialsFilled"},
{ErrNotFound, "ErrNotFound"},
{ErrTorInvalidAddressResponse, "ErrTorInvalidAddressResponse"},
{ErrTorInvalidProxyResponse, "ErrTorInvalidProxyResponse"},
Expand Down Expand Up @@ -91,35 +90,35 @@ func TestErrorKindIsAs(t *testing.T) {
wantMatch: true,
wantAs: ErrDialNil,
}, {
name: "ErrBothDialsFilled != ErrDialNil",
err: ErrBothDialsFilled,
name: "ErrNotFound != ErrDialNil",
err: ErrNotFound,
target: ErrDialNil,
wantMatch: false,
wantAs: ErrBothDialsFilled,
wantAs: ErrNotFound,
}, {
name: "Error.ErrBothDialsFilled != ErrDialNil",
err: MakeError(ErrBothDialsFilled, ""),
name: "Error.ErrNotFound != ErrDialNil",
err: MakeError(ErrNotFound, ""),
target: ErrDialNil,
wantMatch: false,
wantAs: ErrBothDialsFilled,
wantAs: ErrNotFound,
}, {
name: "ErrBothDialsFilled != Error.ErrDialNil",
err: ErrBothDialsFilled,
name: "ErrNotFound != Error.ErrDialNil",
err: ErrNotFound,
target: MakeError(ErrDialNil, ""),
wantMatch: false,
wantAs: ErrBothDialsFilled,
wantAs: ErrNotFound,
}, {
name: "Error.ErrBothDialsFilled != Error.ErrDialNil",
err: MakeError(ErrBothDialsFilled, ""),
name: "Error.ErrNotFound != Error.ErrDialNil",
err: MakeError(ErrNotFound, ""),
target: MakeError(ErrDialNil, ""),
wantMatch: false,
wantAs: ErrBothDialsFilled,
wantAs: ErrNotFound,
}, {
name: "Error.ErrBothDialsFilled != io.EOF",
err: MakeError(ErrBothDialsFilled, ""),
name: "Error.ErrNotFound != io.EOF",
err: MakeError(ErrNotFound, ""),
target: io.EOF,
wantMatch: false,
wantAs: ErrBothDialsFilled,
wantAs: ErrNotFound,
}}

for _, test := range tests {
Expand Down