diff --git a/app/ldap_protocol/policies/network/use_cases.py b/app/ldap_protocol/policies/network/use_cases.py index cde4294d6..38dbc11cc 100644 --- a/app/ldap_protocol/policies/network/use_cases.py +++ b/app/ldap_protocol/policies/network/use_cases.py @@ -148,38 +148,58 @@ async def update( ) -> NetworkPolicyDTO: """Update network policy.""" policy = await self._network_policy_gateway.get_with_for_update(dto.id) + + await self._apply_field_updates(policy, dto) + await self._apply_netmask_updates(policy, dto) + await self._apply_group_updates(policy, dto) + + if await self._network_policy_gateway.check_policy_exists(policy): + raise NetworkPolicyAlreadyExistsError("Entry already exists") + + await self._session.commit() + + return _convert_model_to_dto(policy) + + async def _apply_field_updates( + self, + policy: NetworkPolicy, + dto: NetworkPolicyUpdateDTO, + ) -> None: + """Apply regular field updates.""" for field in dto.fields_to_update: value = getattr(dto, field) if value is not None: setattr(policy, field, value) + async def _apply_netmask_updates( + self, + policy: NetworkPolicy, + dto: NetworkPolicyUpdateDTO, + ) -> None: + """Apply netmask updates.""" if dto.netmasks and dto.raw: policy.netmasks = dto.netmasks policy.raw = dto.raw - if ( - dto.groups is not None - and len(dto.groups) > 0 - and len(dto.groups) != 0 - ): - policy.groups = await self._network_policy_gateway.get_groups( - dto.groups, + async def _apply_group_updates( + self, + policy: NetworkPolicy, + dto: NetworkPolicyUpdateDTO, + ) -> None: + """Apply group updates.""" + if dto.groups is not None: + policy.groups = ( + await self._network_policy_gateway.get_groups(dto.groups) + if dto.groups + else [] ) - if ( - dto.mfa_groups is not None - and len(dto.mfa_groups) > 0 - and len(dto.mfa_groups) != 0 - ): - policy.mfa_groups = await self._network_policy_gateway.get_groups( - dto.mfa_groups, - ) - if await self._network_policy_gateway.check_policy_exists(policy): - raise NetworkPolicyAlreadyExistsError( - "Entry already exists", + if dto.mfa_groups is not None: + policy.mfa_groups = ( + await self._network_policy_gateway.get_groups(dto.mfa_groups) + if dto.mfa_groups + else [] ) - await self._session.commit() - return _convert_model_to_dto(policy) async def swap_priorities(self, id1: int, id2: int) -> SwapPrioritiesDTO: """Swap priorities for network policies.""" diff --git a/interface b/interface index 95ed5e191..f31962020 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit 95ed5e191cdafa07b1dfac96a1659926679ead97 +Subproject commit f31962020a6689e6a4c61fb3349db5b5c7895f92 diff --git a/tests/conftest.py b/tests/conftest.py index c9ba0f8ff..4f95140f8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1064,6 +1064,15 @@ async def network_policy_gateway( yield await container.get(NetworkPolicyGateway) +@pytest_asyncio.fixture(scope="function") +async def network_policy_use_case( + container: AsyncContainer, +) -> AsyncIterator[NetworkPolicyUseCase]: + """Get network policy gateway.""" + async with container(scope=Scope.REQUEST) as container: + yield await container.get(NetworkPolicyUseCase) + + @pytest_asyncio.fixture(scope="function") async def network_policy_validator( container: AsyncContainer, diff --git a/tests/test_ldap/policies/test_network/test_use_case.py b/tests/test_ldap/policies/test_network/test_use_case.py new file mode 100644 index 000000000..d9714f881 --- /dev/null +++ b/tests/test_ldap/policies/test_network/test_use_case.py @@ -0,0 +1,71 @@ +"""Test network policy use case with empty groups. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from ipaddress import IPv4Network + +import pytest + +from enums import MFAFlags +from ldap_protocol.policies.network import NetworkPolicyUseCase +from ldap_protocol.policies.network.dto import ( + NetworkPolicyDTO, + NetworkPolicyUpdateDTO, +) + + +@pytest.mark.asyncio +async def test_create_policy( + network_policy_use_case: NetworkPolicyUseCase, +) -> None: + """Test creating policy with empty groups and mfa_groups.""" + dto = NetworkPolicyDTO[None]( + id=None, + name="Test Empty Groups", + netmasks=[IPv4Network("192.168.1.0/24")], + raw=["192.168.1.0/24"], + priority=2, + mfa_status=MFAFlags.DISABLED, + groups=[], + mfa_groups=[], + ) + + result = await network_policy_use_case.create(dto) + poicy = await network_policy_use_case.get(result.id) + assert poicy.groups == [] + assert poicy.mfa_groups == [] + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_session") +async def test_update_policy_to_empty_groups( + network_policy_use_case: NetworkPolicyUseCase, +) -> None: + """Test updating policy from groups to empty.""" + dto = NetworkPolicyDTO[None]( + id=None, + name="Test Update Groups", + netmasks=[IPv4Network("172.16.0.0/12")], + raw=["172.16.0.0/12"], + priority=3, + mfa_status=MFAFlags.DISABLED, + groups=["cn=domain admins,cn=Groups,dc=md,dc=test"], + mfa_groups=["cn=domain admins,cn=Groups,dc=md,dc=test"], + ) + + created = await network_policy_use_case.create(dto) + assert created.groups + assert created.mfa_groups + + update_dto = NetworkPolicyUpdateDTO( + id=created.id, + groups=[], + mfa_groups=[], + ) + + updated = await network_policy_use_case.update(update_dto) + + assert updated.groups == [] + assert updated.mfa_groups == []