From 99c241e6f4f2a49de5255f7b0ddb40f6bcac32a5 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 23 Apr 2026 21:58:45 +0000 Subject: [PATCH] memcache: add pipelining, modernize some of the Go Also: * add synctest tests, which requires newer Go, so build tag those out. We need to bump go.mod to get working synctest, so: * ... also add a CI check that we keep building with Go 1.18 for now, not using Go language/API changes newer than Go 1.18. * few more config knobs Fixes #160 --- .github/workflows/test.yml | 46 ++- go.mod | 2 +- memcache/fakenet_test.go | 298 +++++++++++++++ memcache/memcache.go | 632 ++++++++++++++----------------- memcache/memcache_test.go | 73 ++-- memcache/pipeline.go | 642 ++++++++++++++++++++++++++++++++ memcache/pipeline_bench_test.go | 85 +++++ memcache/pipeline_test.go | 490 ++++++++++++++++++++++++ memcache/selector.go | 12 +- 9 files changed, 1897 insertions(+), 383 deletions(-) create mode 100644 memcache/fakenet_test.go create mode 100644 memcache/pipeline.go create mode 100644 memcache/pipeline_bench_test.go create mode 100644 memcache/pipeline_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9288fefc..1ef8c99e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,22 +4,42 @@ on: pull_request: jobs: - build: - + # Verify the library still builds with Go 1.18. GOTOOLCHAIN=local prevents + # auto-upgrade; Go 1.18 itself predates the toolchain directive anyway, so + # it ignores go.mod's "go 1.26" line and just tries to compile. This job + # catches accidental use of post-1.18 language/stdlib features. + # Go 1.18 is super old, but is new enough to have generics. + # TODO(bradfitz): decide on an actual support Go version policy. It'd be nice + # to depend on newer Go. + build-go118: + name: Test on old Go 1.18 runs-on: ubuntu-latest - strategy: - matrix: - go-version: [ '1.18', '1.21' ] + steps: + - name: install memcached + run: | + sudo apt-get update + sudo apt-get install -y memcached + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.18' + - name: go test + env: + GOTOOLCHAIN: local + run: go test ./... + # Full test suite on modern Go, including the synctest-gated pipeline tests. + test: + name: Run Go tests + runs-on: ubuntu-latest steps: - name: install memcached run: | - sudo apt update - sudo apt install memcached - - uses: actions/checkout@v3 - - name: Setup Go ${{ matrix.go-version }} - uses: actions/setup-go@v4 + sudo apt-get update + sudo apt-get install -y memcached + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} - - name: Test - run: go test -v ./... + go-version: '1.26' + - name: go test + run: go test -v -race ./... diff --git a/go.mod b/go.mod index 1bf1615a..7d81911f 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/bradfitz/gomemcache -go 1.18 +go 1.26 diff --git a/memcache/fakenet_test.go b/memcache/fakenet_test.go new file mode 100644 index 00000000..32f72232 --- /dev/null +++ b/memcache/fakenet_test.go @@ -0,0 +1,298 @@ +/* +Copyright 2026 The gomemcache AUTHORS + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +*/ + +package memcache + +import ( + "errors" + "io" + "net" + "sync" + "time" +) + +// This file provides a minimal in-memory net.Conn pair with per-direction +// synthetic delay. +// +// Under testing/synctest, time.Sleep and time.AfterFunc use fake time, so +// the "delay" is deterministic and instant in wall-clock terms. + +type fakeAddr string + +func (a fakeAddr) Network() string { return "fakenet" } +func (a fakeAddr) String() string { return string(a) } + +// newFakePipe returns two connected fakeConns. oneWayDelay models link +// latency (the time between Write on one side and the bytes becoming +// readable on the other). A Write records its target arrival time when it +// queues the chunk; the per-direction worker then sleeps to each arrival +// time in order. Under back-to-back Writes, chunks arrive back-to-back, +// emulating an in-flight pipe rather than a queue of serialized latencies. +func newFakePipe(oneWayDelay time.Duration, nameA, nameB string) (a, b *fakeConn) { + a = &fakeConn{ + local: fakeAddr(nameA), + remote: fakeAddr(nameB), + delay: oneWayDelay, + done: make(chan struct{}), + sendCh: make(chan fakeChunk, 256), + } + b = &fakeConn{ + local: fakeAddr(nameB), + remote: fakeAddr(nameA), + delay: oneWayDelay, + done: make(chan struct{}), + sendCh: make(chan fakeChunk, 256), + } + aToB := make(chan []byte, 256) + bToA := make(chan []byte, 256) + a.outCh = aToB + b.inCh = aToB + b.outCh = bToA + a.inCh = bToA + a.peer = b + b.peer = a + go a.deliverLoop() + go b.deliverLoop() + return +} + +type fakeChunk struct { + arriveAt time.Time + data []byte +} + +type fakeConn struct { + local fakeAddr + remote fakeAddr + delay time.Duration + + peer *fakeConn + + sendCh chan fakeChunk // Write pushes here with arrival timestamp + outCh chan []byte // deliverLoop writes here when arrival time passes (peer reads) + inCh chan []byte // chunks arriving from peer (already delayed) + + closeOnce sync.Once + done chan struct{} + + mu sync.Mutex + readBuf []byte + readDeadline time.Time + writeDeadline time.Time +} + +// deliverLoop pops chunks in FIFO order and sleeps until each chunk's +// arrival time. Arrival times are monotonic because Write stamps them at +// submission time on a single-writer send queue. +func (c *fakeConn) deliverLoop() { + for { + select { + case ch, ok := <-c.sendCh: + if !ok { + return + } + if d := time.Until(ch.arriveAt); d > 0 { + timer := time.NewTimer(d) + select { + case <-timer.C: + case <-c.done: + timer.Stop() + return + case <-c.peer.done: + timer.Stop() + return + } + } + select { + case c.outCh <- ch.data: + case <-c.done: + return + case <-c.peer.done: + return + } + case <-c.done: + return + case <-c.peer.done: + return + } + } +} + +func (c *fakeConn) LocalAddr() net.Addr { return c.local } +func (c *fakeConn) RemoteAddr() net.Addr { return c.remote } + +func (c *fakeConn) SetDeadline(t time.Time) error { + c.SetReadDeadline(t) + c.SetWriteDeadline(t) + return nil +} + +func (c *fakeConn) SetReadDeadline(t time.Time) error { + c.mu.Lock() + c.readDeadline = t + c.mu.Unlock() + return nil +} + +func (c *fakeConn) SetWriteDeadline(t time.Time) error { + c.mu.Lock() + c.writeDeadline = t + c.mu.Unlock() + return nil +} + +func (c *fakeConn) Close() error { + c.closeOnce.Do(func() { + close(c.done) + }) + return nil +} + +// Read returns data from the per-conn read buffer, refilling it from inCh +// when empty. Respects the read deadline (via time.Timer, fake-time-friendly). +func (c *fakeConn) Read(p []byte) (int, error) { + c.mu.Lock() + if len(c.readBuf) > 0 { + n := copy(p, c.readBuf) + c.readBuf = c.readBuf[n:] + c.mu.Unlock() + return n, nil + } + dl := c.readDeadline + c.mu.Unlock() + + var dlCh <-chan time.Time + if !dl.IsZero() { + timer := time.NewTimer(time.Until(dl)) + defer timer.Stop() + dlCh = timer.C + } + + select { + case chunk, ok := <-c.inCh: + if !ok { + return 0, io.EOF + } + c.mu.Lock() + c.readBuf = append(c.readBuf, chunk...) + n := copy(p, c.readBuf) + c.readBuf = c.readBuf[n:] + c.mu.Unlock() + return n, nil + case <-dlCh: + return 0, errFakeTimeout + case <-c.done: + return 0, io.EOF + case <-c.peer.done: + return 0, io.EOF + } +} + +// Write copies the bytes, stamps the link arrival time, and queues for +// delivery. Chunks arrive at the peer at (send_time + delay); back-to-back +// Writes result in back-to-back arrivals, not serialized ones. +func (c *fakeConn) Write(p []byte) (int, error) { + select { + case <-c.done: + return 0, io.ErrClosedPipe + case <-c.peer.done: + return 0, io.ErrClosedPipe + default: + } + b := make([]byte, len(p)) + copy(b, p) + ch := fakeChunk{arriveAt: time.Now().Add(c.delay), data: b} + select { + case c.sendCh <- ch: + case <-c.done: + return 0, io.ErrClosedPipe + case <-c.peer.done: + return 0, io.ErrClosedPipe + } + return len(p), nil +} + +var errFakeTimeout = &fakeTimeoutError{} + +type fakeTimeoutError struct{} + +func (*fakeTimeoutError) Error() string { return "fakenet: i/o timeout" } +func (*fakeTimeoutError) Timeout() bool { return true } +func (*fakeTimeoutError) Temporary() bool { return true } + +// fakeListener is a net.Listener backed by fakeConn pairs. Each dial creates +// one pair; one end is returned to the dialer, the other end is enqueued for +// the next Accept. delay is the per-direction one-way delay. +type fakeListener struct { + addr fakeAddr + delay time.Duration + accepts chan net.Conn + closed chan struct{} + once sync.Once +} + +func newFakeListener(addr string, oneWayDelay time.Duration) *fakeListener { + return &fakeListener{ + addr: fakeAddr(addr), + delay: oneWayDelay, + accepts: make(chan net.Conn, 16), + closed: make(chan struct{}), + } +} + +// dial creates a new conn pair and hands the server side to the next Accept. +// It returns the client side. +func (l *fakeListener) dial() (net.Conn, error) { + select { + case <-l.closed: + return nil, errors.New("fakenet: listener closed") + default: + } + client, server := newFakePipe(l.delay, "client-"+string(l.addr), string(l.addr)) + select { + case l.accepts <- server: + return client, nil + case <-l.closed: + client.Close() + server.Close() + return nil, errors.New("fakenet: listener closed") + } +} + +func (l *fakeListener) Accept() (net.Conn, error) { + select { + case c := <-l.accepts: + return c, nil + case <-l.closed: + return nil, errors.New("fakenet: listener closed") + } +} + +func (l *fakeListener) Close() error { + l.once.Do(func() { close(l.closed) }) + return nil +} + +func (l *fakeListener) Addr() net.Addr { return l.addr } + +// singleServerSelector is a ServerSelector that always picks the same addr. +type singleServerSelector struct{ addr net.Addr } + +func (s singleServerSelector) PickServer(key string) (net.Addr, error) { return s.addr, nil } +func (s singleServerSelector) Each(f func(net.Addr) error) error { return f(s.addr) } + +// Compile-time assertions. +var ( + _ net.Conn = (*fakeConn)(nil) + _ net.Addr = fakeAddr("") + _ error = (*fakeTimeoutError)(nil) + _ net.Listener = (*fakeListener)(nil) + _ ServerSelector = singleServerSelector{} +) diff --git a/memcache/memcache.go b/memcache/memcache.go index 6f48caac..805af943 100644 --- a/memcache/memcache.go +++ b/memcache/memcache.go @@ -26,6 +26,7 @@ import ( "io" "math" "net" + "net/netip" "strconv" "strings" "sync" @@ -68,6 +69,9 @@ const ( // DefaultTimeout is the default socket read/write timeout. DefaultTimeout = 500 * time.Millisecond + // DefaultDialTimeout is the default per-attempt connect timeout. + DefaultDialTimeout = 5 * time.Second + // DefaultMaxIdleConns is the default maximum number of idle connections // kept for any single address. DefaultMaxIdleConns = 2 @@ -145,18 +149,55 @@ type Client struct { // If zero, DefaultTimeout is used. Timeout time.Duration - // MaxIdleConns specifies the maximum number of idle connections that will - // be maintained per address. If less than one, DefaultMaxIdleConns will be - // used. + // DialTimeout is the per-attempt timeout for establishing a new + // connection to a server. It is intentionally separate from Timeout + // because the connection that a dial produces may be reused by many + // later operations whose individual contexts have nothing to do with + // how long this dial is allowed to take. // - // Consider your expected traffic rates and latency carefully. This should - // be set to a number higher than your peak parallel requests. + // If zero, DefaultDialTimeout is used. + DialTimeout time.Duration + + // MaxIdleConns specifies the maximum number of idle connections kept + // warm per address after a period of inactivity. A background reaper + // wakes after ~5s of quiescence on a backend and closes idle conns + // beyond this cap. If less than one, DefaultMaxIdleConns is used. MaxIdleConns int + // MaxPipelineDepth is the maximum number of requests allowed in flight + // on a single connection at once. Once every connection to a backend + // is at this cap and no more can be dialed (see MaxConns, MaxDials), + // subsequent operations block in their submit call until a slot frees. + // + // A value of 1 disables pipelining (each connection serves one request + // at a time, matching the pre-pipelining behavior). + // + // A value of 0 selects an implementation-chosen default (currently 8; + // subject to change). + MaxPipelineDepth int + + // MaxConns is the maximum number of open connections per server address + // (in-use + idle). If zero, DefaultMaxConns is used. A negative value + // means unlimited. + MaxConns int + + // MaxDials is the maximum number of concurrent dial attempts per server + // address. If zero, DefaultMaxDials is used. A negative value means + // unlimited. + MaxDials int + selector ServerSelector mu sync.Mutex - freeconn map[string][]*conn + backends map[backendKey]*backend +} + +// backendKey identifies a backend without allocating. For TCP addresses +// (the common case) it's a pure netip.AddrPort value; for unix sockets +// or unknown networks, path holds the address's String() value. +type backendKey struct { + ap netip.AddrPort + path string } // Item is an item to be got or stored in a memcached server. @@ -183,64 +224,6 @@ type Item struct { CasID uint64 } -// conn is a connection to a server. -type conn struct { - nc net.Conn - rw *bufio.ReadWriter - addr net.Addr - c *Client -} - -// release returns this connection back to the client's free pool -func (cn *conn) release() { - cn.c.putFreeConn(cn.addr, cn) -} - -func (cn *conn) extendDeadline() { - cn.nc.SetDeadline(time.Now().Add(cn.c.netTimeout())) -} - -// condRelease releases this connection if the error pointed to by err -// is nil (not an error) or is only a protocol level error (e.g. a -// cache miss). The purpose is to not recycle TCP connections that -// are bad. -func (cn *conn) condRelease(err *error) { - if *err == nil || resumableError(*err) { - cn.release() - } else { - cn.nc.Close() - } -} - -func (c *Client) putFreeConn(addr net.Addr, cn *conn) { - c.mu.Lock() - defer c.mu.Unlock() - if c.freeconn == nil { - c.freeconn = make(map[string][]*conn) - } - freelist := c.freeconn[addr.String()] - if len(freelist) >= c.maxIdleConns() { - cn.nc.Close() - return - } - c.freeconn[addr.String()] = append(freelist, cn) -} - -func (c *Client) getFreeConn(addr net.Addr) (cn *conn, ok bool) { - c.mu.Lock() - defer c.mu.Unlock() - if c.freeconn == nil { - return nil, false - } - freelist, ok := c.freeconn[addr.String()] - if !ok || len(freelist) == 0 { - return nil, false - } - cn = freelist[len(freelist)-1] - c.freeconn[addr.String()] = freelist[:len(freelist)-1] - return cn, true -} - func (c *Client) netTimeout() time.Duration { if c.Timeout != 0 { return c.Timeout @@ -248,6 +231,13 @@ func (c *Client) netTimeout() time.Duration { return DefaultTimeout } +func (c *Client) dialTimeout() time.Duration { + if c.DialTimeout != 0 { + return c.DialTimeout + } + return DefaultDialTimeout +} + func (c *Client) maxIdleConns() int { if c.MaxIdleConns > 0 { return c.MaxIdleConns @@ -267,13 +257,13 @@ func (cte *ConnectTimeoutError) Error() string { } func (c *Client) dial(addr net.Addr) (net.Conn, error) { - ctx, cancel := context.WithTimeout(context.Background(), c.netTimeout()) + ctx, cancel := context.WithTimeout(context.Background(), c.dialTimeout()) defer cancel() dialerContext := c.DialContext if dialerContext == nil { dialer := net.Dialer{ - Timeout: c.netTimeout(), + Timeout: c.dialTimeout(), } dialerContext = dialer.DialContext } @@ -290,40 +280,72 @@ func (c *Client) dial(addr net.Addr) (net.Conn, error) { return nil, err } -func (c *Client) getConn(addr net.Addr) (*conn, error) { - cn, ok := c.getFreeConn(addr) - if ok { - cn.extendDeadline() - return cn, nil - } - nc, err := c.dial(addr) - if err != nil { - return nil, err +// getBackend returns (creating if needed) the per-address backend. +func (c *Client) getBackend(addr net.Addr) *backend { + key := addrBackendKey(addr) + c.mu.Lock() + defer c.mu.Unlock() + if c.backends == nil { + c.backends = make(map[backendKey]*backend) + } + if b, ok := c.backends[key]; ok { + return b + } + b := newBackend(c, addr) + c.backends[key] = b + return b +} + +// addrBackendKey derives a comparable, alloc-free key from a net.Addr. +// TCP-shaped addresses (the common case produced by ServerList) resolve +// to a pure netip.AddrPort with no string conversion. Unix sockets and +// unknown shapes fall back to the address's String() representation. +func addrBackendKey(a net.Addr) backendKey { + switch v := a.(type) { + case *net.TCPAddr: + return backendKey{ap: v.AddrPort()} + case *staticAddr: + if v.ap.IsValid() { + return backendKey{ap: v.ap} + } + return backendKey{path: v.str} + case *net.UnixAddr: + return backendKey{path: v.Name} } - cn = &conn{ - nc: nc, - addr: addr, - rw: bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)), - c: c, + if a.Network() != "unix" { + if ap, err := netip.ParseAddrPort(a.String()); err == nil { + return backendKey{ap: ap} + } } - cn.extendDeadline() - return cn, nil + return backendKey{path: a.String()} } -func (c *Client) onItem(item *Item, fn func(*Client, *bufio.ReadWriter, *Item) error) error { - addr, err := c.selector.PickServer(item.Key) - if err != nil { - return err +// ctxBG is a package-wide context.Background() to avoid the small per-call +// allocation cost of context.Background() in the non-context command paths. +var ctxBG = context.Background() + +// runCmd submits one operation to addr and waits for its completion. +// verb is used to classify idempotency for retry on conn failure. +// Callers without a context should pass ctxBG. +func (c *Client) runCmd(ctx context.Context, addr net.Addr, verb string, + write func(*bufio.Writer) error, + read func(*bufio.Reader) error) error { + req := newPipeReq(verb, write, read) + return c.getBackend(addr).submit(ctx, req) +} + +// runCmdKey picks the server for key (after a legality check) and runs the op. +func (c *Client) runCmdKey(ctx context.Context, key, verb string, + write func(*bufio.Writer) error, + read func(*bufio.Reader) error) error { + if !legalKey(key) { + return ErrMalformedKey } - cn, err := c.getConn(addr) + addr, err := c.selector.PickServer(key) if err != nil { return err } - defer cn.condRelease(&err) - if err = fn(c, cn.rw, item); err != nil { - return err - } - return nil + return c.runCmd(ctx, addr, verb, write, read) } func (c *Client) FlushAll() error { @@ -349,7 +371,7 @@ func (c *Client) Get(key string) (item *Item, err error) { // The key must be at most 250 bytes in length. func (c *Client) Touch(key string, seconds int32) (err error) { return c.withKeyAddr(key, func(addr net.Addr) error { - return c.touchFromAddr(addr, []string{key}, seconds) + return c.touchFromAddr(addr, key, seconds) }) } @@ -364,111 +386,76 @@ func (c *Client) withKeyAddr(key string, fn func(net.Addr) error) (err error) { return fn(addr) } -func (c *Client) withAddrRw(addr net.Addr, fn func(*conn) error) (err error) { - cn, err := c.getConn(addr) - if err != nil { - return err - } - defer cn.condRelease(&err) - return fn(cn) -} - -func (c *Client) withKeyRw(key string, fn func(*conn) error) error { - return c.withKeyAddr(key, func(addr net.Addr) error { - return c.withAddrRw(addr, fn) - }) -} - func (c *Client) getFromAddr(addr net.Addr, keys []string, cb func(*Item)) error { - return c.withAddrRw(addr, func(conn *conn) error { - rw := conn.rw - if _, err := fmt.Fprintf(rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { - return err - } - if err := rw.Flush(); err != nil { - return err - } - if err := parseGetResponse(rw.Reader, conn, cb); err != nil { + return c.runCmd(ctxBG, addr, "gets", + func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "gets %s\r\n", strings.Join(keys, " ")) return err - } - return nil - }) + }, + func(r *bufio.Reader) error { + return parseGetResponse(r, cb) + }) } -// flushAllFromAddr send the flush_all command to the given addr +// flushAllFromAddr sends the flush_all command to the given addr. func (c *Client) flushAllFromAddr(addr net.Addr) error { - return c.withAddrRw(addr, func(conn *conn) error { - rw := conn.rw - if _, err := fmt.Fprintf(rw, "flush_all\r\n"); err != nil { - return err - } - if err := rw.Flush(); err != nil { - return err - } - line, err := rw.ReadSlice('\n') - if err != nil { + return c.runCmd(ctxBG, addr, "flush_all", + func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "flush_all\r\n") return err - } - switch { - case bytes.Equal(line, resultOk): - break - default: - return fmt.Errorf("memcache: unexpected response line from flush_all: %q", string(line)) - } - return nil - }) + }, + func(r *bufio.Reader) error { + line, err := r.ReadSlice('\n') + if err != nil { + return err + } + if !bytes.Equal(line, resultOk) { + return fmt.Errorf("memcache: unexpected response line from flush_all: %q", string(line)) + } + return nil + }) } -// ping sends the version command to the given addr +// ping sends the version command to the given addr. func (c *Client) ping(addr net.Addr) error { - return c.withAddrRw(addr, func(conn *conn) error { - rw := conn.rw - if _, err := fmt.Fprintf(rw, "version\r\n"); err != nil { + return c.runCmd(ctxBG, addr, "version", + func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "version\r\n") return err - } - if err := rw.Flush(); err != nil { - return err - } - line, err := rw.ReadSlice('\n') - if err != nil { - return err - } - - switch { - case bytes.HasPrefix(line, versionPrefix): - break - default: - return fmt.Errorf("memcache: unexpected response line from ping: %q", string(line)) - } - return nil - }) -} - -func (c *Client) touchFromAddr(addr net.Addr, keys []string, expiration int32) error { - return c.withAddrRw(addr, func(conn *conn) error { - rw := conn.rw - for _, key := range keys { - if _, err := fmt.Fprintf(rw, "touch %s %d\r\n", key, expiration); err != nil { + }, + func(r *bufio.Reader) error { + line, err := r.ReadSlice('\n') + if err != nil { return err } - if err := rw.Flush(); err != nil { - return err + if !bytes.HasPrefix(line, versionPrefix) { + return fmt.Errorf("memcache: unexpected response line from ping: %q", string(line)) } - line, err := rw.ReadSlice('\n') + return nil + }) +} + +// touchFromAddr sends a single touch for key to addr. +func (c *Client) touchFromAddr(addr net.Addr, key string, expiration int32) error { + return c.runCmd(ctxBG, addr, "touch", + func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "touch %s %d\r\n", key, expiration) + return err + }, + func(r *bufio.Reader) error { + line, err := r.ReadSlice('\n') if err != nil { return err } switch { case bytes.Equal(line, resultTouched): - break + return nil case bytes.Equal(line, resultNotFound): return ErrCacheMiss default: return fmt.Errorf("memcache: unexpected response line from touch: %q", string(line)) } - } - return nil - }) + }) } // GetMulti is a batch version of Get. The returned map from keys to @@ -513,13 +500,13 @@ func (c *Client) GetMulti(keys []string) (map[string]*Item, error) { } // parseGetResponse reads a GET response from r and calls cb for each -// read and allocated Item -func parseGetResponse(r *bufio.Reader, conn *conn, cb func(*Item)) error { +// read and allocated Item. +// +// The conn-level deadline is set once by the pipeline reader when the +// request is dispatched; a huge response may need Client.Timeout to be +// raised accordingly. +func parseGetResponse(r *bufio.Reader, cb func(*Item)) error { for { - // extend deadline before each additional call, otherwise all cumulative - // calls use the same overall deadline - conn.extendDeadline() - line, err := r.ReadSlice('\n') if err != nil { return err @@ -590,7 +577,7 @@ func scanGetResponseLine(line []byte, it *Item) (size int, err error) { return int(size64), nil } -// Similar to strings.Cut in Go 1.18, but sep can only be 1 byte. +// cut is similar to strings.Cut in Go 1.18, but sep can only be 1 byte. func cut(s string, sep byte) (before, after string, found bool) { if i := strings.IndexByte(s, sep); i >= 0 { return s[:i], s[i+1:], true @@ -599,53 +586,23 @@ func cut(s string, sep byte) (before, after string, found bool) { } // Set writes the given item, unconditionally. -func (c *Client) Set(item *Item) error { - return c.onItem(item, (*Client).set) -} - -func (c *Client) set(rw *bufio.ReadWriter, item *Item) error { - return c.populateOne(rw, "set", item) -} +func (c *Client) Set(item *Item) error { return c.populateOne("set", item) } // Add writes the given item, if no value already exists for its // key. ErrNotStored is returned if that condition is not met. -func (c *Client) Add(item *Item) error { - return c.onItem(item, (*Client).add) -} - -func (c *Client) add(rw *bufio.ReadWriter, item *Item) error { - return c.populateOne(rw, "add", item) -} +func (c *Client) Add(item *Item) error { return c.populateOne("add", item) } // Replace writes the given item, but only if the server *does* -// already hold data for this key -func (c *Client) Replace(item *Item) error { - return c.onItem(item, (*Client).replace) -} - -func (c *Client) replace(rw *bufio.ReadWriter, item *Item) error { - return c.populateOne(rw, "replace", item) -} +// already hold data for this key. +func (c *Client) Replace(item *Item) error { return c.populateOne("replace", item) } // Append appends the given item to the existing item, if a value already // exists for its key. ErrNotStored is returned if that condition is not met. -func (c *Client) Append(item *Item) error { - return c.onItem(item, (*Client).append) -} - -func (c *Client) append(rw *bufio.ReadWriter, item *Item) error { - return c.populateOne(rw, "append", item) -} +func (c *Client) Append(item *Item) error { return c.populateOne("append", item) } // Prepend prepends the given item to the existing item, if a value already // exists for its key. ErrNotStored is returned if that condition is not met. -func (c *Client) Prepend(item *Item) error { - return c.onItem(item, (*Client).prepend) -} - -func (c *Client) prepend(rw *bufio.ReadWriter, item *Item) error { - return c.populateOne(rw, "prepend", item) -} +func (c *Client) Prepend(item *Item) error { return c.populateOne("prepend", item) } // CompareAndSwap writes the given item that was previously returned // by Get, if the value was neither modified or evicted between the @@ -654,104 +611,104 @@ func (c *Client) prepend(rw *bufio.ReadWriter, item *Item) error { // is returned if the value was modified in between the // calls. ErrNotStored is returned if the value was evicted in between // the calls. -func (c *Client) CompareAndSwap(item *Item) error { - return c.onItem(item, (*Client).cas) -} - -func (c *Client) cas(rw *bufio.ReadWriter, item *Item) error { - return c.populateOne(rw, "cas", item) -} +func (c *Client) CompareAndSwap(item *Item) error { return c.populateOne("cas", item) } -func (c *Client) populateOne(rw *bufio.ReadWriter, verb string, item *Item) error { +func (c *Client) populateOne(verb string, item *Item) error { if !legalKey(item.Key) { return ErrMalformedKey } - var err error - if verb == "cas" { - _, err = fmt.Fprintf(rw, "%s %s %d %d %d %d\r\n", - verb, item.Key, item.Flags, item.Expiration, len(item.Value), item.CasID) - } else { - _, err = fmt.Fprintf(rw, "%s %s %d %d %d\r\n", - verb, item.Key, item.Flags, item.Expiration, len(item.Value)) - } - if err != nil { - return err - } - if _, err = rw.Write(item.Value); err != nil { - return err - } - if _, err := rw.Write(crlf); err != nil { - return err - } - if err := rw.Flush(); err != nil { - return err - } - line, err := rw.ReadSlice('\n') - if err != nil { - return err - } - switch { - case bytes.Equal(line, resultStored): - return nil - case bytes.Equal(line, resultNotStored): - return ErrNotStored - case bytes.Equal(line, resultExists): - return ErrCASConflict - case bytes.Equal(line, resultNotFound): - return ErrCacheMiss - } - return fmt.Errorf("memcache: unexpected response line from %q: %q", verb, string(line)) -} - -func writeReadLine(rw *bufio.ReadWriter, format string, args ...interface{}) ([]byte, error) { - _, err := fmt.Fprintf(rw, format, args...) - if err != nil { - return nil, err - } - if err := rw.Flush(); err != nil { - return nil, err - } - line, err := rw.ReadSlice('\n') - return line, err -} - -func writeExpectf(rw *bufio.ReadWriter, expect []byte, format string, args ...interface{}) error { - line, err := writeReadLine(rw, format, args...) + addr, err := c.selector.PickServer(item.Key) if err != nil { return err } - switch { - case bytes.Equal(line, resultOK): - return nil - case bytes.Equal(line, expect): - return nil - case bytes.Equal(line, resultNotStored): - return ErrNotStored - case bytes.Equal(line, resultExists): - return ErrCASConflict - case bytes.Equal(line, resultNotFound): - return ErrCacheMiss - } - return fmt.Errorf("memcache: unexpected response line: %q", string(line)) + return c.runCmd(ctxBG, addr, verb, + func(w *bufio.Writer) error { + var hdrErr error + if verb == "cas" { + _, hdrErr = fmt.Fprintf(w, "%s %s %d %d %d %d\r\n", + verb, item.Key, item.Flags, item.Expiration, len(item.Value), item.CasID) + } else { + _, hdrErr = fmt.Fprintf(w, "%s %s %d %d %d\r\n", + verb, item.Key, item.Flags, item.Expiration, len(item.Value)) + } + if hdrErr != nil { + return hdrErr + } + if _, err := w.Write(item.Value); err != nil { + return err + } + _, err := w.Write(crlf) + return err + }, + func(r *bufio.Reader) error { + line, err := r.ReadSlice('\n') + if err != nil { + return err + } + switch { + case bytes.Equal(line, resultStored): + return nil + case bytes.Equal(line, resultNotStored): + return ErrNotStored + case bytes.Equal(line, resultExists): + return ErrCASConflict + case bytes.Equal(line, resultNotFound): + return ErrCacheMiss + } + return fmt.Errorf("memcache: unexpected response line from %q: %q", verb, string(line)) + }) } // Delete deletes the item with the provided key. The error ErrCacheMiss is // returned if the item didn't already exist in the cache. +// +// If a retry occurs on a new connection because the first attempt's +// connection died, a successful original delete followed by a retry will +// see ErrCacheMiss on the retry: safe, but worth noting. func (c *Client) Delete(key string) error { - return c.withKeyRw(key, func(conn *conn) error { - return writeExpectf(conn.rw, resultDeleted, "delete %s\r\n", key) - }) + return c.runCmdKey(ctxBG, key, "delete", + func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "delete %s\r\n", key) + return err + }, + expectOneOf("delete", resultDeleted)) } -// DeleteAll deletes all items in the cache. +// DeleteAll deletes all items from the server picked for the empty key. func (c *Client) DeleteAll() error { - return c.withKeyRw("", func(conn *conn) error { - return writeExpectf(conn.rw, resultDeleted, "flush_all\r\n") - }) + return c.runCmdKey(ctxBG, "", "flush_all", + func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "flush_all\r\n") + return err + }, + expectOneOf("flush_all", resultDeleted)) } -// Get and Touch the item with the provided key. The error ErrCacheMiss is -// returned if the item didn't already exist in the cache. +// expectOneOf returns a response parser that treats OK or expect as success +// and maps the common error lines to their error types. +func expectOneOf(verb string, expect []byte) func(*bufio.Reader) error { + return func(r *bufio.Reader) error { + line, err := r.ReadSlice('\n') + if err != nil { + return err + } + switch { + case bytes.Equal(line, resultOK), + bytes.Equal(line, expect): + return nil + case bytes.Equal(line, resultNotStored): + return ErrNotStored + case bytes.Equal(line, resultExists): + return ErrCASConflict + case bytes.Equal(line, resultNotFound): + return ErrCacheMiss + } + return fmt.Errorf("memcache: unexpected response line from %s: %q", verb, string(line)) + } +} + +// GetAndTouch gets and updates the expiry of the given key. The error +// ErrCacheMiss is returned if the item didn't already exist in the cache. func (c *Client) GetAndTouch(key string, expiration int32) (item *Item, err error) { err = c.withKeyAddr(key, func(addr net.Addr) error { return c.getAndTouchFromAddr(addr, key, expiration, func(it *Item) { item = it }) @@ -763,22 +720,17 @@ func (c *Client) GetAndTouch(key string, expiration int32) (item *Item, err erro } func (c *Client) getAndTouchFromAddr(addr net.Addr, key string, expiration int32, cb func(*Item)) error { - return c.withAddrRw(addr, func(conn *conn) error { - rw := conn.rw - if _, err := fmt.Fprintf(rw, "gat %d %s\r\n", expiration, key); err != nil { - return err - } - if err := rw.Flush(); err != nil { + return c.runCmd(ctxBG, addr, "gat", + func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "gat %d %s\r\n", expiration, key) return err - } - if err := parseGetResponse(rw.Reader, conn, cb); err != nil { - return err - } - return nil - }) + }, + func(r *bufio.Reader) error { + return parseGetResponse(r, cb) + }) } -// Ping checks all instances if they are alive. Returns error if any +// Ping checks all instances if they are alive. It returns an error if any // of them is down. func (c *Client) Ping() error { return c.selector.Each(c.ping) @@ -805,45 +757,39 @@ func (c *Client) Decrement(key string, delta uint64) (newValue uint64, err error func (c *Client) incrDecr(verb, key string, delta uint64) (uint64, error) { var val uint64 - err := c.withKeyRw(key, func(conn *conn) error { - rw := conn.rw - line, err := writeReadLine(rw, "%s %s %d\r\n", verb, key, delta) - if err != nil { + err := c.runCmdKey(ctxBG, key, verb, + func(w *bufio.Writer) error { + _, err := fmt.Fprintf(w, "%s %s %d\r\n", verb, key, delta) return err - } - switch { - case bytes.Equal(line, resultNotFound): - return ErrCacheMiss - case bytes.HasPrefix(line, resultClientErrorPrefix): - errMsg := line[len(resultClientErrorPrefix) : len(line)-2] - return errors.New("memcache: client error: " + string(errMsg)) - } - val, err = strconv.ParseUint(string(line[:len(line)-2]), 10, 64) - if err != nil { + }, + func(r *bufio.Reader) error { + line, err := r.ReadSlice('\n') + if err != nil { + return err + } + switch { + case bytes.Equal(line, resultNotFound): + return ErrCacheMiss + case bytes.HasPrefix(line, resultClientErrorPrefix): + errMsg := line[len(resultClientErrorPrefix) : len(line)-2] + return errors.New("memcache: client error: " + string(errMsg)) + } + val, err = strconv.ParseUint(string(line[:len(line)-2]), 10, 64) return err - } - return nil - }) + }) return val, err } // Close closes any open connections. // -// It returns the first error encountered closing connections, but always -// closes all connections. -// -// After Close, the Client may still be used. +// It returns nil. After Close, the Client may still be used. func (c *Client) Close() error { c.mu.Lock() - defer c.mu.Unlock() - var ret error - for _, conns := range c.freeconn { - for _, c := range conns { - if err := c.nc.Close(); err != nil && ret == nil { - ret = err - } - } + backends := c.backends + c.backends = nil + c.mu.Unlock() + for _, b := range backends { + b.close() } - c.freeconn = nil - return ret + return nil } diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index a0fa746d..fbf3dc78 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -18,14 +18,12 @@ limitations under the License. package memcache import ( - "bufio" "bytes" "context" "crypto/tls" "flag" "fmt" "io" - "io/ioutil" "net" "os" "os/exec" @@ -395,33 +393,60 @@ func testTouchWithClient(t *testing.T, c *Client) { } } -func BenchmarkOnItem(b *testing.B) { - fakeServer, err := net.Listen("tcp", "localhost:0") +// BenchmarkSet measures the overhead of the command path against a local +// in-process testServer, exercising the full encode/pipeline/decode loop. +func BenchmarkSet(b *testing.B) { + ln, err := net.Listen("tcp", "localhost:0") if err != nil { - b.Fatal("Could not open fake server: ", err) - } - defer fakeServer.Close() - go func() { - for { - if c, err := fakeServer.Accept(); err == nil { - go func() { io.Copy(ioutil.Discard, c) }() - } else { - return - } - } - }() - - addr := fakeServer.Addr() - c := New(addr.String()) - if _, err := c.getConn(addr); err != nil { - b.Fatal("failed to initialize connection to fake server") + b.Fatal("Could not open listener: ", err) } + defer ln.Close() + srv := &testServer{} + go srv.Serve(ln) + + c := New(ln.Addr().String()) + defer c.Close() - item := Item{Key: "foo"} - dummyFn := func(_ *Client, _ *bufio.ReadWriter, _ *Item) error { return nil } + item := &Item{Key: "foo", Value: []byte("bar")} b.ResetTimer() for i := 0; i < b.N; i++ { - c.onItem(&item, dummyFn) + if err := c.Set(item); err != nil { + b.Fatal(err) + } + } +} + +// TestAddrBackendKeyAllocs guards the hot-path map-key extraction against +// regressions that introduce allocations (e.g. re-adding addr.String() for a +// case handled by a type switch). +func TestAddrBackendKeyAllocs(t *testing.T) { + tcp, err := net.ResolveTCPAddr("tcp", "127.0.0.1:11211") + if err != nil { + t.Fatal(err) + } + unix, err := net.ResolveUnixAddr("unix", "/tmp/gomemcache.sock") + if err != nil { + t.Fatal(err) + } + static := newStaticAddr(tcp) // what ServerList produces + + cases := []struct { + name string + addr net.Addr + }{ + {"TCPAddr", tcp}, + {"staticAddr", static}, + {"UnixAddr", unix}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := testing.AllocsPerRun(1000, func() { + _ = addrBackendKey(tc.addr) + }) + if got != 0 { + t.Errorf("addrBackendKey(%T) allocs = %v; want 0", tc.addr, got) + } + }) } } diff --git a/memcache/pipeline.go b/memcache/pipeline.go new file mode 100644 index 00000000..5b293f45 --- /dev/null +++ b/memcache/pipeline.go @@ -0,0 +1,642 @@ +/* +Copyright 2026 The gomemcache AUTHORS + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +*/ + +package memcache + +import ( + "bufio" + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + // DefaultMaxConns is the default value for Client.MaxConns. + DefaultMaxConns = 100 + + // DefaultMaxDials is the default value for Client.MaxDials. + DefaultMaxDials = 10 +) + +// defaultMaxPipelineDepth is the per-conn pipeline depth used when +// Client.MaxPipelineDepth is zero. The scheduler caps each conn's in-flight +// count at the effective MaxPipelineDepth; once a conn reaches the cap it +// prefers to grow (dial a new conn) subject to MaxConns and MaxDials. +const defaultMaxPipelineDepth = 8 + +// maxAttempts caps retries of idempotent requests across dead conns. +const maxAttempts = 2 + +var errBackendClosed = errors.New("memcache: backend closed") + +// set is a tiny generic set helper used internally. +type set[T comparable] map[T]struct{} + +func (s set[T]) add(v T) { s[v] = struct{}{} } +func (s set[T]) remove(v T) { delete(s, v) } + +// idempotentVerb reports whether a memcached ASCII verb is safe to retry +// after a connection failure of undetermined phase. +func idempotentVerb(verb string) bool { + switch verb { + case "get", "gets", "gat", "gats", + "set", "delete", "touch", "flush_all", "version": + return true + } + // add, replace, cas, append, prepend, incr, decr are stateful. + return false +} + +// pipeReq is one pipelined operation. +// +// write is invoked by the conn's writer goroutine when this req is at the +// head of the write queue; it formats the request directly onto the +// bufio.Writer. read is invoked by the reader goroutine when the +// corresponding response is at the head of the response FIFO. +type pipeReq struct { + write func(w *bufio.Writer) error + read func(r *bufio.Reader) error + verb string // for idempotent classification + idempotent bool + attempts int32 // scheduler-only + cancelled int32 // atomic; 0 or 1 + done chan error +} + +func (r *pipeReq) isCancelled() bool { return atomic.LoadInt32(&r.cancelled) != 0 } +func (r *pipeReq) markCancelled() { atomic.StoreInt32(&r.cancelled, 1) } + +// newPipeReq constructs a pipeReq. +func newPipeReq(verb string, write func(w *bufio.Writer) error, read func(r *bufio.Reader) error) *pipeReq { + return &pipeReq{ + write: write, + read: read, + verb: verb, + idempotent: idempotentVerb(verb), + done: make(chan error, 1), + } +} + +// dialResult carries the outcome of a single dial attempt to the scheduler. +type dialResult struct { + cn *pipeConn + err error +} + +// backend is the per-address connection manager. +// +// All mutable state (conns, queue, dialing) is owned by the scheduler +// goroutine run(). Callers communicate only through channels, so blocking +// operations compose with ctx.Done() for future context-aware APIs. +type backend struct { + c *Client + addr net.Addr + + incoming chan *pipeReq + dialDone chan dialResult + connLost chan *pipeConn + connIdle chan struct{} // non-blocking; kicks scheduler to drain queue after an op completes + reapIdle chan struct{} // non-blocking; idleWatch asks scheduler to reap idle conns + quit chan struct{} // closed on shutdown + runDone chan struct{} // closed when run() returns + + // Idle-reaper state (see idleWatch). + lastUsedUnixNano int64 // atomic; updated on every submit + idleTimerActive int32 // atomic; 0 = no idleWatch goroutine, 1 = running + + // Scheduler-owned state (only touched inside run()). + conns set[*pipeConn] + queue []*pipeReq + dialing int +} + +func newBackend(c *Client, addr net.Addr) *backend { + b := &backend{ + c: c, + addr: addr, + incoming: make(chan *pipeReq), + dialDone: make(chan dialResult, 16), + connLost: make(chan *pipeConn, 16), + connIdle: make(chan struct{}, 1), + reapIdle: make(chan struct{}, 1), + quit: make(chan struct{}), + runDone: make(chan struct{}), + conns: set[*pipeConn]{}, + } + go b.run() + return b +} + +func (b *backend) maxConns() int { + m := b.c.MaxConns + switch { + case m < 0: + return -1 + case m == 0: + return DefaultMaxConns + default: + return m + } +} + +func (b *backend) maxDials() int { + m := b.c.MaxDials + switch { + case m < 0: + return -1 + case m == 0: + return DefaultMaxDials + default: + return m + } +} + +// maxDepth returns the effective per-conn pipeline depth cap. +func (b *backend) maxDepth() int32 { + m := b.c.MaxPipelineDepth + if m <= 0 { + return defaultMaxPipelineDepth + } + return int32(m) +} + +// submit sends req to the backend, blocks until done or ctx is cancelled. +// ctx is accepted as context.Background() today; a public ctx-aware API can +// plumb real contexts in later without touching scheduler logic. +func (b *backend) submit(ctx context.Context, req *pipeReq) error { + atomic.StoreInt64(&b.lastUsedUnixNano, time.Now().UnixNano()) + if atomic.CompareAndSwapInt32(&b.idleTimerActive, 0, 1) { + go b.idleWatch() + } + select { + case b.incoming <- req: + case <-ctx.Done(): + return ctx.Err() + case <-b.quit: + return errBackendClosed + } + select { + case err := <-req.done: + return err + case <-ctx.Done(): + req.markCancelled() + return ctx.Err() + } +} + +// idleReapIntervalNanos is how often idleWatch checks for inactivity, in +// nanoseconds. After one interval with no submits, surplus idle conns are +// reaped (see doReap). Accessed atomically so tests can shorten it without +// racing the idleWatch goroutines that read it. +var idleReapIntervalNanos int64 = int64(5 * time.Second) + +func idleReapInterval() time.Duration { + return time.Duration(atomic.LoadInt64(&idleReapIntervalNanos)) +} + +// idleWatch runs once per period-of-activity on a backend. It ticks every +// idleReapInterval; if lastUsedUnixNano hasn't advanced since the previous +// tick, it signals the scheduler to reap excess idle conns, flips +// idleTimerActive back to 0, and exits. The next submit will CAS it back +// to 1 and spawn a fresh idleWatch. +func (b *backend) idleWatch() { + // Seed observed to the value submit already stored before spawning us. + observed := atomic.LoadInt64(&b.lastUsedUnixNano) + for { + select { + case <-time.After(idleReapInterval()): + case <-b.quit: + atomic.StoreInt32(&b.idleTimerActive, 0) + return + } + cur := atomic.LoadInt64(&b.lastUsedUnixNano) + if cur == observed { + select { + case b.reapIdle <- struct{}{}: + default: + } + atomic.StoreInt32(&b.idleTimerActive, 0) + return + } + observed = cur + } +} + +// close stops the scheduler and all its conns. It is safe to call multiple +// times. +func (b *backend) close() { + select { + case <-b.quit: + return + default: + } + close(b.quit) + <-b.runDone +} + +func (b *backend) run() { + defer close(b.runDone) + for { + select { + case req := <-b.incoming: + if req.isCancelled() { + continue + } + b.tryDispatch(req) + case r := <-b.dialDone: + b.dialing-- + if r.err == nil { + b.conns.add(r.cn) + go r.cn.writer() + go r.cn.reader() + } + // Whether success or failure, try to drain the queue. On + // failure, tryDispatch will re-queue and may restart a dial. + b.drainQueueLocked(r.err) + case cn := <-b.connLost: + b.handleConnLost(cn) + case <-b.connIdle: + b.drainQueueLocked(nil) + case <-b.reapIdle: + b.doReap() + case <-b.quit: + b.shutdown() + return + } + } +} + +// canGrow reports whether the scheduler may start another dial. +func (b *backend) canGrow() bool { + mc := b.maxConns() + md := b.maxDials() + if mc >= 0 && len(b.conns)+b.dialing >= mc { + return false + } + if md >= 0 && b.dialing >= md { + return false + } + return true +} + +// tryDispatch assigns req to a conn, queues it, or starts a dial. +// It is called only from run(). +func (b *backend) tryDispatch(req *pipeReq) { + // Pick conn with smallest inFlight. + var best *pipeConn + minIF := int32(-1) + for cn := range b.conns { + v := atomic.LoadInt32(&cn.inFlight) + if minIF == -1 || v < minIF { + minIF = v + best = cn + } + } + + maxDepth := b.maxDepth() + canGrow := b.canGrow() + + // Dispatch only if some conn has room below the depth cap. With + // MaxPipelineDepth=1 this means fully-idle conns only, i.e. strict + // one-op-per-conn serialization (no pipelining). + if best != nil && minIF < maxDepth { + atomic.AddInt32(&best.inFlight, 1) + select { + case best.reqCh <- req: + return + default: + atomic.AddInt32(&best.inFlight, -1) + } + } + + b.queue = append(b.queue, req) + + // Dial only when we're actually behind: if existing dials-in-flight can + // absorb the queue, don't pile on more. Each future conn is expected to + // take ~maxDepth reqs before we'd want another. Without this, bursts + // of 50 concurrent submits trigger ~50 dials instead of ~50/maxDepth. + if canGrow && len(b.queue) > b.dialing*int(maxDepth) { + b.startDial() + } +} + +// startDial launches a dial goroutine. It must be called from run(). +func (b *backend) startDial() { + b.dialing++ + go func() { + nc, err := b.c.dial(b.addr) + var res dialResult + if err != nil { + res = dialResult{err: err} + } else { + res = dialResult{cn: newPipeConn(b, nc)} + } + select { + case b.dialDone <- res: + case <-b.quit: + if res.cn != nil { + res.cn.nc.Close() + } + } + }() +} + +// drainQueueLocked retries dispatching all queued reqs. If lastDialErr != nil +// and we still can't dispatch and can't grow, fail the queue with that error +// so callers don't hang forever on a dead backend. +func (b *backend) drainQueueLocked(lastDialErr error) { + q := b.queue + b.queue = nil + for _, req := range q { + if req.isCancelled() { + continue + } + b.tryDispatch(req) + } + // If nothing in flight can help and the last dial failed, fail the queue. + if lastDialErr != nil && len(b.conns) == 0 && b.dialing == 0 { + failQueue := b.queue + b.queue = nil + for _, req := range failQueue { + select { + case req.done <- lastDialErr: + default: + } + } + } +} + +// handleConnLost is invoked by the scheduler when a pipeConn has died. +// +// Critical: between the conn's writer/reader exiting and the scheduler +// seeing connLost, the scheduler may have dispatched more requests to +// cn.reqCh. We must drain those here (nobody else will). +func (b *backend) handleConnLost(cn *pipeConn) { + b.conns.remove(cn) + + // Drain any stragglers in reqCh that we dispatched after the conn's + // writer exited. Once we remove cn from b.conns above, the scheduler + // won't dispatch any new reqs to it, so this drain is bounded. + for { + select { + case req := <-cn.reqCh: + cn.recordReqChAtDeath(req) + default: + goto drained + } + } +drained: + + retryable := cn.harvest() + for _, req := range retryable { + if req.isCancelled() { + continue + } + if req.idempotent && atomic.AddInt32(&req.attempts, 1) <= maxAttempts { + b.queue = append(b.queue, req) + } else { + select { + case req.done <- cn.closeErr: + default: + } + } + } + b.drainQueueLocked(nil) +} + +// doReap closes surplus idle conns, keeping at most MaxIdleConns warm. +// Called from the scheduler goroutine, so it has exclusive access to +// b.conns. A conn counts as "idle" when its atomic inFlight is 0. Busy +// conns (mid-op) are never closed; they'll be eligible at the next reap. +func (b *backend) doReap() { + keep := b.c.maxIdleConns() + // Count idle conns first so we only close surplus. + idleCount := 0 + for cn := range b.conns { + if atomic.LoadInt32(&cn.inFlight) == 0 { + idleCount++ + } + } + surplus := idleCount - keep + if surplus <= 0 { + return + } + for cn := range b.conns { + if surplus <= 0 { + break + } + if atomic.LoadInt32(&cn.inFlight) == 0 { + // The error value here is immaterial; the conn is being discarded. + cn.close(errBackendClosed) + surplus-- + } + } +} + +// shutdown is called when quit is signaled. It closes each conn and fails +// any queued reqs. Late dialDone/connLost sends won't block because those +// channels are buffered and any senders fall through to `<-b.quit` selects. +func (b *backend) shutdown() { + for cn := range b.conns { + cn.close(errBackendClosed) + } + for _, req := range b.queue { + select { + case req.done <- errBackendClosed: + default: + } + } + b.queue = nil +} + +// pipeConn is a single pipelined connection to a server. +type pipeConn struct { + be *backend + nc net.Conn + br *bufio.Reader + bw *bufio.Writer + + reqCh chan *pipeReq // scheduler → writer + pending chan *pipeReq // writer → reader FIFO + + inFlight int32 // atomic; reqs in reqCh + pending + being served + + closeOnce sync.Once + closeErr error + closeCh chan struct{} + + // harvest state (protected by harvestMu) + harvestMu sync.Mutex + harvestSent bool + writeFailures []*pipeReq // failed before hitting the wire + pendingAtDeath []*pipeReq // in pending at time of close + reqChAtDeath []*pipeReq // still in reqCh at time of close +} + +func newPipeConn(be *backend, nc net.Conn) *pipeConn { + return &pipeConn{ + be: be, + nc: nc, + br: bufio.NewReader(nc), + bw: bufio.NewWriter(nc), + reqCh: make(chan *pipeReq, 256), + pending: make(chan *pipeReq, 256), + closeCh: make(chan struct{}), + } +} + +func (cn *pipeConn) extendDeadline() { + cn.nc.SetDeadline(time.Now().Add(cn.be.c.netTimeout())) +} + +// close marks the conn dead and closes the underlying net.Conn. The first +// caller wins; subsequent calls are no-ops. It is safe to call from any +// goroutine. +func (cn *pipeConn) close(err error) { + cn.closeOnce.Do(func() { + cn.closeErr = err + cn.nc.Close() + close(cn.closeCh) + }) +} + +// writer serializes writes onto the conn. Reads reqs from reqCh; on success, +// pushes them onto pending for the reader to decode. On error, records the +// failed req and stops. Does NOT drain reqCh on exit; the scheduler does +// that in handleConnLost (it's the only one that knows no more reqs will +// be dispatched to this conn). +func (cn *pipeConn) writer() { + for { + select { + case req := <-cn.reqCh: + if req.isCancelled() { + // Caller already moved on; nothing hit the wire yet, so + // no response framing to worry about. Just release the + // slot and let the scheduler dispatch more. + atomic.AddInt32(&cn.inFlight, -1) + cn.signalIdle() + continue + } + cn.extendDeadline() + if err := req.write(cn.bw); err != nil { + cn.recordWriteFailure(req) + cn.close(err) + close(cn.pending) + return + } + if err := cn.bw.Flush(); err != nil { + // Request may have partially reached the wire. Treat it as + // "in flight" so retry honors idempotency. + cn.pending <- req + cn.close(err) + close(cn.pending) + return + } + cn.pending <- req + case <-cn.closeCh: + close(cn.pending) + return + } + } +} + +// reader reads responses in FIFO order from pending. On a non-resumable error, +// closes the conn; pending reqs in the channel at that moment are harvested. +// +// Cancelled reqs (whose caller's context expired after the request bytes +// hit the wire but before the response arrived) still have their response +// consumed here. Skipping the read would leave those bytes on the wire and +// misalign the parse of the NEXT req in the pipeline, corrupting it. The +// discarded result is dropped via non-blocking send to req.done. +func (cn *pipeConn) reader() { + for req := range cn.pending { + cn.extendDeadline() + err := req.read(cn.br) + atomic.AddInt32(&cn.inFlight, -1) + + if err != nil && !resumableError(err) { + // Framing is suspect; close conn and harvest the rest. + select { + case req.done <- err: + default: + } + cn.close(err) + for remainder := range cn.pending { + cn.recordPendingAtDeath(remainder) + } + break + } + // Non-blocking send: if the caller already moved on (ctx expired), + // the buffer may be empty and we fill it (and it gets GC'd), or the + // caller is still waiting and receives err. + select { + case req.done <- err: + default: + } + cn.signalIdle() + } + select { + case cn.be.connLost <- cn: + case <-cn.be.quit: + } +} + +// signalIdle wakes the scheduler to consider dispatching queued reqs to a +// newly-idle conn. Non-blocking: the connIdle chan has capacity 1 and the +// scheduler's single drain pass dispatches across all idle conns, so dropped +// signals are harmless. +func (cn *pipeConn) signalIdle() { + select { + case cn.be.connIdle <- struct{}{}: + default: + } +} + +func (cn *pipeConn) recordWriteFailure(req *pipeReq) { + cn.harvestMu.Lock() + cn.writeFailures = append(cn.writeFailures, req) + cn.harvestMu.Unlock() +} + +func (cn *pipeConn) recordPendingAtDeath(req *pipeReq) { + cn.harvestMu.Lock() + cn.pendingAtDeath = append(cn.pendingAtDeath, req) + cn.harvestMu.Unlock() +} + +func (cn *pipeConn) recordReqChAtDeath(req *pipeReq) { + cn.harvestMu.Lock() + cn.reqChAtDeath = append(cn.reqChAtDeath, req) + cn.harvestMu.Unlock() +} + +// harvest returns all reqs that need scheduler disposition. writeFailures +// never reached the wire (always safe to retry). pendingAtDeath were written +// and may or may not have been server-processed (idempotent-only retry per +// scheduler policy). reqChAtDeath were dispatched by the scheduler but never +// pulled by the writer: safe (never sent). Scheduler uses req.idempotent to +// decide per-req which to retry. +func (cn *pipeConn) harvest() []*pipeReq { + cn.harvestMu.Lock() + defer cn.harvestMu.Unlock() + if cn.harvestSent { + return nil + } + cn.harvestSent = true + n := len(cn.writeFailures) + len(cn.pendingAtDeath) + len(cn.reqChAtDeath) + out := make([]*pipeReq, 0, n) + out = append(out, cn.writeFailures...) + out = append(out, cn.pendingAtDeath...) + out = append(out, cn.reqChAtDeath...) + return out +} diff --git a/memcache/pipeline_bench_test.go b/memcache/pipeline_bench_test.go new file mode 100644 index 00000000..f6736b10 --- /dev/null +++ b/memcache/pipeline_bench_test.go @@ -0,0 +1,85 @@ +/* +Copyright 2026 The gomemcache AUTHORS + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +*/ + +package memcache + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" +) + +// BenchmarkPipelining measures wall-clock time for a burst of N concurrent +// Gets over a single backend with a fixed per-direction delay, comparing +// pipelined vs MaxPipelineDepth=1 (non-pipelined). With MaxConns bounded, +// non-pipelined mode must serialize ops across a few conns while pipelined +// mode keeps the pipe saturated; the gap grows with RTT and N. +// +// This benchmark uses real time (no synctest), so RTTs are kept small to +// keep total wall time reasonable. It still exhibits the expected relative +// behavior because pipelining turns N sequential RTTs into one. +func BenchmarkPipelining(b *testing.B) { + for _, rtt := range []time.Duration{ + 100 * time.Microsecond, + 500 * time.Microsecond, + 2 * time.Millisecond, + } { + for _, n := range []int{10, 100} { + for _, pipelined := range []bool{true, false} { + name := fmt.Sprintf("rtt=%s/n=%d/pipe=%v", rtt, n, pipelined) + b.Run(name, func(b *testing.B) { + benchPipelining(b, rtt, n, pipelined) + }) + } + } + } +} + +func benchPipelining(b *testing.B, rtt time.Duration, n int, pipelined bool) { + b.Helper() + ln := newFakeListener("mc-bench", rtt/2) + srv := &testServer{} + go srv.Serve(ln) + defer ln.Close() + + c := NewFromSelector(singleServerSelector{addr: ln.Addr()}) + c.Timeout = 10 * time.Second + if !pipelined { + c.MaxPipelineDepth = 1 + } + c.MaxConns = 4 + c.MaxDials = 4 + c.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + return ln.dial() + } + defer c.Close() + + if err := c.Set(&Item{Key: "k", Value: []byte("v")}); err != nil { + b.Fatalf("seed Set: %v", err) + } + + b.ResetTimer() + for iter := 0; iter < b.N; iter++ { + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := c.Get("k"); err != nil { + b.Errorf("Get: %v", err) + } + }() + } + wg.Wait() + } +} diff --git a/memcache/pipeline_test.go b/memcache/pipeline_test.go new file mode 100644 index 00000000..accc2f19 --- /dev/null +++ b/memcache/pipeline_test.go @@ -0,0 +1,490 @@ +//go:build go1.25 + +/* +Copyright 2026 The gomemcache AUTHORS + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +*/ + +package memcache + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "net" + "strings" + "sync" + "sync/atomic" + "testing" + "testing/synctest" + "time" +) + +// countingListener wraps a fakeListener to count dial-accepts. +type countingListener struct { + *fakeListener + accepted atomic.Int64 +} + +func (l *countingListener) dial() (net.Conn, error) { + l.accepted.Add(1) + return l.fakeListener.dial() +} + +// newSyncTestClient returns a Client wired to an in-test testServer through +// a fakeListener with the given one-way delay. +func newSyncTestClient(t *testing.T, oneWayDelay time.Duration, cfg func(*Client)) (*Client, *countingListener) { + t.Helper() + ln := &countingListener{fakeListener: newFakeListener("mc", oneWayDelay)} + srv := &testServer{} + go srv.Serve(ln.fakeListener) + t.Cleanup(func() { ln.Close() }) + + c := NewFromSelector(singleServerSelector{addr: ln.Addr()}) + c.Timeout = 10 * time.Second + c.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + return ln.dial() + } + if cfg != nil { + cfg(c) + } + t.Cleanup(func() { c.Close() }) + return c, ln +} + +// serveSlowGets is a minimal gets-only test server. For each `gets \r\n` +// it replies with a canned `VALUE 0 len\r\nval-\r\nEND\r\n`. +// If the key has prefix "slow-", it first sleeps for slowDelay (fake time +// under synctest), holding the pipeline so later requests' bytes sit in the +// wire buffer. +// +// Callers should invoke this as `go serveSlowGets(ln, delay)`. +func serveSlowGets(ln net.Listener, slowDelay time.Duration) { + for { + c, err := ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + br := bufio.NewReader(c) + bw := bufio.NewWriter(c) + for { + line, err := br.ReadSlice('\n') + if err != nil { + if !errors.Is(err, io.EOF) { + // swallow; conn closed + } + return + } + s := strings.TrimSuffix(string(line), "\r\n") + if !strings.HasPrefix(s, "gets ") { + return + } + key := strings.TrimPrefix(s, "gets ") + if strings.HasPrefix(key, "slow-") { + time.Sleep(slowDelay) + } + val := "val-" + key + fmt.Fprintf(bw, "VALUE %s 0 %d\r\n%s\r\nEND\r\n", key, len(val), val) + if err := bw.Flush(); err != nil { + return + } + } + }(c) + } +} + +// TestPipelineCancelPreservesFraming regression-tests the scenario where a +// pipelined request's context expires after its bytes are on the wire but +// before its response is consumed. The reader must still read and discard +// the cancelled request's response; otherwise the next pipelined request's +// parse picks up the cancelled one's bytes and either errors or (worse) +// silently returns the wrong value. +// +// Pipeline on one conn: A ("slow-a", slow server response), B ("b", +// short ctx that expires during A's delay), C ("c"). After the fix C must +// return its own value, not B's. +func TestPipelineCancelPreservesFraming(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const serverDelay = 500 * time.Millisecond + const bCtxDeadline = 100 * time.Millisecond + const linkDelay = 1 * time.Millisecond + + ln := newFakeListener("mc-cancel", linkDelay) + go serveSlowGets(ln, serverDelay) + t.Cleanup(func() { ln.Close() }) + + c := NewFromSelector(singleServerSelector{addr: ln.Addr()}) + c.Timeout = 10 * time.Second + c.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + return ln.dial() + } + t.Cleanup(func() { c.Close() }) + + // A: slow Get. Its server-side delay is what lets B and C be + // pipelined behind it before A's response comes back. + var aItem *Item + var aErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + aItem, aErr = c.Get("slow-a") + }() + synctest.Wait() // let A hit the wire before B, C queue behind it. + + // B: internal submit with a short ctx so we can observe cancellation + // before the public API exposes contexts. + addr, err := c.selector.PickServer("b") + if err != nil { + t.Fatal(err) + } + var bErr error + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(ctxBG, bCtxDeadline) + defer cancel() + // Read closure intentionally discards the Item; after ctx expires + // the reader may still call this closure (that's the whole point + // of the fix), so capturing any test-goroutine-local state would + // race with the test goroutine reading it after wg.Wait. + bErr = c.runCmd(ctx, addr, "gets", + func(w *bufio.Writer) error { + _, werr := fmt.Fprintf(w, "gets b\r\n") + return werr + }, + func(r *bufio.Reader) error { + return parseGetResponse(r, func(*Item) {}) + }) + }() + synctest.Wait() // ensure B's bytes are written before submitting C. + + // C: normal Get, same conn (pipelined behind A and B). + var cItem *Item + var cErr error + wg.Add(1) + go func() { + defer wg.Done() + cItem, cErr = c.Get("c") + }() + + wg.Wait() + + if aErr != nil { + t.Errorf("A: unexpected error %v", aErr) + } else if got := string(aItem.Value); got != "val-slow-a" { + t.Errorf("A: got %q, want val-slow-a", got) + } + if bErr == nil { + t.Errorf("B: want ctx error, got nil") + } else if !errors.Is(bErr, context.DeadlineExceeded) { + t.Errorf("B: want DeadlineExceeded, got %v", bErr) + } + if cErr != nil { + t.Errorf("C: unexpected error %v (framing-corruption regression)", cErr) + } else if got := string(cItem.Value); got != "val-c" { + t.Errorf("C: got %q, want val-c (framing-corruption regression)", got) + } + }) +} + +// TestPipelineMaxConns verifies that total conns opened per backend never +// exceeds MaxConns under a burst of concurrent ops in non-pipelined mode +// (where each op needs its own conn slot). +func TestPipelineMaxConns(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const maxConns = 3 + c, ln := newSyncTestClient(t, 1*time.Millisecond, func(c *Client) { + c.MaxPipelineDepth = 1 + c.MaxConns = maxConns + c.MaxDials = 50 + }) + + if err := c.Set(&Item{Key: "k", Value: []byte("v")}); err != nil { + t.Fatalf("seed Set: %v", err) + } + + const N = 30 + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := c.Get("k"); err != nil { + t.Errorf("Get: %v", err) + } + }() + } + wg.Wait() + + if got := ln.accepted.Load(); got > int64(maxConns) { + t.Errorf("accepted conns = %d; want ≤ MaxConns=%d", got, maxConns) + } else { + t.Logf("accepted %d conns (cap %d)", got, maxConns) + } + }) +} + +// TestPipelineMaxDials verifies the peak number of concurrent dials per +// backend never exceeds MaxDials. +func TestPipelineMaxDials(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + const maxDials = 2 + + // testCtx lets us cancel any dial still sleeping when the test + // body is done, so we don't leave synctest-durable blocked goroutines. + testCtx, testCancel := context.WithCancel(context.Background()) + + var dialInFlight atomic.Int32 + var peak atomic.Int32 + + ln := &countingListener{fakeListener: newFakeListener("mc-dials", time.Millisecond)} + srv := &testServer{} + go srv.Serve(ln.fakeListener) + + c := NewFromSelector(singleServerSelector{addr: ln.Addr()}) + c.Timeout = time.Hour // don't let Client.dial's inner ctx fire + c.MaxPipelineDepth = 1 + c.MaxConns = 100 + c.MaxDials = maxDials + c.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + n := dialInFlight.Add(1) + for { + old := peak.Load() + if n <= old || peak.CompareAndSwap(old, n) { + break + } + } + // Hold the dial long enough for overlap to be observable. + select { + case <-time.After(20 * time.Millisecond): + case <-testCtx.Done(): + dialInFlight.Add(-1) + return nil, context.Canceled + case <-ctx.Done(): + dialInFlight.Add(-1) + return nil, ctx.Err() + } + dialInFlight.Add(-1) + return ln.dial() + } + + if err := c.Set(&Item{Key: "k", Value: []byte("v")}); err != nil { + t.Fatalf("seed Set: %v", err) + } + + const N = 15 + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = c.Get("k") + }() + } + wg.Wait() + + // Release any lingering dial goroutines and tear the client down + // before returning, so synctest doesn't see durable blocks. + testCancel() + c.Close() + ln.Close() + + if got := peak.Load(); got > int32(maxDials) { + t.Errorf("peak concurrent dials = %d; want ≤ MaxDials=%d", got, maxDials) + } else { + t.Logf("peak concurrent dials = %d (cap %d)", got, maxDials) + } + }) +} + +// TestPipelineMaxIdleConns verifies the idle reaper closes surplus idle +// conns down to MaxIdleConns after a period of inactivity. Uses +// DisablePipelining so a burst genuinely opens many conns. +func TestPipelineMaxIdleConns(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + // Shorten the reap interval so the test doesn't wait long in fake + // time. Restore after. + origInterval := atomic.LoadInt64(&idleReapIntervalNanos) + atomic.StoreInt64(&idleReapIntervalNanos, int64(100*time.Millisecond)) + t.Cleanup(func() { + atomic.StoreInt64(&idleReapIntervalNanos, origInterval) + }) + + const maxIdle = 2 + c, ln := newSyncTestClient(t, 1*time.Millisecond, func(c *Client) { + c.MaxPipelineDepth = 1 + c.MaxIdleConns = maxIdle + c.MaxConns = 100 + c.MaxDials = 100 + }) + + if err := c.Set(&Item{Key: "k", Value: []byte("v")}); err != nil { + t.Fatalf("seed Set: %v", err) + } + + // Burst: opens many conns (DisablePipelining => one per op). + const N = 10 + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := c.Get("k"); err != nil { + t.Errorf("Get: %v", err) + } + }() + } + wg.Wait() + + beforeReap := ln.accepted.Load() + + // Wait through two reap intervals. First interval: idleWatch + // observes lastUsed unchanged and queues the reap. Second: the + // reap has been processed by the scheduler. + time.Sleep(3 * idleReapInterval()) + synctest.Wait() + + // Probe the backend directly; it's the only one. + var b *backend + c.mu.Lock() + for _, bb := range c.backends { + b = bb + break + } + c.mu.Unlock() + if b == nil { + t.Fatal("no backend found") + } + + // The scheduler owns b.conns, but this goroutine isn't the scheduler. + // Send a no-op through submit to force a round-trip; by the time + // our submit's req.done fires, the scheduler has processed every + // prior event (reapIdle included). Then block all scheduler events + // by taking no new action and counting conns via len indirectly + // through the listener's accepted count + a fresh op. + // Simpler: just count via accepted vs. what a fresh op uses. + + // Trigger one more op; if reaping worked, this should NOT cause a + // new dial (one of the ≤maxIdle surviving conns handles it). + accBefore := ln.accepted.Load() + if _, err := c.Get("k"); err != nil { + t.Fatalf("post-reap Get: %v", err) + } + accAfter := ln.accepted.Load() + newDials := accAfter - accBefore + + if newDials != 0 { + t.Errorf("post-reap Get required %d new dial(s); want 0 (idle pool should have served it)", newDials) + } + + t.Logf("burst opened %d conns; after reap+1 op: no new dials (idle pool ≤ %d)", beforeReap, maxIdle) + }) +} + +// TestPipelineConcurrency runs N concurrent Gets against a single pipelined +// backend and verifies every response is correct (order-preserving demux). +func TestPipelineConcurrency(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + c, _ := newSyncTestClient(t, 1*time.Millisecond, nil) + + const N = 100 + for i := 0; i < N; i++ { + if err := c.Set(&Item{Key: fmt.Sprintf("k%d", i), Value: []byte(fmt.Sprintf("v%d", i))}); err != nil { + t.Fatalf("Set k%d: %v", i, err) + } + } + + var wg sync.WaitGroup + errs := make([]error, N) + vals := make([]string, N) + for i := 0; i < N; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + it, err := c.Get(fmt.Sprintf("k%d", i)) + errs[i] = err + if err == nil { + vals[i] = string(it.Value) + } + }(i) + } + wg.Wait() + + for i := 0; i < N; i++ { + if errs[i] != nil { + t.Errorf("Get k%d: %v", i, errs[i]) + continue + } + want := fmt.Sprintf("v%d", i) + if vals[i] != want { + t.Errorf("Get k%d: got %q, want %q", i, vals[i], want) + } + } + }) +} + +// TestPipelineConnCount verifies that with pipelining enabled, a burst of N +// concurrent Gets uses substantially fewer connections than N, while +// MaxPipelineDepth=1 opens one conn per concurrent op. +func TestPipelineConnCount(t *testing.T) { + const N = 50 + for _, tc := range []struct { + name string + depth int + }{ + {"pipelined", 0}, // default depth + {"depth_1", 1}, // disables pipelining + } { + t.Run(tc.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + c, ln := newSyncTestClient(t, 5*time.Millisecond, func(c *Client) { + c.MaxPipelineDepth = tc.depth + c.MaxConns = 100 + c.MaxDials = 100 + }) + + if err := c.Set(&Item{Key: "k", Value: []byte("v")}); err != nil { + t.Fatalf("Set: %v", err) + } + + var wg sync.WaitGroup + for i := 0; i < N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := c.Get("k"); err != nil { + t.Errorf("Get: %v", err) + } + }() + } + wg.Wait() + + got := ln.accepted.Load() + if tc.depth == 1 { + // One conn per concurrent op (minus any reused). + if got < int64(N/2) { + t.Errorf("depth=1: want ≥ %d conns (burst needs many), got %d", N/2, got) + } + } else { + // Pipelined: default cap is 8, so at most ~N/8 + small fudge. + const maxExpected = N/4 + 2 + if got > maxExpected { + t.Errorf("pipelined: want ≤ %d conns, got %d", maxExpected, got) + } + } + t.Logf("%s: %d conns for %d concurrent ops", tc.name, got, N) + }) + }) + } +} + diff --git a/memcache/selector.go b/memcache/selector.go index 964dbdb6..849ddc33 100644 --- a/memcache/selector.go +++ b/memcache/selector.go @@ -19,6 +19,7 @@ package memcache import ( "hash/crc32" "net" + "net/netip" "strings" "sync" ) @@ -42,15 +43,22 @@ type ServerList struct { } // staticAddr caches the Network() and String() values from any net.Addr. +// For TCP addresses it also caches the netip.AddrPort, which the Client uses +// as a fast, alloc-free map key for per-backend state. type staticAddr struct { ntw, str string + ap netip.AddrPort // zero (!IsValid) if ntw is not tcp-shaped } func newStaticAddr(a net.Addr) net.Addr { - return &staticAddr{ + sa := &staticAddr{ ntw: a.Network(), str: a.String(), } + if tcp, ok := a.(*net.TCPAddr); ok { + sa.ap = tcp.AddrPort() + } + return sa } func (s *staticAddr) Network() string { return s.ntw } @@ -89,7 +97,7 @@ func (ss *ServerList) SetServers(servers ...string) error { return nil } -// Each iterates over each server calling the given function +// Each iterates over each server calling the given function. func (ss *ServerList) Each(f func(net.Addr) error) error { ss.mu.RLock() defer ss.mu.RUnlock()