diff --git a/changelogs/fragments/248_class_level_imports.yaml b/changelogs/fragments/248_class_level_imports.yaml new file mode 100644 index 00000000..f205588d --- /dev/null +++ b/changelogs/fragments/248_class_level_imports.yaml @@ -0,0 +1,2 @@ +minor_changes: + - Remove class level imports in PFSenseModule to support the use of Mitogen (https://github.com/pfsensible/core/pull/248). diff --git a/misc/local2ansible b/misc/local2ansible index b66d21d0..eb712648 100755 --- a/misc/local2ansible +++ b/misc/local2ansible @@ -25,14 +25,12 @@ fi mkdir -p ${ANSIBLE_INSTALL}/module_utils/network/pfsense -mkdir -p ${ANSIBLE_INSTALL}/module_utils/network/pfsense/__impl mkdir -p ${ANSIBLE_INSTALL}/modules/network/pfsense # remove old modules imports rm -rf ${ANSIBLE_INSTALL}/module_utils/network/pfsense/pfense_* cp module_utils/network/pfsense/*.py ${ANSIBLE_INSTALL}/module_utils/network/pfsense/ -cp module_utils/network/pfsense/__impl/*.py ${ANSIBLE_INSTALL}/module_utils/network/pfsense/__impl/ cp library/*.py ${ANSIBLE_INSTALL}/modules/network/pfsense/ cp lookup_plugins/pfsense.py ${ANSIBLE_INSTALL}/plugins/lookup/pfsense.py diff --git a/plugins/module_utils/__impl/__init__.py b/plugins/module_utils/__impl/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/plugins/module_utils/__impl/addresses.py b/plugins/module_utils/__impl/addresses.py deleted file mode 100644 index 121357bf..00000000 --- a/plugins/module_utils/__impl/addresses.py +++ /dev/null @@ -1,167 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright: (c) 2019, Orion Poplawski -# Copyright: (c) 2019, Frederic Bor -# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) - -from __future__ import absolute_import, division, print_function -__metaclass__ = type -try: - from ipaddress import ip_address, ip_network, IPv4Address, IPv6Address, IPv4Network, IPv6Network -except ImportError: - from ansible_collections.community.general.plugins.module_utils.compat.ipaddress import ( - ip_address, IPv4Address, IPv6Address, - ip_network, IPv4Network, IPv6Network - ) -import re - - -@staticmethod -def is_ipv4_address(address): - """ test if address is a valid ipv4 address """ - try: - addr = ip_address(u'{0}'.format(address)) - return isinstance(addr, IPv4Address) - except ValueError: - pass - return False - - -@staticmethod -def is_ipv6_address(address): - """ test if address is a valid ipv6 address """ - try: - addr = ip_address(u'{0}'.format(address)) - return isinstance(addr, IPv6Address) - except ValueError: - pass - return False - - -@staticmethod -def is_ipv4_network(address, strict=True): - """ test if address is a valid ipv4 network """ - try: - addr = ip_network(u'{0}'.format(address), strict=strict) - return isinstance(addr, IPv4Network) - except ValueError: - pass - return False - - -@staticmethod -def is_ipv6_network(address, strict=True): - """ test if address is a valid ipv6 network """ - try: - addr = ip_network(u'{0}'.format(address), strict=strict) - return isinstance(addr, IPv6Network) - except ValueError: - pass - return False - - -def is_ip_network(self, address, strict=True): - """ test if address is a valid ip network """ - return self.is_ipv4_network(address, strict) or self.is_ipv6_network(address, strict) - - -def is_within_local_networks(self, address): - """ test if address is contained in our local networks """ - networks = self.get_interfaces_networks() - try: - addr = ip_address(u'{0}'.format(address)) - except ValueError: - return False - - for network in networks: - try: - net = ip_network(u'{0}'.format(network), strict=False) - if addr in net: - return True - except ValueError: - pass - return False - - -@staticmethod -def parse_ip_network(address, strict=True, returns_ip=True): - """ return cidr parts of address """ - try: - addr = ip_network(u'{0}'.format(address), strict=strict) - if strict or not returns_ip: - return (str(addr.network_address), addr.prefixlen) - else: - # we parse the address with ipaddr just for type checking - # but we use a regex to return the result as it dont kept the address bits - group = re.match(r'(.*)/(.*)', address) - if group: - return (group.group(1), group.group(2)) - except ValueError: - pass - return None - - -def parse_address(self, param, allow_self=True): - """ validate param address field and returns it as a dict """ - if self.is_ipv6_address(param) or self.is_ipv6_network(param): - addr = [param] - else: - addr = param.split(':', maxsplit=3) - if len(addr) > 3: - self.module.fail_json(msg='Cannot parse address %s' % (param)) - - address = addr[0] - - ret = dict() - # Check if the first character is "!" - if address[0] == '!': - # Invert the rule - ret['not'] = None - address = address[1:] - - if address == 'NET' or address == 'IP': - interface = addr[1] if len(addr) > 1 else None - ports = addr[2] if len(addr) > 2 else None - if interface is None or interface == '': - self.module.fail_json(msg='Cannot parse address %s' % (param)) - - ret['network'] = self.parse_interface(interface) - if address == 'IP': - ret['network'] += 'ip' - else: - ports = addr[1] if len(addr) > 1 else None - if address == 'any': - ret['any'] = None - # rule with this firewall - elif allow_self and address == '(self)': - ret['network'] = '(self)' - # rule with interface name (LAN, WAN...) - elif self.is_interface_display_name(address): - ret['network'] = self.get_interface_by_display_name(address) - else: - if not self.is_ip_or_alias(address): - self.module.fail_json(msg='Cannot parse address %s, not IP or alias' % (address)) - ret['address'] = address - - if ports is not None: - self.parse_port(ports, ret) - msg = "the :ports syntax at end of addresses is deprecated and support will be removed soon. Please use source_port and destination_port options." - self.module.warn(msg) - - return ret - - -def parse_port(self, src_ports, ret): - """ validate and parse port address field and set it in ret """ - ports = src_ports.split('-') - if len(ports) > 2 or ports[0] is None or ports[0] == '' or len(ports) == 2 and (ports[1] is None or ports[1] == ''): - self.module.fail_json(msg='Cannot parse port %s' % (src_ports)) - - if not self.is_port_or_alias(ports[0]): - self.module.fail_json(msg='Cannot parse port %s, not port number or alias' % (ports[0])) - ret['port'] = ports[0] - - if len(ports) > 1: - if not self.is_port_or_alias(ports[1]): - self.module.fail_json(msg='Cannot parse port %s, not port number or alias' % (ports[1])) - ret['port'] += '-' + ports[1] diff --git a/plugins/module_utils/__impl/checks.py b/plugins/module_utils/__impl/checks.py deleted file mode 100644 index 5088f1fe..00000000 --- a/plugins/module_utils/__impl/checks.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright: (c) 2019-2021, Orion Poplawski -# Copyright: (c) 2019, Frederic Bor -# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) - -from __future__ import absolute_import, division, print_function -__metaclass__ = type -import re -import socket - - -def check_name(self, name, objtype): - """ check name validity """ - - msg = None - if len(name) >= 32 or len(re.findall(r'(^_*$|^\d*$|[^a-zA-Z0-9_])', name)) > 0: - msg = "The {0} name '{1}' must be less than 32 characters long, may not consist of only numbers, may not consist of only underscores, ".format( - objtype, name) - msg += "and may only contain the following characters: a-z, A-Z, 0-9, _" - elif name in ["port", "pass"]: - msg = "The {0} name must not be either of the reserved words 'port' or 'pass'".format(objtype) - else: - try: - socket.getprotobyname(name) - msg = 'The {0} name must not be an IP protocol name such as TCP, UDP, ICMP etc.'.format(objtype) - except socket.error: - pass - - try: - socket.getservbyname(name) - msg = 'The {0} name must not be a well-known or registered TCP or UDP port name such as ssh, smtp, pop3, tftp, http, openvpn etc.'.format(objtype) - except socket.error: - pass - - if msg is not None: - self.module.fail_json(msg=msg) - - -def check_ip_address(self, address, ipprotocol, objtype, allow_networks=False, fail_ifnotip=False): - """ check address according to ipprotocol """ - if address is None: - return - if allow_networks: - ipv4 = self.is_ipv4_network(address, False) - ipv6 = self.is_ipv6_network(address, False) - else: - ipv4 = self.is_ipv4_address(address) - ipv6 = self.is_ipv6_address(address) - - if ipprotocol == 'inet': - if ipv6 or not ipv4 and fail_ifnotip: - self.module.fail_json(msg='{0} must use an IPv4 address'.format(objtype)) - elif ipprotocol == 'inet6': - if ipv4 or not ipv6 and fail_ifnotip: - self.module.fail_json(msg='{0} must use an IPv6 address'.format(objtype)) - elif ipprotocol == 'inet46': - if ipv4 or ipv6: - self.module.fail_json(msg='IPv4 and IPv6 addresses can not be used in objects that apply to both IPv4 and IPv6 (except within an alias).') - - -def validate_openvpn_tunnel_network(self, network, ipproto): - """ check openvpn tunnel network validity - based on pfSense's openvpn_validate_tunnel_network() """ - if network is not None and network != '': - alias_elt = self.find_alias(network, aliastype='network') - if alias_elt is not None: - networks = alias_elt.find('address').text.split() - if len(networks) > 1: - self.module.fail_json("The alias {0} contains more than one network".format(network)) - network = networks[0] - - if not self.is_ipv4_network(network, strict=False) and ipproto == 'ipv4': - self.module.fail_json("{0} is not a valid IPv4 network".format(network)) - if not self.is_ipv6_network(network, strict=False) and ipproto == 'ipv6': - self.module.fail_json("{0} is not a valid IPv6 network".format(network)) - return True - - return True - - -def validate_string(self, name, objtype): - """ check string validity - similar to pfSense's do_input_validate() """ - - if len(re.findall(r'[\000-\010\013\014\016-\037]', name)) > 0: - self.module.fail_json("The {0} name contains invalid characters.".format(objtype)) diff --git a/plugins/module_utils/__impl/interfaces.py b/plugins/module_utils/__impl/interfaces.py deleted file mode 100644 index adcec075..00000000 --- a/plugins/module_utils/__impl/interfaces.py +++ /dev/null @@ -1,146 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright: (c) 2019, Orion Poplawski -# Copyright: (c) 2019, Frederic Bor -# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) - -from __future__ import (absolute_import, division, print_function) -__metaclass__ = type - - -def get_interface_by_display_name(self, name): - """ return interface_id by name """ - for interface in self.interfaces: - descr_elt = interface.find('descr') - if descr_elt is not None and descr_elt.text.strip().lower() == name.lower(): - return interface.tag - return None - - -def get_interface_by_port(self, name): - """ return interface_id by port (os name) """ - for interface in self.interfaces: - if interface.find('if').text.strip() == name: - return interface.tag - return None - - -def get_interface_display_name(self, interface_id, return_none=False): - """ return interface display name if found, otherwhise return the interface_id """ - if interface_id == 'enc0': - if return_none and not self.is_ipsec_enabled(): - return None - return 'IPsec' - if interface_id == 'openvpn': - if return_none and not self.is_openvpn_enabled(): - return None - return 'OpenVPN' - - for interface in self.interfaces: - if interface.tag == interface_id: - descr_elt = interface.find('descr') - if descr_elt is not None: - return descr_elt.text.strip() - break - - if return_none: - return None - return interface_id - - -def get_interface_elt(self, interface_id): - """ return interface """ - for interface in self.interfaces: - if interface.tag == interface_id: - return interface - return None - - -def get_interface_port(self, interface_id): - """ return interface port """ - for interface in self.interfaces: - if interface.tag == interface_id: - return interface.find('if').text.strip() - return None - - -def get_interface_port_by_display_name(self, name): - """ return interface port """ - for interface in self.interfaces: - descr_elt = interface.find('descr') - if descr_elt is not None and descr_elt.text.strip().lower() == name.lower(): - return interface.find('if').text.strip() - return None - - -def get_interfaces_networks(self): - """ return interface local networks """ - ret = [] - for interface in self.interfaces: - if interface.find('enable') is None: - continue - - ipaddr_elt = interface.find('ipaddr') - subnet_elt = interface.find('subnet') - if ipaddr_elt is not None and subnet_elt is not None and ipaddr_elt.text is not None and subnet_elt.text is not None: - ret.append('{0}/{1}'.format(ipaddr_elt.text, subnet_elt.text)) - - ipaddr_elt = interface.find('ipaddrv6') - subnet_elt = interface.find('subnetv6') - if ipaddr_elt is not None and subnet_elt is not None and ipaddr_elt.text is not None and subnet_elt.text is not None: - ret.append('{0}/{1}'.format(ipaddr_elt.text, subnet_elt.text)) - - # TODO: add vip networks - return ret - - -def is_interface_port(self, interface_port): - """ determines if arg is a pfsense interface port or not """ - for interface in self.interfaces: - interface_elt = interface.tag.strip() - if interface_elt == interface_port: - return True - return False - - -def is_interface_display_name(self, name): - """ determines if arg is an interface name or not """ - for interface in self.interfaces: - descr_elt = interface.find('descr') - if descr_elt is not None: - if descr_elt.text.strip().lower() == name.lower(): - return True - return False - - -def is_interface_group(self, name): - """ determines if arg is an interface group name or not """ - if self.ifgroups is not None: - for interface in self.ifgroups: - ifname_elt = interface.find('ifname') - if ifname_elt is not None: - # ifgroup names appear to be case sensitive - if ifname_elt.text.strip() == name: - return True - return False - - -def parse_interface(self, interface, fail=True, with_virtual=True, with_gwgroup=False): - """ validate param interface field """ - if with_virtual and (interface == 'enc0' or interface.lower() == 'ipsec') and self.is_ipsec_enabled(): - return 'enc0' - if with_virtual and (interface == 'openvpn' or interface.lower() == 'openvpn') and self.is_openvpn_enabled(): - return 'openvpn' - if with_gwgroup and self.is_gateway_group(interface): - return interface - - if self.is_interface_display_name(interface): - return self.get_interface_by_display_name(interface) - elif self.is_interface_port(interface): - return interface - elif self.is_interface_group(interface): - return interface - - if fail: - self.module.fail_json(msg='%s is not a valid interface' % (interface)) - return None diff --git a/plugins/module_utils/openvpn_override.py b/plugins/module_utils/openvpn_override.py index e099ed09..d15192c3 100644 --- a/plugins/module_utils/openvpn_override.py +++ b/plugins/module_utils/openvpn_override.py @@ -52,8 +52,6 @@ class PFSenseOpenVPNOverrideModule(PFSenseModuleBase): """ module managing pfSense OpenVPN Client Specific Overrides """ - from ansible_collections.pfsensible.core.plugins.module_utils.__impl.checks import validate_openvpn_tunnel_network - @staticmethod def get_argument_spec(): """ return argument spec """ diff --git a/plugins/module_utils/pfsense.py b/plugins/module_utils/pfsense.py index b9b5177c..b6e877a3 100644 --- a/plugins/module_utils/pfsense.py +++ b/plugins/module_utils/pfsense.py @@ -9,12 +9,20 @@ import sys if sys.version_info >= (3, 4): import html +try: + from ipaddress import ip_address, ip_network, IPv4Address, IPv6Address, IPv4Network, IPv6Network +except ImportError: + from ansible_collections.community.general.plugins.module_utils.compat.ipaddress import ( + ip_address, IPv4Address, IPv6Address, + ip_network, IPv4Network, IPv6Network + ) import json import shutil import os import pwd import random import re +import socket import time import xml.etree.ElementTree as ET from tempfile import mkstemp @@ -32,37 +40,6 @@ def xml_find(node, elt): class PFSenseModule(object): """ class managing pfsense base configuration """ - from ansible_collections.pfsensible.core.plugins.module_utils.__impl.interfaces import ( - get_interface_display_name, - get_interface_elt, - get_interface_port, - get_interface_port_by_display_name, - get_interface_by_display_name, - get_interface_by_port, - get_interfaces_networks, - is_interface_display_name, - is_interface_group, - is_interface_port, - parse_interface, - ) - from ansible_collections.pfsensible.core.plugins.module_utils.__impl.addresses import ( - is_ipv4_address, - is_ipv6_address, - is_ipv4_network, - is_ipv6_network, - is_ip_network, - is_within_local_networks, - parse_address, - parse_ip_network, - parse_port, - ) - from ansible_collections.pfsensible.core.plugins.module_utils.__impl.checks import ( - check_name, - check_ip_address, - validate_string, - validate_openvpn_tunnel_network, - ) - def __init__(self, module, config='/cf/conf/config.xml'): self.module = module self.config = config @@ -93,6 +70,346 @@ def _scrub(self): if elt.text is not None: elt.text = html.unescape(elt.text) + def get_interface_by_display_name(self, name): + """ return interface_id by name """ + for interface in self.interfaces: + descr_elt = interface.find('descr') + if descr_elt is not None and descr_elt.text.strip().lower() == name.lower(): + return interface.tag + return None + + def get_interface_by_port(self, name): + """ return interface_id by port (os name) """ + for interface in self.interfaces: + if interface.find('if').text.strip() == name: + return interface.tag + return None + + def get_interface_display_name(self, interface_id, return_none=False): + """ return interface display name if found, otherwhise return the interface_id """ + if interface_id == 'enc0': + return 'IPsec' + if interface_id == 'openvpn': + if return_none and not self.is_openvpn_enabled(): + return None + return 'OpenVPN' + + for interface in self.interfaces: + if interface.tag == interface_id: + descr_elt = interface.find('descr') + if descr_elt is not None: + return descr_elt.text.strip() + break + + if return_none: + return None + return interface_id + + def get_interface_elt(self, interface_id): + """ return interface """ + for interface in self.interfaces: + if interface.tag == interface_id: + return interface + return None + + def get_interface_port(self, interface_id): + """ return interface port """ + for interface in self.interfaces: + if interface.tag == interface_id: + return interface.find('if').text.strip() + return None + + def get_interface_port_by_display_name(self, name): + """ return interface port """ + for interface in self.interfaces: + descr_elt = interface.find('descr') + if descr_elt is not None and descr_elt.text.strip().lower() == name.lower(): + return interface.find('if').text.strip() + return None + + def get_interfaces_networks(self): + """ return interface local networks """ + ret = [] + for interface in self.interfaces: + if interface.find('enable') is None: + continue + + ipaddr_elt = interface.find('ipaddr') + subnet_elt = interface.find('subnet') + if ipaddr_elt is not None and subnet_elt is not None and ipaddr_elt.text is not None and subnet_elt.text is not None: + ret.append('{0}/{1}'.format(ipaddr_elt.text, subnet_elt.text)) + + ipaddr_elt = interface.find('ipaddrv6') + subnet_elt = interface.find('subnetv6') + if ipaddr_elt is not None and subnet_elt is not None and ipaddr_elt.text is not None and subnet_elt.text is not None: + ret.append('{0}/{1}'.format(ipaddr_elt.text, subnet_elt.text)) + + # TODO: add vip networks + return ret + + def is_interface_port(self, interface_port): + """ determines if arg is a pfsense interface port or not """ + for interface in self.interfaces: + interface_elt = interface.tag.strip() + if interface_elt == interface_port: + return True + return False + + def is_interface_display_name(self, name): + """ determines if arg is an interface name or not """ + for interface in self.interfaces: + descr_elt = interface.find('descr') + if descr_elt is not None: + if descr_elt.text.strip().lower() == name.lower(): + return True + return False + + def is_interface_group(self, name): + """ determines if arg is an interface group name or not """ + if self.ifgroups is not None: + for interface in self.ifgroups: + ifname_elt = interface.find('ifname') + if ifname_elt is not None: + # ifgroup names appear to be case sensitive + if ifname_elt.text.strip() == name: + return True + return False + + def parse_interface(self, interface, fail=True, with_virtual=True, with_gwgroup=False): + """ validate param interface field """ + if with_virtual and (interface == 'enc0' or interface.lower() == 'ipsec') and self.is_ipsec_enabled(): + return 'enc0' + if with_virtual and (interface == 'openvpn' or interface.lower() == 'openvpn') and self.is_openvpn_enabled(): + return 'openvpn' + if with_gwgroup and self.is_gateway_group(interface): + return interface + + if self.is_interface_display_name(interface): + return self.get_interface_by_display_name(interface) + elif self.is_interface_port(interface): + return interface + elif self.is_interface_group(interface): + return interface + + if fail: + self.module.fail_json(msg='%s is not a valid interface' % (interface)) + return None + + @staticmethod + def is_ipv4_address(address): + """ test if address is a valid ipv4 address """ + try: + addr = ip_address(u'{0}'.format(address)) + return isinstance(addr, IPv4Address) + except ValueError: + pass + return False + + @staticmethod + def is_ipv6_address(address): + """ test if address is a valid ipv6 address """ + try: + addr = ip_address(u'{0}'.format(address)) + return isinstance(addr, IPv6Address) + except ValueError: + pass + return False + + @staticmethod + def is_ipv4_network(address, strict=True): + """ test if address is a valid ipv4 network """ + try: + addr = ip_network(u'{0}'.format(address), strict=strict) + return isinstance(addr, IPv4Network) + except ValueError: + pass + return False + + @staticmethod + def is_ipv6_network(address, strict=True): + """ test if address is a valid ipv6 network """ + try: + addr = ip_network(u'{0}'.format(address), strict=strict) + return isinstance(addr, IPv6Network) + except ValueError: + pass + return False + + def is_ip_network(self, address, strict=True): + """ test if address is a valid ip network """ + return self.is_ipv4_network(address, strict) or self.is_ipv6_network(address, strict) + + def is_within_local_networks(self, address): + """ test if address is contained in our local networks """ + networks = self.get_interfaces_networks() + try: + addr = ip_address(u'{0}'.format(address)) + except ValueError: + return False + + for network in networks: + try: + net = ip_network(u'{0}'.format(network), strict=False) + if addr in net: + return True + except ValueError: + # ignore invalid networks, keep trying + pass + return False + + @staticmethod + def parse_ip_network(address, strict=True, returns_ip=True): + """ return cidr parts of address """ + try: + addr = ip_network(u'{0}'.format(address), strict=strict) + if strict or not returns_ip: + return (str(addr.network_address), addr.prefixlen) + else: + # we parse the address with ipaddr just for type checking + # but we use a regex to return the result as it dont kept the address bits + group = re.match(r'(.*)/(.*)', address) + if group: + return (group.group(1), group.group(2)) + except ValueError: + return None + return None + + def parse_address(self, param, allow_self=True): + """ validate param address field and returns it as a dict """ + if self.is_ipv6_address(param) or self.is_ipv6_network(param): + addr = [param] + else: + addr = param.split(':', maxsplit=3) + if len(addr) > 3: + self.module.fail_json(msg='Cannot parse address %s' % (param)) + + address = addr[0] + + ret = dict() + # Check if the first character is "!" + if address[0] == '!': + # Invert the rule + ret['not'] = None + address = address[1:] + + if address == 'NET' or address == 'IP': + interface = addr[1] if len(addr) > 1 else None + ports = addr[2] if len(addr) > 2 else None + if interface is None or interface == '': + self.module.fail_json(msg='Cannot parse address %s' % (param)) + + ret['network'] = self.parse_interface(interface) + if address == 'IP': + ret['network'] += 'ip' + else: + ports = addr[1] if len(addr) > 1 else None + if address == 'any': + ret['any'] = None + # rule with this firewall + elif allow_self and address == '(self)': + ret['network'] = '(self)' + # rule with interface name (LAN, WAN...) + elif self.is_interface_display_name(address): + ret['network'] = self.get_interface_by_display_name(address) + else: + if not self.is_ip_or_alias(address): + self.module.fail_json(msg='Cannot parse address %s, not IP or alias' % (address)) + ret['address'] = address + + if ports is not None: + self.parse_port(ports, ret) + msg = "the :ports syntax at end of addresses is deprecated and support will be removed soon. Please use source_port and destination_port options." + self.module.warn(msg) + + return ret + + def parse_port(self, src_ports, ret): + """ validate and parse port address field and set it in ret """ + ports = src_ports.split('-') + if len(ports) > 2 or ports[0] is None or ports[0] == '' or len(ports) == 2 and (ports[1] is None or ports[1] == ''): + self.module.fail_json(msg='Cannot parse port %s' % (src_ports)) + + if not self.is_port_or_alias(ports[0]): + self.module.fail_json(msg='Cannot parse port %s, not port number or alias' % (ports[0])) + ret['port'] = ports[0] + + if len(ports) > 1: + if not self.is_port_or_alias(ports[1]): + self.module.fail_json(msg='Cannot parse port %s, not port number or alias' % (ports[1])) + ret['port'] += '-' + ports[1] + + def check_name(self, name, objtype): + """ check name validity """ + + msg = None + if len(name) >= 32 or len(re.findall(r'(^_*$|^\d*$|[^a-zA-Z0-9_])', name)) > 0: + msg = f"The {objtype} name '{name}' must be less than 32 characters long, may not consist of only numbers, may not consist of only underscores, " + msg += "and may only contain the following characters: a-z, A-Z, 0-9, _" + elif name in ["port", "pass"]: + msg = f"The {objtype} name must not be either of the reserved words 'port' or 'pass'" + else: + try: + socket.getprotobyname(name) + msg = f"The {objtype} name must not be an IP protocol name such as TCP, UDP, ICMP etc." + except socket.error: + # If the protocol name lookup fails, the name is not a reserved protocol and is therefore allowed. + pass + + try: + socket.getservbyname(name) + msg = f"The {objtype} name must not be a well-known or registered TCP or UDP port name such as ssh, smtp, pop3, tftp, http, openvpn etc." + except socket.error: + # If the service name lookup fails, the name is not a reserved TCP/UDP service and is therefore allowed. + pass + + if msg is not None: + self.module.fail_json(msg=msg) + + def check_ip_address(self, address, ipprotocol, objtype, allow_networks=False, fail_ifnotip=False): + """ check address according to ipprotocol """ + if address is None: + return + if allow_networks: + ipv4 = self.is_ipv4_network(address, False) + ipv6 = self.is_ipv6_network(address, False) + else: + ipv4 = self.is_ipv4_address(address) + ipv6 = self.is_ipv6_address(address) + + if ipprotocol == 'inet': + if ipv6 or not ipv4 and fail_ifnotip: + self.module.fail_json(msg='{0} must use an IPv4 address'.format(objtype)) + elif ipprotocol == 'inet6': + if ipv4 or not ipv6 and fail_ifnotip: + self.module.fail_json(msg='{0} must use an IPv6 address'.format(objtype)) + elif ipprotocol == 'inet46': + if ipv4 or ipv6: + self.module.fail_json(msg='IPv4 and IPv6 addresses can not be used in objects that apply to both IPv4 and IPv6 (except within an alias).') + + def validate_openvpn_tunnel_network(self, network, ipproto): + """ check openvpn tunnel network validity - based on pfSense's openvpn_validate_tunnel_network() """ + if network is not None and network != '': + alias_elt = self.find_alias(network, aliastype='network') + if alias_elt is not None: + networks = alias_elt.find('address').text.split() + if len(networks) > 1: + self.module.fail_json("The alias {0} contains more than one network".format(network)) + network = networks[0] + + if not self.is_ipv4_network(network, strict=False) and ipproto == 'ipv4': + self.module.fail_json("{0} is not a valid IPv4 network".format(network)) + if not self.is_ipv6_network(network, strict=False) and ipproto == 'ipv6': + self.module.fail_json("{0} is not a valid IPv6 network".format(network)) + return True + + return True + + def validate_string(self, name, objtype): + """ check string validity - similar to pfSense's do_input_validate() """ + + if len(re.findall(r'[\000-\010\013\014\016-\037]', name)) > 0: + self.module.fail_json("The {0} name contains invalid characters.".format(objtype)) + @staticmethod def addr_normalize(addr): """ return address element formatted like module argument """ diff --git a/tests/unit/plugins/modules/test_pfsense_rule.py b/tests/unit/plugins/modules/test_pfsense_rule.py index 3a69cfc5..3277fc2b 100644 --- a/tests/unit/plugins/modules/test_pfsense_rule.py +++ b/tests/unit/plugins/modules/test_pfsense_rule.py @@ -12,15 +12,13 @@ from ansible_collections.pfsensible.core.plugins.modules import pfsense_rule from ansible_collections.pfsensible.core.plugins.module_utils.rule import PFSenseRuleModule -from ansible_collections.pfsensible.core.plugins.module_utils.__impl.addresses import is_ipv6_address, is_ipv6_network +from ansible_collections.pfsensible.core.plugins.module_utils.pfsense import PFSenseModule from .pfsense_module import TestPFSenseModule class TestPFSenseRuleModule(TestPFSenseModule): module = pfsense_rule - is_ipv6_address = is_ipv6_address - is_ipv6_network = is_ipv6_network def __init__(self, *args, **kwargs): super(TestPFSenseRuleModule, self).__init__(*args, **kwargs) @@ -34,7 +32,7 @@ def runTest(): def parse_address(self, addr): """ return address parsed in dict """ - if self.is_ipv6_address(addr) or self.is_ipv6_network(addr): + if PFSenseModule.is_ipv6_address(addr) or PFSenseModule.is_ipv6_network(addr): parts = [addr] else: parts = addr.split(':')