diff --git a/daemon/command/config_unix.go b/daemon/command/config_unix.go index fa9e3c5f9f376..2119bdffeac21 100644 --- a/daemon/command/config_unix.go +++ b/daemon/command/config_unix.go @@ -27,6 +27,7 @@ func installConfigFlags(conf *config.Config, flags *pflag.FlagSet) { flags.BoolVar(&conf.BridgeConfig.DisableFilterForwardDrop, "ip-forward-no-drop", false, "Do not set the filter-FORWARD policy to DROP when enabling IP forwarding") flags.BoolVar(&conf.BridgeConfig.EnableIPMasq, "ip-masq", true, "Enable IP masquerading for the default bridge network") flags.BoolVar(&conf.BridgeConfig.EnableIPv6, "ipv6", false, "Enable IPv6 networking for the default bridge network") + flags.Var(opts.NewNamedMapOpts("bridge-nftables-priorities", conf.NftablesPriorities, nil), "bridge-nftables-priority", "Base chain priorities for bridge driver nftables") flags.StringVar(&conf.BridgeConfig.IP, "bip", "", "IPv4 address for the default bridge") flags.StringVar(&conf.BridgeConfig.IP6, "bip6", "", "IPv6 address for the default bridge") flags.StringVarP(&conf.BridgeConfig.Iface, "bridge", "b", "", "Attach containers to a network bridge") diff --git a/daemon/config/config.go b/daemon/config/config.go index 353e2e398798f..b2b8245969ebb 100644 --- a/daemon/config/config.go +++ b/daemon/config/config.go @@ -81,13 +81,14 @@ const ( // Use this to differentiate these options // with others like the ones in TLSOptions. var flatOptions = map[string]bool{ - "cluster-store-opts": true, - "default-network-opts": true, - "log-opts": true, - "runtimes": true, - "default-ulimits": true, - "features": true, - "builder": true, + "cluster-store-opts": true, + "default-network-opts": true, + "bridge-nftables-priorities": true, + "log-opts": true, + "runtimes": true, + "default-ulimits": true, + "features": true, + "builder": true, } // skipValidateOptions contains configuration keys diff --git a/daemon/config/config_linux.go b/daemon/config/config_linux.go index 1d11a3d3ec318..2cf3d356e6424 100644 --- a/daemon/config/config_linux.go +++ b/daemon/config/config_linux.go @@ -41,14 +41,15 @@ const ( type BridgeConfig struct { DefaultBridgeConfig - EnableIPTables bool `json:"iptables,omitempty"` - EnableIP6Tables bool `json:"ip6tables,omitempty"` - EnableIPForward bool `json:"ip-forward,omitempty"` - DisableFilterForwardDrop bool `json:"ip-forward-no-drop,omitempty"` - EnableIPMasq bool `json:"ip-masq,omitempty"` - EnableUserlandProxy bool `json:"userland-proxy,omitempty"` - UserlandProxyPath string `json:"userland-proxy-path,omitempty"` - AllowDirectRouting bool `json:"allow-direct-routing,omitempty"` + EnableIPTables bool `json:"iptables,omitempty"` + EnableIP6Tables bool `json:"ip6tables,omitempty"` + EnableIPForward bool `json:"ip-forward,omitempty"` + DisableFilterForwardDrop bool `json:"ip-forward-no-drop,omitempty"` + EnableIPMasq bool `json:"ip-masq,omitempty"` + EnableUserlandProxy bool `json:"userland-proxy,omitempty"` + UserlandProxyPath string `json:"userland-proxy-path,omitempty"` + AllowDirectRouting bool `json:"allow-direct-routing,omitempty"` + NftablesPriorities map[string]string `json:"bridge-nftables-priorities,omitempty"` } // DefaultBridgeConfig stores all the parameters for the default bridge network. @@ -147,6 +148,7 @@ func setPlatformDefaults(cfg *Config) error { cfg.SeccompProfile = SeccompProfileDefault cfg.IpcMode = string(DefaultIpcMode) cfg.Runtimes = make(map[string]system.Runtime) + cfg.NftablesPriorities = make(map[string]string) if cgroups.Mode() != cgroups.Unified { cfg.CgroupNamespaceMode = string(DefaultCgroupV1NamespaceMode) @@ -243,6 +245,10 @@ func validatePlatformConfig(conf *Config) error { return errors.Wrap(err, "invalid fixed-cidr-v6") } + if err := bridge.ValidateBaseChainPriorities(conf.NftablesPriorities); err != nil { + return err + } + if err := validateFirewallBackend(conf.FirewallBackend); err != nil { return errors.Wrap(err, "invalid firewall-backend") } diff --git a/daemon/daemon_unix.go b/daemon/daemon_unix.go index ae294f333908b..1e61b77631930 100644 --- a/daemon/daemon_unix.go +++ b/daemon/daemon_unix.go @@ -938,6 +938,7 @@ func networkPlatformOptions(conf *config.Config) []nwconfig.Option { "EnableIP6Tables": conf.BridgeConfig.EnableIP6Tables, "Hairpin": !conf.EnableUserlandProxy || conf.UserlandProxyPath == "", "AllowDirectRouting": conf.BridgeConfig.AllowDirectRouting, + "NftablesPriorities": conf.NftablesPriorities, }, }), } diff --git a/daemon/libnetwork/drivers/bridge/bridge_linux.go b/daemon/libnetwork/drivers/bridge/bridge_linux.go index 3c815c908b93b..6bb566534727b 100644 --- a/daemon/libnetwork/drivers/bridge/bridge_linux.go +++ b/daemon/libnetwork/drivers/bridge/bridge_linux.go @@ -9,6 +9,7 @@ import ( "net" "net/netip" "os" + "os/signal" "slices" "strconv" "strings" @@ -43,6 +44,7 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "golang.org/x/sys/unix" ) const ( @@ -77,6 +79,7 @@ type configuration struct { // hairpinned. Hairpin bool AllowDirectRouting bool + NftablesPriorities map[string]string } // networkConfiguration for network specific configuration @@ -528,7 +531,7 @@ func (d *driver) configure(option map[string]interface{}) error { Hairpin: config.Hairpin, AllowDirectRouting: config.AllowDirectRouting, WSL2Mirrored: isRunningUnderWSL2MirroredMode(context.Background()), - }) + }, config.NftablesPriorities) if err != nil { return err } @@ -542,13 +545,19 @@ func (d *driver) configure(option map[string]interface{}) error { d.configNetwork.Lock() defer d.configNetwork.Unlock() iptables.OnReloaded(d.handleFirewalldReload) + d.startFirewallReloader() return d.initStore() } -var newFirewaller = func(ctx context.Context, config firewaller.Config) (firewaller.Firewaller, error) { +// ValidateBaseChainPriorities checks nftables base chain priority configuration. +func ValidateBaseChainPriorities(prios map[string]string) error { + return nftabler.ValidateBaseChainPriorities(prios) +} + +var newFirewaller = func(ctx context.Context, config firewaller.Config, nftablesPriorities map[string]string) (firewaller.Firewaller, error) { if nftables.Enabled() { - fw, err := nftabler.NewNftabler(ctx, config) + fw, err := nftabler.NewNftabler(ctx, config, nftablesPriorities) if err != nil { return nil, err } @@ -566,6 +575,26 @@ var newFirewaller = func(ctx context.Context, config firewaller.Config) (firewal return iptabler.NewIptabler(ctx, config) } +func (d *driver) startFirewallReloader() { + r, ok := d.firewaller.(firewaller.Reloader) + if !ok { + return + } + + hupC := make(chan os.Signal, 1) + signal.Notify(hupC, unix.SIGHUP) + go func() { + for range hupC { + d.configNetwork.Lock() + log.G(context.Background()).Info("Received SIGHUP, reloading firewall rules") + if err := r.Reload(context.Background()); err != nil { + log.G(context.Background()).Errorf("Failed to reload firewall rules: %v", err) + } + d.configNetwork.Unlock() + } + }() +} + func (d *driver) getNetwork(id string) (*bridgeNetwork, error) { d.Lock() defer d.Unlock() diff --git a/daemon/libnetwork/drivers/bridge/bridge_linux_test.go b/daemon/libnetwork/drivers/bridge/bridge_linux_test.go index f2daa3c744076..cb28925236b87 100644 --- a/daemon/libnetwork/drivers/bridge/bridge_linux_test.go +++ b/daemon/libnetwork/drivers/bridge/bridge_linux_test.go @@ -1337,7 +1337,7 @@ func TestCreateParallel(t *testing.T) { func useStubFirewaller(t *testing.T) { origNewFirewaller := newFirewaller - newFirewaller = func(_ context.Context, config firewaller.Config) (firewaller.Firewaller, error) { + newFirewaller = func(_ context.Context, config firewaller.Config, _ map[string]string) (firewaller.Firewaller, error) { return firewaller.NewStubFirewaller(config), nil } t.Cleanup(func() { newFirewaller = origNewFirewaller }) diff --git a/daemon/libnetwork/drivers/bridge/internal/firewaller/firewaller.go b/daemon/libnetwork/drivers/bridge/internal/firewaller/firewaller.go index 89622cae33085..d9deb31e308e3 100644 --- a/daemon/libnetwork/drivers/bridge/internal/firewaller/firewaller.go +++ b/daemon/libnetwork/drivers/bridge/internal/firewaller/firewaller.go @@ -106,6 +106,13 @@ type Network interface { DelLink(ctx context.Context, parentIP, childIP netip.Addr, ports []types.TransportPort) } +// Reloader is an optional interface for a Firewaller. +type Reloader interface { + // Reload the current firewall rules. The caller is responsible for locking, + // there must not be any concurrent requests to modify rules. + Reload(ctx context.Context) error +} + // FirewallCleanerSetter is an optional interface for a Firewaller. type FirewallCleanerSetter interface { // SetFirewallCleaner replaces the FirewallCleaner (possibly with 'nil'). diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/endpoint.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/endpoint.go index 1fdac1ca072af..4e39f57d3ff90 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/endpoint.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/endpoint.go @@ -25,18 +25,24 @@ func (n *network) DelEndpoint(ctx context.Context, epIPv4, epIPv6 netip.Addr) er func (n *network) modEndpoint(ctx context.Context, epIPv4, epIPv6 netip.Addr, enable bool) error { if n.fw.config.IPv4 && epIPv4.IsValid() { - if err := n.filterDirectAccess(ctx, n.fw.table4, n.config.Config4, epIPv4, enable); err != nil { - return err + tm := n.fw.table4.Modifier() + updater := tm.Create + if !enable { + updater = tm.Delete } - if err := nftApply(ctx, n.fw.table4); err != nil { + n.filterDirectAccess(updater, tm.Family(), n.config.Config4, epIPv4) + if err := tm.Apply(ctx); err != nil { return fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } } if n.fw.config.IPv6 && epIPv6.IsValid() { - if err := n.filterDirectAccess(ctx, n.fw.table6, n.config.Config6, epIPv6, enable); err != nil { - return err + tm := n.fw.table6.Modifier() + updater := tm.Create + if !enable { + updater = tm.Delete } - if err := nftApply(ctx, n.fw.table6); err != nil { + n.filterDirectAccess(updater, tm.Family(), n.config.Config6, epIPv6) + if err := tm.Apply(ctx); err != nil { return fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } } @@ -53,17 +59,21 @@ func (n *network) modEndpoint(ctx context.Context, epIPv4, epIPv6 netip.Addr, en // kernel support). // // Packets originating on the bridge's own interface and addressed directly to the -// container are allowed - the host always has direct access to its own containers -// (it doesn't need to use the port mapped to its own addresses, although it can). +// container are allowed - the host always has direct access to its own containers. +// (It doesn't need to use the port mapped to its own addresses, although it can.) // // "Trusted interfaces" are treated in the same way as the bridge itself. -func (n *network) filterDirectAccess(ctx context.Context, table nftables.TableRef, conf firewaller.NetworkConfigFam, epIP netip.Addr, enable bool) error { +func (n *network) filterDirectAccess(updater func(nftables.Obj), fam nftables.Family, conf firewaller.NetworkConfigFam, epIP netip.Addr) { if n.config.Internal || conf.Unprotected || conf.Routed || n.fw.config.AllowDirectRouting { - return nil + return } - updater := table.ChainUpdateFunc(ctx, rawPreroutingChain, enable) ifNames := strings.Join(n.config.TrustedHostInterfaces, ", ") - return updater(ctx, rawPreroutingPortsRuleGroup, - `%s daddr %s iifname != { %s, %s } counter drop comment "DROP DIRECT ACCESS"`, - table.Family(), epIP, n.config.IfName, ifNames) + updater(nftables.Rule{ + Chain: rawPreroutingChain, + Group: rawPreroutingPortsRuleGroup, + Rule: []string{ + string(fam), "daddr", epIP.String(), + "iifname != {", n.config.IfName, ",", ifNames, `} counter drop comment "DROP DIRECT ACCESS"`, + }, + }) } diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/link.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/link.go index d612b767a6de8..6179a7f28bf52 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/link.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/link.go @@ -7,6 +7,9 @@ import ( "errors" "fmt" "net/netip" + "strconv" + + "github.com/docker/docker/daemon/libnetwork/internal/nftables" "github.com/containerd/log" "github.com/docker/docker/daemon/libnetwork/types" @@ -20,44 +23,47 @@ func (n *network) AddLink(ctx context.Context, parentIP, childIP netip.Addr, por return errors.New("cannot link to a container with an empty child IP address") } - chain := n.fw.table4.Chain(ctx, chainFilterFwdIn(n.config.IfName)) + tm := n.fw.table4.Modifier() for _, port := range ports { - for _, rule := range legacyLinkRules(parentIP, childIP, port) { - if err := chain.AppendRule(ctx, fwdInLegacyLinksRuleGroup, rule); err != nil { - return err - } - } + updateLegacyLinkRules(tm.Create, chainFilterFwdIn(n.config.IfName), parentIP, childIP, port) } - if err := nftApply(ctx, n.fw.table4); err != nil { + if err := tm.Apply(ctx); err != nil { return fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } return nil } func (n *network) DelLink(ctx context.Context, parentIP, childIP netip.Addr, ports []types.TransportPort) { - chain := n.fw.table4.Chain(ctx, chainFilterFwdIn(n.config.IfName)) + tm := n.fw.table4.Modifier() for _, port := range ports { - for _, rule := range legacyLinkRules(parentIP, childIP, port) { - if err := chain.DeleteRule(ctx, fwdInLegacyLinksRuleGroup, rule); err != nil { - log.G(ctx).WithFields(log.Fields{ - "rule": rule, - "error": err, - }).Warn("Failed to remove link between containers") - } - } + updateLegacyLinkRules(tm.Delete, chainFilterFwdIn(n.config.IfName), parentIP, childIP, port) } - if err := nftApply(ctx, n.fw.table4); err != nil { + if err := tm.Apply(ctx); err != nil { log.G(ctx).WithError(err).Warn("Removing link, failed to update nftables") } } -func legacyLinkRules(parentIP, childIP netip.Addr, port types.TransportPort) []string { +func updateLegacyLinkRules(updater func(command nftables.Obj), chainName string, parentIP, childIP netip.Addr, port types.TransportPort) { // TODO(robmry) - could combine rules for each proto by using an anonymous set. - return []string{ - // Match the iptables implementation, but without checking iifname/oifname (not needed - // because the addresses belong to the bridge). - fmt.Sprintf("ip saddr %s ip daddr %s %s dport %d counter accept", parentIP.Unmap(), childIP.Unmap(), port.Proto, port.Port), - // Conntrack will allow responses. So, this must be to allow unsolicited packets from an exposed port. - fmt.Sprintf("ip daddr %s ip saddr %s %s sport %d counter accept", parentIP.Unmap(), childIP.Unmap(), port.Proto, port.Port), - } + // Match the iptables implementation, but without checking iifname/oifname (not needed + // because the addresses belong to the bridge). + updater(nftables.Rule{ + Chain: chainName, + Group: fwdInLegacyLinksRuleGroup, + Rule: []string{ + "ip saddr", parentIP.Unmap().String(), + "ip daddr", childIP.Unmap().String(), port.Proto.String(), "dport", strconv.Itoa(int(port.Port)), + "counter accept", + }, + }) + // Conntrack will allow responses. So, this must be to allow unsolicited packets from an exposed port. + updater(nftables.Rule{ + Chain: chainName, + Group: fwdInLegacyLinksRuleGroup, + Rule: []string{ + "ip daddr", parentIP.Unmap().String(), + "ip saddr", childIP.Unmap().String(), port.Proto.String(), "sport", strconv.Itoa(int(port.Port)), + "counter accept", + }, + }) } diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/network.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/network.go index e9ced708e9374..5d05db4f9c5df 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/network.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/network.go @@ -9,14 +9,14 @@ import ( "github.com/containerd/log" "github.com/docker/docker/daemon/libnetwork/drivers/bridge/internal/firewaller" "github.com/docker/docker/daemon/libnetwork/internal/nftables" - "github.com/docker/docker/internal/cleanups" "go.opentelemetry.io/otel" ) type network struct { - config firewaller.NetworkConfig - cleaner func(ctx context.Context) error - fw *nftabler + config firewaller.NetworkConfig + remover4 *nftables.Modifier + remover6 *nftables.Modifier + fw *nftabler } func (nft *nftabler) NewNetwork(ctx context.Context, nc firewaller.NetworkConfig) (_ firewaller.Network, retErr error) { @@ -26,108 +26,84 @@ func (nft *nftabler) NewNetwork(ctx context.Context, nc firewaller.NetworkConfig } ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"bridge": n.config.IfName})) - var cleaner cleanups.Composite - defer func() { - if err := cleaner.Call(ctx); err != nil { - log.G(ctx).WithError(err).Warn("Failed to clean up nftables rules for network") - } - }() - if nft.cleaner != nil { nft.cleaner.DelNetwork(ctx, nc) } if n.fw.config.IPv4 { - clean, err := n.configure(ctx, nft.table4, n.config.Config4) + remover, err := n.configure(ctx, nft.table4, n.config.Config4) if err != nil { return nil, err } - if clean != nil { - cleaner.Add(clean) - } + n.remover4 = remover } if n.fw.config.IPv6 { - clean, err := n.configure(ctx, nft.table6, n.config.Config6) + remover, err := n.configure(ctx, nft.table6, n.config.Config6) if err != nil { return nil, err } - if clean != nil { - cleaner.Add(clean) - } + n.remover6 = remover } - - n.cleaner = cleaner.Release() return n, nil } -func (n *network) configure(ctx context.Context, table nftables.TableRef, conf firewaller.NetworkConfigFam) (func(context.Context) error, error) { - ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".newNetwork."+string(table.Family())) - defer span.End() - +func (n *network) configure(ctx context.Context, table nftables.Table, conf firewaller.NetworkConfigFam) (*nftables.Modifier, error) { if !conf.Prefix.IsValid() { return nil, nil } + tm := table.Modifier() + ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".newNetwork."+string(tm.Family())) + defer span.End() - var cleanup cleanups.Composite - defer cleanup.Call(ctx) - var applied bool - cleanup.Add(func(ctx context.Context) error { - if applied { - return nftApply(ctx, table) - } - return nil - }) + fwdInChain := chainFilterFwdIn(n.config.IfName) + fwdOutChain := chainFilterFwdOut(n.config.IfName) + natPostRtInChain := chainNatPostRtIn(n.config.IfName) + natPostRtOutChain := chainNatPostRtOut(n.config.IfName) // Filter chain - fwdInChain := table.Chain(ctx, chainFilterFwdIn(n.config.IfName)) - cleanup.Add(func(ctx context.Context) error { return table.DeleteChain(ctx, chainFilterFwdIn(n.config.IfName)) }) - fwdOutChain := table.Chain(ctx, chainFilterFwdOut(n.config.IfName)) - cleanup.Add(func(ctx context.Context) error { return table.DeleteChain(ctx, chainFilterFwdOut(n.config.IfName)) }) + tm.Create(nftables.Chain{Name: fwdInChain}) + tm.Create(nftables.Chain{Name: fwdOutChain}) - cf, err := table.InterfaceVMap(ctx, filtFwdInVMap).AddElementCf(ctx, n.config.IfName, "jump "+chainFilterFwdIn(n.config.IfName)) - if err != nil { - return nil, fmt.Errorf("adding filter-forward jump for %s to %q: %w", conf.Prefix, chainFilterFwdIn(n.config.IfName), err) - } - cleanup.Add(cf) - - cf, err = table.InterfaceVMap(ctx, filtFwdOutVMap).AddElementCf(ctx, n.config.IfName, "jump "+chainFilterFwdOut(n.config.IfName)) - if err != nil { - return nil, fmt.Errorf("adding filter-forward jump for %s to %q: %w", conf.Prefix, chainFilterFwdOut(n.config.IfName), err) - } - cleanup.Add(cf) + tm.Create(nftables.VMapElement{ + VmapName: filtFwdInVMap, + Key: n.config.IfName, + Verdict: "jump " + fwdInChain, + }) + tm.Create(nftables.VMapElement{ + VmapName: filtFwdOutVMap, + Key: n.config.IfName, + Verdict: "jump " + fwdOutChain, + }) // NAT chain - natPostroutingIn := table.Chain(ctx, chainNatPostRtIn(n.config.IfName)) - cleanup.Add(func(ctx context.Context) error { return table.DeleteChain(ctx, chainNatPostRtIn(n.config.IfName)) }) - cf, err = table.InterfaceVMap(ctx, natPostroutingInVMap).AddElementCf(ctx, n.config.IfName, "jump "+chainNatPostRtIn(n.config.IfName)) - if err != nil { - return nil, fmt.Errorf("adding postrouting ingress jump for %s to %q: %w", conf.Prefix, chainNatPostRtIn(n.config.IfName), err) - } - cleanup.Add(cf) + tm.Create(nftables.Chain{Name: natPostRtInChain}) + tm.Create(nftables.VMapElement{ + VmapName: natPostroutingInVMap, + Key: n.config.IfName, + Verdict: "jump " + natPostRtInChain, + }) - natPostroutingOut := table.Chain(ctx, chainNatPostRtOut(n.config.IfName)) - cleanup.Add(func(ctx context.Context) error { return table.DeleteChain(ctx, chainNatPostRtOut(n.config.IfName)) }) - cf, err = table.InterfaceVMap(ctx, natPostroutingOutVMap).AddElementCf(ctx, n.config.IfName, "jump "+chainNatPostRtOut(n.config.IfName)) - if err != nil { - return nil, fmt.Errorf("adding postrouting egress jump for %s to %q: %w", conf.Prefix, chainNatPostRtOut(n.config.IfName), err) - } - cleanup.Add(cf) + tm.Create(nftables.Chain{Name: chainNatPostRtOut(n.config.IfName)}) + tm.Create(nftables.VMapElement{ + VmapName: natPostroutingOutVMap, + Key: n.config.IfName, + Verdict: "jump " + chainNatPostRtOut(n.config.IfName), + }) // Conntrack - cf, err = fwdInChain.AppendRuleCf(ctx, initialRuleGroup, "ct state established,related counter accept") - if err != nil { - return nil, fmt.Errorf("adding conntrack ingress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) - - cf, err = fwdOutChain.AppendRuleCf(ctx, initialRuleGroup, "ct state established,related counter accept") - if err != nil { - return nil, fmt.Errorf("adding conntrack egress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: chainFilterFwdIn(n.config.IfName), + Group: initialRuleGroup, + Rule: []string{"ct state established,related counter accept"}, + }) + tm.Create(nftables.Rule{ + Chain: chainFilterFwdOut(n.config.IfName), + Group: initialRuleGroup, + Rule: []string{"ct state established,related counter accept"}, + }) iccVerdict := "accept" if !n.config.ICC { @@ -136,68 +112,64 @@ func (n *network) configure(ctx context.Context, table nftables.TableRef, conf f if n.config.Internal { // Drop anything that's not from this network. - cf, err = fwdInChain.AppendRuleCf(ctx, initialRuleGroup, - `iifname != %s counter drop comment "INTERNAL NETWORK INGRESS"`, n.config.IfName) - if err != nil { - return nil, fmt.Errorf("adding INTERNAL NETWORK ingress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) - - cf, err = fwdOutChain.AppendRuleCf(ctx, initialRuleGroup, - `oifname != %s counter drop comment "INTERNAL NETWORK EGRESS"`, n.config.IfName) - if err != nil { - return nil, fmt.Errorf("adding INTERNAL NETWORK egress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: initialRuleGroup, + Rule: []string{`iifname != `, n.config.IfName, `counter drop comment "INTERNAL NETWORK INGRESS"`}, + }) + tm.Create(nftables.Rule{ + Chain: fwdOutChain, + Group: initialRuleGroup, + Rule: []string{`oifname != `, n.config.IfName, `counter drop comment "INTERNAL NETWORK EGRESS"`}, + }) // Accept or drop Inter-Container Communication. - cf, err = fwdInChain.AppendRuleCf(ctx, fwdInICCRuleGroup, "counter %s comment ICC", iccVerdict) - if err != nil { - return nil, fmt.Errorf("adding ICC ingress rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: fwdInICCRuleGroup, + Rule: []string{"counter", iccVerdict, "comment ICC"}, + }) } else { // Inter-Container Communication - cf, err = fwdInChain.AppendRuleCf(ctx, fwdInICCRuleGroup, "iifname == %s counter %s comment ICC", - n.config.IfName, iccVerdict) - if err != nil { - return nil, fmt.Errorf("adding ICC rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: fwdInICCRuleGroup, + Rule: []string{"iifname ==", n.config.IfName, "counter", iccVerdict, "comment ICC"}, + }) // Outgoing traffic - cf, err = fwdOutChain.AppendRuleCf(ctx, initialRuleGroup, "counter accept comment OUTGOING") - if err != nil { - return nil, fmt.Errorf("adding OUTGOING rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdOutChain, + Group: initialRuleGroup, + Rule: []string{"counter accept comment OUTGOING"}, + }) // Incoming traffic if conf.Unprotected { - cf, err = fwdInChain.AppendRuleCf(ctx, fwdInFinalRuleGroup, `counter accept comment "UNPROTECTED"`) - if err != nil { - return nil, fmt.Errorf("adding UNPROTECTED for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: fwdInFinalRuleGroup, + Rule: []string{`counter accept comment "UNPROTECTED"`}, + }) } else { - cf, err = fwdInChain.AppendRuleCf(ctx, fwdInFinalRuleGroup, `counter drop comment "UNPUBLISHED PORT DROP"`) - if err != nil { - return nil, fmt.Errorf("adding UNPUBLISHED PORT DROP for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: fwdInFinalRuleGroup, + Rule: []string{`counter drop comment "UNPUBLISHED PORT DROP"`}, + }) } // ICMP if conf.Routed { rule := "ip protocol icmp" - if table.Family() == nftables.IPv6 { + if tm.Family() == nftables.IPv6 { rule = "meta l4proto ipv6-icmp" } - cf, err = fwdInChain.AppendRuleCf(ctx, initialRuleGroup, rule+" counter accept comment ICMP") - if err != nil { - return nil, fmt.Errorf("adding ICMP rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: fwdInChain, + Group: initialRuleGroup, + Rule: []string{rule, "counter accept comment ICMP"}, + }) } // Masquerade / SNAT - masquerade picks a source IP address based on next-hop, SNAT uses conf.HostIP. @@ -208,34 +180,34 @@ func (n *network) configure(ctx context.Context, table nftables.TableRef, conf f natPostroutingComment = "SNAT" } if n.config.Masquerade && !conf.Routed { - cf, err = natPostroutingOut.AppendRuleCf(ctx, initialRuleGroup, `oifname != %s %s saddr %s counter %s comment "%s"`, - n.config.IfName, table.Family(), conf.Prefix, natPostroutingVerdict, natPostroutingComment) - if err != nil { - return nil, fmt.Errorf("adding NAT rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: natPostRtOutChain, + Group: initialRuleGroup, + Rule: []string{ + "oifname !=", n.config.IfName, string(tm.Family()), "saddr", conf.Prefix.String(), "counter", + natPostroutingVerdict, "comment", natPostroutingComment, + }, + }) } if n.fw.config.Hairpin { - // Masquerade/SNAT traffic from localhost. - cf, err = natPostroutingIn.AppendRuleCf(ctx, initialRuleGroup, `fib saddr type local counter %s comment "%s FROM HOST"`, - natPostroutingVerdict, natPostroutingComment) - if err != nil { - return nil, fmt.Errorf("adding NAT local rule for %q: %w", n.config.IfName, err) - } - cleanup.Add(cf) + tm.Create(nftables.Rule{ + Chain: natPostRtInChain, + Group: initialRuleGroup, + Rule: []string{ + `fib saddr type local counter`, natPostroutingVerdict, `comment "` + natPostroutingComment + ` FROM HOST"`, + }, + }) } } ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{ "bridge": n.config.IfName, - "family": table.Family(), + "family": tm.Family(), })) - if err := nftApply(ctx, table); err != nil { + if err := tm.Apply(ctx); err != nil { return nil, fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } - applied = true - - return cleanup.Release(), nil + return tm.Reverse(), nil } func (n *network) ReapplyNetworkLevelRules(ctx context.Context) error { @@ -245,13 +217,18 @@ func (n *network) ReapplyNetworkLevelRules(ctx context.Context) error { } func (n *network) DelNetworkLevelRules(ctx context.Context) error { - if n.cleaner != nil { - ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"bridge": n.config.IfName})) - if err := n.cleaner(ctx); err != nil { - log.G(ctx).WithError(err).Warn("Failed to remove network rules for network") + remove := func(remover *nftables.Modifier) { + if remover != nil { + ctx := log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"bridge": n.config.IfName})) + if err := remover.Apply(ctx); err != nil { + log.G(ctx).WithError(err).Warn("Failed to remove network rules for network") + } } - n.cleaner = nil } + remove(n.remover4) + n.remover4 = nil + remove(n.remover6) + n.remover6 = nil return nil } diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go index 1be24340cf92f..39203af1eec99 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler.go @@ -4,12 +4,13 @@ package nftabler import ( "context" + "errors" "fmt" + "strconv" "github.com/containerd/log" "github.com/docker/docker/daemon/libnetwork/drivers/bridge/internal/firewaller" "github.com/docker/docker/daemon/libnetwork/internal/nftables" - "go.opentelemetry.io/otel" ) // Prefix for OTEL span names. @@ -44,47 +45,78 @@ const ( rawPreroutingPortsRuleGroup = iota + initialRuleGroup + 1 ) +var baseChainNames = map[string]struct{}{ + forwardChain: {}, + postroutingChain: {}, + preroutingChain: {}, + outputChain: {}, + rawPreroutingChain: {}, +} + type nftabler struct { config firewaller.Config cleaner firewaller.FirewallCleaner - table4 nftables.TableRef - table6 nftables.TableRef + table4 nftables.Table + table6 nftables.Table } -func NewNftabler(ctx context.Context, config firewaller.Config) (firewaller.Firewaller, error) { +func NewNftabler(ctx context.Context, config firewaller.Config, baseChainPriorities map[string]string) (firewaller.Firewaller, error) { nft := &nftabler{config: config} + // Convert base chain priorities to integers, assuming the daemon has called + // ValidateBaseChainPriorities, so errors don't need to be handled. + bcps := map[string]int{} + for chain, prio := range baseChainPriorities { + if p, err := strconv.Atoi(prio); err == nil { + bcps[chain] = p + } + } + if nft.config.IPv4 { var err error - nft.table4, err = nft.init(ctx, nftables.IPv4) + nft.table4, err = nft.init(ctx, nftables.IPv4, bcps) if err != nil { return nil, err } - if err := nftApply(ctx, nft.table4); err != nil { - return nil, fmt.Errorf("IPv4 initialisation: %w", err) - } } if nft.config.IPv6 { var err error - nft.table6, err = nft.init(ctx, nftables.IPv6) + nft.table6, err = nft.init(ctx, nftables.IPv6, bcps) if err != nil { return nil, err } + } + + return nft, nil +} - if err := nftApply(ctx, nft.table6); err != nil { - // Perhaps the kernel has no IPv6 support. It won't be possible to create IPv6 - // networks without enabling ip6_tables in the kernel, or disabling ip6tables in - // the daemon config. But, allow the daemon to start because IPv4 will work. So, - // log the problem, and continue. - log.G(ctx).WithError(err).Warn("ip6tables is enabled, but cannot set up IPv6 nftables table") +// ValidateBaseChainPriorities checks nftables base chain priority configuration. +func ValidateBaseChainPriorities(prios map[string]string) error { + var errs []error + for c, p := range prios { + if _, ok := baseChainNames[c]; !ok { + errs = append(errs, fmt.Errorf("%q is not a valid base chain name", c)) + } + if _, ok := strconv.Atoi(p); ok != nil { + errs = append(errs, fmt.Errorf("priority %q for base chain %q is not an integer", p, c)) } } + return errors.Join(errs...) +} - return nft, nil +func (nft *nftabler) Reload(ctx context.Context) error { + var errs []error + if nft.config.IPv4 { + errs = append(errs, nft.table4.Reload(ctx)) + } + if nft.config.IPv6 { + errs = append(errs, nft.table6.Reload(ctx)) + } + return errors.Join(errs...) } -func (nft *nftabler) getTable(ipv firewaller.IPVersion) nftables.TableRef { +func (nft *nftabler) getTable(ipv firewaller.IPVersion) nftables.Table { if ipv == firewaller.IPv4 { return nft.table4 } @@ -92,21 +124,30 @@ func (nft *nftabler) getTable(ipv firewaller.IPVersion) nftables.TableRef { } func (nft *nftabler) FilterForwardDrop(ctx context.Context, ipv firewaller.IPVersion) error { - table := nft.getTable(ipv) - if err := table.Chain(ctx, forwardChain).SetPolicy("drop"); err != nil { - return err + if err := nft.getTable(ipv).SetBaseChainPolicy(ctx, forwardChain, nftables.BaseChainPolicyDrop); err != nil { + return fmt.Errorf("setting IPv%d filter-forward drop: %w", ipv, err) } - return nftApply(ctx, table) + return nil } // init creates the bridge driver's nftables table for IPv4 or IPv6. -func (nft *nftabler) init(ctx context.Context, family nftables.Family) (nftables.TableRef, error) { +func (nft *nftabler) init(ctx context.Context, family nftables.Family, baseChainPriorities map[string]int) (nftables.Table, error) { // Instantiate the table. table, err := nftables.NewTable(family, dockerTable) if err != nil { return table, err } + // Reload the table while it's got no elements to clear an old table if one + // exists. This is necessary because, if base chain priorities have changed and + // the old table isn't removed, nft produces an error message for the base chain + // (but seems to apply the change anyway). + if err := table.Reload(ctx); err != nil { + return nftables.Table{}, err + } + + tm := table.Modifier() + // Set up the filter forward chain. // // This base chain only contains two rules that use verdict maps: @@ -116,65 +157,89 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family) (nftables // So, packets that aren't related to docker don't need to traverse any per-network filter forward // rules - and packets that are entering or leaving docker networks only need to traverse rules // related to those networks. - fwdChain, err := table.BaseChain(ctx, forwardChain, - nftables.BaseChainTypeFilter, - nftables.BaseChainHookForward, - nftables.BaseChainPriorityFilter) - if err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.BaseChain{ + Name: forwardChain, + ChainType: nftables.BaseChainTypeFilter, + Hook: nftables.BaseChainHookForward, + Priority: baseChainPriority(forwardChain, nftables.BaseChainPriorityFilter, baseChainPriorities), + }) // Instantiate the verdict maps and add the jumps. - _ = table.InterfaceVMap(ctx, filtFwdInVMap) - if err := fwdChain.AppendRule(ctx, initialRuleGroup, "oifname vmap @"+filtFwdInVMap); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } - _ = table.InterfaceVMap(ctx, filtFwdOutVMap) - if err := fwdChain.AppendRule(ctx, initialRuleGroup, "iifname vmap @"+filtFwdOutVMap); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.VMap{ + Name: filtFwdInVMap, + ElementType: nftables.NftTypeIfname, + }) + tm.Create(nftables.Rule{ + Chain: forwardChain, + Group: initialRuleGroup, + Rule: []string{"oifname vmap @", filtFwdInVMap}, + }) + + tm.Create(nftables.VMap{ + Name: filtFwdOutVMap, + ElementType: nftables.NftTypeIfname, + }) + tm.Create(nftables.Rule{ + Chain: forwardChain, + Group: initialRuleGroup, + Rule: []string{"iifname vmap @", filtFwdOutVMap}, + }) // Set up the NAT postrouting base chain. // // Like the filter-forward chain, its only rules are jumps to network-specific ingress and egress chains. - natPostRtChain, err := table.BaseChain(ctx, postroutingChain, - nftables.BaseChainTypeNAT, - nftables.BaseChainHookPostrouting, - nftables.BaseChainPrioritySrcNAT) - if err != nil { - return nftables.TableRef{}, err - } - _ = table.InterfaceVMap(ctx, natPostroutingOutVMap) - if err := natPostRtChain.AppendRule(ctx, initialRuleGroup, "iifname vmap @"+natPostroutingOutVMap); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } - _ = table.InterfaceVMap(ctx, natPostroutingInVMap) - if err := natPostRtChain.AppendRule(ctx, initialRuleGroup, "oifname vmap @"+natPostroutingInVMap); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.BaseChain{ + Name: postroutingChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookPostrouting, + Priority: baseChainPriority(postroutingChain, nftables.BaseChainPrioritySrcNAT, baseChainPriorities), + }) + + tm.Create(nftables.VMap{ + Name: natPostroutingOutVMap, + ElementType: nftables.NftTypeIfname, + }) + tm.Create(nftables.Rule{ + Chain: postroutingChain, + Group: initialRuleGroup, + Rule: []string{"iifname vmap @", natPostroutingOutVMap}, + }) + + tm.Create(nftables.VMap{ + Name: natPostroutingInVMap, + ElementType: nftables.NftTypeIfname, + }) + tm.Create(nftables.Rule{ + Chain: postroutingChain, + Group: initialRuleGroup, + Rule: []string{"oifname vmap @", natPostroutingInVMap}, + }) // Instantiate natChain, for the NAT prerouting and output base chains to jump to. - _ = table.Chain(ctx, natChain) + tm.Create(nftables.Chain{ + Name: natChain, + }) // Set up the NAT prerouting base chain. - natPreRtChain, err := table.BaseChain(ctx, preroutingChain, - nftables.BaseChainTypeNAT, - nftables.BaseChainHookPrerouting, - nftables.BaseChainPriorityDstNAT) - if err != nil { - return nftables.TableRef{}, err - } - if err := natPreRtChain.AppendRule(ctx, initialRuleGroup, "fib daddr type local counter jump "+natChain); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.BaseChain{ + Name: preroutingChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookPrerouting, + Priority: baseChainPriority(preroutingChain, nftables.BaseChainPriorityDstNAT, baseChainPriorities), + }) + tm.Create(nftables.Rule{ + Chain: preroutingChain, + Group: initialRuleGroup, + Rule: []string{"fib daddr type local counter jump", natChain}, + }) // Set up the NAT output base chain - natOutputChain, err := table.BaseChain(ctx, outputChain, - nftables.BaseChainTypeNAT, - nftables.BaseChainHookOutput, - nftables.BaseChainPriorityDstNAT) - if err != nil { - return nftables.TableRef{}, err - } + tm.Create(nftables.BaseChain{ + Name: outputChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookOutput, + Priority: baseChainPriority(outputChain, nftables.BaseChainPriorityDstNAT, baseChainPriorities), + }) + // For output, don't jump to the NAT chain if hairpin is enabled (no userland proxy). var skipLoopback string if !nft.config.Hairpin { @@ -184,32 +249,41 @@ func (nft *nftabler) init(ctx context.Context, family nftables.Family) (nftables skipLoopback = "ip6 daddr != ::1 " } } - if err := natOutputChain.AppendRule(ctx, initialRuleGroup, skipLoopback+"fib daddr type local counter jump "+natChain); err != nil { - return nftables.TableRef{}, fmt.Errorf("initialising nftables: %w", err) - } + tm.Create(nftables.Rule{ + Chain: outputChain, + Group: initialRuleGroup, + Rule: []string{skipLoopback, "fib daddr type local counter jump", natChain}, + }) // Set up the raw prerouting base chain - if _, err := table.BaseChain(ctx, rawPreroutingChain, - nftables.BaseChainTypeFilter, - nftables.BaseChainHookPrerouting, - nftables.BaseChainPriorityRaw); err != nil { - return nftables.TableRef{}, err - } + tm.Create(nftables.BaseChain{ + Name: rawPreroutingChain, + ChainType: nftables.BaseChainTypeFilter, + Hook: nftables.BaseChainHookPrerouting, + Priority: baseChainPriority(rawPreroutingChain, nftables.BaseChainPriorityRaw, baseChainPriorities), + }) if !nft.config.Hairpin && nft.config.WSL2Mirrored { - if err := mirroredWSL2Workaround(ctx, table); err != nil { - return nftables.TableRef{}, err - } + mirroredWSL2Workaround(tm) } + if err := tm.Apply(ctx); err != nil { + if family == nftables.IPv4 { + return nftables.Table{}, err + } + // Perhaps the kernel has no IPv6 support. It won't be possible to create IPv6 + // networks without enabling ip6_tables in the kernel, or disabling ip6tables in + // the daemon config. But, allow the daemon to start because IPv4 will work. So, + // log the problem, and continue. + log.G(ctx).WithError(err).Warn("ip6tables is enabled, but cannot set up IPv6 nftables table") + return nftables.Table{}, nil + } return table, nil } -func nftApply(ctx context.Context, table nftables.TableRef) error { - ctx, span := otel.Tracer("").Start(ctx, spanPrefix+".nftApply."+string(table.Family())) - defer span.End() - if err := table.Apply(ctx); err != nil { - return fmt.Errorf("applying nftables rules: %w", err) +func baseChainPriority(chainName string, def int, overrides map[string]int) int { + if p, ok := overrides[chainName]; ok { + return p } - return nil + return def } diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go index 0dd380b922eba..81ef070663d4a 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/nftabler_test.go @@ -125,7 +125,7 @@ func testNftabler(t *testing.T, tn string, config firewaller.Config, netConfig f // Initialise iptables, check the iptables config looks like it should look at the // end of the test (after deleting per-network and per-port rules). - fw, err := NewNftabler(context.Background(), config) + fw, err := NewNftabler(context.Background(), config, nil) assert.NilError(t, err) checkResults("ip", rnWSL2Mirrored(fmt.Sprintf("%s_cleaned,hairpin=%v", tn, config.Hairpin)), config.IPv4) checkResults("ip6", fmt.Sprintf("%s_cleaned,hairpin=%v", tn, config.Hairpin), config.IPv6) diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/port.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/port.go index 8447f23759576..a6276daa988c1 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/port.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/port.go @@ -5,7 +5,6 @@ package nftabler import ( "context" - "errors" "fmt" "net" "strconv" @@ -17,9 +16,10 @@ import ( ) type pbContext struct { - table nftables.TableRef - conf firewaller.NetworkConfigFam - ipv firewaller.IPVersion + table nftables.Table + updater func(nftables.Obj) + conf firewaller.NetworkConfigFam + ipv nftables.Family } func (n *network) AddPorts(ctx context.Context, pbs []types.PortBinding) error { @@ -43,13 +43,13 @@ func (n *network) modPorts(ctx context.Context, pbs []types.PortBinding, enable pbs4, pbs6 := splitByContainerFam(pbs) if n.fw.config.IPv4 && n.config.Config4.Prefix.IsValid() { - pbc := pbContext{table: n.fw.table4, conf: n.config.Config4, ipv: firewaller.IPv4} + pbc := pbContext{table: n.fw.table4, conf: n.config.Config4, ipv: nftables.IPv4} if err := n.setPerPortRules(ctx, pbs4, pbc, enable); err != nil { return err } } if n.fw.config.IPv6 && n.config.Config6.Prefix.IsValid() { - pbc := pbContext{table: n.fw.table6, conf: n.config.Config6, ipv: firewaller.IPv6} + pbc := pbContext{table: n.fw.table6, conf: n.config.Config6, ipv: nftables.IPv6} if err := n.setPerPortRules(ctx, pbs6, pbc, enable); err != nil { return err } @@ -70,29 +70,34 @@ func splitByContainerFam(pbs []types.PortBinding) ([]types.PortBinding, []types. } func (n *network) setPerPortRules(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { - if err := n.setPerPortForwarding(ctx, pbs, pbc, enable); err != nil { + tm := pbc.table.Modifier() + pbc.updater = tm.Create + if !enable { + pbc.updater = tm.Delete + } + if err := n.setPerPortForwarding(ctx, pbs, pbc); err != nil { return err } - if err := n.setPerPortDNAT(ctx, pbs, pbc, enable); err != nil { + if err := n.setPerPortDNAT(ctx, pbs, pbc); err != nil { return err } - if err := n.setPerPortHairpinMasq(ctx, pbs, pbc, enable); err != nil { + if err := n.setPerPortHairpinMasq(ctx, pbs, pbc); err != nil { return err } - if err := n.filterPortMappedOnLoopback(ctx, pbs, pbc, enable); err != nil { + if err := n.filterPortMappedOnLoopback(ctx, pbs, pbc); err != nil { return err } - if err := nftApply(ctx, pbc.table); err != nil { + if err := tm.Apply(ctx); err != nil { return fmt.Errorf("adding rules for bridge %s: %w", n.config.IfName, err) } return nil } -func (n *network) setPerPortForwarding(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { +func (n *network) setPerPortForwarding(ctx context.Context, pbs []types.PortBinding, pbc pbContext) error { if pbc.conf.Unprotected { return nil } - updateFwdIn := pbc.table.ChainUpdateFunc(ctx, chainFilterFwdIn(n.config.IfName), enable) + chainName := chainFilterFwdIn(n.config.IfName) for _, pb := range pbs { // When more than one host port is mapped to a single container port, this will // generate the same rule for each host port. So, ignore duplicates when adding, @@ -102,24 +107,26 @@ func (n *network) setPerPortForwarding(ctx context.Context, pbs []types.PortBind // also be deleted more than once.) // // TODO(robmry) - track port mappings, use that to edit nftables sets when bindings are added/removed. - rule := fmt.Sprintf("%s daddr %s %s dport %d counter accept", pbc.table.Family(), pb.IP, pb.Proto, pb.Port) - if err := updateFwdIn(ctx, fwdInPortsRuleGroup, rule); err != nil && - !errors.Is(err, nftables.ErrRuleExist) && !errors.Is(err, nftables.ErrRuleNotExist) { - return fmt.Errorf("updating forwarding rule for port %s %s:%d/%s on %s, enable=%v: %w", - pbc.table.Family(), pb.IP, pb.Port, pb.Proto, n.config.IfName, enable, err) - } + pbc.updater(nftables.Rule{ + Chain: chainName, + Group: fwdInPortsRuleGroup, + Rule: []string{ + string(pbc.ipv), "daddr", pb.IP.String(), pb.Proto.String(), + "dport", strconv.Itoa(int(pb.Port)), "counter accept", + }, + IgnoreExist: true, + }) } return nil } -func (n *network) setPerPortDNAT(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { - updater := pbc.table.ChainUpdateFunc(ctx, natChain, enable) +func (n *network) setPerPortDNAT(ctx context.Context, pbs []types.PortBinding, pbc pbContext) error { var proxySkip string if !n.fw.config.Hairpin { proxySkip = fmt.Sprintf("iifname != %s ", n.config.IfName) } var v6LLSkip string - if pbc.ipv == firewaller.IPv6 { + if pbc.ipv == nftables.IPv6 { v6LLSkip = "ip6 saddr != fe80::/10 " } for _, pb := range pbs { @@ -134,26 +141,27 @@ func (n *network) setPerPortDNAT(ctx context.Context, pbs []types.PortBinding, p } var daddrMatch string if !pb.HostIP.IsUnspecified() { - daddrMatch = fmt.Sprintf("%s daddr %s ", pbc.table.Family(), pb.HostIP) - } - rule := fmt.Sprintf("%s%s%s%s dport %d counter dnat to %s comment DNAT", - proxySkip, v6LLSkip, daddrMatch, pb.Proto, pb.HostPort, - net.JoinHostPort(pb.IP.String(), strconv.Itoa(int(pb.Port)))) - if err := updater(ctx, initialRuleGroup, rule); err != nil { - return fmt.Errorf("adding DNAT for %s %s:%d -> %s:%d/%s on %s: %w", - pbc.table.Family(), pb.HostIP, pb.HostPort, pb.IP, pb.Port, pb.Proto, n.config.IfName, err) + daddrMatch = fmt.Sprintf("%s daddr %s ", pbc.ipv, pb.HostIP) } + pbc.updater(nftables.Rule{ + Chain: natChain, + Group: initialRuleGroup, + Rule: []string{ + proxySkip, v6LLSkip, daddrMatch, pb.Proto.String(), "dport", strconv.Itoa(int(pb.HostPort)), "counter dnat to", + net.JoinHostPort(pb.IP.String(), strconv.Itoa(int(pb.Port))), "comment DNAT", + }, + }) } return nil } // setPerPortHairpinMasq allows containers to access their own published ports on the host // when hairpin is enabled (no docker-proxy), by masquerading. -func (n *network) setPerPortHairpinMasq(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { +func (n *network) setPerPortHairpinMasq(ctx context.Context, pbs []types.PortBinding, pbc pbContext) error { if !n.fw.config.Hairpin { return nil } - updater := pbc.table.ChainUpdateFunc(ctx, chainNatPostRtIn(n.config.IfName), enable) + chainName := chainNatPostRtIn(n.config.IfName) for _, pb := range pbs { // Nothing to do if NAT is disabled. if pb.HostPort == 0 { @@ -172,13 +180,16 @@ func (n *network) setPerPortHairpinMasq(ctx context.Context, pbs []types.PortBin // than once.) // // TODO(robmry) - track port mappings, use that to edit nftables sets when bindings are added/removed. - rule := fmt.Sprintf(`%s saddr %s %s daddr %s %s dport %d counter masquerade comment "MASQ TO OWN PORT"`, - pbc.table.Family(), pb.IP, pbc.table.Family(), pb.IP, pb.Proto, pb.Port) - if err := updater(ctx, initialRuleGroup, rule); err != nil && - !errors.Is(err, nftables.ErrRuleExist) && !errors.Is(err, nftables.ErrRuleNotExist) { - return fmt.Errorf("adding MASQ TO OWN PORT for %d -> %s:%d/%s: %w", - pb.Port, pb.IP, pb.Port, pb.Proto, err) - } + pbc.updater(nftables.Rule{ + Chain: chainName, + Group: initialRuleGroup, + Rule: []string{ + string(pbc.ipv), "saddr", pb.IP.String(), string(pbc.ipv), + "daddr", pb.IP.String(), pb.Proto.String(), + "dport", strconv.Itoa(int(pb.Port)), + `counter masquerade comment "MASQ TO OWN PORT"`, + }, + }) } return nil } @@ -189,11 +200,10 @@ func (n *network) setPerPortHairpinMasq(ctx context.Context, pbs []types.PortBin // This is a no-op if the portBinding is for IPv6 (IPv6 loopback address is // non-routable), or over a network with gw_mode=routed (PBs in routed mode // don't map ports on the host). -func (n *network) filterPortMappedOnLoopback(ctx context.Context, pbs []types.PortBinding, pbc pbContext, enable bool) error { - if pbc.ipv == firewaller.IPv6 { +func (n *network) filterPortMappedOnLoopback(ctx context.Context, pbs []types.PortBinding, pbc pbContext) error { + if pbc.ipv == nftables.IPv6 { return nil } - updater := pbc.table.ChainUpdateFunc(ctx, rawPreroutingChain, enable) for _, pb := range pbs { // Nothing to do if not binding to the loopback address. if pb.HostPort == 0 || !pb.HostIP.IsLoopback() { @@ -204,17 +214,25 @@ func (n *network) filterPortMappedOnLoopback(ctx context.Context, pbs []types.Po continue } if n.fw.config.WSL2Mirrored { - if err := updater(ctx, rawPreroutingPortsRuleGroup, - `iifname loopback0 ip daddr %s %s dport %d counter accept comment "%s"`, - pb.HostIP, pb.Proto, pb.HostPort, "ACCEPT WSL2 LOOPBACK"); err != nil { - return fmt.Errorf("adding WSL2 loopback rule for %d: %w", pb.HostPort, err) - } - } - if err := updater(ctx, rawPreroutingPortsRuleGroup, - `iifname != lo ip daddr %s %s dport %d counter drop comment "DROP REMOTE LOOPBACK"`, - pb.HostIP, pb.Proto, pb.HostPort); err != nil { - return fmt.Errorf("adding loopback filter rule for %d: %w", pb.HostPort, err) + pbc.updater(nftables.Rule{ + Chain: rawPreroutingChain, + Group: rawPreroutingPortsRuleGroup, + Rule: []string{ + "iifname loopback0 ip daddr", pb.HostIP.String(), pb.Proto.String(), + "dport", strconv.Itoa(int(pb.HostPort)), + `counter accept comment "ACCEPT WSL2 LOOPBACK"`, + }, + }) } + pbc.updater(nftables.Rule{ + Chain: rawPreroutingChain, + Group: rawPreroutingPortsRuleGroup, + Rule: []string{ + `iifname != lo ip daddr`, pb.HostIP.String(), pb.Proto.String(), + "dport", strconv.Itoa(int(pb.HostPort)), + `counter drop comment "DROP REMOTE LOOPBACK"`, + }, + }) } return nil diff --git a/daemon/libnetwork/drivers/bridge/internal/nftabler/wsl2.go b/daemon/libnetwork/drivers/bridge/internal/nftabler/wsl2.go index 6606b8d79360c..d7f77ad55b8fd 100644 --- a/daemon/libnetwork/drivers/bridge/internal/nftabler/wsl2.go +++ b/daemon/libnetwork/drivers/bridge/internal/nftabler/wsl2.go @@ -3,8 +3,6 @@ package nftabler import ( - "context" - "github.com/docker/docker/daemon/libnetwork/internal/nftables" ) @@ -39,11 +37,14 @@ import ( // arriving from any other bridge network. Similarly, this function adds (or // removes) a rule to RETURN early for packets delivered via loopback0 with // destination 127.0.0.0/8. -func mirroredWSL2Workaround(ctx context.Context, table nftables.TableRef) error { +func mirroredWSL2Workaround(tm *nftables.Modifier) { // WSL2 does not (currently) support Windows<->Linux communication via ::1. - if table.Family() != nftables.IPv4 { - return nil + if tm.Family() != nftables.IPv4 { + return } - return table.Chain(ctx, natChain).AppendRule(ctx, - initialRuleGroup, `iifname "loopback0" ip daddr 127.0.0.0/8 counter return`) + tm.Create(nftables.Rule{ + Chain: natChain, + Group: initialRuleGroup, + Rule: []string{`iifname "loopback0" ip daddr 127.0.0.0/8 counter return`}, + }) } diff --git a/daemon/libnetwork/internal/nftables/nftables_linux.go b/daemon/libnetwork/internal/nftables/nftables_linux.go index ca99134f92e40..d6f17f1b6cdba 100644 --- a/daemon/libnetwork/internal/nftables/nftables_linux.go +++ b/daemon/libnetwork/internal/nftables/nftables_linux.go @@ -4,15 +4,26 @@ // Package nftables provides methods to create an nftables table and manage its maps, sets, // chains, and rules. // -// To use it, the first step is to create a [TableRef] using [NewTable]. The table can -// then be populated and managed using that ref. +// To use it, the first step is to create a [Table] using [NewTable]. Then, retrieve +// a [Modifier], add commands to it, and apply the updates. // -// Modifications to the table are only applied (sent to "nft") when [TableRef.Apply] is -// called. This means a number of updates can be made, for example, adding all the -// rules needed for a docker network - and those rules will then be applied atomically -// in a single "nft" run. +// For example: // -// [TableRef.Apply] can only be called after [Enable], and only if [Enable] returns +// t, _ := NewTable(...) +// tm := t.Modifier() +// // Then a sequence of ... +// tm.Create() +// tm.Delete() +// // Apply the updates with ... +// err := tm.Apply(ctx) +// +// The objects are any of: [BaseChain], [Chain], [Rule], [VMap], [VMapElement], +// [Set], [SetElement] +// +// The modifier can be reused to apply the same set of commands again or, more +// usefully, reversed in order to revert its changes. See [Modifier.Reverse]. +// +// [Modifier.Apply] can only be called after [Enable], and only if [Enable] returns // true (meaning an "nft" executable was found). [Enabled] can be called to check // whether nftables has been enabled. // @@ -20,9 +31,6 @@ // - The implementation is far from complete, only functionality needed so-far has // been included. Currently, there's only a limited set of chain/map/set types, // there's no way to delete sets/maps etc. -// - There's no rollback so, once changes have been made to a TableRef, if the -// Apply fails there is no way to undo changes. The TableRef will be out-of-sync -// with the actual state of nftables. // - This is a thin layer between code and "nft", it doesn't do much error checking. So, // for example, if you get the syntax of a rule wrong the issue won't be reported // until Apply is called. @@ -45,6 +53,7 @@ import ( "fmt" "io" "os/exec" + "runtime" "slices" "strconv" "strings" @@ -72,15 +81,6 @@ var ( enableOnce sync.Once ) -var ( - // ErrRuleExist is returned when a rule is added, but it already exists in the same - // rule group of a chain. - ErrRuleExist = errors.New("rule exists") - // ErrRuleNotExist is returned when a rule is removed, but does not exist in the - // rule group of a chain. - ErrRuleNotExist = errors.New("rule does not exist") -) - // BaseChainType enumerates the base chain types. // See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_types type BaseChainType string @@ -115,6 +115,15 @@ const ( BaseChainPrioritySrcNAT = 100 ) +// BaseChainPolicy enumerates base chain policies. +// See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Base_chain_policy +type BaseChainPolicy string + +const ( + BaseChainPolicyAccept BaseChainPolicy = "accept" + BaseChainPolicyDrop BaseChainPolicy = "drop" +) + // Family enumerates address families. type Family string @@ -123,17 +132,17 @@ const ( IPv6 Family = "ip6" ) -// nftType enumerates nft types that can be used to define maps/sets etc. -type nftType string +// NftType enumerates nft types that can be used to define maps/sets etc. +type NftType string const ( - nftTypeIPv4Addr nftType = "ipv4_addr" - nftTypeIPv6Addr nftType = "ipv6_addr" - nftTypeEtherAddr nftType = "ether_addr" - nftTypeInetProto nftType = "inet_proto" - nftTypeInetService nftType = "inet_service" - nftTypeMark nftType = "mark" - nftTypeIfname nftType = "ifname" + NftTypeIPv4Addr NftType = "ipv4_addr" + NftTypeIPv6Addr NftType = "ipv6_addr" + NftTypeEtherAddr NftType = "ether_addr" + NftTypeInetProto NftType = "inet_proto" + NftTypeInetService NftType = "inet_service" + NftTypeMark NftType = "mark" + NftTypeIfname NftType = "ifname" ) // Enable tries once to initialise nftables. @@ -182,43 +191,58 @@ type table struct { Sets map[string]*set Chains map[string]*chain - Dirty bool // Set when the table is new, not when its elements change. - DeleteChainCommands []string + DeleteCommands []string + MustFlush bool + + applyLock sync.Mutex } -// TableRef is a handle for an nftables table. -type TableRef struct { +// Table is a handle for an nftables table. +type Table struct { t *table } -// NewTable creates a new nftables table and returns a [TableRef] +// IsValid returns true if t is a valid reference to a table. +func (t Table) IsValid() bool { + return t.t != nil +} + +// NewTable creates a new nftables table and returns a [Table] // // See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_tables // -// The table will be created and flushed when [TableRef.Apply] is next called. +// To modify the table, get a [Modifier] using [Table.Modifier], add commands +// to it, and call [Modifier.Apply]. +// // It's flushed in case it already exists in the host's nftables - when that // happens, rules in its chains will be deleted but not the chains themselves, // maps, sets, or elements of maps or sets. But, those un-flushed items can't do // anything disruptive unless referred to by rules, and they will be flushed if -// they get re-created via the [TableRef], when [TableRef.Apply] is next called +// they get re-created via the [Table], when [Modifier.Apply] is next called // (so, before they can be used by a new rule). -func NewTable(family Family, name string) (TableRef, error) { - t := TableRef{ +// +// To fully delete an underlying nftables table, if one already exists, +// use [Reload] after creating the table. +func NewTable(family Family, name string) (Table, error) { + t := Table{ t: &table{ - Name: name, - Family: family, - VMaps: map[string]*vMap{}, - Sets: map[string]*set{}, - Chains: map[string]*chain{}, - Dirty: true, + Name: name, + Family: family, + VMaps: map[string]*vMap{}, + Sets: map[string]*set{}, + Chains: map[string]*chain{}, + MustFlush: true, }, } return t, nil } -// Family returns the address family of the nftables table described by [TableRef]. -func (t TableRef) Family() Family { - return t.t.Family +// Name returns the name of the table, or an empty string if t is not valid. +func (t Table) Name() string { + if !t.IsValid() { + return "" + } + return t.t.Name } // incrementalUpdateTemplText is used with text/template to generate an nftables command file @@ -248,25 +272,25 @@ table {{$family}} {{$tableName}} { {{if len .Flags}}flags{{range .Flags}} {{.}}{{end}}{{end}} } {{end}} - {{range .Chains}}{{if .Dirty}}chain {{.Name}} { + {{range .Chains}}{{if .MustFlush}}chain {{.Name}} { {{if .ChainType}}type {{.ChainType}} hook {{.Hook}} priority {{.Priority}}; policy {{.Policy}}{{end}} } ; {{end}}{{end}} } -{{if .Dirty}}flush table {{$family}} {{$tableName}}{{end}} -{{range .VMaps}}{{if .Dirty}}flush map {{$family}} {{$tableName}} {{.Name}} +{{if .MustFlush}}flush table {{$family}} {{$tableName}}{{end}} +{{range .VMaps}}{{if .MustFlush}}flush map {{$family}} {{$tableName}} {{.Name}} {{end}}{{end}} -{{range .Sets}}{{if .Dirty}}flush set {{$family}} {{$tableName}} {{.Name}} +{{range .Sets}}{{if .MustFlush}}flush set {{$family}} {{$tableName}} {{.Name}} {{end}}{{end}} -{{range .Chains}}{{if .Dirty}}flush chain {{$family}} {{$tableName}} {{.Name}} +{{range .Chains}}{{if .MustFlush}}flush chain {{$family}} {{$tableName}} {{.Name}} {{end}}{{end}} {{range .VMaps}}{{if .DeletedElements}}delete element {{$family}} {{$tableName}} {{.Name}} { {{range $k,$v := .DeletedElements}}{{$k}}, {{end}} } {{end}}{{end}} {{range .Sets}}{{if .DeletedElements}}delete element {{$family}} {{$tableName}} {{.Name}} { {{range $k,$v := .DeletedElements}}{{$k}}, {{end}} } {{end}}{{end}} -{{range .DeleteChainCommands}}{{.}} +{{range .DeleteCommands}}{{.}} {{end}} table {{$family}} {{$tableName}} { - {{range .Chains}}{{if .Dirty}}chain {{.Name}} { + {{range .Chains}}{{if .MustFlush}}chain {{.Name}} { {{if .ChainType}}type {{.ChainType}} hook {{.Hook}} priority {{.Priority}}; policy {{.Policy}}{{end}} {{range .Rules}}{{.}} {{end}} @@ -315,16 +339,159 @@ table {{$family}} {{$tableName}} { } ` -// Apply makes incremental updates to nftables, corresponding to changes to the [TableRef] -// since Apply was last called. -func (t TableRef) Apply(ctx context.Context) error { +// SetBaseChainPolicy sets the default policy for a base chain. The update +// is applied immediately, unlike creation/deletion of objects via a [Modifier] +// which are not applied until [Modifier.Apply] is called. +func (t Table) SetBaseChainPolicy(ctx context.Context, chainName string, policy BaseChainPolicy) error { + if !t.IsValid() { + return errors.New("invalid table") + } + c := t.t.Chains[chainName] + if c == nil { + return fmt.Errorf("cannot set base chain policy for '%s', it does not exist", chainName) + } + if c.ChainType == "" { + return fmt.Errorf("cannot set base chain policy for '%s', it is not a base chain", chainName) + } + oldPolicy := c.Policy + c.Policy = policy + c.MustFlush = true + + if err := t.Modifier().Apply(ctx); err != nil { + c.Policy = oldPolicy + return err + } + return nil +} + +// Modifier retrieves an object that can be used to manipulate the table. +func (t Table) Modifier() *Modifier { + return &Modifier{t: t.t} +} + +// Obj is an object that can be given to a [Modifier], representing an +// nftables object for it to create or delete. +type Obj interface { + create(context.Context, *table) (bool, error) + delete(context.Context, *table) (bool, error) +} + +// Modifier is used to apply changes to a Table. +type Modifier struct { + t *table + cmds []command +} + +// IsValid returns true if tm is valid. +func (tm *Modifier) IsValid() bool { + return tm.t != nil +} + +// Name returns the name of the table tm will modify, or the +// empty string if tm is not valid. +func (tm *Modifier) Name() string { + if !tm.IsValid() { + return "" + } + return tm.t.Name +} + +// Family returns the address family of the nftables table to be modified, +// or the empty string if tm is not valid. +func (tm *Modifier) Family() Family { + if !tm.IsValid() { + return "" + } + return tm.t.Family +} + +// Create enqueues creation of object o, to be applied by tm.Apply. +func (tm *Modifier) Create(o Obj) { + _, f, l, _ := runtime.Caller(1) + tm.cmds = append(tm.cmds, command{ + obj: o, + callerFile: f, + callerLine: l, + }) +} + +// Delete enqueues deletion of object o, to be applied by tm.Apply. +func (tm *Modifier) Delete(o Obj) { + _, f, l, _ := runtime.Caller(1) + tm.cmds = append(tm.cmds, command{ + obj: o, + delete: true, + callerFile: f, + callerLine: l, + }) +} + +// Reverse returns a Modifier that will undo the actions of tm. +// Its operations are performed in reverse order, creates become +// deletes, and deletes become creates. +// +// Most operations are fully reversible (chains/maps/sets must be +// empty before they're deleted, so no information is lost). But, +// there are exceptions, noted in comments in the object definitions. +// +// Applying the updates in a reversed modifier may not work if +// any of the objects have been removed or modified since they +// were added. For example, if a Modifier creates a chain then another +// Modifier adds rules, the reversed Modifier will not be able to +// delete the chain as it is not empty. +func (tm *Modifier) Reverse() *Modifier { + rtm := &Modifier{ + t: tm.t, + cmds: make([]command, len(tm.cmds)), + } + for i, cmd := range tm.cmds { + cmd.delete = !cmd.delete + rtm.cmds[len(tm.cmds)-i-1] = cmd + } + return rtm +} + +// Apply makes incremental updates to nftables. If there's a validation +// error in any of the enqueued objects, or an error applying the updates +// to the underlying nftables, the [Table] will be unmodified. +func (tm *Modifier) Apply(ctx context.Context) (retErr error) { if !Enabled() { return errors.New("nftables is not enabled") } + if !tm.IsValid() { + return errors.New("table modifier is not valid") + } + tm.t.applyLock.Lock() + defer tm.t.applyLock.Unlock() + + var rollback []command + defer func() { + if retErr == nil { + return + } + slices.Reverse(rollback) + for _, c := range rollback { + if _, err := c.rollback(ctx, tm.t); err != nil { + log.G(ctx).WithError(err).Error("Failed to roll back nftables updates") + } + } + tm.t.updatesApplied() + }() + + // Apply tm's updates to the Table. + for _, cmd := range tm.cmds { + applied, err := cmd.apply(ctx, tm.t) + if err != nil { + return fmt.Errorf("rule from %s:%d: %w", cmd.callerFile, cmd.callerLine, err) + } + if applied { + rollback = append(rollback, cmd) + } + } // Update nftables. var buf bytes.Buffer - if err := incrementalUpdateTempl.Execute(&buf, t.t); err != nil { + if err := incrementalUpdateTempl.Execute(&buf, tm.t); err != nil { return fmt.Errorf("failed to execute template nft ruleset: %w", err) } @@ -345,26 +512,32 @@ func (t TableRef) Apply(ctx context.Context) error { // behind networks for the test infrastructure to clean up between tests. Starting // a daemon flushes the "docker-bridges" table, so the cleanup fails to delete a // rule that's been flushed. So, try reloading the whole table to get back in-sync. - return t.Reload(ctx) + return tm.t.reload(ctx) } // Note that updates have been applied. - t.t.updatesApplied() + tm.t.updatesApplied() return nil } // Reload deletes the table, then re-creates it, atomically. -func (t TableRef) Reload(ctx context.Context) error { +func (t Table) Reload(ctx context.Context) error { if !Enabled() { return errors.New("nftables is not enabled") } + if !t.IsValid() { + return errors.New("invalid table") + } + return t.t.reload(ctx) +} - ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"table": t.t.Name, "family": t.t.Family})) +func (t *table) reload(ctx context.Context) error { + ctx = log.WithLogger(ctx, log.G(ctx).WithFields(log.Fields{"table": t.Name, "family": t.Family})) log.G(ctx).Warn("nftables: reloading table") // Build the update. var buf bytes.Buffer - if err := reloadTempl.Execute(&buf, t.t); err != nil { + if err := reloadTempl.Execute(&buf, t); err != nil { return fmt.Errorf("failed to execute reload template: %w", err) } @@ -382,7 +555,7 @@ func (t TableRef) Reload(ctx context.Context) error { } // Note that updates have been applied. - t.t.updatesApplied() + t.updatesApplied() return nil } @@ -404,170 +577,176 @@ type chain struct { ChainType BaseChainType Hook BaseChainHook Priority int - Policy string - Dirty bool + Policy BaseChainPolicy + MustFlush bool ruleGroups map[RuleGroup][]string } -// ChainRef is a handle for an nftables chain. -type ChainRef struct { - c *chain -} - // BaseChain constructs a new nftables base chain and returns a [ChainRef]. // // See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_base_chains // // It is an error to create a base chain that already exists. // If the underlying chain already exists, it will be flushed by the -// next [TableRef.Apply] before new rules are added. -func (t TableRef) BaseChain(ctx context.Context, name string, chainType BaseChainType, hook BaseChainHook, priority int) (ChainRef, error) { - if _, ok := t.t.Chains[name]; ok { - return ChainRef{}, fmt.Errorf("chain %q already exists", name) +// next [Table.Apply] before new rules are added. +type BaseChain struct { + Name string + ChainType BaseChainType + Hook BaseChainHook + Priority int + Policy BaseChainPolicy // Defaults to BaseChainPolicyAccept +} + +func (cd BaseChain) create(ctx context.Context, t *table) (bool, error) { + if _, ok := t.Chains[cd.Name]; ok { + return false, fmt.Errorf("base chain '%s' already exists", cd.Name) + } + if cd.Name == "" { + return false, errors.New("base chain must have a name") + } + if cd.ChainType == "" || cd.Hook == "" { + return false, fmt.Errorf("chain '%s': fields ChainType and Hook are required", cd.Name) + } + if cd.Policy == "" { + // nftables will default to "accept" if unspecified, but the text/template + // requires a policy string. + cd.Policy = BaseChainPolicyAccept } c := &chain{ - table: t.t, - Name: name, - ChainType: chainType, - Hook: hook, - Priority: priority, - Policy: "accept", - Dirty: true, + table: t, + Name: cd.Name, + ChainType: cd.ChainType, + Hook: cd.Hook, + Priority: cd.Priority, + Policy: cd.Policy, + MustFlush: true, ruleGroups: map[RuleGroup][]string{}, } - t.t.Chains[name] = c + t.Chains[c.Name] = c log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "chain": name, - "type": chainType, - "hook": hook, - "prio": priority, + "family": t.Family, + "table": t.Name, + "chain": c.Name, + "type": c.ChainType, + "hook": c.Hook, + "prio": c.Priority, }).Debug("nftables: created base chain") - return ChainRef{c: c}, nil + return true, nil } -// Chain returns a [ChainRef] for an existing chain (which may be a base chain). -// If there is no existing chain, a regular chain is added and its [ChainRef] is -// returned. -// -// See https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_regular_chains -// -// If a new [ChainRef] is created and the underlying chain already exists, it -// will be flushed by the next [TableRef.Apply] before new rules are added. -func (t TableRef) Chain(ctx context.Context, name string) ChainRef { - c, ok := t.t.Chains[name] - if !ok { - c = &chain{ - table: t.t, - Name: name, - Dirty: true, - ruleGroups: map[RuleGroup][]string{}, - } - t.t.Chains[name] = c - } - log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "chain": name, - }).Debug("nftables: created chain") - return ChainRef{c: c} +func (cd BaseChain) delete(ctx context.Context, t *table) (bool, error) { + return t.deleteChain(ctx, cd.Name) } -// ChainUpdateFunc is a function that can add rules to a chain, or remove rules from it. -type ChainUpdateFunc func(context.Context, RuleGroup, string, ...interface{}) error - -// ChainUpdateFunc returns a [ChainUpdateFunc] to add rules to the named chain if -// enable is true, or to remove rules from the chain if enable is false. -// (Written as a convenience function to ease migration of iptables functions -// originally written with an enable flag.) -func (t TableRef) ChainUpdateFunc(ctx context.Context, name string, enable bool) ChainUpdateFunc { - c := t.Chain(ctx, name) - if enable { - return c.AppendRule - } - return c.DeleteRule +// Chain implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete a chain. +type Chain struct { + Name string } -// DeleteChain deletes a chain. It is an error to delete a chain that does not exist. -func (t TableRef) DeleteChain(ctx context.Context, name string) error { - if _, ok := t.t.Chains[name]; !ok { - return fmt.Errorf("chain %q does not exist", name) +func (cd Chain) create(ctx context.Context, t *table) (bool, error) { + if _, ok := t.Chains[cd.Name]; ok { + return false, fmt.Errorf("chain '%s' already exists", cd.Name) + } + if cd.Name == "" { + return false, errors.New("chain must have a name") + } + c := &chain{ + table: t, + Name: cd.Name, + MustFlush: true, + ruleGroups: map[RuleGroup][]string{}, } - delete(t.t.Chains, name) - t.t.DeleteChainCommands = append(t.t.DeleteChainCommands, - fmt.Sprintf("delete chain %s %s %s", t.t.Family, t.t.Name, name)) + t.Chains[c.Name] = c log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "chain": name, - }).Debug("nftables: deleted chain") - return nil + "family": t.Family, + "table": t.Name, + "chain": cd.Name, + }).Debug("nftables: created chain") + return true, nil } -// SetPolicy sets the default policy for a base chain. It is an error to call this -// for a non-base [ChainRef]. -func (c ChainRef) SetPolicy(policy string) error { - if c.c.ChainType == "" { - return errors.New("not a base chain") - } - c.c.Policy = policy - c.c.Dirty = true - return nil +func (cd Chain) delete(ctx context.Context, t *table) (bool, error) { + return t.deleteChain(ctx, cd.Name) +} + +// Rule implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete a rule in a chain. +type Rule struct { + Chain string + Group RuleGroup + Rule []string + // IgnoreExist suppresses errors about deleting a rule that does not exist + // or creating a rule that does already exist. + // + // Note that, when set, reversing the [Modifier] may not do what you want! For + // example, if the original modifier deleted a rule that did not exist, the + // reversed modifier will create that rule. + IgnoreExist bool } -// AppendRule appends a rule to a [RuleGroup] in a [ChainRef]. -func (c ChainRef) AppendRule(ctx context.Context, group RuleGroup, rule string, args ...interface{}) error { - if len(args) > 0 { - rule = fmt.Sprintf(rule, args...) +func (ru Rule) create(ctx context.Context, t *table) (bool, error) { + c := t.Chains[ru.Chain] + if c == nil { + return false, fmt.Errorf("chain '%s' does not exist", ru.Chain) } - if rg, ok := c.c.ruleGroups[group]; ok && slices.Contains(rg, rule) { - return ErrRuleExist + if len(ru.Rule) == 0 { + return false, fmt.Errorf("chain '%s', cannot add empty rule", ru.Chain) } - c.c.ruleGroups[group] = append(c.c.ruleGroups[group], rule) - c.c.Dirty = true + rule := strings.Join(ru.Rule, " ") + if rg, ok := c.ruleGroups[ru.Group]; ok && slices.Contains(rg, rule) { + if !ru.IgnoreExist { + return false, fmt.Errorf("adding rule:'%s' chain:'%s' group:%d: rule exists", rule, ru.Chain, ru.Group) + } + return false, nil + } + c.ruleGroups[ru.Group] = append(c.ruleGroups[ru.Group], rule) + c.MustFlush = true log.G(ctx).WithFields(log.Fields{ - "family": c.c.table.Family, - "table": c.c.table.Name, - "chain": c.c.Name, - "group": group, + "family": t.Family, + "table": t.Name, + "chain": c.Name, + "group": ru.Group, "rule": rule, }).Debug("nftables: appended rule") - return nil + return true, nil } -// AppendRuleCf calls AppendRule and returns a cleanup function or an error. -func (c ChainRef) AppendRuleCf(ctx context.Context, group RuleGroup, rule string, args ...interface{}) (func(context.Context) error, error) { - if err := c.AppendRule(ctx, group, rule, args...); err != nil { - return nil, err +func (ru Rule) delete(ctx context.Context, t *table) (bool, error) { + rule := strings.Join(ru.Rule, " ") + c := t.Chains[ru.Chain] + if c == nil { + return false, fmt.Errorf("deleting rule:'%s' - chain '%s' does not exist", rule, ru.Chain) } - return func(ctx context.Context) error { return c.DeleteRule(ctx, group, rule, args...) }, nil -} - -// DeleteRule deletes a rule from a [RuleGroup] in a [ChainRef]. It is an error -// to delete from a group that does not exist, or to delete a rule that does not -// exist. -func (c ChainRef) DeleteRule(ctx context.Context, group RuleGroup, rule string, args ...interface{}) error { - if len(args) > 0 { - rule = fmt.Sprintf(rule, args...) + if rule == "" { + return false, fmt.Errorf("chain '%s', cannot delete empty rule", ru.Chain) } - rg, ok := c.c.ruleGroups[group] + rg, ok := c.ruleGroups[ru.Group] if !ok { - return fmt.Errorf("rule group %d does not exist", group) + if !ru.IgnoreExist { + return false, fmt.Errorf("deleting rule:'%s' chain:'%s' rule group:%d does not exist", rule, ru.Chain, ru.Group) + } + return false, nil } origLen := len(rg) - c.c.ruleGroups[group] = slices.DeleteFunc(rg, func(r string) bool { return r == rule }) - if len(c.c.ruleGroups[group]) == origLen { - return ErrRuleNotExist + c.ruleGroups[ru.Group] = slices.DeleteFunc(rg, func(r string) bool { return r == rule }) + if len(c.ruleGroups[ru.Group]) == origLen { + if !ru.IgnoreExist { + return false, fmt.Errorf("deleting rule:'%s' chain:'%s' group:%d: rule does not exist", rule, ru.Chain, ru.Group) + } + return false, nil + } + if len(c.ruleGroups[ru.Group]) == 0 { + delete(c.ruleGroups, ru.Group) } - c.c.Dirty = true + c.MustFlush = true log.G(ctx).WithFields(log.Fields{ - "family": c.c.table.Family, - "table": c.c.table.Name, - "chain": c.c.Name, + "family": t.Family, + "table": t.Name, + "chain": c.Name, "rule": rule, }).Debug("nftables: deleted rule") - return nil + return true, nil } // //////////////////////////// @@ -579,89 +758,129 @@ func (c ChainRef) DeleteRule(ctx context.Context, group RuleGroup, rule string, type vMap struct { table *table Name string - ElementType nftType + ElementType NftType Flags []string Elements map[string]string - Dirty bool // New vMap, needs to be flushed (not set when elements are added/deleted). AddedElements map[string]string - DeletedElements map[string]struct{} + DeletedElements map[string]string + MustFlush bool } -// VMapRef is a handle for an nftables verdict map. -type VMapRef struct { - v *vMap +// VMap implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete a verdict map. +type VMap struct { + Name string + ElementType NftType + Flags []string } -// InterfaceVMap creates a map from interface name to a verdict and returns a [VMapRef], -// or returns an existing [VMapRef] if it has already been created. -// -// See https://wiki.nftables.org/wiki-nftables/index.php/Verdict_Maps_(vmaps) -// -// If a [VMapRef] is created and the underlying map already exists, it will be flushed -// by the next [TableRef.Apply] before new elements are added. -func (t TableRef) InterfaceVMap(ctx context.Context, name string) VMapRef { - if vmap, ok := t.t.VMaps[name]; ok { - return VMapRef{vmap} - } - vmap := &vMap{ - table: t.t, - Name: name, - ElementType: nftTypeIfname, +func (vm VMap) create(ctx context.Context, t *table) (bool, error) { + if vm.Name == "" { + return false, errors.New("vmap must have a name") + } + if _, ok := t.VMaps[vm.Name]; ok { + return false, fmt.Errorf("vmap '%s' already exists", vm.Name) + } + if vm.ElementType == "" { + return false, fmt.Errorf("vmap '%s' has no element type", vm.Name) + } + v := &vMap{ + table: t, + Name: vm.Name, + ElementType: vm.ElementType, + Flags: slices.Clone(vm.Flags), Elements: map[string]string{}, AddedElements: map[string]string{}, - DeletedElements: map[string]struct{}{}, - Dirty: true, + DeletedElements: map[string]string{}, + MustFlush: true, } - t.t.VMaps[name] = vmap + t.VMaps[v.Name] = v log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "vmap": name, + "family": t.Family, + "table": t.Name, + "vmap": v.Name, }).Debug("nftables: created interface vmap") - return VMapRef{vmap} + return true, nil } -// AddElement adds an element to a verdict map. The caller must ensure the key has -// the correct type. It is an error to add a key that already exists. -func (v VMapRef) AddElement(ctx context.Context, key string, verdict string) error { - if _, ok := v.v.Elements[key]; ok { - return fmt.Errorf("verdict map already contains element %q", key) +func (vm VMap) delete(ctx context.Context, t *table) (bool, error) { + v := t.VMaps[vm.Name] + if v == nil { + return false, fmt.Errorf("cannot delete vmap '%s', it does not exist", vm.Name) + } + if len(v.Elements) != 0 { + return false, fmt.Errorf("cannot delete vmap '%s', it contains %d elements", v.Name, len(v.Elements)) } - v.v.Elements[key] = verdict - v.v.AddedElements[key] = verdict + delete(t.VMaps, v.Name) + t.DeleteCommands = append(t.DeleteCommands, + fmt.Sprintf("delete map %s %s %s", t.Family, t.Name, v.Name)) log.G(ctx).WithFields(log.Fields{ - "family": v.v.table.Family, - "table": v.v.table.Name, - "vmap": v.v.Name, - "key": key, - "verdict": verdict, - }).Debug("nftables: added vmap element") - return nil + "family": t.Family, + "table": t.Name, + "vmap": v.Name, + }).Debug("nftables: deleted vmap") + return true, nil } -// AddElementCf calls AddElement and returns a cleanup function or an error. -func (v VMapRef) AddElementCf(ctx context.Context, key string, verdict string) (func(context.Context) error, error) { - if err := v.AddElement(ctx, key, verdict); err != nil { - return nil, err +// VMapElement implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete an entry in a verdict map. +type VMapElement struct { + VmapName string + Key string + Verdict string +} + +func (ve VMapElement) create(ctx context.Context, t *table) (bool, error) { + if ve.VmapName == "" { + return false, errors.New("cannot add element to unnamed vmap") + } + v := t.VMaps[ve.VmapName] + if v == nil { + return false, fmt.Errorf("cannot add to vmap '%s', it does not exist", ve.VmapName) + } + if ve.Key == "" || ve.Verdict == "" { + return false, fmt.Errorf("cannot add to vmap '%s', element must have key and verdict", ve.VmapName) + } + if _, ok := v.Elements[ve.Key]; ok { + return false, fmt.Errorf("verdict map '%s' already contains element '%s'", ve.VmapName, ve.Key) } - return func(ctx context.Context) error { return v.DeleteElement(ctx, key) }, nil + v.Elements[ve.Key] = ve.Verdict + v.AddedElements[ve.Key] = ve.Verdict + delete(v.DeletedElements, ve.Key) + log.G(ctx).WithFields(log.Fields{ + "family": t.Family, + "table": t.Name, + "vmap": ve.VmapName, + "key": ve.Key, + "verdict": ve.Verdict, + }).Debug("nftables: added vmap element") + return true, nil } -// DeleteElement deletes an element from a verdict map. It is an error to delete -// an element that does not exist. -func (v VMapRef) DeleteElement(ctx context.Context, key string) error { - if _, ok := v.v.Elements[key]; !ok { - return fmt.Errorf("verdict map does not contain element %q", key) +func (ve VMapElement) delete(ctx context.Context, t *table) (bool, error) { + v := t.VMaps[ve.VmapName] + if v == nil { + return false, fmt.Errorf("cannot delete from vmap '%s', it does not exist", ve.VmapName) + } + oldVerdict, ok := v.Elements[ve.Key] + if !ok { + return false, fmt.Errorf("verdict map '%s' does not contain element '%s'", ve.VmapName, ve.Key) + } + if oldVerdict != ve.Verdict { + return false, fmt.Errorf("cannot delete verdict map '%s' element '%s', verdict was '%s', not '%s'", + ve.VmapName, ve.Key, oldVerdict, ve.Verdict) } - delete(v.v.Elements, key) - v.v.DeletedElements[key] = struct{}{} + delete(v.Elements, ve.Key) + delete(v.AddedElements, ve.Key) + v.DeletedElements[ve.Key] = ve.Verdict log.G(ctx).WithFields(log.Fields{ - "family": v.v.table.Family, - "table": v.v.table.Name, - "vmap": v.v.Name, - "key": key, + "family": t.Family, + "table": t.Name, + "vmap": ve.VmapName, + "key": ve.Key, + "verdict": ve.Verdict, }).Debug("nftables: deleted vmap element") - return nil + return true, nil } // //////////////////////////// @@ -673,108 +892,180 @@ func (v VMapRef) DeleteElement(ctx context.Context, key string) error { type set struct { table *table Name string - ElementType nftType + ElementType NftType Flags []string Elements map[string]struct{} - Dirty bool // New set, needs to be flushed (not set when elements are added/deleted). AddedElements map[string]struct{} DeletedElements map[string]struct{} + MustFlush bool } -// SetRef is a handle for an nftables named set. -type SetRef struct { - s *set +// Set implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete a set. +type Set struct { + Name string + ElementType NftType + Flags []string } -// PrefixSet creates a new named nftables set for IPv4 or IPv6 addresses (depending -// on the address family of the [TableRef]), and returns its [SetRef]. Or, if the -// set has already been created, its [SetRef] is returned. -// -// ([TableRef] does not support "inet", only "ip" or "ip6". So the element type can -// always be determined. But, there's no "inet" element type, so this will need to -// change if we need an "inet" table.) -// // See https://wiki.nftables.org/wiki-nftables/index.php/Sets#Named_sets -func (t TableRef) PrefixSet(ctx context.Context, name string) SetRef { - if s, ok := t.t.Sets[name]; ok { - return SetRef{s} +func (sd Set) create(ctx context.Context, t *table) (bool, error) { + if sd.Name == "" { + return false, errors.New("set must have a name") + } + if _, ok := t.Sets[sd.Name]; ok { + return false, fmt.Errorf("set '%s' already exists", sd.Name) + } + if sd.ElementType == "" { + return false, fmt.Errorf("set '%s' must have a type", sd.Name) } s := &set{ - table: t.t, - Name: name, + table: t, + Name: sd.Name, Elements: map[string]struct{}{}, - ElementType: nftTypeIPv4Addr, - Flags: []string{"interval"}, - Dirty: true, + ElementType: sd.ElementType, + Flags: slices.Clone(sd.Flags), AddedElements: map[string]struct{}{}, DeletedElements: map[string]struct{}{}, + MustFlush: true, } - if t.t.Family == IPv6 { - s.ElementType = nftTypeIPv6Addr - } - t.t.Sets[name] = s + t.Sets[sd.Name] = s log.G(ctx).WithFields(log.Fields{ - "family": t.t.Family, - "table": t.t.Name, - "set": name, + "family": t.Family, + "table": t.Name, + "set": s.Name, }).Debug("nftables: created set") - return SetRef{s} + return true, nil +} + +func (sd Set) delete(ctx context.Context, t *table) (bool, error) { + s := t.Sets[sd.Name] + if s == nil { + return false, fmt.Errorf("cannot delete set '%s', it does not exist", sd.Name) + } + if len(s.Elements) != 0 { + return false, fmt.Errorf("cannot delete set '%s', it contains %d elements", s.Name, len(s.Elements)) + } + delete(t.Sets, sd.Name) + t.DeleteCommands = append(t.DeleteCommands, + fmt.Sprintf("delete set %s %s %s", t.Family, t.Name, s.Name)) + log.G(ctx).WithFields(log.Fields{ + "family": t.Family, + "table": t.Name, + "set": sd.Name, + }).Debug("nftables: deleted set") + return true, nil +} + +// SetElement implements the [Obj] interface, it can be passed to a +// [Modifier] to create or delete an entry in a set. +type SetElement struct { + SetName string + Element string } -// AddElement adds an element to a set. It is the caller's responsibility to make sure -// the element has the correct type. It is an error to add an element that is already -// in the set. -func (s SetRef) AddElement(ctx context.Context, element string) error { - if _, ok := s.s.Elements[element]; ok { - return fmt.Errorf("set already contains element %q", element) +func (se SetElement) create(ctx context.Context, t *table) (bool, error) { + s := t.Sets[se.SetName] + if s == nil { + return false, fmt.Errorf("cannot add to set '%s', it does not exist", se.SetName) + } + if se.Element == "" { + return false, fmt.Errorf("cannot add to set '%s', element not specified", se.SetName) + } + if _, ok := s.Elements[se.Element]; ok { + return false, fmt.Errorf("set '%s' already contains element '%s'", s.Name, se.Element) } - s.s.Elements[element] = struct{}{} - s.s.AddedElements[element] = struct{}{} + s.Elements[se.Element] = struct{}{} + s.AddedElements[se.Element] = struct{}{} + delete(s.DeletedElements, se.Element) log.G(ctx).WithFields(log.Fields{ - "family": s.s.table.Family, - "table": s.s.table.Name, - "set": s.s.Name, - "element": element, + "family": t.Family, + "table": t.Name, + "set": s.Name, + "element": se.Element, }).Debug("nftables: added set element") - return nil + return true, nil } -// DeleteElement deletes an element from the set. It is an error to delete an -// element that is not in the set. -func (s SetRef) DeleteElement(ctx context.Context, element string) error { - if _, ok := s.s.Elements[element]; !ok { - return fmt.Errorf("set does not contain element %q", element) +func (se SetElement) delete(ctx context.Context, t *table) (bool, error) { + s := t.Sets[se.SetName] + if s == nil { + return false, fmt.Errorf("cannot delete from set '%s', it does not exist", se.SetName) + } + if _, ok := s.Elements[se.Element]; !ok { + return false, fmt.Errorf("cannot delete '%s' from set '%s', it does not exist", se.Element, s.Name) } - delete(s.s.Elements, element) - s.s.DeletedElements[element] = struct{}{} + delete(s.Elements, se.Element) + delete(s.AddedElements, se.Element) + s.DeletedElements[se.Element] = struct{}{} log.G(ctx).WithFields(log.Fields{ - "family": s.s.table.Family, - "table": s.s.table.Name, - "set": s.s.Name, - "element": element, - }).Debug("nftables: deleted set element") - return nil + "family": t.Family, + "table": t.Name, + "set": s.Name, + "element": se.Element, + }).Debug("nftables: added set element") + return true, nil } // //////////////////////////// // Internal +func (t *table) deleteChain(ctx context.Context, name string) (bool, error) { + c := t.Chains[name] + if c == nil { + return false, fmt.Errorf("cannot delete chain '%s', it does not exist", name) + } + if len(c.ruleGroups) != 0 { + return false, fmt.Errorf("cannot delete chain '%s', it is not empty", name) + } + delete(t.Chains, name) + t.DeleteCommands = append(t.DeleteCommands, + fmt.Sprintf("delete chain %s %s %s", t.Family, t.Name, name)) + log.G(ctx).WithFields(log.Fields{ + "family": t.Family, + "table": t.Name, + "chain": name, + }).Debug("nftables: deleted chain") + return true, nil +} + +type command struct { + obj Obj + callerFile string + callerLine int + delete bool +} + +func (c command) apply(ctx context.Context, t *table) (bool, error) { + if c.delete { + return c.obj.delete(ctx, t) + } + return c.obj.create(ctx, t) +} + +func (c command) rollback(ctx context.Context, t *table) (bool, error) { + if c.delete { + return c.obj.create(ctx, t) + } + return c.obj.delete(ctx, t) +} + func (t *table) updatesApplied() { - t.DeleteChainCommands = t.DeleteChainCommands[:0] + t.DeleteCommands = t.DeleteCommands[:0] for _, c := range t.Chains { - c.Dirty = false + c.MustFlush = false } for _, m := range t.VMaps { - m.Dirty = false m.AddedElements = map[string]string{} - m.DeletedElements = map[string]struct{}{} + m.DeletedElements = map[string]string{} + m.MustFlush = false } for _, s := range t.Sets { - s.Dirty = false s.AddedElements = map[string]struct{}{} s.DeletedElements = map[string]struct{}{} + s.MustFlush = false } - t.Dirty = false + t.MustFlush = false } /* Can't make text/template range over this, not sure why ... diff --git a/daemon/libnetwork/internal/nftables/nftables_linux_test.go b/daemon/libnetwork/internal/nftables/nftables_linux_test.go index 526fe89956593..95fc4689f532a 100644 --- a/daemon/libnetwork/internal/nftables/nftables_linux_test.go +++ b/daemon/libnetwork/internal/nftables/nftables_linux_test.go @@ -31,11 +31,20 @@ func testSetup(t *testing.T) func() { } } -func applyAndCheck(t *testing.T, tbl TableRef, goldenFilename string) { +func applyAndCheck(t *testing.T, tm *Modifier, goldenFilename string) { t.Helper() - err := tbl.Apply(context.Background()) + err := tm.Apply(context.Background()) assert.Check(t, err) - res := icmd.RunCommand("nft", "list", "ruleset") + res := icmd.RunCommand("nft", "list", "table", string(tm.Family()), tm.Name()) + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), goldenFilename) +} + +func reloadAndCheck(t *testing.T, tbl Table, ipv Family, goldenFilename string) { + t.Helper() + err := tbl.Reload(context.Background()) + assert.Check(t, err) + res := icmd.RunCommand("nft", "list", "table", string(ipv), tbl.Name()) res.Assert(t, icmd.Success) golden.Assert(t, res.Combined(), goldenFilename) } @@ -48,17 +57,19 @@ func TestTable(t *testing.T) { tbl6, err := NewTable(IPv6, "ipv6_table") assert.NilError(t, err) - assert.Check(t, is.Equal(tbl4.Family(), IPv4)) - assert.Check(t, is.Equal(tbl6.Family(), IPv6)) + tm4 := tbl4.Modifier() + tm6 := tbl6.Modifier() + + assert.Check(t, is.Equal(tm4.Family(), IPv4)) + assert.Check(t, is.Equal(tm6.Family(), IPv6)) // Update nftables and check what happened. - applyAndCheck(t, tbl4, t.Name()+"_created4.golden") - applyAndCheck(t, tbl6, t.Name()+"_created46.golden") + applyAndCheck(t, tm4, t.Name()+"_created4.golden") + applyAndCheck(t, tm6, t.Name()+"_created6.golden") } func TestChain(t *testing.T) { defer testSetup(t)() - ctx := context.Background() // Create a table. tbl, err := NewTable(IPv4, "this_is_a_table") @@ -66,143 +77,176 @@ func TestChain(t *testing.T) { // Create a base chain. const bcName = "this_is_a_base_chain" - bc1, err := tbl.BaseChain(ctx, bcName, BaseChainTypeFilter, BaseChainHookForward, BaseChainPriorityFilter+10) - assert.NilError(t, err) - - // Check that it's an error to add a new base chain with the same name. - _, err = tbl.BaseChain(ctx, bcName, BaseChainTypeNAT, BaseChainHookPrerouting, BaseChainPriorityDstNAT) - assert.Check(t, is.ErrorContains(err, "already exists")) - - // Add a rule. - err = bc1.AppendRule(ctx, 0, "counter") - assert.NilError(t, err) + tm := tbl.Modifier() + bcDesc := BaseChain{ + Name: bcName, + ChainType: BaseChainTypeFilter, + Hook: BaseChainHookForward, + Priority: BaseChainPriorityFilter + 10, + Policy: "accept", + } + tm.Create(bcDesc) + // Add a rule to the base chain. + bcCounterRule := Rule{Chain: bcName, Group: 0, Rule: []string{"counter"}} + tm.Create(bcCounterRule) // Add a regular chain. const regularChainName = "this_is_a_regular_chain" - _ = tbl.Chain(ctx, regularChainName) + cDesc := Chain{Name: regularChainName} + tm.Create(cDesc) + // Add a rule to the regular chain. + cRule := Rule{Chain: regularChainName, Group: 0, Rule: []string{"counter", "accept"}} + tm.Create(cRule) - // Add a rule to the regular chain, use string formatting and a func retrieved - // from the table. - f := tbl.ChainUpdateFunc(ctx, regularChainName, true) - err = f(ctx, 0, "counter %s", "accept") - assert.Check(t, err) + // Add another rule to the base chain. + bcJumpRule := Rule{Chain: bcName, Group: 0, Rule: []string{"jump", regularChainName}} + tm.Create(bcJumpRule) - // Fetch the base chain by name. - bc1 = tbl.Chain(ctx, bcName) + // Update nftables and check what happened. + applyAndCheck(t, tm, t.Name()+"_created.golden") - // Add another rule to the base chain, using the newly-retrieved handle. - err = bc1.AppendRule(ctx, 0, "jump %s", regularChainName) - assert.Check(t, err) + // Delete a rule from the base chain. + tm = tbl.Modifier() + tm.Delete(bcCounterRule) // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_created.golden") + applyAndCheck(t, tm, t.Name()+"_modified.golden") - // Delete a rule from the base chain. - f = tbl.ChainUpdateFunc(ctx, bcName, false) - err = f(ctx, 0, "counter") - assert.Check(t, err) + // Delete the base chain. + tm = tbl.Modifier() + tm.Delete(bcJumpRule) + tm.Delete(bcDesc) + tm.Delete(cRule) + tm.Delete(cDesc) - // Check it's an error to delete that rule again. This time, call the delete - // function directly on a newly retrieved handle. - err = tbl.Chain(ctx, bcName).DeleteRule(ctx, 0, "counter") - assert.Check(t, is.ErrorContains(err, "does not exist")) + // Update nftables and check what happened. + applyAndCheck(t, tm, t.Name()+"_deleted.golden") +} - // Update the base chain's policy. - err = tbl.Chain(ctx, bcName).SetPolicy("drop") - assert.Check(t, err) +func TestSetBaseChainPolicy(t *testing.T) { + defer testSetup(t)() + baseChainName := "aBaseChain" + chainName := "aChain" - // Check it's an error to set a policy on a regular chain. - err = tbl.Chain(ctx, regularChainName).SetPolicy("drop") - assert.Check(t, is.ErrorContains(err, "not a base chain")) + tbl := Table{} + err := tbl.SetBaseChainPolicy(context.Background(), baseChainName, "accept") + assert.Check(t, is.ErrorContains(err, "invalid table")) - // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_modified.golden") + tbl, err = NewTable(IPv4, "this_is_a_table") + assert.NilError(t, err) - // Delete the base chain. - err = tbl.DeleteChain(ctx, bcName) - assert.Check(t, err) + err = tbl.SetBaseChainPolicy(context.Background(), baseChainName, "accept") + assert.Check(t, is.Error(err, "cannot set base chain policy for '"+baseChainName+"', it does not exist")) + + tm := tbl.Modifier() + tm.Create(BaseChain{Name: baseChainName, ChainType: BaseChainTypeFilter, Hook: BaseChainHookForward, Priority: BaseChainPriorityFilter}) + tm.Create(Chain{Name: chainName}) + applyAndCheck(t, tm, t.Name()+"_accept.golden") + + err = tbl.SetBaseChainPolicy(context.Background(), chainName, "accept") + assert.Check(t, is.Error(err, "cannot set base chain policy for '"+chainName+"', it is not a base chain")) - // Delete the regular chain. - err = tbl.DeleteChain(ctx, regularChainName) + err = tbl.SetBaseChainPolicy(context.Background(), baseChainName, "drop") assert.Check(t, err) + res := icmd.RunCommand("nft", "list", "table", string(IPv4), "this_is_a_table") + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), t.Name()+"_drop.golden") - // Check that it's an error to delete it again. - err = tbl.DeleteChain(ctx, regularChainName) - assert.Check(t, is.ErrorContains(err, "does not exist")) + err = tbl.SetBaseChainPolicy(context.Background(), baseChainName, "badpolicy") + assert.Check(t, err != nil, "Expected an error") + res = icmd.RunCommand("nft", "list", "table", string(IPv4), "this_is_a_table") + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), t.Name()+"_drop.golden") - // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_deleted.golden") + err = tbl.Reload(context.Background()) + assert.Check(t, err) + res = icmd.RunCommand("nft", "list", "table", string(IPv4), "this_is_a_table") + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), t.Name()+"_drop.golden") } func TestChainRuleGroups(t *testing.T) { defer testSetup(t)() - ctx := context.Background() tbl, err := NewTable(IPv4, "testtable") assert.NilError(t, err) - c := tbl.Chain(ctx, "testchain") - err = c.AppendRule(ctx, 100, "hello100") - assert.Check(t, err) - err = c.AppendRule(ctx, 200, "hello200") - assert.Check(t, err) - err = c.AppendRule(ctx, 100, "hello101") - assert.Check(t, err) - err = c.AppendRule(ctx, 200, "hello201") - assert.Check(t, err) - err = c.AppendRule(ctx, 100, "hello102") - assert.Check(t, err) + tm := tbl.Modifier() + chainName := "testchain" + tm.Create(Chain{Name: chainName}) + tm.Create(Rule{Chain: chainName, Group: 100, Rule: []string{"iifname hello100 counter"}}) + tm.Create(Rule{Chain: chainName, Group: 200, Rule: []string{"iifname hello200 counter"}}) + tm.Create(Rule{Chain: chainName, Group: 100, Rule: []string{"iifname hello101 counter"}}) + tm.Create(Rule{Chain: chainName, Group: 200, Rule: []string{"iifname hello201 counter"}}) + tm.Create(Rule{Chain: chainName, Group: 100, Rule: []string{"iifname hello102 counter"}}) + applyAndCheck(t, tm, t.Name()+".golden") +} - assert.Check(t, is.DeepEqual(c.c.Rules(), []string{ - "hello100", "hello101", "hello102", - "hello200", "hello201", - })) +func TestIgnoreExist(t *testing.T) { + defer testSetup(t)() + tbl, err := NewTable(IPv4, "this_is_a_table") + assert.NilError(t, err) + tm := tbl.Modifier() + + // Create a chain with a single rule, add the rule again but drop the duplicate. + const chainName = "this_is_a_chain" + tm.Create(Chain{Name: chainName}) + tm.Create(Rule{Chain: chainName, Rule: []string{"counter"}}) + tm.Create(Rule{Chain: chainName, Rule: []string{"counter"}, IgnoreExist: true}) + applyAndCheck(t, tm, t.Name()+"_created.golden") + + // Add the rule again, ignoring the duplicate, but in a modifier that has an + // error - check that the existing rule isn't removed by rollback of this modifier. + tmErr := tbl.Modifier() + tmErr.Create(Rule{Chain: chainName, Rule: []string{"counter"}, IgnoreExist: true}) + tmErr.Create(Rule{Chain: chainName}) + err = tmErr.Apply(context.Background()) + assert.Check(t, err != nil, "Expected an error") + + // Reload, to flush table state. + reloadAndCheck(t, tbl, IPv4, t.Name()+"_created.golden") + + // Delete the rule. + tmDel := tbl.Modifier() + tmDel.Delete(Rule{Chain: chainName, Rule: []string{"counter"}}) + applyAndCheck(t, tmDel, t.Name()+"_deleted.golden") + + // Delete it again, in another chain that will roll back, to check it's not resurrected. + tmReDel := tbl.Modifier() + tmReDel.Delete(Rule{Chain: chainName, Rule: []string{"counter"}, IgnoreExist: true}) + tmReDel.Create(Rule{Chain: chainName}) + err = tmReDel.Apply(context.Background()) + assert.Check(t, err != nil, "Expected an error") + + // Reload, to flush table state. + reloadAndCheck(t, tbl, IPv4, t.Name()+"_deleted.golden") } func TestVMap(t *testing.T) { defer testSetup(t)() - ctx := context.Background() // Create a table. tbl, err := NewTable(IPv6, "this_is_a_table") assert.NilError(t, err) + tm := tbl.Modifier() // Create a verdict map. const mapName = "this_is_a_vmap" - m := tbl.InterfaceVMap(ctx, mapName) - - // Add an element. - err = m.AddElement(ctx, "eth0", "return") - assert.Check(t, err) - - // Check that it's an error to add the element again. - err = m.AddElement(ctx, "eth0", "return") - assert.Check(t, is.ErrorContains(err, "already contains element")) - - // Fetch the existing vmap. - m = tbl.InterfaceVMap(ctx, mapName) - - // Add another element. - err = m.AddElement(ctx, "eth1", "drop") - assert.Check(t, err) + tm.Create(VMap{Name: mapName, ElementType: NftTypeIfname}) + tm.Create(VMapElement{VmapName: mapName, Key: "eth0", Verdict: "return"}) + tm.Create(VMapElement{VmapName: mapName, Key: "eth1", Verdict: "drop"}) // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_created.golden") + applyAndCheck(t, tm, t.Name()+"_created.golden") - // Delete an element. - err = m.DeleteElement(ctx, "eth1") - assert.Check(t, err) - - // Check it's an error to delete it again. - err = m.DeleteElement(ctx, "eth1") - assert.Check(t, is.ErrorContains(err, "does not contain element")) + // Undo those changes by reversing the commands. + tmRev := tm.Reverse() // Update nftables and check what happened. - applyAndCheck(t, tbl, t.Name()+"_deleted.golden") + applyAndCheck(t, tmRev, t.Name()+"_deleted.golden") } func TestSet(t *testing.T) { defer testSetup(t)() - ctx := context.Background() // Create v4 and v6 tables. tbl4, err := NewTable(IPv4, "table4") @@ -211,62 +255,55 @@ func TestSet(t *testing.T) { assert.NilError(t, err) // Create a set in each table. - s4 := tbl4.PrefixSet(ctx, "set4") - s6 := tbl6.PrefixSet(ctx, "set6") + const set4Name = "set4" + tm4 := tbl4.Modifier() + tm4.Create(Set{Name: set4Name, ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}) + const set6Name = "set6" + tm6 := tbl6.Modifier() + tm6.Create(Set{Name: set6Name, ElementType: NftTypeIPv6Addr, Flags: []string{"interval"}}) // Add elements to each set. - err = s4.AddElement(ctx, "192.0.2.1/24") - assert.Check(t, err) - err = s6.AddElement(ctx, "2001:db8::1/64") - assert.Check(t, err) - - // Check it's an error to add those elements again. - err = s4.AddElement(ctx, "192.0.2.1/24") - assert.Check(t, is.ErrorContains(err, "already contains element")) - err = s6.AddElement(ctx, "2001:db8::1/64") - assert.Check(t, is.ErrorContains(err, "already contains element")) + tm4.Create(SetElement{SetName: set4Name, Element: "192.0.2.0/24"}) + tm6.Create(SetElement{SetName: set6Name, Element: "2001:db8::/64"}) // Update nftables and check what happened. - applyAndCheck(t, tbl4, t.Name()+"_created4.golden") - applyAndCheck(t, tbl6, t.Name()+"_created46.golden") + applyAndCheck(t, tm4, t.Name()+"_created4.golden") + applyAndCheck(t, tm6, t.Name()+"_created6.golden") // Delete elements. - err = s4.DeleteElement(ctx, "192.0.2.1/24") - assert.Check(t, err) - err = s6.DeleteElement(ctx, "2001:db8::1/64") - assert.Check(t, err) - - // Check it's an error to delete those elements again. - err = s4.DeleteElement(ctx, "192.0.2.1/24") - assert.Check(t, is.ErrorContains(err, "does not contain element")) - err = s6.DeleteElement(ctx, "2001:db8::1/64") - assert.Check(t, is.ErrorContains(err, "does not contain element")) - - // Update nftables and check what happened. - applyAndCheck(t, tbl4, t.Name()+"_deleted4.golden") - applyAndCheck(t, tbl6, t.Name()+"_deleted46.golden") + applyAndCheck(t, tm4.Reverse(), t.Name()+"_deleted4.golden") + applyAndCheck(t, tm6.Reverse(), t.Name()+"_deleted6.golden") } func TestReload(t *testing.T) { defer testSetup(t)() - ctx := context.Background() // Create a table with some stuff in it. const tableName = "this_is_a_table" tbl, err := NewTable(IPv4, tableName) assert.NilError(t, err) - bc, err := tbl.BaseChain(ctx, "a_base_chain", BaseChainTypeFilter, BaseChainHookForward, BaseChainPriorityFilter) - assert.NilError(t, err) - err = bc.AppendRule(ctx, 0, "counter") - assert.NilError(t, err) - m := tbl.InterfaceVMap(ctx, "this_is_a_vmap") - err = m.AddElement(ctx, "eth0", "return") - assert.Check(t, err) - err = m.AddElement(ctx, "eth1", "return") - assert.Check(t, err) - err = tbl.PrefixSet(ctx, "set4").AddElement(ctx, "192.0.2.0/24") - assert.Check(t, err) - applyAndCheck(t, tbl, t.Name()+"_created.golden") + tm := tbl.Modifier() + + const bcName = "a_base_chain" + tm.Create(BaseChain{ + Name: bcName, + ChainType: BaseChainTypeFilter, + Hook: BaseChainHookForward, + Priority: BaseChainPriorityFilter, + Policy: "accept", + }) + tm.Create(Rule{Chain: bcName, Group: 0, Rule: []string{"counter"}}) + + const vmapName = "this_is_a_vmap" + tm.Create(VMap{Name: vmapName, ElementType: NftTypeIfname}) + tm.Create(VMapElement{VmapName: vmapName, Key: "eth0", Verdict: "return"}) + tm.Create(VMapElement{VmapName: vmapName, Key: "eth1", Verdict: "return"}) + + const setName = "this_is_a_set" + tm.Create(Set{Name: setName, ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}) + tm.Create(SetElement{SetName: setName, Element: "192.0.2.0/24"}) + + applyAndCheck(t, tm, t.Name()+"_created.golden") // Delete the underlying nftables table. deleteTable := func() { @@ -282,14 +319,371 @@ func TestReload(t *testing.T) { // Reconstruct the nftables table. err = tbl.Reload(context.Background()) assert.Check(t, err) - applyAndCheck(t, tbl, t.Name()+"_reloaded.golden") + res := icmd.RunCommand("nft", "list", "table", string(tm.Family()), tm.Name()) + res.Assert(t, icmd.Success) + golden.Assert(t, res.Combined(), t.Name()+"_created.golden") // Delete again. deleteTable() // Check implicit/recovery reload - only deleting something that's gone missing // from a vmap/set will trigger this. - err = m.DeleteElement(ctx, "eth1") - assert.Check(t, err) - applyAndCheck(t, tbl, t.Name()+"_recovered.golden") + tm = tbl.Modifier() + tm.Delete(SetElement{SetName: setName, Element: "192.0.2.0/24"}) + applyAndCheck(t, tm, t.Name()+"_recovered.golden") +} + +func TestValidation(t *testing.T) { + testcases := []struct { + name string + cmds []command + expErr string + }{ + // BaseChain + { + name: "create with missing base chain name", + cmds: []command{ + {obj: BaseChain{ChainType: BaseChainTypeNAT, Hook: BaseChainHookPostrouting, Priority: BaseChainPrioritySrcNAT}}, + }, + expErr: "base chain must have a name", + }, + { + name: "create with missing base chain type", + cmds: []command{ + {obj: BaseChain{Name: "achain", Hook: BaseChainHookPostrouting, Priority: BaseChainPrioritySrcNAT}}, + }, + expErr: "chain 'achain': fields ChainType and Hook are required", + }, + { + name: "create with missing base chain hook", + cmds: []command{ + {obj: BaseChain{Name: "achain", ChainType: BaseChainTypeNAT, Priority: BaseChainPrioritySrcNAT}}, + }, + expErr: "chain 'achain': fields ChainType and Hook are required", + }, + { + name: "delete non-empty base chain", + cmds: []command{ + {obj: BaseChain{ + Name: "achain", ChainType: BaseChainTypeNAT, Hook: BaseChainHookPostrouting, Priority: BaseChainPrioritySrcNAT, + }}, + {obj: Rule{Chain: "achain", Group: 0, Rule: []string{"counter"}}}, + { + obj: BaseChain{ + Name: "achain", ChainType: BaseChainTypeNAT, Hook: BaseChainHookPostrouting, Priority: BaseChainPrioritySrcNAT, + }, + delete: true, + }, + }, + expErr: "cannot delete chain 'achain', it is not empty", + }, + // Chain + { + name: "duplicate chain", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Chain{Name: "achain"}}, + }, + expErr: "already exists", + }, + { + name: "delete missing chain", + cmds: []command{ + {obj: Chain{Name: "achain"}, delete: true}, + }, + expErr: "does not exist", + }, + { + name: "missing chain name", + cmds: []command{ + {obj: Chain{}}, + }, + expErr: "chain must have a name", + }, + { + name: "delete non-empty chain", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + {obj: Chain{Name: "achain"}, delete: true}, + }, + expErr: "cannot delete chain 'achain', it is not empty", + }, + // Rule + { + name: "bad rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"this is nonsense"}}}, + }, + expErr: "syntax error", + }, + { + name: "duplicate rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + }, + expErr: "rule exists", + }, + { + name: "delete missing rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + }, + expErr: "does not exist", + }, + { + name: "duplicate rule delete", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + }, + expErr: "does not exist", + }, + { + name: "create rule with missing chain name", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Rule: []string{"counter"}}}, + }, + expErr: "chain '' does not exist", + }, + { + name: "delete rule with missing chain name", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Rule: []string{"counter"}}, delete: true}, + }, + expErr: "chain '' does not exist", + }, + { + name: "create rule with nonexistent chain", + cmds: []command{ + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + }, + expErr: "chain 'achain' does not exist", + }, + { + name: "delete rule with nonexistent chain", + cmds: []command{ + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + }, + expErr: "chain 'achain' does not exist", + }, + { + name: "create rule with no rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain"}}, + }, + expErr: "cannot add empty rule", + }, + { + name: "delete rule with no rule", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain"}, delete: true}, + }, + expErr: "cannot delete empty rule", + }, + { + name: "bad rule mid-sequence", + cmds: []command{ + {obj: Chain{Name: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}, delete: true}, + {obj: Rule{Chain: "achain"}}, + {obj: Rule{Chain: "achain", Rule: []string{"counter"}}}, + }, + expErr: "chain 'achain', cannot add empty rule", + }, + // VMap + { + name: "duplicate vmap", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + }, + expErr: "vmap 'avmap' already exists", + }, + { + name: "delete nonexistent vmap", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}, delete: true}, + }, + expErr: "cannot delete vmap 'avmap', it does not exist", + }, + { + name: "missing vmap name", + cmds: []command{{obj: VMap{ElementType: NftTypeIfname}}}, + expErr: "vmap must have a name", + }, + { + name: "missing vmap element type", + cmds: []command{{obj: VMap{Name: "avmap"}}}, + expErr: "vmap 'avmap' has no element type", + }, + { + name: "delete non-empty vmap", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}}, + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}, delete: true}, + }, + expErr: "cannot delete vmap 'avmap', it contains 1 elements", + }, + // VMapElement + { + name: "duplicate vmap element", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}}, + }, + expErr: "verdict map 'avmap' already contains element 'eth0'", + }, + { + name: "add to vmap that does not exist", + cmds: []command{ + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}}, + }, + expErr: "cannot add to vmap 'avmap', it does not exist", + }, + { + name: "delete nonexistent vmap element", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0", Verdict: "drop"}, delete: true}, + }, + expErr: "verdict map 'avmap' does not contain element 'eth0'", + }, + { + name: "vmap element with no named vmap", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{Key: "eth0", Verdict: "drop"}}, + }, + expErr: "cannot add element to unnamed vmap", + }, + { + name: "vmap element with no key", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Verdict: "drop"}}, + }, + expErr: "cannot add to vmap 'avmap', element must have key and verdict", + }, + { + name: "vmap element with no verdict", + cmds: []command{ + {obj: VMap{Name: "avmap", ElementType: NftTypeIfname}}, + {obj: VMapElement{VmapName: "avmap", Key: "eth0"}}, + }, + expErr: "cannot add to vmap 'avmap', element must have key and verdict", + }, + // Set + { + name: "duplicate set", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + }, + expErr: "set 'aset' already exists", + }, + { + name: "delete nonexistent set", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}, delete: true}, + }, + expErr: "cannot delete set 'aset', it does not exist", + }, + { + name: "missing set name", + cmds: []command{ + {obj: Set{ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + }, + expErr: "set must have a name", + }, + { + name: "missing set element type", + cmds: []command{ + {obj: Set{Name: "aset", Flags: []string{"interval"}}}, + }, + expErr: "set 'aset' must have a type", + }, + { + name: "delete non-empty set", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset", Element: "192.0.2.0/24"}}, + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}, delete: true}, + }, + expErr: "cannot delete set 'aset', it contains 1 elements", + }, + // SetElement + { + name: "duplicate set element", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset", Element: "192.0.2.0/24"}}, + {obj: SetElement{SetName: "aset", Element: "192.0.2.0/24"}}, + }, + expErr: "set 'aset' already contains element '192.0.2.0/24'", + }, + { + name: "delete nonexistent set element", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset", Element: "192.0.2.0/24"}, delete: true}, + }, + expErr: "cannot delete '192.0.2.0/24' from set 'aset', it does not exist", + }, + { + name: "add set element to unnamed set", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{Element: "192.0.2.0/24"}}, + }, + expErr: "cannot add to set '', it does not exist", + }, + { + name: "add set element with no element", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset"}}, + }, + expErr: "cannot add to set 'aset', element not specified", + }, + { + name: "mismatched set element type", + cmds: []command{ + {obj: Set{Name: "aset", ElementType: NftTypeIPv4Addr, Flags: []string{"interval"}}}, + {obj: SetElement{SetName: "aset", Element: "2001:db8::/64"}}, + }, + expErr: "Address family for hostname not supported", + }, + } + + testName := t.Name() + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + defer testSetup(t)() + tbl, err := NewTable(IPv4, "tablename") + assert.NilError(t, err) + tm := Modifier{t: tbl.t, cmds: tc.cmds} + err = tm.Apply(context.Background()) + assert.Check(t, err != nil, "expected error containing '%s'", tc.expErr) + assert.Check(t, is.ErrorContains(err, tc.expErr)) + // Check the table wasn't created. + res := icmd.RunCommand("nft", "list", "table", string(IPv4), "tablename") + res.Assert(t, icmd.Expected{ExitCode: 1}) + // Check the empty table can be created (the Table structure is still healthy). + applyAndCheck(t, tbl.Modifier(), testName+"_empty.golden") + }) + } } diff --git a/daemon/libnetwork/internal/nftables/testdata/.gitattributes b/daemon/libnetwork/internal/nftables/testdata/.gitattributes new file mode 100644 index 0000000000000..fab3c980014f5 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/.gitattributes @@ -0,0 +1 @@ +*.golden linguist-generated=true diff --git a/daemon/libnetwork/internal/nftables/testdata/TestChainRuleGroups.golden b/daemon/libnetwork/internal/nftables/testdata/TestChainRuleGroups.golden new file mode 100644 index 0000000000000..cf5befeb3e6f5 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestChainRuleGroups.golden @@ -0,0 +1,9 @@ +table ip testtable { + chain testchain { + iifname "hello100" counter packets 0 bytes 0 + iifname "hello101" counter packets 0 bytes 0 + iifname "hello102" counter packets 0 bytes 0 + iifname "hello200" counter packets 0 bytes 0 + iifname "hello201" counter packets 0 bytes 0 + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestChain_modified.golden b/daemon/libnetwork/internal/nftables/testdata/TestChain_modified.golden index efbf1dfe142d7..d9a59eb2eae13 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestChain_modified.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestChain_modified.golden @@ -1,6 +1,6 @@ table ip this_is_a_table { chain this_is_a_base_chain { - type filter hook forward priority filter + 10; policy drop; + type filter hook forward priority filter + 10; policy accept; jump this_is_a_regular_chain } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_created.golden b/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_created.golden new file mode 100644 index 0000000000000..41c0aa7251679 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_created.golden @@ -0,0 +1,5 @@ +table ip this_is_a_table { + chain this_is_a_chain { + counter packets 0 bytes 0 + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_deleted.golden b/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_deleted.golden new file mode 100644 index 0000000000000..148748dc0c14e --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestIgnoreExist_deleted.golden @@ -0,0 +1,4 @@ +table ip this_is_a_table { + chain this_is_a_chain { + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestReload_created.golden b/daemon/libnetwork/internal/nftables/testdata/TestReload_created.golden index f69a290c706c9..dec6dc28cc3a6 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestReload_created.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestReload_created.golden @@ -5,7 +5,7 @@ table ip this_is_a_table { "eth1" : return } } - set set4 { + set this_is_a_set { type ipv4_addr flags interval elements = { 192.0.2.0/24 } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestReload_recovered.golden b/daemon/libnetwork/internal/nftables/testdata/TestReload_recovered.golden index 98d703606128a..a112403abf807 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestReload_recovered.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestReload_recovered.golden @@ -1,13 +1,13 @@ table ip this_is_a_table { map this_is_a_vmap { type ifname : verdict - elements = { "eth0" : return } + elements = { "eth0" : return, + "eth1" : return } } - set set4 { + set this_is_a_set { type ipv4_addr flags interval - elements = { 192.0.2.0/24 } } chain a_base_chain { diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_accept.golden b/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_accept.golden new file mode 100644 index 0000000000000..2de16c86a3325 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_accept.golden @@ -0,0 +1,8 @@ +table ip this_is_a_table { + chain aBaseChain { + type filter hook forward priority filter; policy accept; + } + + chain aChain { + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_drop.golden b/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_drop.golden new file mode 100644 index 0000000000000..f5d80a60229f3 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestSetBaseChainPolicy_drop.golden @@ -0,0 +1,8 @@ +table ip this_is_a_table { + chain aBaseChain { + type filter hook forward priority filter; policy drop; + } + + chain aChain { + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSet_created46.golden b/daemon/libnetwork/internal/nftables/testdata/TestSet_created46.golden index ea8d0cd04a59f..f6fbbb281be1b 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestSet_created46.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestSet_created46.golden @@ -1,10 +1,3 @@ -table ip table4 { - set set4 { - type ipv4_addr - flags interval - elements = { 192.0.2.0/24 } - } -} table ip6 table6 { set set6 { type ipv6_addr diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSet_created6.golden b/daemon/libnetwork/internal/nftables/testdata/TestSet_created6.golden new file mode 100644 index 0000000000000..f6fbbb281be1b --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestSet_created6.golden @@ -0,0 +1,7 @@ +table ip6 table6 { + set set6 { + type ipv6_addr + flags interval + elements = { 2001:db8::/64 } + } +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted4.golden b/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted4.golden index 54f481f747604..f9b03350baee0 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted4.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted4.golden @@ -1,13 +1,2 @@ table ip table4 { - set set4 { - type ipv4_addr - flags interval - } -} -table ip6 table6 { - set set6 { - type ipv6_addr - flags interval - elements = { 2001:db8::/64 } - } } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted6.golden b/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted6.golden new file mode 100644 index 0000000000000..813114005b892 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestSet_deleted6.golden @@ -0,0 +1,2 @@ +table ip6 table6 { +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestTable_created46.golden b/daemon/libnetwork/internal/nftables/testdata/TestTable_created46.golden index 3bb18c8761575..d15df4247e30c 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestTable_created46.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestTable_created46.golden @@ -1,4 +1,2 @@ -table ip ipv4_table { -} table ip6 ipv6_table { } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestTable_created6.golden b/daemon/libnetwork/internal/nftables/testdata/TestTable_created6.golden new file mode 100644 index 0000000000000..d15df4247e30c --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestTable_created6.golden @@ -0,0 +1,2 @@ +table ip6 ipv6_table { +} diff --git a/daemon/libnetwork/internal/nftables/testdata/TestVMap_deleted.golden b/daemon/libnetwork/internal/nftables/testdata/TestVMap_deleted.golden index 432c30d38afda..9a1af3597868c 100644 --- a/daemon/libnetwork/internal/nftables/testdata/TestVMap_deleted.golden +++ b/daemon/libnetwork/internal/nftables/testdata/TestVMap_deleted.golden @@ -1,6 +1,2 @@ table ip6 this_is_a_table { - map this_is_a_vmap { - type ifname : verdict - elements = { "eth0" : return } - } } diff --git a/daemon/libnetwork/internal/nftables/testdata/TestValidation_empty.golden b/daemon/libnetwork/internal/nftables/testdata/TestValidation_empty.golden new file mode 100644 index 0000000000000..593e68402a3f8 --- /dev/null +++ b/daemon/libnetwork/internal/nftables/testdata/TestValidation_empty.golden @@ -0,0 +1,2 @@ +table ip tablename { +} diff --git a/daemon/libnetwork/resolver_unix.go b/daemon/libnetwork/resolver_unix.go index f7de9defa5449..57037bc16b042 100644 --- a/daemon/libnetwork/resolver_unix.go +++ b/daemon/libnetwork/resolver_unix.go @@ -97,32 +97,47 @@ func (r *Resolver) setupNftablesNAT(ctx context.Context, laddr, ltcpaddr, resolv if err != nil { return err } + tm := table.Modifier() - dnatChain, err := table.BaseChain(ctx, "dns-dnat", nftables.BaseChainTypeNAT, nftables.BaseChainHookOutput, nftables.BaseChainPriorityDstNAT) - if err != nil { - return err - } - if err := dnatChain.AppendRule(ctx, 0, "ip daddr %s udp dport %s counter dnat to %s", resolverIP, dnsPort, laddr); err != nil { - return err - } - if err := dnatChain.AppendRule(ctx, 0, "ip daddr %s tcp dport %s counter dnat to %s", resolverIP, dnsPort, ltcpaddr); err != nil { - return err - } + const dnatChain = "dns-dnat" + tm.Create(nftables.BaseChain{ + Name: dnatChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookOutput, + Priority: nftables.BaseChainPriorityDstNAT, + }) + tm.Create(nftables.Rule{ + Chain: dnatChain, + Rule: []string{"ip daddr", resolverIP, "udp dport", dnsPort, "counter dnat to", laddr}, + IgnoreExist: false, + }) + tm.Create(nftables.Rule{ + Chain: dnatChain, + Rule: []string{"ip daddr", resolverIP, "tcp dport", dnsPort, "counter dnat to", ltcpaddr}, + IgnoreExist: false, + }) - snatChain, err := table.BaseChain(ctx, "dns-snat", nftables.BaseChainTypeNAT, nftables.BaseChainHookPostrouting, nftables.BaseChainPrioritySrcNAT) - if err != nil { - return err - } - if err := snatChain.AppendRule(ctx, 0, "ip saddr %s udp sport %s counter snat to :%s", resolverIP, ipPort, dnsPort); err != nil { - return err - } - if err := snatChain.AppendRule(ctx, 0, "ip saddr %s tcp sport %s counter snat to :%s", resolverIP, tcpPort, dnsPort); err != nil { - return err - } + const snatChain = "dns-snat" + tm.Create(nftables.BaseChain{ + Name: snatChain, + ChainType: nftables.BaseChainTypeNAT, + Hook: nftables.BaseChainHookPostrouting, + Priority: nftables.BaseChainPrioritySrcNAT, + }) + tm.Create(nftables.Rule{ + Chain: snatChain, + Rule: []string{"ip saddr", resolverIP, "udp sport", ipPort, "counter snat to :" + dnsPort}, + IgnoreExist: false, + }) + tm.Create(nftables.Rule{ + Chain: snatChain, + Rule: []string{"ip saddr", resolverIP, "tcp sport", tcpPort, "counter snat to :" + dnsPort}, + IgnoreExist: false, + }) var setupErr error if err := r.backend.ExecFunc(func() { - setupErr = table.Apply(ctx) + setupErr = tm.Apply(ctx) }); err != nil { return err } diff --git a/integration/network/bridge/bridge_linux_test.go b/integration/network/bridge/bridge_linux_test.go index 1825800bbe258..d284b0309da6c 100644 --- a/integration/network/bridge/bridge_linux_test.go +++ b/integration/network/bridge/bridge_linux_test.go @@ -5,6 +5,8 @@ import ( "fmt" "net" "net/netip" + "regexp" + "strconv" "strings" "testing" "time" @@ -939,3 +941,97 @@ func TestFirewallBackendSwitch(t *testing.T) { assert.Check(t, len(dockerChains) > 0, "Expected iptables to have at least one docker chain") assert.Check(t, !nftablesTablesExist(), "nftables tables exist after running with iptables") } + +// TestBaseChainPriorityOverride checks that priorities of nftables base chains can be configured. +func TestBaseChainPriorityOverride(t *testing.T) { + skip.If(t, !strings.Contains(testEnv.FirewallBackendDriver(), "nftables"), + "test is nftables specific, and nftables isn't in use") + _ = setupTest(t) + d := daemon.New(t) + defer d.Stop(t) + + baseChains := []string{"filter-FORWARD", "nat-POSTROUTING", "nat-PREROUTING", "nat-OUTPUT", "raw-PREROUTING"} + + re := regexp.MustCompile("priority ([-]?[0-9]+)") + chainPriority := func(name string) int { + t.Helper() + res := icmd.RunCommand("nft", "-n", "list chain "+name) + assert.Assert(t, res.ExitCode == 0, "failed to list chain %s", name) + m := re.FindStringSubmatch(res.Stdout()) + assert.Assert(t, is.Len(m, 2), "failed to find chain priority in %s", res.Stdout()) + i, err := strconv.Atoi(m[1]) + assert.NilError(t, err) + return i + } + chainPriorities := func(fam string) map[string]int { + t.Helper() + res := map[string]int{} + for _, chain := range baseChains { + res[chain] = chainPriority(fam + " docker-bridges " + chain) + } + return res + } + + d.Start(t) + defaults4 := chainPriorities("ip") + defaults6 := chainPriorities("ip6") + assert.Check(t, is.DeepEqual(defaults4, defaults6), "Expected ip/ip6 base chain priorities to be the same") + + args := make([]string, 0, len(baseChains)) + for _, chain := range baseChains { + args = append(args, fmt.Sprintf("--bridge-nftables-priority=%s=%d", chain, defaults4[chain]+1)) + } + + d.Stop(t) + d.Start(t, args...) + + modified4 := chainPriorities("ip") + modified6 := chainPriorities("ip6") + assert.Check(t, is.DeepEqual(modified4, modified6), "Expected ip/ip6 base chain priorities to be the same") + for _, chain := range baseChains { + assert.Check(t, is.Equal(modified4[chain], defaults4[chain]+1)) + } +} + +// TestBaseChainPriorityValidation checks that nftables priority overrides are validated on startup. +func TestBaseChainPriorityValidation(t *testing.T) { + skip.If(t, !strings.Contains(testEnv.FirewallBackendDriver(), "nftables"), + "test is nftables specific, and nftables isn't in use") + _ = setupTest(t) + + testcases := []struct { + name string + opt string + expErr string + }{ + { + name: "bad base chain name", + opt: "nosuchchain=0", + expErr: `"nosuchchain" is not a valid base chain name`, + }, + { + name: "non-integer priority", + opt: "filter-FORWARD=filter+1", + expErr: `priority "filter+1" for base chain "filter-FORWARD" is not an integer`, + }, + { + name: "missing priority", + opt: "filter-FORWARD", + expErr: `priority "" for base chain "filter-FORWARD" is not an integer`, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + d := daemon.New(t) + defer d.Stop(t) + + err := d.StartWithError("--bridge-nftables-priority=" + tc.opt) + assert.Check(t, err != nil, "expected startup error") + logLine, err := d.TailLogs(1) + assert.Check(t, err) + assert.Check(t, len(logLine) == 1) + assert.Check(t, is.Contains(string(logLine[0]), tc.expErr)) + }) + } +}