diff --git a/internal/connmgr/connmanager.go b/internal/connmgr/connmanager.go index 7e7d6d63d..60e1afcc1 100644 --- a/internal/connmgr/connmanager.go +++ b/internal/connmgr/connmanager.go @@ -151,15 +151,9 @@ type Config struct { // to. If nil, no new connections will be made automatically. GetNewAddress func() (net.Addr, error) - // Dial connects to the address on the named network. Either Dial or - // DialAddr need to be specified (but not both). + // Dial connects to the address on the named network. Dial func(ctx context.Context, network, addr string) (net.Conn, error) - // DialAddr is an alternative to Dial which receives a full net.Addr instead - // of just the protocol family and address. Either DialAddr or Dial need - // to be specified (but not both). - DialAddr func(context.Context, net.Addr) (net.Conn, error) - // DialTimeout specifies the amount of time to wait for a connection to // complete before giving up. DialTimeout time.Duration @@ -388,12 +382,7 @@ func (cm *ConnManager) Connect(ctx context.Context, c *ConnReq) { defer cancel() } var conn net.Conn - var err error - if cm.cfg.Dial != nil { - conn, err = cm.cfg.Dial(ctx, c.Addr.Network(), c.Addr.String()) - } else { - conn, err = cm.cfg.DialAddr(ctx, c.Addr) - } + conn, err := cm.cfg.Dial(ctx, c.Addr.Network(), c.Addr.String()) if err != nil { cm.connMtx.Lock() cm.handleFailedPending(ctx, c, err) @@ -648,13 +637,9 @@ func (cm *ConnManager) Run(ctx context.Context) { // // Use Run to start listening and/or connecting to the network. func New(cfg *Config) (*ConnManager, error) { - if cfg.Dial == nil && cfg.DialAddr == nil { + if cfg.Dial == nil { return nil, MakeError(ErrDialNil, "dial cannot be nil") } - if cfg.Dial != nil && cfg.DialAddr != nil { - return nil, MakeError(ErrBothDialsFilled, - "cannot specify both Dial and DialAddr") - } // Default to sane values if cfg.RetryDuration <= 0 { cfg.RetryDuration = defaultRetryDuration diff --git a/internal/connmgr/connmanager_test.go b/internal/connmgr/connmanager_test.go index bc4e19bc8..65900d583 100644 --- a/internal/connmgr/connmanager_test.go +++ b/internal/connmgr/connmanager_test.go @@ -86,16 +86,6 @@ func mockDialer(ctx context.Context, network, addr string) (net.Conn, error) { return c, ctx.Err() } -// mockDialer mocks the net.Dial interface by returning a mock connection to -// the given address. -func mockDialerAddr(ctx context.Context, addr net.Addr) (net.Conn, error) { - r, w := io.Pipe() - c := &mockConn{rAddr: addr} - c.Reader = r - c.Writer = w - return c, nil -} - // TestNewConfig tests that new ConnManager config is validated as expected. func TestNewConfig(t *testing.T) { _, err := New(&Config{}) @@ -108,21 +98,6 @@ func TestNewConfig(t *testing.T) { if err != nil { t.Fatalf("New unexpected error: %v", err) } - - _, err = New(&Config{ - Dial: mockDialer, - DialAddr: mockDialerAddr, - }) - if err == nil { - t.Fatal("New expected error: 'Dial and DialAddr can't be both nil', got nil") - } - - _, err = New(&Config{ - DialAddr: mockDialerAddr, - }) - if err != nil { - t.Fatalf("New unexpected error: %v", err) - } } // assertConnReqID ensures the provided connection request has the given ID. @@ -248,54 +223,6 @@ func TestTargetOutbound(t *testing.T) { wg.Wait() } -// TestPassAddrAlongDialAddr tests if when using the DialAddr config option, -// any address object returned by GetNewAddress will be correctly passed along -// to DialAddr to be used for connecting to a host. -func TestPassAddrAlongDialAddr(t *testing.T) { - dailedAddr := make(chan net.Addr) - detectDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) { - dailedAddr <- addr - return nil, errors.New("error") - } - - // targetAddr will be the specific address we'll use to connect. It _could_ - // be carrying more info than a standard (tcp/udp) network address, so it - // needs to be relayed to dialAddr. - targetAddr := mockAddr{ - net: "invalid", - address: "unreachable", - } - - cmgr, err := New(&Config{ - TargetOutbound: 1, - DialAddr: detectDialer, - GetNewAddress: func() (net.Addr, error) { - return targetAddr, nil - }, - }) - if err != nil { - t.Fatalf("New error: %v", err) - } - _, shutdown, wg := runConnMgrAsync(context.Background(), cmgr) - - select { - case addr := <-dailedAddr: - receivedMock, isMockAddr := addr.(mockAddr) - if !isMockAddr { - t.Fatal("connected to an address that was not a mockAddr") - } - if receivedMock != targetAddr { - t.Fatal("connected to an address different than the expected target") - } - case <-time.After(time.Millisecond * 20): - t.Fatal("did not get connection to target address before timeout") - } - - // Ensure clean shutdown of connection manager. - shutdown() - wg.Wait() -} - // TestRetryPermanent tests that permanent connection requests are retried. // // We make a permanent connection request using Connect, disconnect it using @@ -589,13 +516,13 @@ func TestRemovePendingConnection(t *testing.T) { // succeed. dialed := make(chan struct{}) wait := make(chan struct{}) - indefiniteDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) { + indefiniteDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { close(dialed) <-wait return nil, errors.New("error") } cmgr, err := New(&Config{ - DialAddr: indefiniteDialer, + Dial: indefiniteDialer, }) if err != nil { t.Fatalf("New error: %v", err) @@ -647,10 +574,10 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { // connect chan is signaled. The dial attempt immediately after that // will succeed in returning a connection. connect := make(chan struct{}) - failingDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) { + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { select { case <-connect: - return mockDialerAddr(ctx, addr) + return mockDialer(ctx, network, addr) default: } @@ -659,7 +586,7 @@ func TestCancelIgnoreDelayedConnection(t *testing.T) { connected := make(chan *ConnReq) cmgr, err := New(&Config{ - DialAddr: failingDialer, + Dial: failingDialer, RetryDuration: retryTimeout, OnConnection: func(c *ConnReq, conn net.Conn) { connected <- c @@ -823,17 +750,17 @@ func TestForEachConnReq(t *testing.T) { targetOutbound := uint32(5) connected := make(chan *ConnReq) pending := make(chan struct{}) - delayDialer := func(ctx context.Context, addr net.Addr) (net.Conn, error) { - if addr.String() == "127.0.0.1:18557" { + delayDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + if addr == "127.0.0.1:18557" { close(pending) time.Sleep(time.Second) return nil, errors.New("error") } - return mockDialerAddr(ctx, addr) + return mockDialer(ctx, network, addr) } cmgr, err := New(&Config{ TargetOutbound: targetOutbound, - DialAddr: delayDialer, + Dial: delayDialer, GetNewAddress: func() (net.Addr, error) { return &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), diff --git a/internal/connmgr/error.go b/internal/connmgr/error.go index e1ed71f11..932a13f28 100644 --- a/internal/connmgr/error.go +++ b/internal/connmgr/error.go @@ -14,10 +14,6 @@ const ( // the configuration. ErrDialNil = ErrorKind("ErrDialNil") - // ErrBothDialsFilled is used to indicate that Dial and DialAddr - // cannot both be specified in the configuration. - ErrBothDialsFilled = ErrorKind("ErrBothDialsFilled") - // ErrNotFound indicates a specified connection ID or address is unknown to // the connection manager. ErrNotFound = ErrorKind("ErrNotFound") diff --git a/internal/connmgr/error_test.go b/internal/connmgr/error_test.go index f1f3bcf05..1e177b97e 100644 --- a/internal/connmgr/error_test.go +++ b/internal/connmgr/error_test.go @@ -17,7 +17,6 @@ func TestErrorKindStringer(t *testing.T) { want string }{ {ErrDialNil, "ErrDialNil"}, - {ErrBothDialsFilled, "ErrBothDialsFilled"}, {ErrNotFound, "ErrNotFound"}, {ErrTorInvalidAddressResponse, "ErrTorInvalidAddressResponse"}, {ErrTorInvalidProxyResponse, "ErrTorInvalidProxyResponse"}, @@ -91,35 +90,35 @@ func TestErrorKindIsAs(t *testing.T) { wantMatch: true, wantAs: ErrDialNil, }, { - name: "ErrBothDialsFilled != ErrDialNil", - err: ErrBothDialsFilled, + name: "ErrNotFound != ErrDialNil", + err: ErrNotFound, target: ErrDialNil, wantMatch: false, - wantAs: ErrBothDialsFilled, + wantAs: ErrNotFound, }, { - name: "Error.ErrBothDialsFilled != ErrDialNil", - err: MakeError(ErrBothDialsFilled, ""), + name: "Error.ErrNotFound != ErrDialNil", + err: MakeError(ErrNotFound, ""), target: ErrDialNil, wantMatch: false, - wantAs: ErrBothDialsFilled, + wantAs: ErrNotFound, }, { - name: "ErrBothDialsFilled != Error.ErrDialNil", - err: ErrBothDialsFilled, + name: "ErrNotFound != Error.ErrDialNil", + err: ErrNotFound, target: MakeError(ErrDialNil, ""), wantMatch: false, - wantAs: ErrBothDialsFilled, + wantAs: ErrNotFound, }, { - name: "Error.ErrBothDialsFilled != Error.ErrDialNil", - err: MakeError(ErrBothDialsFilled, ""), + name: "Error.ErrNotFound != Error.ErrDialNil", + err: MakeError(ErrNotFound, ""), target: MakeError(ErrDialNil, ""), wantMatch: false, - wantAs: ErrBothDialsFilled, + wantAs: ErrNotFound, }, { - name: "Error.ErrBothDialsFilled != io.EOF", - err: MakeError(ErrBothDialsFilled, ""), + name: "Error.ErrNotFound != io.EOF", + err: MakeError(ErrNotFound, ""), target: io.EOF, wantMatch: false, - wantAs: ErrBothDialsFilled, + wantAs: ErrNotFound, }} for _, test := range tests {