From 1ced044b3aa68ea0ba8fab3b293f5ebe54b6d00f Mon Sep 17 00:00:00 2001 From: Rob Murray Date: Mon, 16 Jun 2025 14:37:57 +0100 Subject: [PATCH 1/3] nftables: configurable base chain priorities Make base chain priorities in the bridge's nftables tables configurable. For example, in daemon.json: "bridge-nftables-priorities": { "filter-FORWARD": "3", "nat-POSTROUTING": "101", "nat-PREROUTING": "-101", "nat-OUTPUT": "-102", "raw-PREROUTING": "-301" }, Or, on the command line: dockerd --bridge-nftables-priority filter-FORWARD=3 ... Signed-off-by: Rob Murray --- daemon/command/config_unix.go | 1 + daemon/config/config.go | 15 +-- daemon/config/config_linux.go | 22 +++-- daemon/daemon_unix.go | 1 + .../libnetwork/drivers/bridge/bridge_linux.go | 12 ++- .../drivers/bridge/bridge_linux_test.go | 2 +- .../bridge/internal/nftabler/nftabler.go | 66 +++++++++++-- .../bridge/internal/nftabler/nftabler_test.go | 2 +- .../network/bridge/bridge_linux_test.go | 96 +++++++++++++++++++ 9 files changed, 188 insertions(+), 29 deletions(-) diff --git a/daemon/command/config_unix.go b/daemon/command/config_unix.go index fa9e3c5f9f376..2119bdffeac21 100644 --- a/daemon/command/config_unix.go +++ b/daemon/command/config_unix.go @@ -27,6 +27,7 @@ func installConfigFlags(conf *config.Config, flags *pflag.FlagSet) { flags.BoolVar(&conf.BridgeConfig.DisableFilterForwardDrop, "ip-forward-no-drop", false, "Do not set the filter-FORWARD policy to DROP when enabling IP forwarding") flags.BoolVar(&conf.BridgeConfig.EnableIPMasq, "ip-masq", true, "Enable IP masquerading for the default bridge network") flags.BoolVar(&conf.BridgeConfig.EnableIPv6, "ipv6", false, "Enable IPv6 networking for the default bridge network") + flags.Var(opts.NewNamedMapOpts("bridge-nftables-priorities", conf.NftablesPriorities, nil), "bridge-nftables-priority", "Base chain priorities for bridge driver nftables") flags.StringVar(&conf.BridgeConfig.IP, "bip", "", "IPv4 address for the default bridge") flags.StringVar(&conf.BridgeConfig.IP6, "bip6", "", "IPv6 address for the default bridge") flags.StringVarP(&conf.BridgeConfig.Iface, "bridge", "b", "", "Attach containers to a network bridge") diff --git a/daemon/config/config.go b/daemon/config/config.go index 353e2e398798f..b2b8245969ebb 100644 --- a/daemon/config/config.go +++ b/daemon/config/config.go @@ -81,13 +81,14 @@ const ( // Use this to differentiate these options // with others like the ones in TLSOptions. var flatOptions = map[string]bool{ - "cluster-store-opts": true, - "default-network-opts": true, - "log-opts": true, - "runtimes": true, - "default-ulimits": true, - "features": true, - "builder": true, + "cluster-store-opts": true, + "default-network-opts": true, + "bridge-nftables-priorities": true, + "log-opts": true, + "runtimes": true, + "default-ulimits": true, + "features": true, + "builder": true, } // skipValidateOptions contains configuration keys diff --git a/daemon/config/config_linux.go b/daemon/config/config_linux.go index 1d11a3d3ec318..2cf3d356e6424 100644 --- a/daemon/config/config_linux.go +++ b/daemon/config/config_linux.go @@ -41,14 +41,15 @@ const ( type BridgeConfig struct { DefaultBridgeConfig - EnableIPTables bool `json:"iptables,omitempty"` - EnableIP6Tables bool `json:"ip6tables,omitempty"` - EnableIPForward bool `json:"ip-forward,omitempty"` - DisableFilterForwardDrop bool `json:"ip-forward-no-drop,omitempty"` - EnableIPMasq bool `json:"ip-masq,omitempty"` - EnableUserlandProxy bool `json:"userland-proxy,omitempty"` - UserlandProxyPath string `json:"userland-proxy-path,omitempty"` - AllowDirectRouting bool `json:"allow-direct-routing,omitempty"` + EnableIPTables bool `json:"iptables,omitempty"` + EnableIP6Tables bool `json:"ip6tables,omitempty"` + EnableIPForward bool `json:"ip-forward,omitempty"` + DisableFilterForwardDrop bool `json:"ip-forward-no-drop,omitempty"` + EnableIPMasq bool `json:"ip-masq,omitempty"` + EnableUserlandProxy bool `json:"userland-proxy,omitempty"` + UserlandProxyPath string `json:"userland-proxy-path,omitempty"` + AllowDirectRouting bool `json:"allow-direct-routing,omitempty"` + NftablesPriorities map[string]string `json:"bridge-nftables-priorities,omitempty"` } // DefaultBridgeConfig stores all the parameters for the default bridge network. @@ -147,6 +148,7 @@ func setPlatformDefaults(cfg *Config) error { cfg.SeccompProfile = SeccompProfileDefault cfg.IpcMode = string(DefaultIpcMode) cfg.Runtimes = make(map[string]system.Runtime) + cfg.NftablesPriorities = make(map[string]string) if cgroups.Mode() != cgroups.Unified { cfg.CgroupNamespaceMode = string(DefaultCgroupV1NamespaceMode) @@ -243,6 +245,10 @@ func validatePlatformConfig(conf *Config) error { return errors.Wrap(err, "invalid fixed-cidr-v6") } + if err := bridge.ValidateBaseChainPriorities(conf.NftablesPriorities); err != nil { + return err + } + if err := validateFirewallBackend(conf.FirewallBackend); err != nil { return errors.Wrap(err, "invalid firewall-backend") } diff --git a/daemon/daemon_unix.go b/daemon/daemon_unix.go index ae294f333908b..1e61b77631930 100644 --- a/daemon/daemon_unix.go +++ b/daemon/daemon_unix.go @@ -938,6 +938,7 @@ func networkPlatformOptions(conf *config.Config) []nwconfig.Option { "EnableIP6Tables": conf.BridgeConfig.EnableIP6Tables, "Hairpin": !conf.EnableUserlandProxy || conf.UserlandProxyPath == "", "AllowDirectRouting": conf.BridgeConfig.AllowDirectRouting, + "NftablesPriorities": conf.NftablesPriorities, }, }), } diff --git a/daemon/libnetwork/drivers/bridge/bridge_linux.go b/daemon/libnetwork/drivers/bridge/bridge_linux.go index 3c815c908b93b..d71002850ab3f 100644 --- a/daemon/libnetwork/drivers/bridge/bridge_linux.go +++ b/daemon/libnetwork/drivers/bridge/bridge_linux.go @@ -77,6 +77,7 @@ type configuration struct { // hairpinned. Hairpin bool AllowDirectRouting bool + NftablesPriorities map[string]string } // networkConfiguration for network specific configuration @@ -528,7 +529,7 @@ func (d *driver) configure(option map[string]interface{}) error { Hairpin: config.Hairpin, AllowDirectRouting: config.AllowDirectRouting, WSL2Mirrored: isRunningUnderWSL2MirroredMode(context.Background()), - }) + }, config.NftablesPriorities) if err != nil { return err } @@ -546,9 +547,14 @@ func (d *driver) configure(option map[string]interface{}) error { return d.initStore() } -var newFirewaller = func(ctx context.Context, config firewaller.Config) (firewaller.Firewaller, error) { +// ValidateBaseChainPriorities checks nftables base chain priority configuration. +func ValidateBaseChainPriorities(prios map[string]string) error { + return nftabler.ValidateBaseChainPriorities(prios) +} + +var newFirewaller = func(ctx context.Context, config firewaller.Config, nftablesPriorities map[string]string) (firewaller.Firewaller, error) { if nftables.Enabled() { - fw, err := nftabler.NewNftabler(ctx, config) + fw, err := nftabler.NewNftabler(ctx, config, nftablesPriorities) if err != nil { return nil, err } diff --git a/daemon/libnetwork/drivers/bridge/bridge_linux_test.go b/daemon/libnetwork/drivers/bridge/bridge_linux_test.go index f2daa3c744076..cb28925236b87 100644 --- a/daemon/libnetwork/drivers/bridge/bridge_linux_test.go +++ b/daemon/libnetwork/drivers/bridge/bridge_linux_test.go @@ -1337,7 +1337,7 @@ func TestCreateParallel(t *testing.T) { func useStubFirewaller(t *testing.T) { origNewFirewaller := newFirewaller - newFirewaller = func(_ context.Context, config firewaller.Config) (firewaller.Firewaller, error) { + newFirewaller = func(_ context.Context, config firewaller.Config, _ map[string]string) (firewaller.Firewaller, error) { return firewaller.NewStubFirewaller(config), nil } t.Cleanup(func() { newFirewaller = origNewFirewaller }) diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go index 1be24340cf92f..b1434cf38e915 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go @@ -4,7 +4,9 @@ package nftabler import ( "context" + "errors" "fmt" + "strconv" "github.com/containerd/log" "github.com/docker/docker/daemon/libnetwork/drivers/bridge/internal/firewaller" @@ -44,6 +46,14 @@ const ( rawPreroutingPortsRuleGroup = iota + initialRuleGroup + 1 ) +var baseChainNames = map[string]struct{}{ + forwardChain: {}, + postroutingChain: {}, + preroutingChain: {}, + outputChain: {}, + rawPreroutingChain: {}, +} + type nftabler struct { config firewaller.Config cleaner firewaller.FirewallCleaner @@ -51,12 +61,21 @@ type nftabler struct { table6 nftables.TableRef } -func NewNftabler(ctx context.Context, config firewaller.Config) (firewaller.Firewaller, error) { +func NewNftabler(ctx context.Context, config firewaller.Config, baseChainPriorities map[string]string) (firewaller.Firewaller, error) { nft := &nftabler{config: config} + // Convert base chain priorities to integers, assuming the daemon has called + // ValidateBaseChainPriorities, so errors don't need to be handled. + bcps := map[string]int{} + for chain, prio := range baseChainPriorities { + if p, err := strconv.Atoi(prio); err == nil { + bcps[chain] = p + } + } + if nft.config.IPv4 { var err error - nft.table4, err = nft.init(ctx, nftables.IPv4) + nft.table4, err = nft.init(ctx, nftables.IPv4, bcps) if err != nil { return nil, err } @@ -67,7 +86,7 @@ func NewNftabler(ctx context.Context, config firewaller.Config) (firewaller.Fire if nft.config.IPv6 { var err error - nft.table6, err = nft.init(ctx, nftables.IPv6) + nft.table6, err = nft.init(ctx, nftables.IPv6, bcps) if err != nil { return nil, err } @@ -84,6 +103,20 @@ func NewNftabler(ctx context.Context, config firewaller.Config) (firewaller.Fire return nft, nil } +// ValidateBaseChainPriorities checks nftables base chain priority configuration. +func ValidateBaseChainPriorities(prios map[string]string) error { + var errs []error + for c, p := range prios { + if _, ok := baseChainNames[c]; !ok { + errs = append(errs, fmt.Errorf("%q is not a valid base chain name", c)) + } + if _, ok := strconv.Atoi(p); ok != nil { + errs = append(errs, fmt.Errorf("priority %q for base chain %q is not an integer", p, c)) + } + } + return errors.Join(errs...) +} + func (nft *nftabler) getTable(ipv firewaller.IPVersion) nftables.TableRef { if ipv == firewaller.IPv4 { return nft.table4 @@ -100,13 +133,21 @@ func (nft *nftabler) FilterForwardDrop(ctx context.Context, ipv firewaller.IPVer } // init creates the bridge driver's nftables table for IPv4 or IPv6. -func (nft *nftabler) init(ctx context.Context, family nftables.Family) (nftables.TableRef, error) { +func (nft *nftabler) init(ctx context.Context, family nftables.Family, baseChainPriorities map[string]int) (nftables.TableRef, error) { // Instantiate the table. table, err := nftables.NewTable(family, dockerTable) if err != nil { return table, err } + // Reload the table while it's got no elements to clear an old table if one + // exists. This is necessary because, if base chain priorities have changed and + // the old table isn't removed, nft produces an error message for the base chain + // (but seems to apply the change anyway). + if err := table.Reload(ctx); err != nil { + return nftables.TableRef{}, err + } + // Set up the filter forward chain. // // This base chain only contains two rules that use verdict maps: @@ -119,7 +160,7 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family) (nftables fwdChain, err := table.BaseChain(ctx, forwardChain, nftables.BaseChainTypeFilter, nftables.BaseChainHookForward, - nftables.BaseChainPriorityFilter) + baseChainPriority(forwardChain, nftables.BaseChainPriorityFilter, baseChainPriorities)) if err != nil { return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) } @@ -139,7 +180,7 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family) (nftables natPostRtChain, err := table.BaseChain(ctx, postroutingChain, nftables.BaseChainTypeNAT, nftables.BaseChainHookPostrouting, - nftables.BaseChainPrioritySrcNAT) + baseChainPriority(postroutingChain, nftables.BaseChainPrioritySrcNAT, baseChainPriorities)) if err != nil { return nftables.TableRef{}, err } @@ -159,7 +200,7 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family) (nftables natPreRtChain, err := table.BaseChain(ctx, preroutingChain, nftables.BaseChainTypeNAT, nftables.BaseChainHookPrerouting, - nftables.BaseChainPriorityDstNAT) + baseChainPriority(preroutingChain, nftables.BaseChainPriorityDstNAT, baseChainPriorities)) if err != nil { return nftables.TableRef{}, err } @@ -171,7 +212,7 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family) (nftables natOutputChain, err := table.BaseChain(ctx, outputChain, nftables.BaseChainTypeNAT, nftables.BaseChainHookOutput, - nftables.BaseChainPriorityDstNAT) + baseChainPriority(outputChain, nftables.BaseChainPriorityDstNAT, baseChainPriorities)) if err != nil { return nftables.TableRef{}, err } @@ -192,7 +233,7 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family) (nftables if _, err := table.BaseChain(ctx, rawPreroutingChain, nftables.BaseChainTypeFilter, nftables.BaseChainHookPrerouting, - nftables.BaseChainPriorityRaw); err != nil { + baseChainPriority(rawPreroutingChain, nftables.BaseChainPriorityRaw, baseChainPriorities)); err != nil { return nftables.TableRef{}, err } @@ -213,3 +254,10 @@ func nftApply(ctx context.Context, table nftables.TableRef) error { } return nil } + +func baseChainPriority(chainName string, def int, overrides map[string]int) int { + if p, ok := overrides[chainName]; ok { + return p + } + return def +} diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go index 0dd380b922eba..81ef070663d4a 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go @@ -125,7 +125,7 @@ func testNftabler(t *testing.T, tn string, config firewaller.Config, netConfig f // Initialise iptables, check the iptables config looks like it should look at the // end of the test (after deleting per-network and per-port rules). - fw, err := NewNftabler(context.Background(), config) + fw, err := NewNftabler(context.Background(), config, nil) assert.NilError(t, err) checkResults("ip", rnWSL2Mirrored(fmt.Sprintf("%s_cleaned,hairpin=%v", tn, config.Hairpin)), config.IPv4) checkResults("ip6", fmt.Sprintf("%s_cleaned,hairpin=%v", tn, config.Hairpin), config.IPv6) diff --git a/integration/network/bridge/bridge_linux_test.go b/integration/network/bridge/bridge_linux_test.go index 1825800bbe258..d284b0309da6c 100644 --- a/integration/network/bridge/bridge_linux_test.go +++ b/integration/network/bridge/bridge_linux_test.go @@ -5,6 +5,8 @@ import ( "fmt" "net" "net/netip" + "regexp" + "strconv" "strings" "testing" "time" @@ -939,3 +941,97 @@ func TestFirewallBackendSwitch(t *testing.T) { assert.Check(t, len(dockerChains) > 0, "Expected iptables to have at least one docker chain") assert.Check(t, !nftablesTablesExist(), "nftables tables exist after running with iptables") } + +// TestBaseChainPriorityOverride checks that priorities of nftables base chains can be configured. +func TestBaseChainPriorityOverride(t *testing.T) { + skip.If(t, !strings.Contains(testEnv.FirewallBackendDriver(), "nftables"), + "test is nftables specific, and nftables isn't in use") + _ = setupTest(t) + d := daemon.New(t) + defer d.Stop(t) + + baseChains := []string{"filter-FORWARD", "nat-POSTROUTING", "nat-PREROUTING", "nat-OUTPUT", "raw-PREROUTING"} + + re := regexp.MustCompile("priority ([-]?[0-9]+)") + chainPriority := func(name string) int { + t.Helper() + res := icmd.RunCommand("nft", "-n", "list chain "+name) + assert.Assert(t, res.ExitCode == 0, "failed to list chain %s", name) + m := re.FindStringSubmatch(res.Stdout()) + assert.Assert(t, is.Len(m, 2), "failed to find chain priority in %s", res.Stdout()) + i, err := strconv.Atoi(m[1]) + assert.NilError(t, err) + return i + } + chainPriorities := func(fam string) map[string]int { + t.Helper() + res := map[string]int{} + for _, chain := range baseChains { + res[chain] = chainPriority(fam + " docker-bridges " + chain) + } + return res + } + + d.Start(t) + defaults4 := chainPriorities("ip") + defaults6 := chainPriorities("ip6") + assert.Check(t, is.DeepEqual(defaults4, defaults6), "Expected ip/ip6 base chain priorities to be the same") + + args := make([]string, 0, len(baseChains)) + for _, chain := range baseChains { + args = append(args, fmt.Sprintf("--bridge-nftables-priority=%s=%d", chain, defaults4[chain]+1)) + } + + d.Stop(t) + d.Start(t, args...) + + modified4 := chainPriorities("ip") + modified6 := chainPriorities("ip6") + assert.Check(t, is.DeepEqual(modified4, modified6), "Expected ip/ip6 base chain priorities to be the same") + for _, chain := range baseChains { + assert.Check(t, is.Equal(modified4[chain], defaults4[chain]+1)) + } +} + +// TestBaseChainPriorityValidation checks that nftables priority overrides are validated on startup. +func TestBaseChainPriorityValidation(t *testing.T) { + skip.If(t, !strings.Contains(testEnv.FirewallBackendDriver(), "nftables"), + "test is nftables specific, and nftables isn't in use") + _ = setupTest(t) + + testcases := []struct { + name string + opt string + expErr string + }{ + { + name: "bad base chain name", + opt: "nosuchchain=0", + expErr: `"nosuchchain" is not a valid base chain name`, + }, + { + name: "non-integer priority", + opt: "filter-FORWARD=filter+1", + expErr: `priority "filter+1" for base chain "filter-FORWARD" is not an integer`, + }, + { + name: "missing priority", + opt: "filter-FORWARD", + expErr: `priority "" for base chain "filter-FORWARD" is not an integer`, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + d := daemon.New(t) + defer d.Stop(t) + + err := d.StartWithError("--bridge-nftables-priority=" + tc.opt) + assert.Check(t, err != nil, "expected startup error") + logLine, err := d.TailLogs(1) + assert.Check(t, err) + assert.Check(t, len(logLine) == 1) + assert.Check(t, is.Contains(string(logLine[0]), tc.expErr)) + }) + } +} From 00a768933887022cfdf3e466b368880800933df7 Mon Sep 17 00:00:00 2001 From: Rob Murray Date: Thu, 24 Apr 2025 17:13:57 +0100 Subject: [PATCH 2/3] Reload nftables on SIGHUP Signed-off-by: Rob Murray --- .../libnetwork/drivers/bridge/bridge_linux.go | 23 +++++++++++++++++++ .../bridge/internal/firewaller/firewaller.go | 7 ++++++ .../bridge/internal/nftabler/nftabler.go | 11 +++++++++ 3 files changed, 41 insertions(+) diff --git a/daemon/libnetwork/drivers/bridge/bridge_linux.go b/daemon/libnetwork/drivers/bridge/bridge_linux.go index d71002850ab3f..6bb566534727b 100644 --- a/daemon/libnetwork/drivers/bridge/bridge_linux.go +++ b/daemon/libnetwork/drivers/bridge/bridge_linux.go @@ -9,6 +9,7 @@ import ( "net" "net/netip" "os" + "os/signal" "slices" "strconv" "strings" @@ -43,6 +44,7 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "golang.org/x/sys/unix" ) const ( @@ -543,6 +545,7 @@ func (d *driver) configure(option map[string]interface{}) error { d.configNetwork.Lock() defer d.configNetwork.Unlock() iptables.OnReloaded(d.handleFirewalldReload) + d.startFirewallReloader() return d.initStore() } @@ -572,6 +575,26 @@ var newFirewaller = func(ctx context.Context, config firewaller.Config, nftables return iptabler.NewIptabler(ctx, config) } +func (d *driver) startFirewallReloader() { + r, ok := d.firewaller.(firewaller.Reloader) + if !ok { + return + } + + hupC := make(chan os.Signal, 1) + signal.Notify(hupC, unix.SIGHUP) + go func() { + for range hupC { + d.configNetwork.Lock() + log.G(context.Background()).Info("Received SIGHUP, reloading firewall rules") + if err := r.Reload(context.Background()); err != nil { + log.G(context.Background()).Errorf("Failed to reload firewall rules: %v", err) + } + d.configNetwork.Unlock() + } + }() +} + func (d *driver) getNetwork(id string) (*bridgeNetwork, error) { d.Lock() defer d.Unlock() diff --git a/daemon/libnetwork/drivers/bridge/internal/firewaller/firewaller.go b/daemon/libnetwork/drivers/bridge/internal/firewaller/firewaller.go index 89622cae33085..d9deb31e308e3 100644 --- a/daemon/libnetwork/drivers/bridge/internal/firewaller/firewaller.go +++ b/daemon/libnetwork/drivers/bridge/internal/firewaller/firewaller.go @@ -106,6 +106,13 @@ type Network interface { DelLink(ctx context.Context, parentIP, childIP netip.Addr, ports []types.TransportPort) } +// Reloader is an optional interface for a Firewaller. +type Reloader interface { + // Reload the current firewall rules. The caller is responsible for locking, + // there must not be any concurrent requests to modify rules. + Reload(ctx context.Context) error +} + // FirewallCleanerSetter is an optional interface for a Firewaller. type FirewallCleanerSetter interface { // SetFirewallCleaner replaces the FirewallCleaner (possibly with 'nil'). diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go index b1434cf38e915..5be4e378b81f4 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go @@ -117,6 +117,17 @@ func ValidateBaseChainPriorities(prios map[string]string) error { return errors.Join(errs...) } +func (nft *nftabler) Reload(ctx context.Context) error { + var errs []error + if nft.config.IPv4 { + errs = append(errs, nft.table4.Reload(ctx)) + } + if nft.config.IPv6 { + errs = append(errs, nft.table6.Reload(ctx)) + } + return errors.Join(errs...) +} + func (nft *nftabler) getTable(ipv firewaller.IPVersion) nftables.TableRef { if ipv == firewaller.IPv4 { return nft.table4 From c837bb0abf4119e45715b6dcc1d1023155990a27 Mon Sep 17 00:00:00 2001 From: Rob Murray Date: Tue, 8 Jul 2025 11:40:52 +0100 Subject: [PATCH 3/3] Rework the interface to libnet/internal/nftables Add nftables.Modifier, to hold a queue of commands that can be applied using Modifier.Apply. No updates are made to the underlying Table until Apply is called, errors in the queue if commands are deferred until Apply. This has the advantages that: - less error handling is needed in code that generates update commands - it's transactional, without needing explicit transactions Minor disadvantages are that it's slightly more difficult to debug updates, as it's no longer possible to step through the call making an update to the Table manipulation in a debugger - and errors in the command, and errors like trying to update a nonexistent chain/set/vmap, deleting an object that doesn't exist or creating a duplicate are not reported until the updates are applied (so, it's a little less clear where the update came from). Signed-off-by: Rob Murray --- .../bridge/internal/nftabler/endpoint.go | 38 +- .../drivers/bridge/internal/nftabler/link.go | 56 +- .../bridge/internal/nftabler/network.go | 261 +++-- .../bridge/internal/nftabler/nftabler.go | 195 ++-- .../drivers/bridge/internal/nftabler/port.go | 122 ++- .../drivers/bridge/internal/nftabler/wsl2.go | 15 +- .../internal/nftables/nftables_linux.go | 899 ++++++++++++------ .../internal/nftables/nftables_linux_test.go | 682 ++++++++++--- .../internal/nftables/testdata/.gitattributes | 1 + .../testdata/TestChainRuleGroups.golden | 9 + .../testdata/TestChain_modified.golden | 2 +- .../testdata/TestIgnoreExist_created.golden | 5 + .../testdata/TestIgnoreExist_deleted.golden | 4 + .../testdata/TestReload_created.golden | 2 +- .../testdata/TestReload_recovered.golden | 6 +- .../TestSetBaseChainPolicy_accept.golden | 8 + .../TestSetBaseChainPolicy_drop.golden | 8 + .../testdata/TestSet_created46.golden | 7 - .../nftables/testdata/TestSet_created6.golden | 7 + .../nftables/testdata/TestSet_deleted4.golden | 11 - .../nftables/testdata/TestSet_deleted6.golden | 2 + .../testdata/TestTable_created46.golden | 2 - .../testdata/TestTable_created6.golden | 2 + .../nftables/testdata/TestVMap_deleted.golden | 4 - .../testdata/TestValidation_empty.golden | 2 + daemon/libnetwork/resolver_unix.go | 57 +- 26 files changed, 1579 insertions(+), 828 deletions(-) create mode 100644 daemon/libnetwork/internal/nftables/testdata/.gitattributes create mode 100644 daemon/libnetwork/internal/nftables/testdata/TestChainRuleGroups.golden create mode 100644 daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_created.golden create mode 100644 daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_deleted.golden create mode 100644 daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_accept.golden create mode 100644 daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_drop.golden create mode 100644 daemon/libnetwork/internal/nftables/testdata/TestSet_created6.golden create mode 100644 daemon/libnetwork/internal/nftables/testdata/TestSet_deleted6.golden create mode 100644 daemon/libnetwork/internal/nftables/testdata/TestTable_created6.golden create mode 100644 daemon/libnetwork/internal/nftables/testdata/TestValidation_empty.golden diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/endpoint.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/endpoint.go index 1fdac1ca072af..4e39f57d3ff90 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/endpoint.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/endpoint.go @@ -25,18 +25,24 @@ func (n *network) DelEndpoint(ctx context.Context, epIPv4, epIPv6 netip.Addr) er func (n *network) modEndpoint(ctx context.Context, epIPv4, epIPv6 netip.Addr, enable bool) error { if n.fw.config.IPv4 && epIPv4.IsValid() { - if err := n.filterDirectAccess(ctx, n.fw.table4, n.config.Config4, epIPv4, enable); err != nil { - return err + tm := n.fw.table4.Modifier() + updater := tm.Create + if !enable { + updater = tm.Delete } - if err := nftApply(ctx, n.fw.table4); err != nil { + n.filterDirectAccess(updater, tm.Family(), n.config.Config4, epIPv4) + if err := tm.Apply(ctx); err != nil { return fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } } if n.fw.config.IPv6 && epIPv6.IsValid() { - if err := n.filterDirectAccess(ctx, n.fw.table6, n.config.Config6, epIPv6, enable); err != nil { - return err + tm := n.fw.table6.Modifier() + updater := tm.Create + if !enable { + updater = tm.Delete } - if err := nftApply(ctx, n.fw.table6); err != nil { + n.filterDirectAccess(updater, tm.Family(), n.config.Config6, epIPv6) + if err := tm.Apply(ctx); err != nil { return fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } } @@ -53,17 +59,21 @@ func (n *network) modEndpoint(ctx context.Context, epIPv4, epIPv6 netip.Addr, en // kernel support). // // Packets originating on the bridge's own interface and addressed directly to the -// container are allowed - the host always has direct access to its own containers -// (it doesn't need to use the port mapped to its own addresses, although it can). +// container are allowed - the host always has direct access to its own containers. +// (It doesn't need to use the port mapped to its own addresses, although it can.) // // "Trusted interfaces" are treated in the same way as the bridge itself. -func (n *network) filterDirectAccess(ctx context.Context, table nftables.TableRef, conf firewaller.NetworkConfigFam, epIP netip.Addr, enable bool) error { +func (n *network) filterDirectAccess(updater func(nftables.Obj), fam nftables.Family, conf firewaller.NetworkConfigFam, epIP netip.Addr) { if n.config.Internal || conf.Unprotected || conf.Routed || n.fw.config.AllowDirectRouting { - return nil + return } - updater := table.ChainUpdateFunc(ctx, rawPreroutingChain, enable) ifNames := strings.Join(n.config.TrustedHostInterfaces, ", ") - return updater(ctx, rawPreroutingPortsRuleGroup, - `%s daddr %s iifname != { %s, %s } counter drop comment "DROP DIRECT ACCESS"`, - table.Family(), epIP, n.config.IfName, ifNames) + updater(nftables.Rule{ + Chain: rawPreroutingChain, + Group: rawPreroutingPortsRuleGroup, + Rule: []string{ + string(fam), "daddr", epIP.String(), + "iifname != {", n.config.IfName, ",", ifNames, `} counter drop comment "DROP DIRECT ACCESS"`, + }, + }) } diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/link.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/link.go index d612b767a6de8..6179a7f28bf52 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/link.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/link.go @@ -7,6 +7,9 @@ import ( "errors" "fmt" "net/netip" + "strconv" + + "github.com/docker/docker/daemon/libnetwork/internal/nftables" "github.com/containerd/log" "github.com/docker/docker/daemon/libnetwork/types" @@ -20,44 +23,47 @@ func (n *network) AddLink(ctx context.Context, parentIP, childIP netip.Addr, por return errors.New("cannot link to a container with an empty child IP address") } - chain := n.fw.table4.Chain(ctx, chainFilterFwdIn(n.config.IfName)) + tm := n.fw.table4.Modifier() for _, port := range ports { - for _, rule := range legacyLinkRules(parentIP, childIP, port) { - if err := chain.AppendRule(ctx, fwdInLegacyLinksRuleGroup, rule); err != nil { - return err - } - } + updateLegacyLinkRules(tm.Create, chainFilterFwdIn(n.config.IfName), parentIP, childIP, port) } - if err := nftApply(ctx, n.fw.table4); err != nil { + if err := tm.Apply(ctx); err != nil { return fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } return nil } func (n *network) DelLink(ctx context.Context, parentIP, childIP netip.Addr, ports []types.TransportPort) { - chain := n.fw.table4.Chain(ctx, chainFilterFwdIn(n.config.IfName)) + tm := n.fw.table4.Modifier() for _, port := range ports { - for _, rule := range legacyLinkRules(parentIP, childIP, port) { - if err := chain.DeleteRule(ctx, fwdInLegacyLinksRuleGroup, rule); err != nil { - log.G(ctx).WithFields(log.Fields{ - "rule": rule, - "error": err, - }).Warn("Failed to remove link between containers") - } - } + updateLegacyLinkRules(tm.Delete, chainFilterFwdIn(n.config.IfName), parentIP, childIP, port) } - if err := nftApply(ctx, n.fw.table4); err != nil { + if err := tm.Apply(ctx); err != nil { log.G(ctx).WithError(err).Warn("Removing link, failed to update nftables") } } -func legacyLinkRules(parentIP, childIP netip.Addr, port types.TransportPort) []string { +func updateLegacyLinkRules(updater func(command nftables.Obj), chainName string, parentIP, childIP netip.Addr, port types.TransportPort) { // TODO(robmry) - could combine rules for each proto by using an anonymous set. - return []string{ - // Match the iptables implementation, but without checking iifname/oifname (not needed - // because the addresses belong to the bridge). - fmt.Sprintf("ip saddr %s ip daddr %s %s dport %d counter accept", parentIP.Unmap(), childIP.Unmap(), port.Proto, port.Port), - // Conntrack will allow responses. So, this must be to allow unsolicited packets from an exposed port. - fmt.Sprintf("ip daddr %s ip saddr %s %s sport %d counter accept", parentIP.Unmap(), childIP.Unmap(), port.Proto, port.Port), - } + // Match the iptables implementation, but without checking iifname/oifname (not needed + // because the addresses belong to the bridge). + updater(nftables.Rule{ + Chain: chainName, + Group: fwdInLegacyLinksRuleGroup, + Rule: []string{ + "ip saddr", parentIP.Unmap().String(), + "ip daddr", childIP.Unmap().String(), port.Proto.String(), "dport", strconv.Itoa(int(port.Port)), + "counter accept", + }, + }) + // Conntrack will allow responses. So, this must be to allow unsolicited packets from an exposed port. + updater(nftables.Rule{ + Chain: chainName, + Group: fwdInLegacyLinksRuleGroup, + Rule: []string{ + "ip daddr", parentIP.Unmap().String(), + "ip saddr", childIP.Unmap().String(), port.Proto.String(), "sport", strconv.Itoa(int(port.Port)), + "counter accept", + }, + }) } diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/network.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/network.go index e9ced708e9374..5d05db4f9c5df 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/network.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/network.go @@ -9,14 +9,14 @@ import ( "github.com/containerd/log" "github.com/docker/docker/daemon/libnetwork/drivers/bridge/internal/firewaller" "github.com/docker/docker/daemon/libnetwork/internal/nftables" - "github.com/docker/docker/internal/cleanups" "go.opentelemetry.io/otel" ) type network struct { - config firewaller.NetworkConfig - cleaner func(ctx context.Context) error - fw *nftabler + config firewaller.NetworkConfig + remover4 *nftables.Modifier + remover6 *nftables.Modifier + fw *nftabler } func (nft *nftabler) NewNetwork(ctx context.Context, nc firewaller.NetworkConfig) (_ firewaller.Network, retErr error) { @@ -26,108 +26,84 @@ func (nft *nftabler) NewNetwork(ctx context.Context, nc firewaller.NetworkConfig } ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"bridge": n.config.IfName})) - var cleaner cleanups.Composite - defer func() { - if err := cleaner.Call(ctx); err != nil { - log.G(ctx).WithError(err).Warn("Failed to clean up nftables rules for network") - } - }() - if nft.cleaner != nil { nft.cleaner.DelNetwork(ctx, nc) } if n.fw.config.IPv4 { - clean, err := n.configure(ctx, nft.table4, n.config.Config4) + remover, err := n.configure(ctx, nft.table4, n.config.Config4) if err != nil { return nil, err } - if clean != nil { - cleaner.Add(clean) - } + n.remover4 = remover } if n.fw.config.IPv6 { - clean, err := n.configure(ctx, nft.table6, n.config.Config6) + remover, err := n.configure(ctx, nft.table6, n.config.Config6) if err != nil { return nil, err } - if clean != nil { - cleaner.Add(clean) - } + n.remover6 = remover } - - n.cleaner = cleaner.Release() return n, nil } -func (n *network) configure(ctx context.Context, table nftables.TableRef, conf firewaller.NetworkConfigFam) (func(context.Context) error, error) { - ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".newNetwork."+string(table.Family())) - defer span.End() - +func (n *network) configure(ctx context.Context, table nftables.Table, conf firewaller.NetworkConfigFam) (*nftables.Modifier, error) { if !conf.Prefix.IsValid() { return nil, nil } + tm := table.Modifier() + ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".newNetwork."+string(tm.Family())) + defer span.End() - var cleanup cleanups.Composite - defer cleanup.Call(ctx) - var applied bool - cleanup.Add(func(ctx context.Context) error { - if applied { - return nftApply(ctx, table) - } - return nil - }) + fwdInChain := chainFilterFwdIn(n.config.IfName) + fwdOutChain := chainFilterFwdOut(n.config.IfName) + natPostRtInChain := chainNatPostRtIn(n.config.IfName) + natPostRtOutChain := chainNatPostRtOut(n.config.IfName) // Filter chain - fwdInChain := table.Chain(ctx, chainFilterFwdIn(n.config.IfName)) - cleanup.Add(func(ctx context.Context) error { return table.DeleteChain(ctx, chainFilterFwdIn(n.config.IfName)) }) - fwdOutChain := table.Chain(ctx, chainFilterFwdOut(n.config.IfName)) - cleanup.Add(func(ctx context.Context) error { return table.DeleteChain(ctx, chainFilterFwdOut(n.config.IfName)) }) + tm.Create(nftables.Chain{Name: fwdInChain}) + tm.Create(nftables.Chain{Name: fwdOutChain}) - cf, err := table.InterfaceVMap(ctx, filtFwdInVMap).AddElementCf(ctx, n.config.IfName, "jump "+chainFilterFwdIn(n.config.IfName)) - if err != nil { - return nil, fmt.Errorf("adding filter-forward jump for %s to %q: %w", conf.Prefix, chainFilterFwdIn(n.config.IfName), err) - } - cleanup.Add(cf) - - cf, err = table.InterfaceVMap(ctx, filtFwdOutVMap).AddElementCf(ctx, n.config.IfName, "jump "+chainFilterFwdOut(n.config.IfName)) - if err != nil { - return nil, fmt.Errorf("adding filter-forward jump for %s to %q: %w", conf.Prefix, chainFilterFwdOut(n.config.IfName), err) - } - cleanup.Add(cf) + tm.Create(nftables.VMapElement{ + VmapName: filtFwdInVMap, + Key: n.config.IfName, + Verdict: "jump " + fwdInChain, + }) + tm.Create(nftables.VMapElement{ + VmapName: filtFwdOutVMap, + Key: n.config.IfName, + Verdict: "jump " + fwdOutChain, + }) // NAT chain - natPostroutingIn := table.Chain(ctx, chainNatPostRtIn(n.config.IfName)) - cleanup.Add(func(ctx context.Context) error { return table.DeleteChain(ctx, chainNatPostRtIn(n.config.IfName)) }) - cf, err = table.InterfaceVMap(ctx, natPostroutingInVMap).AddElementCf(ctx, n.config.IfName, "jump "+chainNatPostRtIn(n.config.IfName)) - if err != nil { - return nil, fmt.Errorf("adding postrouting ingress jump for %s to %q: %w", conf.Prefix, chainNatPostRtIn(n.config.IfName), err) - } - cleanup.Add(cf) + tm.Create(nftables.Chain{Name: natPostRtInChain}) + tm.Create(nftables.VMapElement{ + VmapName: natPostroutingInVMap, + Key: n.config.IfName, + Verdict: "jump " + natPostRtInChain, + }) - natPostroutingOut := table.Chain(ctx, chainNatPostRtOut(n.config.IfName)) - cleanup.Add(func(ctx context.Context) error { return table.DeleteChain(ctx, chainNatPostRtOut(n.config.IfName)) }) - cf, err = table.InterfaceVMap(ctx, natPostroutingOutVMap).AddElementCf(ctx, n.config.IfName, "jump "+chainNatPostRtOut(n.config.IfName)) - if err != nil { - return nil, fmt.Errorf("adding postrouting egress jump for %s to %q: %w", conf.Prefix, chainNatPostRtOut(n.config.IfName), err) - } - cleanup.Add(cf) + tm.Create(nftables.Chain{Name: chainNatPostRtOut(n.config.IfName)}) + tm.Create(nftables.VMapElement{ + VmapName: natPostroutingOutVMap, + Key: n.config.IfName, + Verdict: "jump " + chainNatPostRtOut(n.config.IfName), + }) // Conntrack - cf, err = fwdInChain.AppendRuleCf(ctx, initialRuleGroup, "ct state established,related counter accept") - if err != nil { - return nil, fmt.Errorf("adding conntrack ingress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) - - cf, err = fwdOutChain.AppendRuleCf(ctx, initialRuleGroup, "ct state established,related counter accept") - if err != nil { - return nil, fmt.Errorf("adding conntrack egress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: chainFilterFwdIn(n.config.IfName), + Group: initialRuleGroup, + Rule: []string{"ct state established,related counter accept"}, + }) + tm.Create(nftables.Rule{ + Chain: chainFilterFwdOut(n.config.IfName), + Group: initialRuleGroup, + Rule: []string{"ct state established,related counter accept"}, + }) iccVerdict := "accept" if !n.config.ICC { @@ -136,68 +112,64 @@ func (n *network) configure(ctx context.Context, table nftables.TableRef, conf f if n.config.Internal { // Drop anything that's not from this network. - cf, err = fwdInChain.AppendRuleCf(ctx, initialRuleGroup, - `iifname != %s counter drop comment "INTERNAL NETWORK INGRESS"`, n.config.IfName) - if err != nil { - return nil, fmt.Errorf("adding INTERNAL NETWORK ingress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) - - cf, err = fwdOutChain.AppendRuleCf(ctx, initialRuleGroup, - `oifname != %s counter drop comment "INTERNAL NETWORK EGRESS"`, n.config.IfName) - if err != nil { - return nil, fmt.Errorf("adding INTERNAL NETWORK egress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: initialRuleGroup, + Rule: []string{`iifname != `, n.config.IfName, `counter drop comment "INTERNAL NETWORK INGRESS"`}, + }) + tm.Create(nftables.Rule{ + Chain: fwdOutChain, + Group: initialRuleGroup, + Rule: []string{`oifname != `, n.config.IfName, `counter drop comment "INTERNAL NETWORK EGRESS"`}, + }) // Accept or drop Inter-Container Communication. - cf, err = fwdInChain.AppendRuleCf(ctx, fwdInICCRuleGroup, "counter %s comment ICC", iccVerdict) - if err != nil { - return nil, fmt.Errorf("adding ICC ingress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: fwdInICCRuleGroup, + Rule: []string{"counter", iccVerdict, "comment ICC"}, + }) } else { // Inter-Container Communication - cf, err = fwdInChain.AppendRuleCf(ctx, fwdInICCRuleGroup, "iifname == %s counter %s comment ICC", - n.config.IfName, iccVerdict) - if err != nil { - return nil, fmt.Errorf("adding ICC rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: fwdInICCRuleGroup, + Rule: []string{"iifname ==", n.config.IfName, "counter", iccVerdict, "comment ICC"}, + }) // Outgoing traffic - cf, err = fwdOutChain.AppendRuleCf(ctx, initialRuleGroup, "counter accept comment OUTGOING") - if err != nil { - return nil, fmt.Errorf("adding OUTGOING rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdOutChain, + Group: initialRuleGroup, + Rule: []string{"counter accept comment OUTGOING"}, + }) // Incoming traffic if conf.Unprotected { - cf, err = fwdInChain.AppendRuleCf(ctx, fwdInFinalRuleGroup, `counter accept comment "UNPROTECTED"`) - if err != nil { - return nil, fmt.Errorf("adding UNPROTECTED for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: fwdInFinalRuleGroup, + Rule: []string{`counter accept comment "UNPROTECTED"`}, + }) } else { - cf, err = fwdInChain.AppendRuleCf(ctx, fwdInFinalRuleGroup, `counter drop comment "UNPUBLISHED PORT DROP"`) - if err != nil { - return nil, fmt.Errorf("adding UNPUBLISHED PORT DROP for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: fwdInFinalRuleGroup, + Rule: []string{`counter drop comment "UNPUBLISHED PORT DROP"`}, + }) } // ICMP if conf.Routed { rule := "ip protocol icmp" - if table.Family() == nftables.IPv6 { + if tm.Family() == nftables.IPv6 { rule = "meta l4proto ipv6-icmp" } - cf, err = fwdInChain.AppendRuleCf(ctx, initialRuleGroup, rule+" counter accept comment ICMP") - if err != nil { - return nil, fmt.Errorf("adding ICMP rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: initialRuleGroup, + Rule: []string{rule, "counter accept comment ICMP"}, + }) } // Masquerade / SNAT - masquerade picks a source IP address based on next-hop, SNAT uses conf.HostIP. @@ -208,34 +180,34 @@ func (n *network) configure(ctx context.Context, table nftables.TableRef, conf f natPostroutingComment = "SNAT" } if n.config.Masquerade && !conf.Routed { - cf, err = natPostroutingOut.AppendRuleCf(ctx, initialRuleGroup, `oifname != %s %s saddr %s counter %s comment "%s"`, - n.config.IfName, table.Family(), conf.Prefix, natPostroutingVerdict, natPostroutingComment) - if err != nil { - return nil, fmt.Errorf("adding NAT rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: natPostRtOutChain, + Group: initialRuleGroup, + Rule: []string{ + "oifname !=", n.config.IfName, string(tm.Family()), "saddr", conf.Prefix.String(), "counter", + natPostroutingVerdict, "comment", natPostroutingComment, + }, + }) } if n.fw.config.Hairpin { - // Masquerade/SNAT traffic from localhost. - cf, err = natPostroutingIn.AppendRuleCf(ctx, initialRuleGroup, `fib saddr type local counter %s comment "%s FROM HOST"`, - natPostroutingVerdict, natPostroutingComment) - if err != nil { - return nil, fmt.Errorf("adding NAT local rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: natPostRtInChain, + Group: initialRuleGroup, + Rule: []string{ + `fib saddr type local counter`, natPostroutingVerdict, `comment "` + natPostroutingComment + ` FROM HOST"`, + }, + }) } } ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{ "bridge": n.config.IfName, - "family": table.Family(), + "family": tm.Family(), })) - if err := nftApply(ctx, table); err != nil { + if err := tm.Apply(ctx); err != nil { return nil, fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } - applied = true - - return cleanup.Release(), nil + return tm.Reverse(), nil } func (n *network) ReapplyNetworkLevelRules(ctx context.Context) error { @@ -245,13 +217,18 @@ func (n *network) ReapplyNetworkLevelRules(ctx context.Context) error { } func (n *network) DelNetworkLevelRules(ctx context.Context) error { - if n.cleaner != nil { - ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"bridge": n.config.IfName})) - if err := n.cleaner(ctx); err != nil { - log.G(ctx).WithError(err).Warn("Failed to remove network rules for network") + remove := func(remover *nftables.Modifier) { + if remover != nil { + ctx := log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"bridge": n.config.IfName})) + if err := remover.Apply(ctx); err != nil { + log.G(ctx).WithError(err).Warn("Failed to remove network rules for network") + } } - n.cleaner = nil } + remove(n.remover4) + n.remover4 = nil + remove(n.remover6) + n.remover6 = nil return nil } diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go index 5be4e378b81f4..39203af1eec99 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go @@ -11,7 +11,6 @@ import ( "github.com/containerd/log" "github.com/docker/docker/daemon/libnetwork/drivers/bridge/internal/firewaller" "github.com/docker/docker/daemon/libnetwork/internal/nftables" - "go.opentelemetry.io/otel" ) // Prefix for OTEL span names. @@ -57,8 +56,8 @@ var baseChainNames = map[string]struct{}{ type nftabler struct { config firewaller.Config cleaner firewaller.FirewallCleaner - table4 nftables.TableRef - table6 nftables.TableRef + table4 nftables.Table + table6 nftables.Table } func NewNftabler(ctx context.Context, config firewaller.Config, baseChainPriorities map[string]string) (firewaller.Firewaller, error) { @@ -79,9 +78,6 @@ func NewNftabler(ctx context.Context, config firewaller.Config, baseChainPriorit if err != nil { return nil, err } - if err := nftApply(ctx, nft.table4); err != nil { - return nil, fmt.Errorf("IPv4 initialisation: %w", err) - } } if nft.config.IPv6 { @@ -90,14 +86,6 @@ func NewNftabler(ctx context.Context, config firewaller.Config, baseChainPriorit if err != nil { return nil, err } - - if err := nftApply(ctx, nft.table6); err != nil { - // Perhaps the kernel has no IPv6 support. It won't be possible to create IPv6 - // networks without enabling ip6_tables in the kernel, or disabling ip6tables in - // the daemon config. But, allow the daemon to start because IPv4 will work. So, - // log the problem, and continue. - log.G(ctx).WithError(err).Warn("ip6tables is enabled, but cannot set up IPv6 nftables table") - } } return nft, nil @@ -128,7 +116,7 @@ func (nft *nftabler) Reload(ctx context.Context) error { return errors.Join(errs...) } -func (nft *nftabler) getTable(ipv firewaller.IPVersion) nftables.TableRef { +func (nft *nftabler) getTable(ipv firewaller.IPVersion) nftables.Table { if ipv == firewaller.IPv4 { return nft.table4 } @@ -136,15 +124,14 @@ func (nft *nftabler) getTable(ipv firewaller.IPVersion) nftables.TableRef { } func (nft *nftabler) FilterForwardDrop(ctx context.Context, ipv firewaller.IPVersion) error { - table := nft.getTable(ipv) - if err := table.Chain(ctx, forwardChain).SetPolicy("drop"); err != nil { - return err + if err := nft.getTable(ipv).SetBaseChainPolicy(ctx, forwardChain, nftables.BaseChainPolicyDrop); err != nil { + return fmt.Errorf("setting IPv%d filter-forward drop: %w", ipv, err) } - return nftApply(ctx, table) + return nil } // init creates the bridge driver's nftables table for IPv4 or IPv6. -func (nft *nftabler) init(ctx context.Context, family nftables.Family, baseChainPriorities map[string]int) (nftables.TableRef, error) { +func (nft *nftabler) init(ctx context.Context, family nftables.Family, baseChainPriorities map[string]int) (nftables.Table, error) { // Instantiate the table. table, err := nftables.NewTable(family, dockerTable) if err != nil { @@ -156,9 +143,11 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family, baseChain // the old table isn't removed, nft produces an error message for the base chain // (but seems to apply the change anyway). if err := table.Reload(ctx); err != nil { - return nftables.TableRef{}, err + return nftables.Table{}, err } + tm := table.Modifier() + // Set up the filter forward chain. // // This base chain only contains two rules that use verdict maps: @@ -168,65 +157,89 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family, baseChain // So, packets that aren't related to docker don't need to traverse any per-network filter forward // rules - and packets that are entering or leaving docker networks only need to traverse rules // related to those networks. - fwdChain, err := table.BaseChain(ctx, forwardChain, - nftables.BaseChainTypeFilter, - nftables.BaseChainHookForward, - baseChainPriority(forwardChain, nftables.BaseChainPriorityFilter, baseChainPriorities)) - if err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.BaseChain{ + Name: forwardChain, + ChainType: nftables.BaseChainTypeFilter, + Hook: nftables.BaseChainHookForward, + Priority: baseChainPriority(forwardChain, nftables.BaseChainPriorityFilter, baseChainPriorities), + }) // Instantiate the verdict maps and add the jumps. - _ = table.InterfaceVMap(ctx, filtFwdInVMap) - if err := fwdChain.AppendRule(ctx, initialRuleGroup, "oifname vmap @"+filtFwdInVMap); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } - _ = table.InterfaceVMap(ctx, filtFwdOutVMap) - if err := fwdChain.AppendRule(ctx, initialRuleGroup, "iifname vmap @"+filtFwdOutVMap); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.VMap{ + Name: filtFwdInVMap, + ElementType: nftables.NftTypeIfname, + }) + tm.Create(nftables.Rule{ + Chain: forwardChain, + Group: initialRuleGroup, + Rule: []string{"oifname vmap @", filtFwdInVMap}, + }) + + tm.Create(nftables.VMap{ + Name: filtFwdOutVMap, + ElementType: nftables.NftTypeIfname, + }) + tm.Create(nftables.Rule{ + Chain: forwardChain, + Group: initialRuleGroup, + Rule: []string{"iifname vmap @", filtFwdOutVMap}, + }) // Set up the NAT postrouting base chain. // // Like the filter-forward chain, its only rules are jumps to network-specific ingress and egress chains. - natPostRtChain, err := table.BaseChain(ctx, postroutingChain, - nftables.BaseChainTypeNAT, - nftables.BaseChainHookPostrouting, - baseChainPriority(postroutingChain, nftables.BaseChainPrioritySrcNAT, baseChainPriorities)) - if err != nil { - return nftables.TableRef{}, err - } - _ = table.InterfaceVMap(ctx, natPostroutingOutVMap) - if err := natPostRtChain.AppendRule(ctx, initialRuleGroup, "iifname vmap @"+natPostroutingOutVMap); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } - _ = table.InterfaceVMap(ctx, natPostroutingInVMap) - if err := natPostRtChain.AppendRule(ctx, initialRuleGroup, "oifname vmap @"+natPostroutingInVMap); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.BaseChain{ + Name: postroutingChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookPostrouting, + Priority: baseChainPriority(postroutingChain, nftables.BaseChainPrioritySrcNAT, baseChainPriorities), + }) + + tm.Create(nftables.VMap{ + Name: natPostroutingOutVMap, + ElementType: nftables.NftTypeIfname, + }) + tm.Create(nftables.Rule{ + Chain: postroutingChain, + Group: initialRuleGroup, + Rule: []string{"iifname vmap @", natPostroutingOutVMap}, + }) + + tm.Create(nftables.VMap{ + Name: natPostroutingInVMap, + ElementType: nftables.NftTypeIfname, + }) + tm.Create(nftables.Rule{ + Chain: postroutingChain, + Group: initialRuleGroup, + Rule: []string{"oifname vmap @", natPostroutingInVMap}, + }) // Instantiate natChain, for the NAT prerouting and output base chains to jump to. - _ = table.Chain(ctx, natChain) + tm.Create(nftables.Chain{ + Name: natChain, + }) // Set up the NAT prerouting base chain. - natPreRtChain, err := table.BaseChain(ctx, preroutingChain, - nftables.BaseChainTypeNAT, - nftables.BaseChainHookPrerouting, - baseChainPriority(preroutingChain, nftables.BaseChainPriorityDstNAT, baseChainPriorities)) - if err != nil { - return nftables.TableRef{}, err - } - if err := natPreRtChain.AppendRule(ctx, initialRuleGroup, "fib daddr type local counter jump "+natChain); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.BaseChain{ + Name: preroutingChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookPrerouting, + Priority: baseChainPriority(preroutingChain, nftables.BaseChainPriorityDstNAT, baseChainPriorities), + }) + tm.Create(nftables.Rule{ + Chain: preroutingChain, + Group: initialRuleGroup, + Rule: []string{"fib daddr type local counter jump", natChain}, + }) // Set up the NAT output base chain - natOutputChain, err := table.BaseChain(ctx, outputChain, - nftables.BaseChainTypeNAT, - nftables.BaseChainHookOutput, - baseChainPriority(outputChain, nftables.BaseChainPriorityDstNAT, baseChainPriorities)) - if err != nil { - return nftables.TableRef{}, err - } + tm.Create(nftables.BaseChain{ + Name: outputChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookOutput, + Priority: baseChainPriority(outputChain, nftables.BaseChainPriorityDstNAT, baseChainPriorities), + }) + // For output, don't jump to the NAT chain if hairpin is enabled (no userland proxy). var skipLoopback string if !nft.config.Hairpin { @@ -236,34 +249,36 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family, baseChain skipLoopback = "ip6 daddr != ::1 " } } - if err := natOutputChain.AppendRule(ctx, initialRuleGroup, skipLoopback+"fib daddr type local counter jump "+natChain); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.Rule{ + Chain: outputChain, + Group: initialRuleGroup, + Rule: []string{skipLoopback, "fib daddr type local counter jump", natChain}, + }) // Set up the raw prerouting base chain - if _, err := table.BaseChain(ctx, rawPreroutingChain, - nftables.BaseChainTypeFilter, - nftables.BaseChainHookPrerouting, - baseChainPriority(rawPreroutingChain, nftables.BaseChainPriorityRaw, baseChainPriorities)); err != nil { - return nftables.TableRef{}, err - } + tm.Create(nftables.BaseChain{ + Name: rawPreroutingChain, + ChainType: nftables.BaseChainTypeFilter, + Hook: nftables.BaseChainHookPrerouting, + Priority: baseChainPriority(rawPreroutingChain, nftables.BaseChainPriorityRaw, baseChainPriorities), + }) if !nft.config.Hairpin && nft.config.WSL2Mirrored { - if err := mirroredWSL2Workaround(ctx, table); err != nil { - return nftables.TableRef{}, err - } + mirroredWSL2Workaround(tm) } - return table, nil -} - -func nftApply(ctx context.Context, table nftables.TableRef) error { - ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".nftApply."+string(table.Family())) - defer span.End() - if err := table.Apply(ctx); err != nil { - return fmt.Errorf("applying nftables rules: %w", err) + if err := tm.Apply(ctx); err != nil { + if family == nftables.IPv4 { + return nftables.Table{}, err + } + // Perhaps the kernel has no IPv6 support. It won't be possible to create IPv6 + // networks without enabling ip6_tables in the kernel, or disabling ip6tables in + // the daemon config. But, allow the daemon to start because IPv4 will work. So, + // log the problem, and continue. + log.G(ctx).WithError(err).Warn("ip6tables is enabled, but cannot set up IPv6 nftables table") + return nftables.Table{}, nil } - return nil + return table, nil } func baseChainPriority(chainName string, def int, overrides map[string]int) int { diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/port.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/port.go index 8447f23759576..a6276daa988c1 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/port.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/port.go @@ -5,7 +5,6 @@ package nftabler import ( "context" - "errors" "fmt" "net" "strconv" @@ -17,9 +16,10 @@ import ( ) type pbContext struct { - table nftables.TableRef - conf firewaller.NetworkConfigFam - ipv firewaller.IPVersion + table nftables.Table + updater func(nftables.Obj) + conf firewaller.NetworkConfigFam + ipv nftables.Family } func (n *network) AddPorts(ctx context.Context, pbs []types.PortBinding) error { @@ -43,13 +43,13 @@ func (n *network) modPorts(ctx context.Context, pbs []types.PortBinding, enable pbs4, pbs6 := splitByContainerFam(pbs) if n.fw.config.IPv4 && n.config.Config4.Prefix.IsValid() { - pbc := pbContext{table: n.fw.table4, conf: n.config.Config4, ipv: firewaller.IPv4} + pbc := pbContext{table: n.fw.table4, conf: n.config.Config4, ipv: nftables.IPv4} if err := n.setPerPortRules(ctx, pbs4, pbc, enable); err != nil { return err } } if n.fw.config.IPv6 && n.config.Config6.Prefix.IsValid() { - pbc := pbContext{table: n.fw.table6, conf: n.config.Config6, ipv: firewaller.IPv6} + pbc := pbContext{table: n.fw.table6, conf: n.config.Config6, ipv: nftables.IPv6} if err := n.setPerPortRules(ctx, pbs6, pbc, enable); err != nil { return err } @@ -70,29 +70,34 @@ func splitByContainerFam(pbs []types.PortBinding) ([]types.PortBinding, []types. } func (n *network) setPerPortRules(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { - if err := n.setPerPortForwarding(ctx, pbs, pbc, enable); err != nil { + tm := pbc.table.Modifier() + pbc.updater = tm.Create + if !enable { + pbc.updater = tm.Delete + } + if err := n.setPerPortForwarding(ctx, pbs, pbc); err != nil { return err } - if err := n.setPerPortDNAT(ctx, pbs, pbc, enable); err != nil { + if err := n.setPerPortDNAT(ctx, pbs, pbc); err != nil { return err } - if err := n.setPerPortHairpinMasq(ctx, pbs, pbc, enable); err != nil { + if err := n.setPerPortHairpinMasq(ctx, pbs, pbc); err != nil { return err } - if err := n.filterPortMappedOnLoopback(ctx, pbs, pbc, enable); err != nil { + if err := n.filterPortMappedOnLoopback(ctx, pbs, pbc); err != nil { return err } - if err := nftApply(ctx, pbc.table); err != nil { + if err := tm.Apply(ctx); err != nil { return fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } return nil } -func (n *network) setPerPortForwarding(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { +func (n *network) setPerPortForwarding(ctx context.Context, pbs []types.PortBinding, pbc pbContext) error { if pbc.conf.Unprotected { return nil } - updateFwdIn := pbc.table.ChainUpdateFunc(ctx, chainFilterFwdIn(n.config.IfName), enable) + chainName := chainFilterFwdIn(n.config.IfName) for _, pb := range pbs { // When more than one host port is mapped to a single container port, this will // generate the same rule for each host port. So, ignore duplicates when adding, @@ -102,24 +107,26 @@ func (n *network) setPerPortForwarding(ctx context.Context, pbs []types.PortBind // also be deleted more than once.) // // TODO(robmry) - track port mappings, use that to edit nftables sets when bindings are added/removed. - rule := fmt.Sprintf("%s daddr %s %s dport %d counter accept", pbc.table.Family(), pb.IP, pb.Proto, pb.Port) - if err := updateFwdIn(ctx, fwdInPortsRuleGroup, rule); err != nil && - !errors.Is(err, nftables.ErrRuleExist) && !errors.Is(err, nftables.ErrRuleNotExist) { - return fmt.Errorf("updating forwarding rule for port %s %s:%d/%s on %s, enable=%v: %w", - pbc.table.Family(), pb.IP, pb.Port, pb.Proto, n.config.IfName, enable, err) - } + pbc.updater(nftables.Rule{ + Chain: chainName, + Group: fwdInPortsRuleGroup, + Rule: []string{ + string(pbc.ipv), "daddr", pb.IP.String(), pb.Proto.String(), + "dport", strconv.Itoa(int(pb.Port)), "counter accept", + }, + IgnoreExist: true, + }) } return nil } -func (n *network) setPerPortDNAT(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { - updater := pbc.table.ChainUpdateFunc(ctx, natChain, enable) +func (n *network) setPerPortDNAT(ctx context.Context, pbs []types.PortBinding, pbc pbContext) error { var proxySkip string if !n.fw.config.Hairpin { proxySkip = fmt.Sprintf("iifname != %s ", n.config.IfName) } var v6LLSkip string - if pbc.ipv == firewaller.IPv6 { + if pbc.ipv == nftables.IPv6 { v6LLSkip = "ip6 saddr != fe80::/10 " } for _, pb := range pbs { @@ -134,26 +141,27 @@ func (n *network) setPerPortDNAT(ctx context.Context, pbs []types.PortBinding, p } var daddrMatch string if !pb.HostIP.IsUnspecified() { - daddrMatch = fmt.Sprintf("%s daddr %s ", pbc.table.Family(), pb.HostIP) - } - rule := fmt.Sprintf("%s%s%s%s dport %d counter dnat to %s comment DNAT", - proxySkip, v6LLSkip, daddrMatch, pb.Proto, pb.HostPort, - net.JoinHostPort(pb.IP.String(), strconv.Itoa(int(pb.Port)))) - if err := updater(ctx, initialRuleGroup, rule); err != nil { - return fmt.Errorf("adding DNAT for %s %s:%d -> %s:%d/%s on %s: %w", - pbc.table.Family(), pb.HostIP, pb.HostPort, pb.IP, pb.Port, pb.Proto, n.config.IfName, err) + daddrMatch = fmt.Sprintf("%s daddr %s ", pbc.ipv, pb.HostIP) } + pbc.updater(nftables.Rule{ + Chain: natChain, + Group: initialRuleGroup, + Rule: []string{ + proxySkip, v6LLSkip, daddrMatch, pb.Proto.String(), "dport", strconv.Itoa(int(pb.HostPort)), "counter dnat to", + net.JoinHostPort(pb.IP.String(), strconv.Itoa(int(pb.Port))), "comment DNAT", + }, + }) } return nil } // setPerPortHairpinMasq allows containers to access their own published ports on the host // when hairpin is enabled (no docker-proxy), by masquerading. -func (n *network) setPerPortHairpinMasq(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { +func (n *network) setPerPortHairpinMasq(ctx context.Context, pbs []types.PortBinding, pbc pbContext) error { if !n.fw.config.Hairpin { return nil } - updater := pbc.table.ChainUpdateFunc(ctx, chainNatPostRtIn(n.config.IfName), enable) + chainName := chainNatPostRtIn(n.config.IfName) for _, pb := range pbs { // Nothing to do if NAT is disabled. if pb.HostPort == 0 { @@ -172,13 +180,16 @@ func (n *network) setPerPortHairpinMasq(ctx context.Context, pbs []types.PortBin // than once.) // // TODO(robmry) - track port mappings, use that to edit nftables sets when bindings are added/removed. - rule := fmt.Sprintf(`%s saddr %s %s daddr %s %s dport %d counter masquerade comment "MASQ TO OWN PORT"`, - pbc.table.Family(), pb.IP, pbc.table.Family(), pb.IP, pb.Proto, pb.Port) - if err := updater(ctx, initialRuleGroup, rule); err != nil && - !errors.Is(err, nftables.ErrRuleExist) && !errors.Is(err, nftables.ErrRuleNotExist) { - return fmt.Errorf("adding MASQ TO OWN PORT for %d -> %s:%d/%s: %w", - pb.Port, pb.IP, pb.Port, pb.Proto, err) - } + pbc.updater(nftables.Rule{ + Chain: chainName, + Group: initialRuleGroup, + Rule: []string{ + string(pbc.ipv), "saddr", pb.IP.String(), string(pbc.ipv), + "daddr", pb.IP.String(), pb.Proto.String(), + "dport", strconv.Itoa(int(pb.Port)), + `counter masquerade comment "MASQ TO OWN PORT"`, + }, + }) } return nil } @@ -189,11 +200,10 @@ func (n *network) setPerPortHairpinMasq(ctx context.Context, pbs []types.PortBin // This is a no-op if the portBinding is for IPv6 (IPv6 loopback address is // non-routable), or over a network with gw_mode=routed (PBs in routed mode // don't map ports on the host). -func (n *network) filterPortMappedOnLoopback(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { - if pbc.ipv == firewaller.IPv6 { +func (n *network) filterPortMappedOnLoopback(ctx context.Context, pbs []types.PortBinding, pbc pbContext) error { + if pbc.ipv == nftables.IPv6 { return nil } - updater := pbc.table.ChainUpdateFunc(ctx, rawPreroutingChain, enable) for _, pb := range pbs { // Nothing to do if not binding to the loopback address. if pb.HostPort == 0 || !pb.HostIP.IsLoopback() { @@ -204,17 +214,25 @@ func (n *network) filterPortMappedOnLoopback(ctx context.Context, pbs []types.Po continue } if n.fw.config.WSL2Mirrored { - if err := updater(ctx, rawPreroutingPortsRuleGroup, - `iifname loopback0 ip daddr %s %s dport %d counter accept comment "%s"`, - pb.HostIP, pb.Proto, pb.HostPort, "ACCEPT WSL2 LOOPBACK"); err != nil { - return fmt.Errorf("adding WSL2 loopback rule for %d: %w", pb.HostPort, err) - } - } - if err := updater(ctx, rawPreroutingPortsRuleGroup, - `iifname != lo ip daddr %s %s dport %d counter drop comment "DROP REMOTE LOOPBACK"`, - pb.HostIP, pb.Proto, pb.HostPort); err != nil { - return fmt.Errorf("adding loopback filter rule for %d: %w", pb.HostPort, err) + pbc.updater(nftables.Rule{ + Chain: rawPreroutingChain, + Group: rawPreroutingPortsRuleGroup, + Rule: []string{ + "iifname loopback0 ip daddr", pb.HostIP.String(), pb.Proto.String(), + "dport", strconv.Itoa(int(pb.HostPort)), + `counter accept comment "ACCEPT WSL2 LOOPBACK"`, + }, + }) } + pbc.updater(nftables.Rule{ + Chain: rawPreroutingChain, + Group: rawPreroutingPortsRuleGroup, + Rule: []string{ + `iifname != lo ip daddr`, pb.HostIP.String(), pb.Proto.String(), + "dport", strconv.Itoa(int(pb.HostPort)), + `counter drop comment "DROP REMOTE LOOPBACK"`, + }, + }) } return nil diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/wsl2.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/wsl2.go index 6606b8d79360c..d7f77ad55b8fd 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/wsl2.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/wsl2.go @@ -3,8 +3,6 @@ package nftabler import ( - "context" - "github.com/docker/docker/daemon/libnetwork/internal/nftables" ) @@ -39,11 +37,14 @@ import ( // arriving from any other bridge network. Similarly, this function adds (or // removes) a rule to RETURN early for packets delivered via loopback0 with // destination 127.0.0.0/8. -func mirroredWSL2Workaround(ctx context.Context, table nftables.TableRef) error { +func mirroredWSL2Workaround(tm *nftables.Modifier) { // WSL2 does not (currently) support Windows<->Linux communication via ::1. - if table.Family() != nftables.IPv4 { - return nil + if tm.Family() != nftables.IPv4 { + return } - return table.Chain(ctx, natChain).AppendRule(ctx, - initialRuleGroup, `iifname "loopback0" ip daddr 127.0.0.0/8 counter return`) + tm.Create(nftables.Rule{ + Chain: natChain, + Group: initialRuleGroup, + Rule: []string{`iifname "loopback0" ip daddr 127.0.0.0/8 counter return`}, + }) } diff --git a/daemon/libnetwork/internal/nftables/nftables_linux.go b/daemon/libnetwork/internal/nftables/nftables_linux.go index ca99134f92e40..d6f17f1b6cdba 100644 --- a/daemon/libnetwork/internal/nftables/nftables_linux.go +++ b/daemon/libnetwork/internal/nftables/nftables_linux.go @@ -4,15 +4,26 @@ // Package nftables provides methods to create an nftables table and manage its maps, sets, // chains, and rules. // -// To use it, the first step is to create a [TableRef] using [NewTable]. The table can -// then be populated and managed using that ref. +// To use it, the first step is to create a [Table] using [NewTable]. Then, retrieve +// a [Modifier], add commands to it, and apply the updates. // -// Modifications to the table are only applied (sent to "nft") when [TableRef.Apply] is -// called. This means a number of updates can be made, for example, adding all the -// rules needed for a docker network - and those rules will then be applied atomically -// in a single "nft" run. +// For example: // -// [TableRef.Apply] can only be called after [Enable], and only if [Enable] returns +// t, _ := NewTable(...) +// tm := t.Modifier() +// // Then a sequence of ... +// tm.Create() +// tm.Delete() +// // Apply the updates with ... +// err := tm.Apply(ctx) +// +// The objects are any of: [BaseChain], [Chain], [Rule], [VMap], [VMapElement], +// [Set], [SetElement] +// +// The modifier can be reused to apply the same set of commands again or, more +// usefully, reversed in order to revert its changes. See [Modifier.Reverse]. +// +// [Modifier.Apply] can only be called after [Enable], and only if [Enable] returns // true (meaning an "nft" executable was found). [Enabled] can be called to check // whether nftables has been enabled. // @@ -20,9 +31,6 @@ // - The implementation is far from complete, only functionality needed so-far has // been included. Currently, there's only a limited set of chain/map/set types, // there's no way to delete sets/maps etc. -// - There's no rollback so, once changes have been made to a TableRef, if the -// Apply fails there is no way to undo changes. The TableRef will be out-of-sync -// with the actual state of nftables. // - This is a thin layer between code and "nft", it doesn't do much error checking. So, // for example, if you get the syntax of a rule wrong the issue won't be reported // until Apply is called. @@ -45,6 +53,7 @@ import ( "fmt" "io" "os/exec" + "runtime" "slices" "strconv" "strings" @@ -72,15 +81,6 @@ var ( enableOnce sync.Once ) -var ( - // ErrRuleExist is returned when a rule is added, but it already exists in the same - // rule group of a chain. - ErrRuleExist = errors.New("rule exists") - // ErrRuleNotExist is returned when a rule is removed, but does not exist in the - // rule group of a chain. - ErrRuleNotExist = errors.New("rule does not exist") -) - // BaseChainType enumerates the base chain types. // See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_types type BaseChainType string @@ -115,6 +115,15 @@ const ( BaseChainPrioritySrcNAT = 100 ) +// BaseChainPolicy enumerates base chain policies. +// See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_policy +type BaseChainPolicy string + +const ( + BaseChainPolicyAccept BaseChainPolicy = "accept" + BaseChainPolicyDrop BaseChainPolicy = "drop" +) + // Family enumerates address families. type Family string @@ -123,17 +132,17 @@ const ( IPv6 Family = "ip6" ) -// nftType enumerates nft types that can be used to define maps/sets etc. -type nftType string +// NftType enumerates nft types that can be used to define maps/sets etc. +type NftType string const ( - nftTypeIPv4Addr nftType = "ipv4_addr" - nftTypeIPv6Addr nftType = "ipv6_addr" - nftTypeEtherAddr nftType = "ether_addr" - nftTypeInetProto nftType = "inet_proto" - nftTypeInetService nftType = "inet_service" - nftTypeMark nftType = "mark" - nftTypeIfname nftType = "ifname" + NftTypeIPv4Addr NftType = "ipv4_addr" + NftTypeIPv6Addr NftType = "ipv6_addr" + NftTypeEtherAddr NftType = "ether_addr" + NftTypeInetProto NftType = "inet_proto" + NftTypeInetService NftType = "inet_service" + NftTypeMark NftType = "mark" + NftTypeIfname NftType = "ifname" ) // Enable tries once to initialise nftables. @@ -182,43 +191,58 @@ type table struct { Sets map[string]*set Chains map[string]*chain - Dirty bool // Set when the table is new, not when its elements change. - DeleteChainCommands []string + DeleteCommands []string + MustFlush bool + + applyLock sync.Mutex } -// TableRef is a handle for an nftables table. -type TableRef struct { +// Table is a handle for an nftables table. +type Table struct { t *table } -// NewTable creates a new nftables table and returns a [TableRef] +// IsValid returns true if t is a valid reference to a table. +func (t Table) IsValid() bool { + return t.t != nil +} + +// NewTable creates a new nftables table and returns a [Table] // // See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables // -// The table will be created and flushed when [TableRef.Apply] is next called. +// To modify the table, get a [Modifier] using [Table.Modifier], add commands +// to it, and call [Modifier.Apply]. +// // It's flushed in case it already exists in the host's nftables - when that // happens, rules in its chains will be deleted but not the chains themselves, // maps, sets, or elements of maps or sets. But, those un-flushed items can't do // anything disruptive unless referred to by rules, and they will be flushed if -// they get re-created via the [TableRef], when [TableRef.Apply] is next called +// they get re-created via the [Table], when [Modifier.Apply] is next called // (so, before they can be used by a new rule). -func NewTable(family Family, name string) (TableRef, error) { - t := TableRef{ +// +// To fully delete an underlying nftables table, if one already exists, +// use [Reload] after creating the table. +func NewTable(family Family, name string) (Table, error) { + t := Table{ t: &table{ - Name: name, - Family: family, - VMaps: map[string]*vMap{}, - Sets: map[string]*set{}, - Chains: map[string]*chain{}, - Dirty: true, + Name: name, + Family: family, + VMaps: map[string]*vMap{}, + Sets: map[string]*set{}, + Chains: map[string]*chain{}, + MustFlush: true, }, } return t, nil } -// Family returns the address family of the nftables table described by [TableRef]. -func (t TableRef) Family() Family { - return t.t.Family +// Name returns the name of the table, or an empty string if t is not valid. +func (t Table) Name() string { + if !t.IsValid() { + return "" + } + return t.t.Name } // incrementalUpdateTemplText is used with text/template to generate an nftables command file @@ -248,25 +272,25 @@ table {{$family}} {{$tableName}} { {{if len .Flags}}flags{{range .Flags}} {{.}}{{end}}{{end}} } {{end}} - {{range .Chains}}{{if .Dirty}}chain {{.Name}} { + {{range .Chains}}{{if .MustFlush}}chain {{.Name}} { {{if .ChainType}}type {{.ChainType}} hook {{.Hook}} priority {{.Priority}}; policy {{.Policy}}{{end}} } ; {{end}}{{end}} } -{{if .Dirty}}flush table {{$family}} {{$tableName}}{{end}} -{{range .VMaps}}{{if .Dirty}}flush map {{$family}} {{$tableName}} {{.Name}} +{{if .MustFlush}}flush table {{$family}} {{$tableName}}{{end}} +{{range .VMaps}}{{if .MustFlush}}flush map {{$family}} {{$tableName}} {{.Name}} {{end}}{{end}} -{{range .Sets}}{{if .Dirty}}flush set {{$family}} {{$tableName}} {{.Name}} +{{range .Sets}}{{if .MustFlush}}flush set {{$family}} {{$tableName}} {{.Name}} {{end}}{{end}} -{{range .Chains}}{{if .Dirty}}flush chain {{$family}} {{$tableName}} {{.Name}} +{{range .Chains}}{{if .MustFlush}}flush chain {{$family}} {{$tableName}} {{.Name}} {{end}}{{end}} {{range .VMaps}}{{if .DeletedElements}}delete element {{$family}} {{$tableName}} {{.Name}} { {{range $k,$v := .DeletedElements}}{{$k}}, {{end}} } {{end}}{{end}} {{range .Sets}}{{if .DeletedElements}}delete element {{$family}} {{$tableName}} {{.Name}} { {{range $k,$v := .DeletedElements}}{{$k}}, {{end}} } {{end}}{{end}} -{{range .DeleteChainCommands}}{{.}} +{{range .DeleteCommands}}{{.}} {{end}} table {{$family}} {{$tableName}} { - {{range .Chains}}{{if .Dirty}}chain {{.Name}} { + {{range .Chains}}{{if .MustFlush}}chain {{.Name}} { {{if .ChainType}}type {{.ChainType}} hook {{.Hook}} priority {{.Priority}}; policy {{.Policy}}{{end}} {{range .Rules}}{{.}} {{end}} @@ -315,16 +339,159 @@ table {{$family}} {{$tableName}} { } ` -// Apply makes incremental updates to nftables, corresponding to changes to the [TableRef] -// since Apply was last called. -func (t TableRef) Apply(ctx context.Context) error { +// SetBaseChainPolicy sets the default policy for a base chain. The update +// is applied immediately, unlike creation/deletion of objects via a [Modifier] +// which are not applied until [Modifier.Apply] is called. +func (t Table) SetBaseChainPolicy(ctx context.Context, chainName string, policy BaseChainPolicy) error { + if !t.IsValid() { + return errors.New("invalid table") + } + c := t.t.Chains[chainName] + if c == nil { + return fmt.Errorf("cannot set base chain policy for '%s', it does not exist", chainName) + } + if c.ChainType == "" { + return fmt.Errorf("cannot set base chain policy for '%s', it is not a base chain", chainName) + } + oldPolicy := c.Policy + c.Policy = policy + c.MustFlush = true + + if err := t.Modifier().Apply(ctx); err != nil { + c.Policy = oldPolicy + return err + } + return nil +} + +// Modifier retrieves an object that can be used to manipulate the table. +func (t Table) Modifier() *Modifier { + return &Modifier{t: t.t} +} + +// Obj is an object that can be given to a [Modifier], representing an +// nftables object for it to create or delete. +type Obj interface { + create(context.Context, *table) (bool, error) + delete(context.Context, *table) (bool, error) +} + +// Modifier is used to apply changes to a Table. +type Modifier struct { + t *table + cmds []command +} + +// IsValid returns true if tm is valid. +func (tm *Modifier) IsValid() bool { + return tm.t != nil +} + +// Name returns the name of the table tm will modify, or the +// empty string if tm is not valid. +func (tm *Modifier) Name() string { + if !tm.IsValid() { + return "" + } + return tm.t.Name +} + +// Family returns the address family of the nftables table to be modified, +// or the empty string if tm is not valid. +func (tm *Modifier) Family() Family { + if !tm.IsValid() { + return "" + } + return tm.t.Family +} + +// Create enqueues creation of object o, to be applied by tm.Apply. +func (tm *Modifier) Create(o Obj) { + _, f, l, _ := runtime.Caller(1) + tm.cmds = append(tm.cmds, command{ + obj: o, + callerFile: f, + callerLine: l, + }) +} + +// Delete enqueues deletion of object o, to be applied by tm.Apply. +func (tm *Modifier) Delete(o Obj) { + _, f, l, _ := runtime.Caller(1) + tm.cmds = append(tm.cmds, command{ + obj: o, + delete: true, + callerFile: f, + callerLine: l, + }) +} + +// Reverse returns a Modifier that will undo the actions of tm. +// Its operations are performed in reverse order, creates become +// deletes, and deletes become creates. +// +// Most operations are fully reversible (chains/maps/sets must be +// empty before they're deleted, so no information is lost). But, +// there are exceptions, noted in comments in the object definitions. +// +// Applying the updates in a reversed modifier may not work if +// any of the objects have been removed or modified since they +// were added. For example, if a Modifier creates a chain then another +// Modifier adds rules, the reversed Modifier will not be able to +// delete the chain as it is not empty. +func (tm *Modifier) Reverse() *Modifier { + rtm := &Modifier{ + t: tm.t, + cmds: make([]command, len(tm.cmds)), + } + for i, cmd := range tm.cmds { + cmd.delete = !cmd.delete + rtm.cmds[len(tm.cmds)-i-1] = cmd + } + return rtm +} + +// Apply makes incremental updates to nftables. If there's a validation +// error in any of the enqueued objects, or an error applying the updates +// to the underlying nftables, the [Table] will be unmodified. +func (tm *Modifier) Apply(ctx context.Context) (retErr error) { if !Enabled() { return errors.New("nftables is not enabled") } + if !tm.IsValid() { + return errors.New("table modifier is not valid") + } + tm.t.applyLock.Lock() + defer tm.t.applyLock.Unlock() + + var rollback []command + defer func() { + if retErr == nil { + return + } + slices.Reverse(rollback) + for _, c := range rollback { + if _, err := c.rollback(ctx, tm.t); err != nil { + log.G(ctx).WithError(err).Error("Failed to roll back nftables updates") + } + } + tm.t.updatesApplied() + }() + + // Apply tm's updates to the Table. + for _, cmd := range tm.cmds { + applied, err := cmd.apply(ctx, tm.t) + if err != nil { + return fmt.Errorf("rule from %s:%d: %w", cmd.callerFile, cmd.callerLine, err) + } + if applied { + rollback = append(rollback, cmd) + } + } // Update nftables. var buf bytes.Buffer - if err := incrementalUpdateTempl.Execute(&buf, t.t); err != nil { + if err := incrementalUpdateTempl.Execute(&buf, tm.t); err != nil { return fmt.Errorf("failed to execute template nft ruleset: %w", err) } @@ -345,26 +512,32 @@ func (t TableRef) Apply(ctx context.Context) error { // behind networks for the test infrastructure to clean up between tests. Starting // a daemon flushes the "docker-bridges" table, so the cleanup fails to delete a // rule that's been flushed. So, try reloading the whole table to get back in-sync. - return t.Reload(ctx) + return tm.t.reload(ctx) } // Note that updates have been applied. - t.t.updatesApplied() + tm.t.updatesApplied() return nil } // Reload deletes the table, then re-creates it, atomically. -func (t TableRef) Reload(ctx context.Context) error { +func (t Table) Reload(ctx context.Context) error { if !Enabled() { return errors.New("nftables is not enabled") } + if !t.IsValid() { + return errors.New("invalid table") + } + return t.t.reload(ctx) +} - ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"table": t.t.Name, "family": t.t.Family})) +func (t *table) reload(ctx context.Context) error { + ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"table": t.Name, "family": t.Family})) log.G(ctx).Warn("nftables: reloading table") // Build the update. var buf bytes.Buffer - if err := reloadTempl.Execute(&buf, t.t); err != nil { + if err := reloadTempl.Execute(&buf, t); err != nil { return fmt.Errorf("failed to execute reload template: %w", err) } @@ -382,7 +555,7 @@ func (t TableRef) Reload(ctx context.Context) error { } // Note that updates have been applied. - t.t.updatesApplied() + t.updatesApplied() return nil } @@ -404,170 +577,176 @@ type chain struct { ChainType BaseChainType Hook BaseChainHook Priority int - Policy string - Dirty bool + Policy BaseChainPolicy + MustFlush bool ruleGroups map[RuleGroup][]string } -// ChainRef is a handle for an nftables chain. -type ChainRef struct { - c *chain -} - // BaseChain constructs a new nftables base chain and returns a [ChainRef]. // // See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_base_chains // // It is an error to create a base chain that already exists. // If the underlying chain already exists, it will be flushed by the -// next [TableRef.Apply] before new rules are added. -func (t TableRef) BaseChain(ctx context.Context, name string, chainType BaseChainType, hook BaseChainHook, priority int) (ChainRef, error) { - if _, ok := t.t.Chains[name]; ok { - return ChainRef{}, fmt.Errorf("chain %q already exists", name) +// next [Table.Apply] before new rules are added. +type BaseChain struct { + Name string + ChainType BaseChainType + Hook BaseChainHook + Priority int + Policy BaseChainPolicy // Defaults to BaseChainPolicyAccept +} + +func (cd BaseChain) create(ctx context.Context, t *table) (bool, error) { + if _, ok := t.Chains[cd.Name]; ok { + return false, fmt.Errorf("base chain '%s' already exists", cd.Name) + } + if cd.Name == "" { + return false, errors.New("base chain must have a name") + } + if cd.ChainType == "" || cd.Hook == "" { + return false, fmt.Errorf("chain '%s': fields ChainType and Hook are required", cd.Name) + } + if cd.Policy == "" { + // nftables will default to "accept" if unspecified, but the text/template + // requires a policy string. + cd.Policy = BaseChainPolicyAccept } c := &chain{ - table: t.t, - Name: name, - ChainType: chainType, - Hook: hook, - Priority: priority, - Policy: "accept", - Dirty: true, + table: t, + Name: cd.Name, + ChainType: cd.ChainType, + Hook: cd.Hook, + Priority: cd.Priority, + Policy: cd.Policy, + MustFlush: true, ruleGroups: map[RuleGroup][]string{}, } - t.t.Chains[name] = c + t.Chains[c.Name] = c log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "chain": name, - "type": chainType, - "hook": hook, - "prio": priority, + "family": t.Family, + "table": t.Name, + "chain": c.Name, + "type": c.ChainType, + "hook": c.Hook, + "prio": c.Priority, }).Debug("nftables: created base chain") - return ChainRef{c: c}, nil + return true, nil } -// Chain returns a [ChainRef] for an existing chain (which may be a base chain). -// If there is no existing chain, a regular chain is added and its [ChainRef] is -// returned. -// -// See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_regular_chains -// -// If a new [ChainRef] is created and the underlying chain already exists, it -// will be flushed by the next [TableRef.Apply] before new rules are added. -func (t TableRef) Chain(ctx context.Context, name string) ChainRef { - c, ok := t.t.Chains[name] - if !ok { - c = &chain{ - table: t.t, - Name: name, - Dirty: true, - ruleGroups: map[RuleGroup][]string{}, - } - t.t.Chains[name] = c - } - log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "chain": name, - }).Debug("nftables: created chain") - return ChainRef{c: c} +func (cd BaseChain) delete(ctx context.Context, t *table) (bool, error) { + return t.deleteChain(ctx, cd.Name) } -// ChainUpdateFunc is a function that can add rules to a chain, or remove rules from it. -type ChainUpdateFunc func(context.Context, RuleGroup, string, ...interface{}) error - -// ChainUpdateFunc returns a [ChainUpdateFunc] to add rules to the named chain if -// enable is true, or to remove rules from the chain if enable is false. -// (Written as a convenience function to ease migration of iptables functions -// originally written with an enable flag.) -func (t TableRef) ChainUpdateFunc(ctx context.Context, name string, enable bool) ChainUpdateFunc { - c := t.Chain(ctx, name) - if enable { - return c.AppendRule - } - return c.DeleteRule +// Chain implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete a chain. +type Chain struct { + Name string } -// DeleteChain deletes a chain. It is an error to delete a chain that does not exist. -func (t TableRef) DeleteChain(ctx context.Context, name string) error { - if _, ok := t.t.Chains[name]; !ok { - return fmt.Errorf("chain %q does not exist", name) +func (cd Chain) create(ctx context.Context, t *table) (bool, error) { + if _, ok := t.Chains[cd.Name]; ok { + return false, fmt.Errorf("chain '%s' already exists", cd.Name) + } + if cd.Name == "" { + return false, errors.New("chain must have a name") + } + c := &chain{ + table: t, + Name: cd.Name, + MustFlush: true, + ruleGroups: map[RuleGroup][]string{}, } - delete(t.t.Chains, name) - t.t.DeleteChainCommands = append(t.t.DeleteChainCommands, - fmt.Sprintf("delete chain %s %s %s", t.t.Family, t.t.Name, name)) + t.Chains[c.Name] = c log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "chain": name, - }).Debug("nftables: deleted chain") - return nil + "family": t.Family, + "table": t.Name, + "chain": cd.Name, + }).Debug("nftables: created chain") + return true, nil } -// SetPolicy sets the default policy for a base chain. It is an error to call this -// for a non-base [ChainRef]. -func (c ChainRef) SetPolicy(policy string) error { - if c.c.ChainType == "" { - return errors.New("not a base chain") - } - c.c.Policy = policy - c.c.Dirty = true - return nil +func (cd Chain) delete(ctx context.Context, t *table) (bool, error) { + return t.deleteChain(ctx, cd.Name) +} + +// Rule implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete a rule in a chain. +type Rule struct { + Chain string + Group RuleGroup + Rule []string + // IgnoreExist suppresses errors about deleting a rule that does not exist + // or creating a rule that does already exist. + // + // Note that, when set, reversing the [Modifier] may not do what you want! For + // example, if the original modifier deleted a rule that did not exist, the + // reversed modifier will create that rule. + IgnoreExist bool } -// AppendRule appends a rule to a [RuleGroup] in a [ChainRef]. -func (c ChainRef) AppendRule(ctx context.Context, group RuleGroup, rule string, args ...interface{}) error { - if len(args) > 0 { - rule = fmt.Sprintf(rule, args...) +func (ru Rule) create(ctx context.Context, t *table) (bool, error) { + c := t.Chains[ru.Chain] + if c == nil { + return false, fmt.Errorf("chain '%s' does not exist", ru.Chain) } - if rg, ok := c.c.ruleGroups[group]; ok && slices.Contains(rg, rule) { - return ErrRuleExist + if len(ru.Rule) == 0 { + return false, fmt.Errorf("chain '%s', cannot add empty rule", ru.Chain) } - c.c.ruleGroups[group] = append(c.c.ruleGroups[group], rule) - c.c.Dirty = true + rule := strings.Join(ru.Rule, " ") + if rg, ok := c.ruleGroups[ru.Group]; ok && slices.Contains(rg, rule) { + if !ru.IgnoreExist { + return false, fmt.Errorf("adding rule:'%s' chain:'%s' group:%d: rule exists", rule, ru.Chain, ru.Group) + } + return false, nil + } + c.ruleGroups[ru.Group] = append(c.ruleGroups[ru.Group], rule) + c.MustFlush = true log.G(ctx).WithFields(log.Fields{ - "family": c.c.table.Family, - "table": c.c.table.Name, - "chain": c.c.Name, - "group": group, + "family": t.Family, + "table": t.Name, + "chain": c.Name, + "group": ru.Group, "rule": rule, }).Debug("nftables: appended rule") - return nil + return true, nil } -// AppendRuleCf calls AppendRule and returns a cleanup function or an error. -func (c ChainRef) AppendRuleCf(ctx context.Context, group RuleGroup, rule string, args ...interface{}) (func(context.Context) error, error) { - if err := c.AppendRule(ctx, group, rule, args...); err != nil { - return nil, err +func (ru Rule) delete(ctx context.Context, t *table) (bool, error) { + rule := strings.Join(ru.Rule, " ") + c := t.Chains[ru.Chain] + if c == nil { + return false, fmt.Errorf("deleting rule:'%s' - chain '%s' does not exist", rule, ru.Chain) } - return func(ctx context.Context) error { return c.DeleteRule(ctx, group, rule, args...) }, nil -} - -// DeleteRule deletes a rule from a [RuleGroup] in a [ChainRef]. It is an error -// to delete from a group that does not exist, or to delete a rule that does not -// exist. -func (c ChainRef) DeleteRule(ctx context.Context, group RuleGroup, rule string, args ...interface{}) error { - if len(args) > 0 { - rule = fmt.Sprintf(rule, args...) + if rule == "" { + return false, fmt.Errorf("chain '%s', cannot delete empty rule", ru.Chain) } - rg, ok := c.c.ruleGroups[group] + rg, ok := c.ruleGroups[ru.Group] if !ok { - return fmt.Errorf("rule group %d does not exist", group) + if !ru.IgnoreExist { + return false, fmt.Errorf("deleting rule:'%s' chain:'%s' rule group:%d does not exist", rule, ru.Chain, ru.Group) + } + return false, nil } origLen := len(rg) - c.c.ruleGroups[group] = slices.DeleteFunc(rg, func(r string) bool { return r == rule }) - if len(c.c.ruleGroups[group]) == origLen { - return ErrRuleNotExist + c.ruleGroups[ru.Group] = slices.DeleteFunc(rg, func(r string) bool { return r == rule }) + if len(c.ruleGroups[ru.Group]) == origLen { + if !ru.IgnoreExist { + return false, fmt.Errorf("deleting rule:'%s' chain:'%s' group:%d: rule does not exist", rule, ru.Chain, ru.Group) + } + return false, nil + } + if len(c.ruleGroups[ru.Group]) == 0 { + delete(c.ruleGroups, ru.Group) } - c.c.Dirty = true + c.MustFlush = true log.G(ctx).WithFields(log.Fields{ - "family": c.c.table.Family, - "table": c.c.table.Name, - "chain": c.c.Name, + "family": t.Family, + "table": t.Name, + "chain": c.Name, "rule": rule, }).Debug("nftables: deleted rule") - return nil + return true, nil } // //////////////////////////// @@ -579,89 +758,129 @@ func (c ChainRef) DeleteRule(ctx context.Context, group RuleGroup, rule string, type vMap struct { table *table Name string - ElementType nftType + ElementType NftType Flags []string Elements map[string]string - Dirty bool // New vMap, needs to be flushed (not set when elements are added/deleted). AddedElements map[string]string - DeletedElements map[string]struct{} + DeletedElements map[string]string + MustFlush bool } -// VMapRef is a handle for an nftables verdict map. -type VMapRef struct { - v *vMap +// VMap implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete a verdict map. +type VMap struct { + Name string + ElementType NftType + Flags []string } -// InterfaceVMap creates a map from interface name to a verdict and returns a [VMapRef], -// or returns an existing [VMapRef] if it has already been created. -// -// See https://wiki.nftables.org/wiki-nftables/index.php/Verdict_Maps_(vmaps) -// -// If a [VMapRef] is created and the underlying map already exists, it will be flushed -// by the next [TableRef.Apply] before new elements are added. -func (t TableRef) InterfaceVMap(ctx context.Context, name string) VMapRef { - if vmap, ok := t.t.VMaps[name]; ok { - return VMapRef{vmap} - } - vmap := &vMap{ - table: t.t, - Name: name, - ElementType: nftTypeIfname, +func (vm VMap) create(ctx context.Context, t *table) (bool, error) { + if vm.Name == "" { + return false, errors.New("vmap must have a name") + } + if _, ok := t.VMaps[vm.Name]; ok { + return false, fmt.Errorf("vmap '%s' already exists", vm.Name) + } + if vm.ElementType == "" { + return false, fmt.Errorf("vmap '%s' has no element type", vm.Name) + } + v := &vMap{ + table: t, + Name: vm.Name, + ElementType: vm.ElementType, + Flags: slices.Clone(vm.Flags), Elements: map[string]string{}, AddedElements: map[string]string{}, - DeletedElements: map[string]struct{}{}, - Dirty: true, + DeletedElements: map[string]string{}, + MustFlush: true, } - t.t.VMaps[name] = vmap + t.VMaps[v.Name] = v log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "vmap": name, + "family": t.Family, + "table": t.Name, + "vmap": v.Name, }).Debug("nftables: created interface vmap") - return VMapRef{vmap} + return true, nil } -// AddElement adds an element to a verdict map. The caller must ensure the key has -// the correct type. It is an error to add a key that already exists. -func (v VMapRef) AddElement(ctx context.Context, key string, verdict string) error { - if _, ok := v.v.Elements[key]; ok { - return fmt.Errorf("verdict map already contains element %q", key) +func (vm VMap) delete(ctx context.Context, t *table) (bool, error) { + v := t.VMaps[vm.Name] + if v == nil { + return false, fmt.Errorf("cannot delete vmap '%s', it does not exist", vm.Name) + } + if len(v.Elements) != 0 { + return false, fmt.Errorf("cannot delete vmap '%s', it contains %d elements", v.Name, len(v.Elements)) } - v.v.Elements[key] = verdict - v.v.AddedElements[key] = verdict + delete(t.VMaps, v.Name) + t.DeleteCommands = append(t.DeleteCommands, + fmt.Sprintf("delete map %s %s %s", t.Family, t.Name, v.Name)) log.G(ctx).WithFields(log.Fields{ - "family": v.v.table.Family, - "table": v.v.table.Name, - "vmap": v.v.Name, - "key": key, - "verdict": verdict, - }).Debug("nftables: added vmap element") - return nil + "family": t.Family, + "table": t.Name, + "vmap": v.Name, + }).Debug("nftables: deleted vmap") + return true, nil } -// AddElementCf calls AddElement and returns a cleanup function or an error. -func (v VMapRef) AddElementCf(ctx context.Context, key string, verdict string) (func(context.Context) error, error) { - if err := v.AddElement(ctx, key, verdict); err != nil { - return nil, err +// VMapElement implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete an entry in a verdict map. +type VMapElement struct { + VmapName string + Key string + Verdict string +} + +func (ve VMapElement) create(ctx context.Context, t *table) (bool, error) { + if ve.VmapName == "" { + return false, errors.New("cannot add element to unnamed vmap") + } + v := t.VMaps[ve.VmapName] + if v == nil { + return false, fmt.Errorf("cannot add to vmap '%s', it does not exist", ve.VmapName) + } + if ve.Key == "" || ve.Verdict == "" { + return false, fmt.Errorf("cannot add to vmap '%s', element must have key and verdict", ve.VmapName) + } + if _, ok := v.Elements[ve.Key]; ok { + return false, fmt.Errorf("verdict map '%s' already contains element '%s'", ve.VmapName, ve.Key) } - return func(ctx context.Context) error { return v.DeleteElement(ctx, key) }, nil + v.Elements[ve.Key] = ve.Verdict + v.AddedElements[ve.Key] = ve.Verdict + delete(v.DeletedElements, ve.Key) + log.G(ctx).WithFields(log.Fields{ + "family": t.Family, + "table": t.Name, + "vmap": ve.VmapName, + "key": ve.Key, + "verdict": ve.Verdict, + }).Debug("nftables: added vmap element") + return true, nil } -// DeleteElement deletes an element from a verdict map. It is an error to delete -// an element that does not exist. -func (v VMapRef) DeleteElement(ctx context.Context, key string) error { - if _, ok := v.v.Elements[key]; !ok { - return fmt.Errorf("verdict map does not contain element %q", key) +func (ve VMapElement) delete(ctx context.Context, t *table) (bool, error) { + v := t.VMaps[ve.VmapName] + if v == nil { + return false, fmt.Errorf("cannot delete from vmap '%s', it does not exist", ve.VmapName) + } + oldVerdict, ok := v.Elements[ve.Key] + if !ok { + return false, fmt.Errorf("verdict map '%s' does not contain element '%s'", ve.VmapName, ve.Key) + } + if oldVerdict != ve.Verdict { + return false, fmt.Errorf("cannot delete verdict map '%s' element '%s', verdict was '%s', not '%s'", + ve.VmapName, ve.Key, oldVerdict, ve.Verdict) } - delete(v.v.Elements, key) - v.v.DeletedElements[key] = struct{}{} + delete(v.Elements, ve.Key) + delete(v.AddedElements, ve.Key) + v.DeletedElements[ve.Key] = ve.Verdict log.G(ctx).WithFields(log.Fields{ - "family": v.v.table.Family, - "table": v.v.table.Name, - "vmap": v.v.Name, - "key": key, + "family": t.Family, + "table": t.Name, + "vmap": ve.VmapName, + "key": ve.Key, + "verdict": ve.Verdict, }).Debug("nftables: deleted vmap element") - return nil + return true, nil } // //////////////////////////// @@ -673,108 +892,180 @@ func (v VMapRef) DeleteElement(ctx context.Context, key string) error { type set struct { table *table Name string - ElementType nftType + ElementType NftType Flags []string Elements map[string]struct{} - Dirty bool // New set, needs to be flushed (not set when elements are added/deleted). AddedElements map[string]struct{} DeletedElements map[string]struct{} + MustFlush bool } -// SetRef is a handle for an nftables named set. -type SetRef struct { - s *set +// Set implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete a set. +type Set struct { + Name string + ElementType NftType + Flags []string } -// PrefixSet creates a new named nftables set for IPv4 or IPv6 addresses (depending -// on the address family of the [TableRef]), and returns its [SetRef]. Or, if the -// set has already been created, its [SetRef] is returned. -// -// ([TableRef] does not support "inet", only "ip" or "ip6". So the element type can -// always be determined. But, there's no "inet" element type, so this will need to -// change if we need an "inet" table.) -// // See https://wiki.nftables.org/wiki-nftables/index.php/Sets#Named_sets -func (t TableRef) PrefixSet(ctx context.Context, name string) SetRef { - if s, ok := t.t.Sets[name]; ok { - return SetRef{s} +func (sd Set) create(ctx context.Context, t *table) (bool, error) { + if sd.Name == "" { + return false, errors.New("set must have a name") + } + if _, ok := t.Sets[sd.Name]; ok { + return false, fmt.Errorf("set '%s' already exists", sd.Name) + } + if sd.ElementType == "" { + return false, fmt.Errorf("set '%s' must have a type", sd.Name) } s := &set{ - table: t.t, - Name: name, + table: t, + Name: sd.Name, Elements: map[string]struct{}{}, - ElementType: nftTypeIPv4Addr, - Flags: []string{"interval"}, - Dirty: true, + ElementType: sd.ElementType, + Flags: slices.Clone(sd.Flags), AddedElements: map[string]struct{}{}, DeletedElements: map[string]struct{}{}, + MustFlush: true, } - if t.t.Family == IPv6 { - s.ElementType = nftTypeIPv6Addr - } - t.t.Sets[name] = s + t.Sets[sd.Name] = s log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "set": name, + "family": t.Family, + "table": t.Name, + "set": s.Name, }).Debug("nftables: created set") - return SetRef{s} + return true, nil +} + +func (sd Set) delete(ctx context.Context, t *table) (bool, error) { + s := t.Sets[sd.Name] + if s == nil { + return false, fmt.Errorf("cannot delete set '%s', it does not exist", sd.Name) + } + if len(s.Elements) != 0 { + return false, fmt.Errorf("cannot delete set '%s', it contains %d elements", s.Name, len(s.Elements)) + } + delete(t.Sets, sd.Name) + t.DeleteCommands = append(t.DeleteCommands, + fmt.Sprintf("delete set %s %s %s", t.Family, t.Name, s.Name)) + log.G(ctx).WithFields(log.Fields{ + "family": t.Family, + "table": t.Name, + "set": sd.Name, + }).Debug("nftables: deleted set") + return true, nil +} + +// SetElement implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete an entry in a set. +type SetElement struct { + SetName string + Element string } -// AddElement adds an element to a set. It is the caller's responsibility to make sure -// the element has the correct type. It is an error to add an element that is already -// in the set. -func (s SetRef) AddElement(ctx context.Context, element string) error { - if _, ok := s.s.Elements[element]; ok { - return fmt.Errorf("set already contains element %q", element) +func (se SetElement) create(ctx context.Context, t *table) (bool, error) { + s := t.Sets[se.SetName] + if s == nil { + return false, fmt.Errorf("cannot add to set '%s', it does not exist", se.SetName) + } + if se.Element == "" { + return false, fmt.Errorf("cannot add to set '%s', element not specified", se.SetName) + } + if _, ok := s.Elements[se.Element]; ok { + return false, fmt.Errorf("set '%s' already contains element '%s'", s.Name, se.Element) } - s.s.Elements[element] = struct{}{} - s.s.AddedElements[element] = struct{}{} + s.Elements[se.Element] = struct{}{} + s.AddedElements[se.Element] = struct{}{} + delete(s.DeletedElements, se.Element) log.G(ctx).WithFields(log.Fields{ - "family": s.s.table.Family, - "table": s.s.table.Name, - "set": s.s.Name, - "element": element, + "family": t.Family, + "table": t.Name, + "set": s.Name, + "element": se.Element, }).Debug("nftables: added set element") - return nil + return true, nil } -// DeleteElement deletes an element from the set. It is an error to delete an -// element that is not in the set. -func (s SetRef) DeleteElement(ctx context.Context, element string) error { - if _, ok := s.s.Elements[element]; !ok { - return fmt.Errorf("set does not contain element %q", element) +func (se SetElement) delete(ctx context.Context, t *table) (bool, error) { + s := t.Sets[se.SetName] + if s == nil { + return false, fmt.Errorf("cannot delete from set '%s', it does not exist", se.SetName) + } + if _, ok := s.Elements[se.Element]; !ok { + return false, fmt.Errorf("cannot delete '%s' from set '%s', it does not exist", se.Element, s.Name) } - delete(s.s.Elements, element) - s.s.DeletedElements[element] = struct{}{} + delete(s.Elements, se.Element) + delete(s.AddedElements, se.Element) + s.DeletedElements[se.Element] = struct{}{} log.G(ctx).WithFields(log.Fields{ - "family": s.s.table.Family, - "table": s.s.table.Name, - "set": s.s.Name, - "element": element, - }).Debug("nftables: deleted set element") - return nil + "family": t.Family, + "table": t.Name, + "set": s.Name, + "element": se.Element, + }).Debug("nftables: added set element") + return true, nil } // //////////////////////////// // Internal +func (t *table) deleteChain(ctx context.Context, name string) (bool, error) { + c := t.Chains[name] + if c == nil { + return false, fmt.Errorf("cannot delete chain '%s', it does not exist", name) + } + if len(c.ruleGroups) != 0 { + return false, fmt.Errorf("cannot delete chain '%s', it is not empty", name) + } + delete(t.Chains, name) + t.DeleteCommands = append(t.DeleteCommands, + fmt.Sprintf("delete chain %s %s %s", t.Family, t.Name, name)) + log.G(ctx).WithFields(log.Fields{ + "family": t.Family, + "table": t.Name, + "chain": name, + }).Debug("nftables: deleted chain") + return true, nil +} + +type command struct { + obj Obj + callerFile string + callerLine int + delete bool +} + +func (c command) apply(ctx context.Context, t *table) (bool, error) { + if c.delete { + return c.obj.delete(ctx, t) + } + return c.obj.create(ctx, t) +} + +func (c command) rollback(ctx context.Context, t *table) (bool, error) { + if c.delete { + return c.obj.create(ctx, t) + } + return c.obj.delete(ctx, t) +} + func (t *table) updatesApplied() { - t.DeleteChainCommands = t.DeleteChainCommands[:0] + t.DeleteCommands = t.DeleteCommands[:0] for _, c := range t.Chains { - c.Dirty = false + c.MustFlush = false } for _, m := range t.VMaps { - m.Dirty = false m.AddedElements = map[string]string{} - m.DeletedElements = map[string]struct{}{} + m.DeletedElements = map[string]string{} + m.MustFlush = false } for _, s := range t.Sets { - s.Dirty = false s.AddedElements = map[string]struct{}{} s.DeletedElements = map[string]struct{}{} + s.MustFlush = false } - t.Dirty = false + t.MustFlush = false } /* Can't make text/template range over this, not sure why ... diff --git a/daemon/libnetwork/internal/nftables/nftables_linux_test.go b/daemon/libnetwork/internal/nftables/nftables_linux_test.go index 526fe89956593..95fc4689f532a 100644 --- a/daemon/libnetwork/internal/nftables/nftables_linux_test.go +++ b/daemon/libnetwork/internal/nftables/nftables_linux_test.go @@ -31,11 +31,20 @@ func testSetup(t *testing.T) func() { } } -func applyAndCheck(t *testing.T, tbl TableRef, goldenFilename string) { +func applyAndCheck(t *testing.T, tm *Modifier, goldenFilename string) { t.Helper() - err := tbl.Apply(context.Background()) + err := tm.Apply(context.Background()) assert.Check(t, err) - res := icmd.RunCommand("nft", "list", "ruleset") + res := icmd.RunCommand("nft", "list", "table", string(tm.Family()), tm.Name()) + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), goldenFilename) +} + +func reloadAndCheck(t *testing.T, tbl Table, ipv Family, goldenFilename string) { + t.Helper() + err := tbl.Reload(context.Background()) + assert.Check(t, err) + res := icmd.RunCommand("nft", "list", "table", string(ipv), tbl.Name()) res.Assert(t, icmd.Success) golden.Assert(t, res.Combined(), goldenFilename) } @@ -48,17 +57,19 @@ func TestTable(t *testing.T) { tbl6, err := NewTable(IPv6, "ipv6_table") assert.NilError(t, err) - assert.Check(t, is.Equal(tbl4.Family(), IPv4)) - assert.Check(t, is.Equal(tbl6.Family(), IPv6)) + tm4 := tbl4.Modifier() + tm6 := tbl6.Modifier() + + assert.Check(t, is.Equal(tm4.Family(), IPv4)) + assert.Check(t, is.Equal(tm6.Family(), IPv6)) // Update nftables and check what happened. - applyAndCheck(t, tbl4, t.Name()+"_created4.golden") - applyAndCheck(t, tbl6, t.Name()+"_created46.golden") + applyAndCheck(t, tm4, t.Name()+"_created4.golden") + applyAndCheck(t, tm6, t.Name()+"_created6.golden") } func TestChain(t *testing.T) { defer testSetup(t)() - ctx := context.Background() // Create a table. tbl, err := NewTable(IPv4, "this_is_a_table") @@ -66,143 +77,176 @@ func TestChain(t *testing.T) { // Create a base chain. const bcName = "this_is_a_base_chain" - bc1, err := tbl.BaseChain(ctx, bcName, BaseChainTypeFilter, BaseChainHookForward, BaseChainPriorityFilter+10) - assert.NilError(t, err) - - // Check that it's an error to add a new base chain with the same name. - _, err = tbl.BaseChain(ctx, bcName, BaseChainTypeNAT, BaseChainHookPrerouting, BaseChainPriorityDstNAT) - assert.Check(t, is.ErrorContains(err, "already exists")) - - // Add a rule. - err = bc1.AppendRule(ctx, 0, "counter") - assert.NilError(t, err) + tm := tbl.Modifier() + bcDesc := BaseChain{ + Name: bcName, + ChainType: BaseChainTypeFilter, + Hook: BaseChainHookForward, + Priority: BaseChainPriorityFilter + 10, + Policy: "accept", + } + tm.Create(bcDesc) + // Add a rule to the base chain. + bcCounterRule := Rule{Chain: bcName, Group: 0, Rule: []string{"counter"}} + tm.Create(bcCounterRule) // Add a regular chain. const regularChainName = "this_is_a_regular_chain" - _ = tbl.Chain(ctx, regularChainName) + cDesc := Chain{Name: regularChainName} + tm.Create(cDesc) + // Add a rule to the regular chain. + cRule := Rule{Chain: regularChainName, Group: 0, Rule: []string{"counter", "accept"}} + tm.Create(cRule) - // Add a rule to the regular chain, use string formatting and a func retrieved - // from the table. - f := tbl.ChainUpdateFunc(ctx, regularChainName, true) - err = f(ctx, 0, "counter %s", "accept") - assert.Check(t, err) + // Add another rule to the base chain. + bcJumpRule := Rule{Chain: bcName, Group: 0, Rule: []string{"jump", regularChainName}} + tm.Create(bcJumpRule) - // Fetch the base chain by name. - bc1 = tbl.Chain(ctx, bcName) + // Update nftables and check what happened. + applyAndCheck(t, tm, t.Name()+"_created.golden") - // Add another rule to the base chain, using the newly-retrieved handle. - err = bc1.AppendRule(ctx, 0, "jump %s", regularChainName) - assert.Check(t, err) + // Delete a rule from the base chain. + tm = tbl.Modifier() + tm.Delete(bcCounterRule) // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_created.golden") + applyAndCheck(t, tm, t.Name()+"_modified.golden") - // Delete a rule from the base chain. - f = tbl.ChainUpdateFunc(ctx, bcName, false) - err = f(ctx, 0, "counter") - assert.Check(t, err) + // Delete the base chain. + tm = tbl.Modifier() + tm.Delete(bcJumpRule) + tm.Delete(bcDesc) + tm.Delete(cRule) + tm.Delete(cDesc) - // Check it's an error to delete that rule again. This time, call the delete - // function directly on a newly retrieved handle. - err = tbl.Chain(ctx, bcName).DeleteRule(ctx, 0, "counter") - assert.Check(t, is.ErrorContains(err, "does not exist")) + // Update nftables and check what happened. + applyAndCheck(t, tm, t.Name()+"_deleted.golden") +} - // Update the base chain's policy. - err = tbl.Chain(ctx, bcName).SetPolicy("drop") - assert.Check(t, err) +func TestSetBaseChainPolicy(t *testing.T) { + defer testSetup(t)() + baseChainName := "aBaseChain" + chainName := "aChain" - // Check it's an error to set a policy on a regular chain. - err = tbl.Chain(ctx, regularChainName).SetPolicy("drop") - assert.Check(t, is.ErrorContains(err, "not a base chain")) + tbl := Table{} + err := tbl.SetBaseChainPolicy(context.Background(), baseChainName, "accept") + assert.Check(t, is.ErrorContains(err, "invalid table")) - // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_modified.golden") + tbl, err = NewTable(IPv4, "this_is_a_table") + assert.NilError(t, err) - // Delete the base chain. - err = tbl.DeleteChain(ctx, bcName) - assert.Check(t, err) + err = tbl.SetBaseChainPolicy(context.Background(), baseChainName, "accept") + assert.Check(t, is.Error(err, "cannot set base chain policy for '"+baseChainName+"', it does not exist")) + + tm := tbl.Modifier() + tm.Create(BaseChain{Name: baseChainName, ChainType: BaseChainTypeFilter, Hook: BaseChainHookForward, Priority: BaseChainPriorityFilter}) + tm.Create(Chain{Name: chainName}) + applyAndCheck(t, tm, t.Name()+"_accept.golden") + + err = tbl.SetBaseChainPolicy(context.Background(), chainName, "accept") + assert.Check(t, is.Error(err, "cannot set base chain policy for '"+chainName+"', it is not a base chain")) - // Delete the regular chain. - err = tbl.DeleteChain(ctx, regularChainName) + err = tbl.SetBaseChainPolicy(context.Background(), baseChainName, "drop") assert.Check(t, err) + res := icmd.RunCommand("nft", "list", "table", string(IPv4), "this_is_a_table") + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), t.Name()+"_drop.golden") - // Check that it's an error to delete it again. - err = tbl.DeleteChain(ctx, regularChainName) - assert.Check(t, is.ErrorContains(err, "does not exist")) + err = tbl.SetBaseChainPolicy(context.Background(), baseChainName, "badpolicy") + assert.Check(t, err != nil, "Expected an error") + res = icmd.RunCommand("nft", "list", "table", string(IPv4), "this_is_a_table") + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), t.Name()+"_drop.golden") - // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_deleted.golden") + err = tbl.Reload(context.Background()) + assert.Check(t, err) + res = icmd.RunCommand("nft", "list", "table", string(IPv4), "this_is_a_table") + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), t.Name()+"_drop.golden") } func TestChainRuleGroups(t *testing.T) { defer testSetup(t)() - ctx := context.Background() tbl, err := NewTable(IPv4, "testtable") assert.NilError(t, err) - c := tbl.Chain(ctx, "testchain") - err = c.AppendRule(ctx, 100, "hello100") - assert.Check(t, err) - err = c.AppendRule(ctx, 200, "hello200") - assert.Check(t, err) - err = c.AppendRule(ctx, 100, "hello101") - assert.Check(t, err) - err = c.AppendRule(ctx, 200, "hello201") - assert.Check(t, err) - err = c.AppendRule(ctx, 100, "hello102") - assert.Check(t, err) + tm := tbl.Modifier() + chainName := "testchain" + tm.Create(Chain{Name: chainName}) + tm.Create(Rule{Chain: chainName, Group: 100, Rule: []string{"iifname hello100 counter"}}) + tm.Create(Rule{Chain: chainName, Group: 200, Rule: []string{"iifname hello200 counter"}}) + tm.Create(Rule{Chain: chainName, Group: 100, Rule: []string{"iifname hello101 counter"}}) + tm.Create(Rule{Chain: chainName, Group: 200, Rule: []string{"iifname hello201 counter"}}) + tm.Create(Rule{Chain: chainName, Group: 100, Rule: []string{"iifname hello102 counter"}}) + applyAndCheck(t, tm, t.Name()+".golden") +} - assert.Check(t, is.DeepEqual(c.c.Rules(), []string{ - "hello100", "hello101", "hello102", - "hello200", "hello201", - })) +func TestIgnoreExist(t *testing.T) { + defer testSetup(t)() + tbl, err := NewTable(IPv4, "this_is_a_table") + assert.NilError(t, err) + tm := tbl.Modifier() + + // Create a chain with a single rule, add the rule again but drop the duplicate. + const chainName = "this_is_a_chain" + tm.Create(Chain{Name: chainName}) + tm.Create(Rule{Chain: chainName, Rule: []string{"counter"}}) + tm.Create(Rule{Chain: chainName, Rule: []string{"counter"}, IgnoreExist: true}) + applyAndCheck(t, tm, t.Name()+"_created.golden") + + // Add the rule again, ignoring the duplicate, but in a modifier that has an + // error - check that the existing rule isn't removed by rollback of this modifier. + tmErr := tbl.Modifier() + tmErr.Create(Rule{Chain: chainName, Rule: []string{"counter"}, IgnoreExist: true}) + tmErr.Create(Rule{Chain: chainName}) + err = tmErr.Apply(context.Background()) + assert.Check(t, err != nil, "Expected an error") + + // Reload, to flush table state. + reloadAndCheck(t, tbl, IPv4, t.Name()+"_created.golden") + + // Delete the rule. + tmDel := tbl.Modifier() + tmDel.Delete(Rule{Chain: chainName, Rule: []string{"counter"}}) + applyAndCheck(t, tmDel, t.Name()+"_deleted.golden") + + // Delete it again, in another chain that will roll back, to check it's not resurrected. + tmReDel := tbl.Modifier() + tmReDel.Delete(Rule{Chain: chainName, Rule: []string{"counter"}, IgnoreExist: true}) + tmReDel.Create(Rule{Chain: chainName}) + err = tmReDel.Apply(context.Background()) + assert.Check(t, err != nil, "Expected an error") + + // Reload, to flush table state. + reloadAndCheck(t, tbl, IPv4, t.Name()+"_deleted.golden") } func TestVMap(t *testing.T) { defer testSetup(t)() - ctx := context.Background() // Create a table. tbl, err := NewTable(IPv6, "this_is_a_table") assert.NilError(t, err) + tm := tbl.Modifier() // Create a verdict map. const mapName = "this_is_a_vmap" - m := tbl.InterfaceVMap(ctx, mapName) - - // Add an element. - err = m.AddElement(ctx, "eth0", "return") - assert.Check(t, err) - - // Check that it's an error to add the element again. - err = m.AddElement(ctx, "eth0", "return") - assert.Check(t, is.ErrorContains(err, "already contains element")) - - // Fetch the existing vmap. - m = tbl.InterfaceVMap(ctx, mapName) - - // Add another element. - err = m.AddElement(ctx, "eth1", "drop") - assert.Check(t, err) + tm.Create(VMap{Name: mapName, ElementType: NftTypeIfname}) + tm.Create(VMapElement{VmapName: mapName, Key: "eth0", Verdict: "return"}) + tm.Create(VMapElement{VmapName: mapName, Key: "eth1", Verdict: "drop"}) // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_created.golden") + applyAndCheck(t, tm, t.Name()+"_created.golden") - // Delete an element. - err = m.DeleteElement(ctx, "eth1") - assert.Check(t, err) - - // Check it's an error to delete it again. - err = m.DeleteElement(ctx, "eth1") - assert.Check(t, is.ErrorContains(err, "does not contain element")) + // Undo those changes by reversing the commands. + tmRev := tm.Reverse() // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_deleted.golden") + applyAndCheck(t, tmRev, t.Name()+"_deleted.golden") } func TestSet(t *testing.T) { defer testSetup(t)() - ctx := context.Background() // Create v4 and v6 tables. tbl4, err := NewTable(IPv4, "table4") @@ -211,62 +255,55 @@ func TestSet(t *testing.T) { assert.NilError(t, err) // Create a set in each table. - s4 := tbl4.PrefixSet(ctx, "set4") - s6 := tbl6.PrefixSet(ctx, "set6") + const set4Name = "set4" + tm4 := tbl4.Modifier() + tm4.Create(Set{Name: set4Name, ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}) + const set6Name = "set6" + tm6 := tbl6.Modifier() + tm6.Create(Set{Name: set6Name, ElementType: NftTypeIPv6Addr, Flags: []string{"interval"}}) // Add elements to each set. - err = s4.AddElement(ctx, "192.0.2.1/24") - assert.Check(t, err) - err = s6.AddElement(ctx, "2001:db8::1/64") - assert.Check(t, err) - - // Check it's an error to add those elements again. - err = s4.AddElement(ctx, "192.0.2.1/24") - assert.Check(t, is.ErrorContains(err, "already contains element")) - err = s6.AddElement(ctx, "2001:db8::1/64") - assert.Check(t, is.ErrorContains(err, "already contains element")) + tm4.Create(SetElement{SetName: set4Name, Element: "192.0.2.0/24"}) + tm6.Create(SetElement{SetName: set6Name, Element: "2001:db8::/64"}) // Update nftables and check what happened. - applyAndCheck(t, tbl4, t.Name()+"_created4.golden") - applyAndCheck(t, tbl6, t.Name()+"_created46.golden") + applyAndCheck(t, tm4, t.Name()+"_created4.golden") + applyAndCheck(t, tm6, t.Name()+"_created6.golden") // Delete elements. - err = s4.DeleteElement(ctx, "192.0.2.1/24") - assert.Check(t, err) - err = s6.DeleteElement(ctx, "2001:db8::1/64") - assert.Check(t, err) - - // Check it's an error to delete those elements again. - err = s4.DeleteElement(ctx, "192.0.2.1/24") - assert.Check(t, is.ErrorContains(err, "does not contain element")) - err = s6.DeleteElement(ctx, "2001:db8::1/64") - assert.Check(t, is.ErrorContains(err, "does not contain element")) - - // Update nftables and check what happened. - applyAndCheck(t, tbl4, t.Name()+"_deleted4.golden") - applyAndCheck(t, tbl6, t.Name()+"_deleted46.golden") + applyAndCheck(t, tm4.Reverse(), t.Name()+"_deleted4.golden") + applyAndCheck(t, tm6.Reverse(), t.Name()+"_deleted6.golden") } func TestReload(t *testing.T) { defer testSetup(t)() - ctx := context.Background() // Create a table with some stuff in it. const tableName = "this_is_a_table" tbl, err := NewTable(IPv4, tableName) assert.NilError(t, err) - bc, err := tbl.BaseChain(ctx, "a_base_chain", BaseChainTypeFilter, BaseChainHookForward, BaseChainPriorityFilter) - assert.NilError(t, err) - err = bc.AppendRule(ctx, 0, "counter") - assert.NilError(t, err) - m := tbl.InterfaceVMap(ctx, "this_is_a_vmap") - err = m.AddElement(ctx, "eth0", "return") - assert.Check(t, err) - err = m.AddElement(ctx, "eth1", "return") - assert.Check(t, err) - err = tbl.PrefixSet(ctx, "set4").AddElement(ctx, "192.0.2.0/24") - assert.Check(t, err) - applyAndCheck(t, tbl, t.Name()+"_created.golden") + tm := tbl.Modifier() + + const bcName = "a_base_chain" + tm.Create(BaseChain{ + Name: bcName, + ChainType: BaseChainTypeFilter, + Hook: BaseChainHookForward, + Priority: BaseChainPriorityFilter, + Policy: "accept", + }) + tm.Create(Rule{Chain: bcName, Group: 0, Rule: []string{"counter"}}) + + const vmapName = "this_is_a_vmap" + tm.Create(VMap{Name: vmapName, ElementType: NftTypeIfname}) + tm.Create(VMapElement{VmapName: vmapName, Key: "eth0", Verdict: "return"}) + tm.Create(VMapElement{VmapName: vmapName, Key: "eth1", Verdict: "return"}) + + const setName = "this_is_a_set" + tm.Create(Set{Name: setName, ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}) + tm.Create(SetElement{SetName: setName, Element: "192.0.2.0/24"}) + + applyAndCheck(t, tm, t.Name()+"_created.golden") // Delete the underlying nftables table. deleteTable := func() { @@ -282,14 +319,371 @@ func TestReload(t *testing.T) { // Reconstruct the nftables table. err = tbl.Reload(context.Background()) assert.Check(t, err) - applyAndCheck(t, tbl, t.Name()+"_reloaded.golden") + res := icmd.RunCommand("nft", "list", "table", string(tm.Family()), tm.Name()) + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), t.Name()+"_created.golden") // Delete again. deleteTable() // Check implicit/recovery reload - only deleting something that's gone missing // from a vmap/set will trigger this. - err = m.DeleteElement(ctx, "eth1") - assert.Check(t, err) - applyAndCheck(t, tbl, t.Name()+"_recovered.golden") + tm = tbl.Modifier() + tm.Delete(SetElement{SetName: setName, Element: "192.0.2.0/24"}) + applyAndCheck(t, tm, t.Name()+"_recovered.golden") +} + +func TestValidation(t *testing.T) { + testcases := []struct { + name string + cmds []command + expErr string + }{ + // BaseChain + { + name: "create with missing base chain name", + cmds: []command{ + {obj: BaseChain{ChainType: BaseChainTypeNAT, Hook: BaseChainHookPostrouting, Priority: BaseChainPrioritySrcNAT}}, + }, + expErr: "base chain must have a name", + }, + { + name: "create with missing base chain type", + cmds: []command{ + {obj: BaseChain{Name: "achain", Hook: BaseChainHookPostrouting, Priority: BaseChainPrioritySrcNAT}}, + }, + expErr: "chain 'achain': fields ChainType and Hook are required", + }, + { + name: "create with missing base chain hook", + cmds: []command{ + {obj: BaseChain{Name: "achain", ChainType: BaseChainTypeNAT, Priority: BaseChainPrioritySrcNAT}}, + }, + expErr: "chain 'achain': fields ChainType and Hook are required", + }, + { + name: "delete non-empty base chain", + cmds: []command{ + {obj: BaseChain{ + Name: "achain", ChainType: BaseChainTypeNAT, Hook: BaseChainHookPostrouting, Priority: BaseChainPrioritySrcNAT, + }}, + {obj: Rule{Chain: "achain", Group: 0, Rule: []string{"counter"}}}, + { + obj: BaseChain{ + Name: "achain", ChainType: BaseChainTypeNAT, Hook: BaseChainHookPostrouting, Priority: BaseChainPrioritySrcNAT, + }, + delete: true, + }, + }, + expErr: "cannot delete chain 'achain', it is not empty", + }, + // Chain + { + name: "duplicate chain", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Chain{Name: "achain"}}, + }, + expErr: "already exists", + }, + { + name: "delete missing chain", + cmds: []command{ + {obj: Chain{Name: "achain"}, delete: true}, + }, + expErr: "does not exist", + }, + { + name: "missing chain name", + cmds: []command{ + {obj: Chain{}}, + }, + expErr: "chain must have a name", + }, + { + name: "delete non-empty chain", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + {obj: Chain{Name: "achain"}, delete: true}, + }, + expErr: "cannot delete chain 'achain', it is not empty", + }, + // Rule + { + name: "bad rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"this is nonsense"}}}, + }, + expErr: "syntax error", + }, + { + name: "duplicate rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + }, + expErr: "rule exists", + }, + { + name: "delete missing rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + }, + expErr: "does not exist", + }, + { + name: "duplicate rule delete", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + }, + expErr: "does not exist", + }, + { + name: "create rule with missing chain name", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Rule: []string{"counter"}}}, + }, + expErr: "chain '' does not exist", + }, + { + name: "delete rule with missing chain name", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Rule: []string{"counter"}}, delete: true}, + }, + expErr: "chain '' does not exist", + }, + { + name: "create rule with nonexistent chain", + cmds: []command{ + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + }, + expErr: "chain 'achain' does not exist", + }, + { + name: "delete rule with nonexistent chain", + cmds: []command{ + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + }, + expErr: "chain 'achain' does not exist", + }, + { + name: "create rule with no rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain"}}, + }, + expErr: "cannot add empty rule", + }, + { + name: "delete rule with no rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain"}, delete: true}, + }, + expErr: "cannot delete empty rule", + }, + { + name: "bad rule mid-sequence", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + {obj: Rule{Chain: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + }, + expErr: "chain 'achain', cannot add empty rule", + }, + // VMap + { + name: "duplicate vmap", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + }, + expErr: "vmap 'avmap' already exists", + }, + { + name: "delete nonexistent vmap", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}, delete: true}, + }, + expErr: "cannot delete vmap 'avmap', it does not exist", + }, + { + name: "missing vmap name", + cmds: []command{{obj: VMap{ElementType: NftTypeIfname}}}, + expErr: "vmap must have a name", + }, + { + name: "missing vmap element type", + cmds: []command{{obj: VMap{Name: "avmap"}}}, + expErr: "vmap 'avmap' has no element type", + }, + { + name: "delete non-empty vmap", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}}, + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}, delete: true}, + }, + expErr: "cannot delete vmap 'avmap', it contains 1 elements", + }, + // VMapElement + { + name: "duplicate vmap element", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}}, + }, + expErr: "verdict map 'avmap' already contains element 'eth0'", + }, + { + name: "add to vmap that does not exist", + cmds: []command{ + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}}, + }, + expErr: "cannot add to vmap 'avmap', it does not exist", + }, + { + name: "delete nonexistent vmap element", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}, delete: true}, + }, + expErr: "verdict map 'avmap' does not contain element 'eth0'", + }, + { + name: "vmap element with no named vmap", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{Key: "eth0", Verdict: "drop"}}, + }, + expErr: "cannot add element to unnamed vmap", + }, + { + name: "vmap element with no key", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Verdict: "drop"}}, + }, + expErr: "cannot add to vmap 'avmap', element must have key and verdict", + }, + { + name: "vmap element with no verdict", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0"}}, + }, + expErr: "cannot add to vmap 'avmap', element must have key and verdict", + }, + // Set + { + name: "duplicate set", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + }, + expErr: "set 'aset' already exists", + }, + { + name: "delete nonexistent set", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}, delete: true}, + }, + expErr: "cannot delete set 'aset', it does not exist", + }, + { + name: "missing set name", + cmds: []command{ + {obj: Set{ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + }, + expErr: "set must have a name", + }, + { + name: "missing set element type", + cmds: []command{ + {obj: Set{Name: "aset", Flags: []string{"interval"}}}, + }, + expErr: "set 'aset' must have a type", + }, + { + name: "delete non-empty set", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset", Element: "192.0.2.0/24"}}, + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}, delete: true}, + }, + expErr: "cannot delete set 'aset', it contains 1 elements", + }, + // SetElement + { + name: "duplicate set element", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset", Element: "192.0.2.0/24"}}, + {obj: SetElement{SetName: "aset", Element: "192.0.2.0/24"}}, + }, + expErr: "set 'aset' already contains element '192.0.2.0/24'", + }, + { + name: "delete nonexistent set element", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset", Element: "192.0.2.0/24"}, delete: true}, + }, + expErr: "cannot delete '192.0.2.0/24' from set 'aset', it does not exist", + }, + { + name: "add set element to unnamed set", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{Element: "192.0.2.0/24"}}, + }, + expErr: "cannot add to set '', it does not exist", + }, + { + name: "add set element with no element", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset"}}, + }, + expErr: "cannot add to set 'aset', element not specified", + }, + { + name: "mismatched set element type", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset", Element: "2001:db8::/64"}}, + }, + expErr: "Address family for hostname not supported", + }, + } + + testName := t.Name() + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + defer testSetup(t)() + tbl, err := NewTable(IPv4, "tablename") + assert.NilError(t, err) + tm := Modifier{t: tbl.t, cmds: tc.cmds} + err = tm.Apply(context.Background()) + assert.Check(t, err != nil, "expected error containing '%s'", tc.expErr) + assert.Check(t, is.ErrorContains(err, tc.expErr)) + // Check the table wasn't created. + res := icmd.RunCommand("nft", "list", "table", string(IPv4), "tablename") + res.Assert(t, icmd.Expected{ExitCode: 1}) + // Check the empty table can be created (the Table structure is still healthy). + applyAndCheck(t, tbl.Modifier(), testName+"_empty.golden") + }) + } } diff --git a/daemon/libnetwork/internal/nftables/testdata/.gitattributes b/daemon/libnetwork/internal/nftables/testdata/.gitattributes new file mode 100644 index 0000000000000..fab3c980014f5 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/.gitattributes @@ -0,0 +1 @@ +*.golden linguist-generated=true diff --git a/daemon/libnetwork/internal/nftables/testdata/TestChainRuleGroups.golden b/daemon/libnetwork/internal/nftables/testdata/TestChainRuleGroups.golden new file mode 100644 index 0000000000000..cf5befeb3e6f5 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestChainRuleGroups.golden @@ -0,0 +1,9 @@ +table ip testtable { + chain testchain { + iifname "hello100" counter packets 0 bytes 0 + iifname "hello101" counter packets 0 bytes 0 + iifname "hello102" counter packets 0 bytes 0 + iifname "hello200" counter packets 0 bytes 0 + iifname "hello201" counter packets 0 bytes 0 + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestChain_modified.golden b/daemon/libnetwork/internal/nftables/testdata/TestChain_modified.golden index efbf1dfe142d7..d9a59eb2eae13 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestChain_modified.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestChain_modified.golden @@ -1,6 +1,6 @@ table ip this_is_a_table { chain this_is_a_base_chain { - type filter hook forward priority filter + 10; policy drop; + type filter hook forward priority filter + 10; policy accept; jump this_is_a_regular_chain } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_created.golden b/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_created.golden new file mode 100644 index 0000000000000..41c0aa7251679 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_created.golden @@ -0,0 +1,5 @@ +table ip this_is_a_table { + chain this_is_a_chain { + counter packets 0 bytes 0 + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_deleted.golden b/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_deleted.golden new file mode 100644 index 0000000000000..148748dc0c14e --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_deleted.golden @@ -0,0 +1,4 @@ +table ip this_is_a_table { + chain this_is_a_chain { + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestReload_created.golden b/daemon/libnetwork/internal/nftables/testdata/TestReload_created.golden index f69a290c706c9..dec6dc28cc3a6 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestReload_created.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestReload_created.golden @@ -5,7 +5,7 @@ table ip this_is_a_table { "eth1" : return } } - set set4 { + set this_is_a_set { type ipv4_addr flags interval elements = { 192.0.2.0/24 } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestReload_recovered.golden b/daemon/libnetwork/internal/nftables/testdata/TestReload_recovered.golden index 98d703606128a..a112403abf807 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestReload_recovered.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestReload_recovered.golden @@ -1,13 +1,13 @@ table ip this_is_a_table { map this_is_a_vmap { type ifname : verdict - elements = { "eth0" : return } + elements = { "eth0" : return, + "eth1" : return } } - set set4 { + set this_is_a_set { type ipv4_addr flags interval - elements = { 192.0.2.0/24 } } chain a_base_chain { diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_accept.golden b/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_accept.golden new file mode 100644 index 0000000000000..2de16c86a3325 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_accept.golden @@ -0,0 +1,8 @@ +table ip this_is_a_table { + chain aBaseChain { + type filter hook forward priority filter; policy accept; + } + + chain aChain { + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_drop.golden b/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_drop.golden new file mode 100644 index 0000000000000..f5d80a60229f3 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_drop.golden @@ -0,0 +1,8 @@ +table ip this_is_a_table { + chain aBaseChain { + type filter hook forward priority filter; policy drop; + } + + chain aChain { + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSet_created46.golden b/daemon/libnetwork/internal/nftables/testdata/TestSet_created46.golden index ea8d0cd04a59f..f6fbbb281be1b 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestSet_created46.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestSet_created46.golden @@ -1,10 +1,3 @@ -table ip table4 { - set set4 { - type ipv4_addr - flags interval - elements = { 192.0.2.0/24 } - } -} table ip6 table6 { set set6 { type ipv6_addr diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSet_created6.golden b/daemon/libnetwork/internal/nftables/testdata/TestSet_created6.golden new file mode 100644 index 0000000000000..f6fbbb281be1b --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestSet_created6.golden @@ -0,0 +1,7 @@ +table ip6 table6 { + set set6 { + type ipv6_addr + flags interval + elements = { 2001:db8::/64 } + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted4.golden b/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted4.golden index 54f481f747604..f9b03350baee0 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted4.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted4.golden @@ -1,13 +1,2 @@ table ip table4 { - set set4 { - type ipv4_addr - flags interval - } -} -table ip6 table6 { - set set6 { - type ipv6_addr - flags interval - elements = { 2001:db8::/64 } - } } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted6.golden b/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted6.golden new file mode 100644 index 0000000000000..813114005b892 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted6.golden @@ -0,0 +1,2 @@ +table ip6 table6 { +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestTable_created46.golden b/daemon/libnetwork/internal/nftables/testdata/TestTable_created46.golden index 3bb18c8761575..d15df4247e30c 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestTable_created46.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestTable_created46.golden @@ -1,4 +1,2 @@ -table ip ipv4_table { -} table ip6 ipv6_table { } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestTable_created6.golden b/daemon/libnetwork/internal/nftables/testdata/TestTable_created6.golden new file mode 100644 index 0000000000000..d15df4247e30c --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestTable_created6.golden @@ -0,0 +1,2 @@ +table ip6 ipv6_table { +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestVMap_deleted.golden b/daemon/libnetwork/internal/nftables/testdata/TestVMap_deleted.golden index 432c30d38afda..9a1af3597868c 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestVMap_deleted.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestVMap_deleted.golden @@ -1,6 +1,2 @@ table ip6 this_is_a_table { - map this_is_a_vmap { - type ifname : verdict - elements = { "eth0" : return } - } } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestValidation_empty.golden b/daemon/libnetwork/internal/nftables/testdata/TestValidation_empty.golden new file mode 100644 index 0000000000000..593e68402a3f8 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestValidation_empty.golden @@ -0,0 +1,2 @@ +table ip tablename { +} diff --git a/daemon/libnetwork/resolver_unix.go b/daemon/libnetwork/resolver_unix.go index f7de9defa5449..57037bc16b042 100644 --- a/daemon/libnetwork/resolver_unix.go +++ b/daemon/libnetwork/resolver_unix.go @@ -97,32 +97,47 @@ func (r *Resolver) setupNftablesNAT(ctx context.Context, laddr, ltcpaddr, resolv if err != nil { return err } + tm := table.Modifier() - dnatChain, err := table.BaseChain(ctx, "dns-dnat", nftables.BaseChainTypeNAT, nftables.BaseChainHookOutput, nftables.BaseChainPriorityDstNAT) - if err != nil { - return err - } - if err := dnatChain.AppendRule(ctx, 0, "ip daddr %s udp dport %s counter dnat to %s", resolverIP, dnsPort, laddr); err != nil { - return err - } - if err := dnatChain.AppendRule(ctx, 0, "ip daddr %s tcp dport %s counter dnat to %s", resolverIP, dnsPort, ltcpaddr); err != nil { - return err - } + const dnatChain = "dns-dnat" + tm.Create(nftables.BaseChain{ + Name: dnatChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookOutput, + Priority: nftables.BaseChainPriorityDstNAT, + }) + tm.Create(nftables.Rule{ + Chain: dnatChain, + Rule: []string{"ip daddr", resolverIP, "udp dport", dnsPort, "counter dnat to", laddr}, + IgnoreExist: false, + }) + tm.Create(nftables.Rule{ + Chain: dnatChain, + Rule: []string{"ip daddr", resolverIP, "tcp dport", dnsPort, "counter dnat to", ltcpaddr}, + IgnoreExist: false, + }) - snatChain, err := table.BaseChain(ctx, "dns-snat", nftables.BaseChainTypeNAT, nftables.BaseChainHookPostrouting, nftables.BaseChainPrioritySrcNAT) - if err != nil { - return err - } - if err := snatChain.AppendRule(ctx, 0, "ip saddr %s udp sport %s counter snat to :%s", resolverIP, ipPort, dnsPort); err != nil { - return err - } - if err := snatChain.AppendRule(ctx, 0, "ip saddr %s tcp sport %s counter snat to :%s", resolverIP, tcpPort, dnsPort); err != nil { - return err - } + const snatChain = "dns-snat" + tm.Create(nftables.BaseChain{ + Name: snatChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookPostrouting, + Priority: nftables.BaseChainPrioritySrcNAT, + }) + tm.Create(nftables.Rule{ + Chain: snatChain, + Rule: []string{"ip saddr", resolverIP, "udp sport", ipPort, "counter snat to :" + dnsPort}, + IgnoreExist: false, + }) + tm.Create(nftables.Rule{ + Chain: snatChain, + Rule: []string{"ip saddr", resolverIP, "tcp sport", tcpPort, "counter snat to :" + dnsPort}, + IgnoreExist: false, + }) var setupErr error if err := r.backend.ExecFunc(func() { - setupErr = table.Apply(ctx) + setupErr = tm.Apply(ctx) }); err != nil { return err }