Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions app/ldap_protocol/policies/network/use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,38 +148,58 @@ async def update(
) -> NetworkPolicyDTO:
"""Update network policy."""
policy = await self._network_policy_gateway.get_with_for_update(dto.id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

может стоит внутри метода get_with_for_update сделать # NOTE: с небольшим описанием того, зачем используется .with_for_update() ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

расписать что делает алхимический with_for_update? ИМХО избыточно

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

нужна блокировка на запись(в целом with_for_update для этого и используют)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

я про то, зачем нам здесь нужна блокировка. может несколько запрсоов на изменение прийти?


await self._apply_field_updates(policy, dto)
await self._apply_netmask_updates(policy, dto)
await self._apply_group_updates(policy, dto)

if await self._network_policy_gateway.check_policy_exists(policy):
raise NetworkPolicyAlreadyExistsError("Entry already exists")

await self._session.commit()

return _convert_model_to_dto(policy)

async def _apply_field_updates(
self,
policy: NetworkPolicy,
dto: NetworkPolicyUpdateDTO,
) -> None:
"""Apply regular field updates."""
for field in dto.fields_to_update:
value = getattr(dto, field)
if value is not None:
setattr(policy, field, value)

async def _apply_netmask_updates(
self,
policy: NetworkPolicy,
dto: NetworkPolicyUpdateDTO,
) -> None:
"""Apply netmask updates."""
if dto.netmasks and dto.raw:
policy.netmasks = dto.netmasks
policy.raw = dto.raw

if (
dto.groups is not None
and len(dto.groups) > 0
and len(dto.groups) != 0
):
policy.groups = await self._network_policy_gateway.get_groups(
dto.groups,
async def _apply_group_updates(
self,
policy: NetworkPolicy,
dto: NetworkPolicyUpdateDTO,
) -> None:
"""Apply group updates."""
if dto.groups is not None:
policy.groups = (
await self._network_policy_gateway.get_groups(dto.groups)
if dto.groups
else []
)

if (
dto.mfa_groups is not None
and len(dto.mfa_groups) > 0
and len(dto.mfa_groups) != 0
):
policy.mfa_groups = await self._network_policy_gateway.get_groups(
dto.mfa_groups,
)
if await self._network_policy_gateway.check_policy_exists(policy):
raise NetworkPolicyAlreadyExistsError(
"Entry already exists",
if dto.mfa_groups is not None:
policy.mfa_groups = (
await self._network_policy_gateway.get_groups(dto.mfa_groups)
if dto.mfa_groups
else []
)
await self._session.commit()
return _convert_model_to_dto(policy)

async def swap_priorities(self, id1: int, id2: int) -> SwapPrioritiesDTO:
"""Swap priorities for network policies."""
Expand Down
2 changes: 1 addition & 1 deletion interface
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,15 @@ async def network_policy_gateway(
yield await container.get(NetworkPolicyGateway)


@pytest_asyncio.fixture(scope="function")
async def network_policy_use_case(
container: AsyncContainer,
) -> AsyncIterator[NetworkPolicyUseCase]:
"""Get network policy gateway."""
async with container(scope=Scope.REQUEST) as container:
yield await container.get(NetworkPolicyUseCase)


@pytest_asyncio.fixture(scope="function")
async def network_policy_validator(
container: AsyncContainer,
Expand Down
71 changes: 71 additions & 0 deletions tests/test_ldap/policies/test_network/test_use_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Test network policy use case with empty groups.

Copyright (c) 2025 MultiFactor
License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE
"""

from ipaddress import IPv4Network

import pytest

from enums import MFAFlags
from ldap_protocol.policies.network import NetworkPolicyUseCase
from ldap_protocol.policies.network.dto import (
NetworkPolicyDTO,
NetworkPolicyUpdateDTO,
)


@pytest.mark.asyncio
async def test_create_policy(
network_policy_use_case: NetworkPolicyUseCase,
) -> None:
"""Test creating policy with empty groups and mfa_groups."""
dto = NetworkPolicyDTO[None](
id=None,
name="Test Empty Groups",
netmasks=[IPv4Network("192.168.1.0/24")],
raw=["192.168.1.0/24"],
priority=2,
mfa_status=MFAFlags.DISABLED,
groups=[],
mfa_groups=[],
)

result = await network_policy_use_case.create(dto)
poicy = await network_policy_use_case.get(result.id)
assert poicy.groups == []
assert poicy.mfa_groups == []


@pytest.mark.asyncio
@pytest.mark.usefixtures("setup_session")
async def test_update_policy_to_empty_groups(
network_policy_use_case: NetworkPolicyUseCase,
) -> None:
"""Test updating policy from groups to empty."""
dto = NetworkPolicyDTO[None](
id=None,
name="Test Update Groups",
netmasks=[IPv4Network("172.16.0.0/12")],
raw=["172.16.0.0/12"],
priority=3,
mfa_status=MFAFlags.DISABLED,
groups=["cn=domain admins,cn=Groups,dc=md,dc=test"],
mfa_groups=["cn=domain admins,cn=Groups,dc=md,dc=test"],
)

created = await network_policy_use_case.create(dto)
assert created.groups
assert created.mfa_groups

update_dto = NetworkPolicyUpdateDTO(
id=created.id,
groups=[],
mfa_groups=[],
)

updated = await network_policy_use_case.update(update_dto)

assert updated.groups == []
assert updated.mfa_groups == []