diff --git a/peers/e2e_test.go b/peers/e2e_test.go index 3e0457c..59ee260 100644 --- a/peers/e2e_test.go +++ b/peers/e2e_test.go @@ -7,6 +7,8 @@ import ( "fmt" "log" "net/http" + "net/netip" + "sync" "testing" "time" @@ -64,3 +66,252 @@ backend st_src_global } }) } + +func TestE2EWriter(t *testing.T) { + writerCh := make(chan *Writer, 1) + a := Peer{HandlerSource: func() Handler { + return &writerE2EHandler{writerCh: writerCh} + }} + + l := testutil.TCPListener(t) + go a.Serve(l) + + cfg := testutil.HAProxyConfig{ + FrontendPort: fmt.Sprintf("%d", testutil.TCPPort(t)), + CustomFrontendConfig: ` + http-request track-sc0 src table st_blocklist + http-request deny deny_status 403 if { sc0_get_gpc0 gt 0 } +`, + CustomConfig: ` +backend st_blocklist + stick-table type ip size 200k expire 5m store gpc0 peers mypeers +`, + PeerAddr: l.Addr().String(), + } + + t.Run("push entry blocks request", func(t *testing.T) { + cfg.Run(t) + + var w *Writer + select { + case w = <-writerCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for HAProxy peer connection") + } + + time.Sleep(500 * time.Millisecond) + + resp, err := http.Get("http://127.0.0.1:" + cfg.FrontendPort) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 before push, got %d", resp.StatusCode) + } + + tableDef := &sticktable.Definition{ + StickTableID: 0, + Name: "st_blocklist", + KeyType: sticktable.KeyTypeIPv4Address, + KeyLength: 4, + DataTypes: []sticktable.DataTypeDefinition{ + {DataType: sticktable.DataTypeGPC0}, + }, + Expiry: 300000, + } + + if err := w.SendTableDefinition(tableDef); err != nil { + t.Fatal(err) + } + + key := sticktable.IPv4AddressKey(netip.MustParseAddr("127.0.0.1")) + gpc0 := sticktable.UnsignedIntegerData(1) + entry := &sticktable.EntryUpdate{ + StickTable: tableDef, + Key: &key, + Data: []sticktable.MapData{&gpc0}, + } + if err := w.SendEntry(entry); err != nil { + t.Fatal(err) + } + + time.Sleep(1 * time.Second) + + resp, err = http.Get("http://127.0.0.1:" + cfg.FrontendPort) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected 403 after push, got %d", resp.StatusCode) + } + }) +} + +type writerE2EHandler struct { + writerCh chan *Writer + once sync.Once +} + +func (h *writerE2EHandler) HandleUpdate(_ context.Context, u *sticktable.EntryUpdate) { + log.Println(u) +} + +func (h *writerE2EHandler) HandleHandshake(ctx context.Context, _ *Handshake) { + h.once.Do(func() { + h.writerCh <- WriterFromContext(ctx) + }) +} + +func (h *writerE2EHandler) Close() error { return nil } + +func TestE2EWriterTimedEntry(t *testing.T) { + writerCh := make(chan *Writer, 1) + a := Peer{HandlerSource: func() Handler { + return &writerE2EHandler{writerCh: writerCh} + }} + + l := testutil.TCPListener(t) + go a.Serve(l) + + cfg := testutil.HAProxyConfig{ + FrontendPort: fmt.Sprintf("%d", testutil.TCPPort(t)), + CustomConfig: ` +backend st_timed + stick-table type ip size 200k expire 5m peers mypeers +`, + BackendConfig: ` + http-request set-var(txn.lookup_ip) str(127.0.0.2) + http-request return status 200 content-type "text/plain" hdr X-Expire %[var(txn.lookup_ip),table_expire(st_timed)] string "OK\n" +`, + PeerAddr: l.Addr().String(), + } + + t.Run("push timed entry with 60s expiry", func(t *testing.T) { + cfg.Run(t) + + var w *Writer + select { + case w = <-writerCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for HAProxy peer connection") + } + + time.Sleep(500 * time.Millisecond) + + tableDef := &sticktable.Definition{ + StickTableID: 0, + Name: "st_timed", + KeyType: sticktable.KeyTypeIPv4Address, + KeyLength: 4, + Expiry: 300000, + } + + if err := w.SendTableDefinition(tableDef); err != nil { + t.Fatal(err) + } + + key := sticktable.IPv4AddressKey(netip.MustParseAddr("127.0.0.2")) + entry := &sticktable.EntryUpdate{ + StickTable: tableDef, + Key: &key, + WithExpiry: true, + Expiry: 60000, + } + if err := w.SendEntry(entry); err != nil { + t.Fatal(err) + } + + time.Sleep(1 * time.Second) + + resp, err := http.Get("http://127.0.0.1:" + cfg.FrontendPort) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + xexpire := resp.Header.Get("X-Expire") + t.Logf("X-Expire: %s", xexpire) + + if xexpire == "" || xexpire == "0" { + t.Errorf("expected non-zero X-Expire header, got %q", xexpire) + } + }) +} + +func TestE2EWriterBulkEntries(t *testing.T) { + writerCh := make(chan *Writer, 1) + a := Peer{HandlerSource: func() Handler { + return &writerE2EHandler{writerCh: writerCh} + }} + + l := testutil.TCPListener(t) + go a.Serve(l) + + cfg := testutil.HAProxyConfig{ + FrontendPort: fmt.Sprintf("%d", testutil.TCPPort(t)), + CustomConfig: ` +backend st_bulk + stick-table type ip size 200k expire 5m peers mypeers +`, + BackendConfig: ` + http-request return status 200 content-type "text/plain" hdr X-Count %[table_cnt(st_bulk)] string "OK\n" +`, + PeerAddr: l.Addr().String(), + } + + t.Run("push 20 entries", func(t *testing.T) { + cfg.Run(t) + + var w *Writer + select { + case w = <-writerCh: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for HAProxy peer connection") + } + + time.Sleep(500 * time.Millisecond) + + tableDef := &sticktable.Definition{ + StickTableID: 0, + Name: "st_bulk", + KeyType: sticktable.KeyTypeIPv4Address, + KeyLength: 4, + Expiry: 300000, + } + + if err := w.SendTableDefinition(tableDef); err != nil { + t.Fatal(err) + } + + for i := 0; i < 20; i++ { + ip := netip.AddrFrom4([4]byte{10, 0, 0, byte(i + 1)}) + key := sticktable.IPv4AddressKey(ip) + entry := &sticktable.EntryUpdate{ + StickTable: tableDef, + Key: &key, + WithExpiry: true, + Expiry: 60000, + } + if err := w.SendEntry(entry); err != nil { + t.Fatalf("sending entry %d (%s): %v", i, ip, err) + } + } + + time.Sleep(1 * time.Second) + + resp, err := http.Get("http://127.0.0.1:" + cfg.FrontendPort) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + xcount := resp.Header.Get("X-Count") + t.Logf("X-Count: %s", xcount) + + if xcount != "20" { + t.Errorf("expected X-Count=20, got %q", xcount) + } + }) +} diff --git a/peers/example/push/go.mod b/peers/example/push/go.mod new file mode 100644 index 0000000..87604ea --- /dev/null +++ b/peers/example/push/go.mod @@ -0,0 +1,7 @@ +module github.com/dropmorepackets/haproxy-go/peers/example/push + +go 1.21 + +replace github.com/dropmorepackets/haproxy-go => ../../../ + +require github.com/dropmorepackets/haproxy-go v0.0.0-00010101000000-000000000000 diff --git a/peers/example/push/main.go b/peers/example/push/main.go new file mode 100644 index 0000000..159d245 --- /dev/null +++ b/peers/example/push/main.go @@ -0,0 +1,62 @@ +// push is an example that demonstrates how to push stick table entries +// to HAProxy over an existing peer connection. When HAProxy connects to +// this peer, the handler uses WriterFromContext to obtain a Writer and +// sends a table definition followed by entry updates. +package main + +import ( + "context" + "log" + "net/netip" + + "github.com/dropmorepackets/haproxy-go/peers" + "github.com/dropmorepackets/haproxy-go/peers/sticktable" +) + +func main() { + log.SetFlags(log.LstdFlags | log.Lshortfile) + + err := peers.ListenAndServe(":21000", peers.HandlerFunc(func(ctx context.Context, u *sticktable.EntryUpdate) { + log.Println("received:", u.String()) + + // Get the writer for this connection to push entries back. + w := peers.WriterFromContext(ctx) + + // Define the stick table we want to push to. + // Matches: stick-table type ip size 200k expire 5m store gpc0 peers local-peers + tableDef := &sticktable.Definition{ + StickTableID: 0, + Name: "my_blocklist", + KeyType: sticktable.KeyTypeIPv4Address, + KeyLength: 4, + DataTypes: []sticktable.DataTypeDefinition{ + {DataType: sticktable.DataTypeGPC0}, + }, + Expiry: 300000, // 5 minutes in ms + } + + if err := w.SendTableDefinition(tableDef); err != nil { + log.Printf("error sending table definition: %v", err) + return + } + + // Push an entry marking an IP as blocked (gpc0 = 1). + key := sticktable.IPv4AddressKey(netip.MustParseAddr("10.0.0.1")) + gpc0 := sticktable.UnsignedIntegerData(1) + entry := &sticktable.EntryUpdate{ + StickTable: tableDef, + Key: &key, + Data: []sticktable.MapData{&gpc0}, + } + + if err := w.SendEntry(entry); err != nil { + log.Printf("error sending entry: %v", err) + return + } + + log.Println("pushed blocklist entry for 10.0.0.1") + })) + if err != nil { + log.Fatal(err) + } +} diff --git a/peers/peers.go b/peers/peers.go index 6e949c7..2a4d056 100644 --- a/peers/peers.go +++ b/peers/peers.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "net" + "sync" ) type Peer struct { @@ -59,7 +60,10 @@ func (a *Peer) Serve(l net.Listener) error { // Wrap the context to provide access to the underlying connection. // TODO(tim): Do we really want this? ctx := context.WithValue(a.BaseContext, connectionKey, nc) - p := newProtocolClient(ctx, nc, a.HandlerSource()) + wmu := &sync.Mutex{} + w := newWriter(nc, wmu) + ctx = context.WithValue(ctx, writerKey, w) + p := newProtocolClient(ctx, nc, a.HandlerSource(), wmu, w.bufferedWriter()) go func() { defer nc.Close() defer p.Close() @@ -75,6 +79,7 @@ type contextKey string const ( connectionKey = contextKey("connection") + writerKey = contextKey("writer") ) // Connection returns the underlying connection used in calls @@ -82,3 +87,11 @@ const ( func Connection(ctx context.Context) net.Conn { return ctx.Value(connectionKey).(net.Conn) } + +// WriterFromContext returns the Writer associated with the current peer +// connection. Use this inside a Handler to push stick table updates back +// to HAProxy over the same connection that HAProxy established to us. +// Panics if called outside a handler context. +func WriterFromContext(ctx context.Context) *Writer { + return ctx.Value(writerKey).(*Writer) +} diff --git a/peers/protocol.go b/peers/protocol.go index 04d913d..f51b2b7 100644 --- a/peers/protocol.go +++ b/peers/protocol.go @@ -8,6 +8,7 @@ import ( "io" "log" "net" + "sync" "time" "github.com/dropmorepackets/haproxy-go/peers/sticktable" @@ -19,6 +20,8 @@ type protocolClient struct { ctxCancel context.CancelFunc rw io.ReadWriter br *bufio.Reader + bw *bufio.Writer + wmu *sync.Mutex nextHeartbeat *time.Ticker lastMessageTimer *time.Timer @@ -28,11 +31,13 @@ type protocolClient struct { handler Handler } -func newProtocolClient(ctx context.Context, rw io.ReadWriter, handler Handler) *protocolClient { +func newProtocolClient(ctx context.Context, rw io.ReadWriter, handler Handler, wmu *sync.Mutex, bw *bufio.Writer) *protocolClient { var c protocolClient c.rw = rw c.br = bufio.NewReader(rw) + c.bw = bw c.handler = handler + c.wmu = wmu c.ctx, c.ctxCancel = context.WithCancel(ctx) return &c } @@ -46,6 +51,16 @@ func (c *protocolClient) Close() error { return c.handler.Close() } +func (c *protocolClient) lockedWrite(data []byte) (int, error) { + c.wmu.Lock() + defer c.wmu.Unlock() + n, err := c.bw.Write(data) + if err != nil { + return n, err + } + return n, c.bw.Flush() +} + func (c *protocolClient) peerHandshake() error { var h Handshake if _, err := h.ReadFrom(c.br); err != nil { @@ -54,7 +69,7 @@ func (c *protocolClient) peerHandshake() error { c.handler.HandleHandshake(c.ctx, &h) - if _, err := c.rw.Write([]byte(fmt.Sprintf("%d\n", HandshakeStatusHandshakeSucceeded))); err != nil { + if _, err := c.lockedWrite([]byte(fmt.Sprintf("%d\n", HandshakeStatusHandshakeSucceeded))); err != nil { return fmt.Errorf("handshake failed: %v", err) } @@ -90,7 +105,7 @@ func (c *protocolClient) resetLastMessage() { func (c *protocolClient) heartbeat() { for range c.nextHeartbeat.C { - _, err := c.rw.Write([]byte{byte(MessageClassControl), byte(ControlMessageHeartbeat)}) + _, err := c.lockedWrite([]byte{byte(MessageClassControl), byte(ControlMessageHeartbeat)}) if err != nil { _ = c.Close() return @@ -207,7 +222,7 @@ func (t ErrorMessageType) OnMessage(m *rawMessage, c *protocolClient) error { func (t ControlMessageType) OnMessage(m *rawMessage, c *protocolClient) error { switch t { case ControlMessageSyncRequest: - _, _ = c.rw.Write([]byte{byte(MessageClassControl), byte(ControlMessageSyncPartial)}) + _, _ = c.lockedWrite([]byte{byte(MessageClassControl), byte(ControlMessageSyncPartial)}) return nil case ControlMessageSyncFinished: return nil @@ -236,7 +251,9 @@ func (t StickTableUpdateMessageType) OnMessage(m *rawMessage, c *protocolClient) log.Printf("not implemented: %s", t) return nil case StickTableUpdateMessageTypeUpdateAcknowledge: - log.Printf("not implemented: %s", t) + // HAProxy sends ack messages after receiving our pushed updates. + // The ack contains the remote table ID and last committed update ID. + // We currently don't track these, so just accept silently. return nil case StickTableUpdateMessageTypeEntryUpdate, StickTableUpdateMessageTypeUpdateTimed, diff --git a/peers/sticktable/sticktables.go b/peers/sticktable/sticktables.go index 13ac81c..134d0b1 100644 --- a/peers/sticktable/sticktables.go +++ b/peers/sticktable/sticktables.go @@ -209,7 +209,7 @@ func (e *EntryUpdate) Marshal(b []byte) (int, error) { } for _, data := range e.Data { - n, err := data.Unmarshal(b[offset:]) + n, err := data.Marshal(b[offset:]) offset += n if err != nil { return offset, err diff --git a/peers/writer.go b/peers/writer.go new file mode 100644 index 0000000..2a378e3 --- /dev/null +++ b/peers/writer.go @@ -0,0 +1,233 @@ +package peers + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "sync" + + "github.com/dropmorepackets/haproxy-go/peers/sticktable" + "github.com/dropmorepackets/haproxy-go/pkg/encoding" +) + +// Writer sends stick table updates over an existing peer connection. +// It is safe for concurrent use. Obtain a Writer from a handler's context +// using WriterFromContext. +type Writer struct { + bw *bufio.Writer + mu *sync.Mutex + buf []byte // reusable scratch buffer for marshaling + + nextUpdateID uint32 +} + +func newWriter(w io.Writer, mu *sync.Mutex) *Writer { + bw := bufio.NewWriterSize(w, 64*1024) + return &Writer{ + bw: bw, + mu: mu, + buf: make([]byte, 65536), + } +} + +// bufferedWriter returns the underlying bufio.Writer so the protocol +// client can share the same buffered output (under the shared mutex). +func (w *Writer) bufferedWriter() *bufio.Writer { + return w.bw +} + +// Flush flushes any buffered data to the underlying connection. +// The caller must hold the mutex or call this after a batch of writes. +func (w *Writer) Flush() error { + w.mu.Lock() + defer w.mu.Unlock() + return w.bw.Flush() +} + +// writeMessage writes a peer protocol message. Messages with type >= 128 +// include a varint-encoded data length prefix before the payload. +// Caller must NOT hold the mutex — this method acquires it. +func (w *Writer) writeMessage(class MessageClass, msgType byte, data []byte) error { + w.mu.Lock() + defer w.mu.Unlock() + return w.writeMessageLocked(class, msgType, data) +} + +// writeMessageLocked writes a peer protocol message. +// Caller MUST hold the mutex. +func (w *Writer) writeMessageLocked(class MessageClass, msgType byte, data []byte) error { + var lenBuf [10]byte + var lenBytes int + if msgType >= 128 { + n, err := encoding.PutVarint(lenBuf[:], uint64(len(data))) + if err != nil { + return fmt.Errorf("encoding data length: %w", err) + } + lenBytes = n + } + + // Write header (class + type) + if _, err := w.bw.Write([]byte{byte(class), msgType}); err != nil { + return err + } + + // Write length prefix if present + if lenBytes > 0 { + if _, err := w.bw.Write(lenBuf[:lenBytes]); err != nil { + return err + } + } + + // Write payload + if len(data) > 0 { + if _, err := w.bw.Write(data); err != nil { + return err + } + } + + return nil +} + +// SendTableDefinition sends a stick table definition message. +// This must be called before sending entry updates for a table so +// that the remote peer knows which table the updates refer to. +func (w *Writer) SendTableDefinition(def *sticktable.Definition) error { + var buf [4096]byte + n, err := def.Marshal(buf[:]) + if err != nil { + return fmt.Errorf("marshaling table definition: %w", err) + } + + if err := w.writeMessage( + MessageClassStickTableUpdates, + byte(StickTableUpdateMessageTypeStickTableDefinition), + buf[:n], + ); err != nil { + return err + } + + return w.Flush() +} + +// SendTableSwitch sends a table switch message to select a previously +// defined table by its sender table ID. +func (w *Writer) SendTableSwitch(tableID uint64) error { + var buf [10]byte + n, err := encoding.PutVarint(buf[:], tableID) + if err != nil { + return fmt.Errorf("encoding table ID: %w", err) + } + + if err := w.writeMessage( + MessageClassStickTableUpdates, + byte(StickTableUpdateMessageTypeStickTableSwitch), + buf[:n], + ); err != nil { + return err + } + + return w.Flush() +} + +// marshalEntry marshals a single entry update into buf and returns the +// byte count. The updateID is written first, followed by optional expiry, +// key and data values. +func marshalEntry(buf []byte, entry *sticktable.EntryUpdate, updateID uint32) (int, error) { + offset := 0 + + binary.BigEndian.PutUint32(buf[offset:], updateID) + offset += 4 + + if entry.WithExpiry { + binary.BigEndian.PutUint32(buf[offset:], entry.Expiry) + offset += 4 + } + + n, err := entry.Key.Marshal(buf[offset:], entry.StickTable.KeyLength) + offset += n + if err != nil { + return offset, fmt.Errorf("marshaling entry key: %w", err) + } + + for _, data := range entry.Data { + n, err := data.Marshal(buf[offset:]) + offset += n + if err != nil { + return offset, fmt.Errorf("marshaling entry data: %w", err) + } + } + + return offset, nil +} + +// SendEntry sends a stick table entry update with an automatically +// assigned update ID. The message type is chosen based on the entry's +// WithExpiry flag: +// - WithExpiry=false: Entry Update (0x80) +// - WithExpiry=true: Update Timed (0x85) +// +// Note: for bulk operations, prefer SendEntries which batches writes and flushes once. +func (w *Writer) SendEntry(entry *sticktable.EntryUpdate) error { + w.mu.Lock() + updateID := w.nextUpdateID + w.nextUpdateID++ + + msgType := StickTableUpdateMessageTypeEntryUpdate + if entry.WithExpiry { + msgType = StickTableUpdateMessageTypeUpdateTimed + } + + offset, err := marshalEntry(w.buf, entry, updateID) + if err != nil { + w.mu.Unlock() + return fmt.Errorf("marshaling entry update: %w", err) + } + + if err = w.writeMessageLocked( + MessageClassStickTableUpdates, + byte(msgType), + w.buf[:offset], + ); err != nil { + w.mu.Unlock() + return err + } + + err = w.bw.Flush() + w.mu.Unlock() + return err +} + +// SendEntries sends multiple stick table entry updates in a single +// locked batch. This is significantly faster than calling SendEntry +// in a loop because it acquires the mutex once, marshals and writes +// all entries into the buffer, then flushes once. +func (w *Writer) SendEntries(entries []*sticktable.EntryUpdate) error { + w.mu.Lock() + defer w.mu.Unlock() + + for _, entry := range entries { + updateID := w.nextUpdateID + w.nextUpdateID++ + + msgType := StickTableUpdateMessageTypeEntryUpdate + if entry.WithExpiry { + msgType = StickTableUpdateMessageTypeUpdateTimed + } + + offset, err := marshalEntry(w.buf, entry, updateID) + if err != nil { + return fmt.Errorf("marshaling entry update: %w", err) + } + + if err := w.writeMessageLocked( + MessageClassStickTableUpdates, + byte(msgType), + w.buf[:offset], + ); err != nil { + return err + } + } + + return w.bw.Flush() +} diff --git a/peers/writer_test.go b/peers/writer_test.go new file mode 100644 index 0000000..529a742 --- /dev/null +++ b/peers/writer_test.go @@ -0,0 +1,358 @@ +package peers + +import ( + "bufio" + "context" + "fmt" + "net" + "net/netip" + "sync" + "testing" + "time" + + "github.com/dropmorepackets/haproxy-go/peers/sticktable" + "github.com/google/go-cmp/cmp" +) + +// helperDialPeer performs the client-side handshake to connect to a Peer server. +// HAProxy would normally do this. In tests, we simulate HAProxy connecting to us. +func helperDialPeer(t *testing.T, addr, localPeer, remotePeer string) net.Conn { + t.Helper() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("dialing peer: %v", err) + } + + h := NewHandshake(remotePeer) + h.LocalPeerIdentifier = localPeer + if _, err = h.WriteTo(conn); err != nil { + conn.Close() + t.Fatalf("writing handshake: %v", err) + } + + br := bufio.NewReader(conn) + line, err := br.ReadString('\n') + if err != nil { + conn.Close() + t.Fatalf("reading handshake status: %v", err) + } + + var status int + if _, err = fmt.Sscanf(line, "%d\n", &status); err != nil { + conn.Close() + t.Fatalf("parsing status %q: %v", line, err) + } + + if HandshakeStatus(status) != HandshakeStatusHandshakeSucceeded { + conn.Close() + t.Fatalf("handshake failed with status %d", status) + } + + return conn +} + +func TestWriterSendEntry(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + writerReady := make(chan *Writer, 1) + peer := &Peer{ + BaseContext: ctx, + HandlerSource: func() Handler { + return &testHandler{ + onHandshake: func(ctx context.Context, h *Handshake) { + writerReady <- WriterFromContext(ctx) + }, + } + }, + } + go peer.Serve(l) + + conn := helperDialPeer(t, l.Addr().String(), "haproxy_peer", "go_peer") + defer conn.Close() + + var w *Writer + select { + case w = <-writerReady: + case <-ctx.Done(): + t.Fatal("timeout waiting for writer") + } + + tableDef := &sticktable.Definition{ + StickTableID: 0, + Name: "test_table", + KeyType: sticktable.KeyTypeString, + KeyLength: 50, + DataTypes: []sticktable.DataTypeDefinition{ + {DataType: sticktable.DataTypeGPC0}, + }, + Expiry: 600000, + } + + if err := w.SendTableDefinition(tableDef); err != nil { + t.Fatal(err) + } + + key := sticktable.StringKey("192.168.1.1") + gpc0 := sticktable.UnsignedIntegerData(42) + entry := &sticktable.EntryUpdate{ + StickTable: tableDef, + Key: &key, + Data: []sticktable.MapData{&gpc0}, + } + + if err := w.SendEntry(entry); err != nil { + t.Fatal(err) + } +} + +func TestWriterSendMultipleEntries(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + writerReady := make(chan *Writer, 1) + peer := &Peer{ + BaseContext: ctx, + HandlerSource: func() Handler { + return &testHandler{ + onHandshake: func(ctx context.Context, h *Handshake) { + writerReady <- WriterFromContext(ctx) + }, + } + }, + } + go peer.Serve(l) + + conn := helperDialPeer(t, l.Addr().String(), "haproxy_peer", "go_peer") + defer conn.Close() + + var w *Writer + select { + case w = <-writerReady: + case <-ctx.Done(): + t.Fatal("timeout waiting for writer") + } + + tableDef := &sticktable.Definition{ + StickTableID: 0, + Name: "multi_table", + KeyType: sticktable.KeyTypeString, + KeyLength: 50, + DataTypes: []sticktable.DataTypeDefinition{ + {DataType: sticktable.DataTypeConnectionsCounter}, + {DataType: sticktable.DataTypeBytesInCounter}, + }, + Expiry: 600000, + } + + if err := w.SendTableDefinition(tableDef); err != nil { + t.Fatal(err) + } + + const numEntries = 10 + for i := 0; i < numEntries; i++ { + key := sticktable.StringKey(fmt.Sprintf("key_%d", i)) + connCnt := sticktable.UnsignedIntegerData(uint32(i * 10)) + bytesIn := sticktable.UnsignedLongLongData(uint64(i * 1000)) + entry := &sticktable.EntryUpdate{ + StickTable: tableDef, + Key: &key, + Data: []sticktable.MapData{&connCnt, &bytesIn}, + } + + if err := w.SendEntry(entry); err != nil { + t.Fatalf("sending entry %d: %v", i, err) + } + } + + if w.nextUpdateID != numEntries { + t.Errorf("expected nextUpdateID %d, got %d", numEntries, w.nextUpdateID) + } +} + +func TestWriterSendTimedEntry(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + writerReady := make(chan *Writer, 1) + peer := &Peer{ + BaseContext: ctx, + HandlerSource: func() Handler { + return &testHandler{ + onHandshake: func(ctx context.Context, h *Handshake) { + writerReady <- WriterFromContext(ctx) + }, + } + }, + } + go peer.Serve(l) + + conn := helperDialPeer(t, l.Addr().String(), "haproxy_peer", "go_peer") + defer conn.Close() + + var w *Writer + select { + case w = <-writerReady: + case <-ctx.Done(): + t.Fatal("timeout waiting for writer") + } + + tableDef := &sticktable.Definition{ + StickTableID: 0, + Name: "timed_table", + KeyType: sticktable.KeyTypeIPv4Address, + KeyLength: 4, + DataTypes: []sticktable.DataTypeDefinition{ + {DataType: sticktable.DataTypeSessionsCounter}, + }, + Expiry: 300000, + } + + if err := w.SendTableDefinition(tableDef); err != nil { + t.Fatal(err) + } + + key := sticktable.IPv4AddressKey(netip.MustParseAddr("10.0.0.1")) + sessCnt := sticktable.UnsignedIntegerData(99) + entry := &sticktable.EntryUpdate{ + StickTable: tableDef, + Key: &key, + Data: []sticktable.MapData{&sessCnt}, + WithExpiry: true, + Expiry: 60000, + } + + if err := w.SendEntry(entry); err != nil { + t.Fatal(err) + } +} + +// TestWriterRoundTrip verifies that data written by the Writer can be read +// and decoded correctly by the protocol client's message handler (full loop). +func TestWriterRoundTrip(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + updates := make(chan *sticktable.EntryUpdate, 10) + + peerB := &Peer{ + BaseContext: ctx, + Handler: HandlerFunc(func(_ context.Context, u *sticktable.EntryUpdate) { + updates <- u + }), + } + go peerB.Serve(l) + + conn := helperDialPeer(t, l.Addr().String(), "peer_a", "peer_b") + defer conn.Close() + + w := newWriter(conn, &sync.Mutex{}) + + tableDef := &sticktable.Definition{ + StickTableID: 0, + Name: "roundtrip_table", + KeyType: sticktable.KeyTypeString, + KeyLength: 50, + DataTypes: []sticktable.DataTypeDefinition{ + {DataType: sticktable.DataTypeGPC0}, + {DataType: sticktable.DataTypeHttpRequestsRate, Counter: 1, Period: 10000}, + }, + Expiry: 600000, + } + + if err := w.SendTableDefinition(tableDef); err != nil { + t.Fatal(err) + } + + key := sticktable.StringKey("test_key") + gpc0 := sticktable.UnsignedIntegerData(42) + rate := sticktable.FreqData{ + CurrentTick: 500, + CurrentPeriod: 10, + LastPeriod: 8, + } + entry := &sticktable.EntryUpdate{ + StickTable: tableDef, + Key: &key, + Data: []sticktable.MapData{&gpc0, &rate}, + } + + if err := w.SendEntry(entry); err != nil { + t.Fatal(err) + } + + select { + case u := <-updates: + if u.StickTable.Name != "roundtrip_table" { + t.Errorf("expected table name %q, got %q", "roundtrip_table", u.StickTable.Name) + } + if u.Key.String() != "test_key" { + t.Errorf("expected key %q, got %q", "test_key", u.Key.String()) + } + if u.LocalUpdateID != 0 { + t.Errorf("expected update ID 0, got %d", u.LocalUpdateID) + } + + gotGPC0 := u.Data[0].(*sticktable.UnsignedIntegerData) + if *gotGPC0 != 42 { + t.Errorf("expected gpc0 value 42, got %d", *gotGPC0) + } + + wantRate := &sticktable.FreqData{ + CurrentTick: 500, + CurrentPeriod: 10, + LastPeriod: 8, + } + gotRate := u.Data[1].(*sticktable.FreqData) + if diff := cmp.Diff(wantRate, gotRate); diff != "" { + t.Errorf("FreqData mismatch (-want +got):\n%s", diff) + } + case <-ctx.Done(): + t.Fatal("timeout waiting for roundtrip update") + } +} + +// testHandler is a Handler implementation for testing that allows +// overriding individual methods. +type testHandler struct { + onUpdate func(context.Context, *sticktable.EntryUpdate) + onHandshake func(context.Context, *Handshake) +} + +func (h *testHandler) HandleUpdate(ctx context.Context, u *sticktable.EntryUpdate) { + if h.onUpdate != nil { + h.onUpdate(ctx, u) + } +} + +func (h *testHandler) HandleHandshake(ctx context.Context, hs *Handshake) { + if h.onHandshake != nil { + h.onHandshake(ctx, hs) + } +} + +func (h *testHandler) Close() error { return nil } diff --git a/pkg/encoding/kvunmarshal.go b/pkg/encoding/kvunmarshal.go index 8c274d5..3b289bc 100644 --- a/pkg/encoding/kvunmarshal.go +++ b/pkg/encoding/kvunmarshal.go @@ -43,11 +43,11 @@ func (k *KVScanner) Unmarshal(v any) error { // Build a slice of field info to avoid string allocations during lookup type fieldInfo struct { - keyStr string // cached for NameEquals and error messages - fieldIdx int field reflect.Value // cached to avoid repeated rv.Field() calls - fieldKind reflect.Kind // cached to avoid repeated Kind() calls - isPointer bool // cached to avoid repeated checks + keyStr string // cached for NameEquals and error messages + fieldIdx int + fieldKind reflect.Kind // cached to avoid repeated Kind() calls + isPointer bool // cached to avoid repeated checks } fields := make([]fieldInfo, 0, rt.NumField()) pointerFieldIndices := make([]int, 0, rt.NumField()) // track pointer field indices for final cleanup