diff --git a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py index 947bbba5c..a9e38f219 100644 --- a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py +++ b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py @@ -148,9 +148,9 @@ async def _add_primary_group_id(connection: AsyncConnection) -> None: # noqa: A query = ( select(Directory) .options( - selectinload(qa(Directory.groups)).selectinload( - qa(Group.directory), - ), + selectinload(qa(Directory.groups)) + .selectinload(qa(Group.directory)) + .selectinload(qa(Directory.attributes)), ) .where( qa(Directory.entity_type_id).in_(entity_type_ids), diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py new file mode 100644 index 000000000..db6546bcf --- /dev/null +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -0,0 +1,485 @@ +"""Add rIDManager and rIDSet objectClasses to LDAP schema. + +Revision ID: 552b4eafb1aa +Revises: 1b71cafba681 +Create Date: 2026-02-17 09:24:57.906080 + +""" + +import secrets + +import sqlalchemy as sa +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy import delete, select, update +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + +from constants import ( + COMPUTERS_CONTAINER_NAME, + CONFIGURATION_DIR_NAME, + DOMAIN_ADMIN_GROUP_NAME, + DOMAIN_COMPUTERS_GROUP_NAME, + DOMAIN_CONTROLLERS_OU_NAME, + DOMAIN_USERS_GROUP_NAME, + GROUPS_CONTAINER_NAME, + READ_ONLY_GROUP_NAME, + SYSTEM_CONTAINER_NAME, + USERS_CONTAINER_NAME, +) +from entities import Attribute, Directory, EntityType +from enums import EntityTypeNames, SecurityPrincipalRid +from ldap_protocol.ldap_schema.dto import EntityTypeDTO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.rid_manager import ( + RIDManagerGateway, + RIDManagerSetupGateway, + RIDManagerSetupUseCase, + RIDManagerUseCase, + RIDSetUseCase, +) +from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerNotFoundError, + RIDManagerRidSetNotFoundError, +) +from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway +from ldap_protocol.rid_manager.utils import from_qword, to_qword +from ldap_protocol.roles.role_use_case import RoleUseCase +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa + +# revision identifiers, used by Alembic. +revision: None | str = "552b4eafb1aa" +down_revision: None | str = "1b71cafba681" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +async def _directory_ids_skipped_for_object_sid_migration( + session: AsyncSession, + domain: Directory, +) -> set[int]: + """Directory ids for which objectSid is not copied into Attributes. + + Top-level peer containers (System, OU DC, Users, Computers, Groups) and + the full subtree under ``Configuration``. + """ + peer_container_names = ( + SYSTEM_CONTAINER_NAME, + DOMAIN_CONTROLLERS_OU_NAME, + USERS_CONTAINER_NAME, + COMPUTERS_CONTAINER_NAME, + GROUPS_CONTAINER_NAME, + ) + peer_rows = await session.scalars( + select(qa(Directory.id)).where( + qa(Directory.parent_id) == domain.id, + qa(Directory.name).in_(peer_container_names), + ), + ) + skip_ids: set[int] = set(peer_rows.all()) + configuration_id = await session.scalar( + select(qa(Directory.id)).where( + qa(Directory.parent_id) == domain.id, + qa(Directory.name) == CONFIGURATION_DIR_NAME, + ), + ) + if configuration_id is None: + return skip_ids + + subtree = ( + select(qa(Directory.id)) + .where(qa(Directory.id) == configuration_id) + .cte(name="subtree", recursive=True) + ) + subtree = subtree.union_all( + select(qa(Directory.id)).where( + qa(Directory.parent_id) == subtree.c.id, + ), + ) + cfg_rows = await session.execute(select(subtree.c.id)) + skip_ids |= {row[0] for row in cfg_rows.all()} + return skip_ids + + +def upgrade(container: AsyncContainer) -> None: # noqa: C901 + """Add rIDManager and rIDSet objectClasses to LDAP schema.""" + + async def _create_entity_types( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Create rIDManager and rIDSet Entity Types.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + entity_type_use_case = await cnt.get(EntityTypeUseCase) + + if not await get_base_directories(session): + return + + await entity_type_use_case.create( + EntityTypeDTO( + name=EntityTypeNames.RID_MANAGER, + object_class_names=[ + "top", + "rIDManager", + ], + is_system=True, + ), + ) + + await entity_type_use_case.create( + EntityTypeDTO( + name=EntityTypeNames.RID_SET, + object_class_names=[ + "top", + "rIDSet", + ], + is_system=True, + ), + ) + + await session.commit() + + op.run_async(_create_entity_types) + + async def _migrate_object_sids( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Move Directory.objectSid values into Attributes table. + + Add ``DomainIdentifier`` on the domain (from ``Directory.objectSid`` + column when present). Do not store domain ``objectSid`` in Attributes. + Normalize built-in group / administrator SIDs once. + """ + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + base_dn_list = await get_base_directories(session) + if not base_dn_list: + return + domain = base_dn_list[0] + + skip_object_sid_ids = ( + await _directory_ids_skipped_for_object_sid_migration( + session, + domain, + ) + ) + + directory_table = sa.table( + "Directory", + sa.column("id", sa.Integer), + sa.column("parentId", sa.Integer), + sa.column("objectSid", sa.String), + ) + + domain_sid_from_column = await session.scalar( + select(directory_table.c.objectSid).where( + directory_table.c.id == domain.id, + ), + ) + + identifier: str | None = None + if domain_sid_from_column: + parts = domain_sid_from_column.split("-") + # "S-1-5-21-AAA-BBB-CCC" -> "AAA-BBB-CCC" + if len(parts) >= 7 and domain_sid_from_column.startswith( + "S-1-5-21-", + ): + identifier = "-".join(parts[4:7]) + + if identifier is None: + identifier = ( + f"{secrets.randbits(32)}-" + f"{secrets.randbits(32)}-" + f"{secrets.randbits(32)}" + ) + + session.add( + Attribute( + name="DomainIdentifier", + value=identifier, + directory_id=domain.id, + ), + ) + result = ( + await session.execute( + select( + directory_table.c.id, + directory_table.c.parentId, + directory_table.c.objectSid, + ), + ) + ).all() + for directory_id, parent_id, object_sid in result: + if not object_sid: + continue + if parent_id is None: + continue + if directory_id in skip_object_sid_ids: + continue + + session.add( + Attribute( + name="objectSid", + value=object_sid, + directory_id=directory_id, + ), + ) + + built_in_sid_prefix = "S-1-5-32" + for dir_name, rid in ( + (DOMAIN_ADMIN_GROUP_NAME, SecurityPrincipalRid.DOMAIN_ADMINS), + (DOMAIN_USERS_GROUP_NAME, SecurityPrincipalRid.DOMAIN_USERS), + ( + DOMAIN_COMPUTERS_GROUP_NAME, + SecurityPrincipalRid.DOMAIN_COMPUTERS, + ), + (READ_ONLY_GROUP_NAME, SecurityPrincipalRid.DOMAIN_READ_ONLY), + ): + await session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "objectSid", + qa(Attribute.directory_id).in_( + select(qa(Directory.id)).where( + qa(Directory.name) == dir_name, + ), + ), + ) + .values( + value=f"{built_in_sid_prefix}-{int(rid)}", + ), + ) + + await session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "objectSid", + qa(Attribute.value).like( + f"S-1-5-21-%-{int(SecurityPrincipalRid.ADMINISTRATOR)}", + ), + ) + .values( + value=( + f"{built_in_sid_prefix}" + f"-{int(SecurityPrincipalRid.ADMINISTRATOR)}" + ), + ), + ) + + await session.commit() + + op.run_async(_migrate_object_sids) + + async def _init_rid_manager( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Initialize RID Manager and RID Set for existing data.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + rid_setup_gateway = await cnt.get(RIDManagerSetupGateway) + rid_gateway = await cnt.get(RIDManagerGateway) + rid_manager_use_case = await cnt.get(RIDManagerUseCase) + rid_set_gateway = await cnt.get(RIDSetGateway) + rid_set_use_case = await cnt.get(RIDSetUseCase) + role_use_case = await cnt.get(RoleUseCase) + + if not await get_base_directories(session): + return + + try: + rid_manager_dir = await rid_gateway.get_rid_manager() + except RIDManagerNotFoundError: + rid_manager_dir = await rid_setup_gateway.set_rid_manager() + + base_dn_list = await get_base_directories(session) + if not base_dn_list: + return + domain = base_dn_list[0] + + domain_identifier = await session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == domain.id, + qa(Attribute.name) == "DomainIdentifier", + ), + ) + if not (domain_identifier and domain_identifier.value): + return + + sid_prefix = f"S-1-5-21-{domain_identifier.value}-" + + sid_values = await session.scalars( + select(Attribute).where( + qa(Attribute.name) == "objectSid", + qa(Attribute.value).like(f"{sid_prefix}%"), + ), + ) + + max_rid = 0 + for sid_value in sid_values: + if not sid_value or not sid_value.value: + continue + try: + parts = sid_value.value.split("-") + rid = int(parts[-1]) + except (ValueError, IndexError): + continue + if rid > max_rid: + max_rid = rid + + start_rid = max(max_rid, RIDManagerSetupUseCase.RID_MIN) + + qword = to_qword(start_rid, RIDManagerSetupUseCase.RID_AVAILABLE_MAX) + + await rid_setup_gateway.set_rid_available_pool(rid_manager_dir, qword) + + system_container = await rid_setup_gateway.get_system_container() + await role_use_case.inherit_parent_aces( + parent_directory=system_container, + directory=rid_manager_dir, + ) + + domain_controller = await rid_gateway.get_domain_controller() + rid_set_dir: Directory | None = None + try: + rid_set_dir = await rid_set_gateway.get(domain_controller) + except RIDManagerRidSetNotFoundError: + rid_set_dir = None + + if rid_set_dir is None: + previous_allocation_pool = ( + await rid_manager_use_case.allocate_pool() + ) + allocation_pool = await rid_manager_use_case.allocate_pool() + lower, _ = from_qword(previous_allocation_pool) + + rid_set_dir = await rid_set_use_case.add( + domain_controller, + RIDSetAllocationParamsDTO( + next_rid=lower, + allocation_pool=allocation_pool, + previous_allocation_pool=previous_allocation_pool, + ), + ) + + await session.commit() + return + + await session.commit() + + op.run_async(_init_rid_manager) + + op.drop_column("Directory", "objectSid") + + +def downgrade(container: AsyncContainer) -> None: + """Remove rIDManager and rIDSet objectClasses from LDAP schema.""" + op.add_column( + "Directory", + sa.Column("objectSid", sa.String(), nullable=True), + ) + + async def _delete_entity_types( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Delete rIDManager and rIDSet Entity Types.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + if not await get_base_directories(session): + return + + await session.execute( + delete(EntityType).where( + qa(EntityType.name).in_( + [ + EntityTypeNames.RID_MANAGER, + EntityTypeNames.RID_SET, + ], + ), + ), + ) + + await session.commit() + + op.run_async(_delete_entity_types) + + async def _delete_rid_manager_dirs( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Delete RID Manager and RID Set directories.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + if not await get_base_directories(session): + return + + await session.execute( + delete(Directory).where( + qa(Directory.name).in_( + [ + "RID Manager$", + "RID Set", + ], + ), + ), + ) + await session.commit() + + op.run_async(_delete_rid_manager_dirs) + + async def _rollback_object_sids( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Restore Directory.objectSid values from Attributes. + + Also removes the DomainIdentifier attribute that was introduced in + upgrade for domain directories. + """ + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + directory_table = sa.table( + "Directory", + sa.column("id", sa.Integer), + sa.column("objectSid", sa.String), + ) + + result = await session.execute(select(directory_table.c.id)) + + for (directory_id,) in result: + await session.execute( + delete(Attribute).where( + qa(Attribute.directory_id) == directory_id, + qa(Attribute.name) == "DomainIdentifier", + ), + ) + + attr = await session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == directory_id, + qa(Attribute.name) == "objectSid", + ), + ) + + if not attr or not attr.value: + continue + + await session.execute( + update(directory_table) + .where(directory_table.c.id == directory_id) + .values(objectSid=attr.value), + ) + + await session.execute( + delete(Attribute).where( + qa(Attribute.directory_id) == directory_id, + qa(Attribute.name) == "objectSid", + ), + ) + + await session.commit() + + op.run_async(_rollback_object_sids) diff --git a/app/alembic/versions/6f8fe2548893_fix_read_only.py b/app/alembic/versions/6f8fe2548893_fix_read_only.py index f28264704..2a5f51ef4 100644 --- a/app/alembic/versions/6f8fe2548893_fix_read_only.py +++ b/app/alembic/versions/6f8fe2548893_fix_read_only.py @@ -31,6 +31,12 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 bind = op.get_bind() session = Session(bind=bind) + directory_table = sa.table( + "Directory", + sa.column("id", sa.Integer), + sa.column("objectSid", sa.String), + ) + ro_dir = session.scalar( select(Directory) .filter_by(name="readonly domain controllers"), @@ -82,8 +88,18 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 ), ) - domain_sid = "-".join(ro_dir.object_sid.split("-")[:-1]) - ro_dir.object_sid = domain_sid + "-521" + ro_object_sid = session.scalar( + select(directory_table.c.objectSid).where( + directory_table.c.id == ro_dir.id, + ), + ) + if ro_object_sid: + domain_sid = "-".join(ro_object_sid.split("-")[:-1]) + session.execute( + update(directory_table) + .where(directory_table.c.id == ro_dir.id) + .values(objectSid=domain_sid + "-521"), + ) session.commit() diff --git a/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py b/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py index e8d480d94..1f9f4d9d4 100644 --- a/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py +++ b/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py @@ -11,6 +11,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession +from constants import SYSTEM_CONTAINER_NAME from entities import Attribute, Directory from repo.pg.tables import queryable_attr as qa @@ -111,7 +112,7 @@ async def _rename_system_to_services(connection: AsyncConnection) -> None: # no system_dir = await session.scalar( select(Directory).where( - qa(Directory.name) == "System", + qa(Directory.name) == SYSTEM_CONTAINER_NAME, qa(Directory.is_system).is_(True), ), ) diff --git a/app/api/main/schema.py b/app/api/main/schema.py index 5ea6545a8..bc559de4d 100644 --- a/app/api/main/schema.py +++ b/app/api/main/schema.py @@ -38,7 +38,7 @@ def _cast_filter(self) -> UnaryExpression | ColumnElement: ) @staticmethod - def get_directory_sid(directory: Directory) -> str: # type: ignore + def get_directory_sid(directory: Directory) -> str | None: # type: ignore return directory.object_sid @staticmethod diff --git a/app/constants.py b/app/constants.py index a6192f314..5bece61a8 100644 --- a/app/constants.py +++ b/app/constants.py @@ -4,13 +4,14 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from enums import EntityTypeNames, SamAccountTypeCodes +from enums import EntityTypeNames, SamAccountTypeCodes, SecurityPrincipalRid from ldap_protocol.ldap_schema.dto import EntityTypeDTO CONFIGURATION_DIR_NAME = "Configuration" GROUPS_CONTAINER_NAME = "Groups" COMPUTERS_CONTAINER_NAME = "Computers" USERS_CONTAINER_NAME = "Users" +SYSTEM_CONTAINER_NAME = "System" DOMAIN_CONTROLLERS_OU_NAME = "Domain Controllers" READ_ONLY_GROUP_NAME = "read-only" @@ -324,6 +325,15 @@ "object_class": "container", "attributes": {"objectClass": ["top", "configuration"]}, }, + { + "name": SYSTEM_CONTAINER_NAME, + "entity_type_name": EntityTypeNames.ORGANIZATIONAL_UNIT, + "object_class": "organizationalUnit", + "attributes": { + "objectClass": ["top", "container"], + }, + "children": [], + }, { "name": GROUPS_CONTAINER_NAME, "entity_type_name": EntityTypeNames.CONTAINER, @@ -347,7 +357,7 @@ ], "gidNumber": ["512"], }, - "objectSid": 512, + "objectSid": SecurityPrincipalRid.DOMAIN_ADMINS, }, { "name": DOMAIN_USERS_GROUP_NAME, @@ -363,7 +373,7 @@ ], "gidNumber": ["513"], }, - "objectSid": 513, + "objectSid": SecurityPrincipalRid.DOMAIN_USERS, }, { "name": READ_ONLY_GROUP_NAME, @@ -379,7 +389,7 @@ ], "gidNumber": ["521"], }, - "objectSid": 521, + "objectSid": SecurityPrincipalRid.DOMAIN_READ_ONLY, }, { "name": DOMAIN_COMPUTERS_GROUP_NAME, @@ -395,7 +405,7 @@ ], "gidNumber": ["515"], }, - "objectSid": 515, + "objectSid": SecurityPrincipalRid.DOMAIN_COMPUTERS, }, ], }, diff --git a/app/entities.py b/app/entities.py index 5a1ec2adb..dee8fea63 100644 --- a/app/entities.py +++ b/app/entities.py @@ -101,7 +101,6 @@ class Directory: id: int = field(init=False) name: str is_system: bool = field(default=False) - object_sid: str = field(default="") object_guid: uuid.UUID = field(default_factory=uuid.uuid4) parent_id: int | None = None entity_type_id: int | None = None @@ -140,7 +139,6 @@ class Directory: search_fields: ClassVar[dict[str, str]] = { "name": "name", "objectguid": "objectGUID", - "objectsid": "objectSid", } ro_fields: ClassVar[set[str]] = { "uid", @@ -184,15 +182,18 @@ def create_path( self.depth = len(self.path) self.rdname = dn + @property + def object_sid(self) -> str: + """Get objectSid attribute value.""" + return self.attributes_dict.get("objectSid", [""])[0] + @property def relative_id(self) -> str: - """Get RID from objectSid. + """Get RID from objectSid attribute. Relative Identifier (RID) is the last sub-authority value of a SID. """ - if "-" in self.object_sid: - return self.object_sid.split("-")[-1] - return "" + return self.object_sid.split("-")[-1] if self.object_sid else "" @property def attributes_dict(self) -> defaultdict[str, list[str]]: diff --git a/app/enums.py b/app/enums.py index e9bdd8f1f..cd2b786cb 100644 --- a/app/enums.py +++ b/app/enums.py @@ -72,6 +72,8 @@ class EntityTypeNames(StrEnum): KRB_CONTAINER = "KRB Container" KRB_PRINCIPAL = "KRB Principal" KRB_REALM_CONTAINER = "KRB Realm Container" + RID_MANAGER = "RID Manager" + RID_SET = "RID Set" class KindType(StrEnum): @@ -279,3 +281,19 @@ class SamAccountTypeCodes(IntEnum): def to_hex(value: int) -> str: """Convert decimal value to hex string.""" return hex(value) + + +class SidPrefix(StrEnum): + """SID prefix.""" + + DOMAIN_IDENTIFIER = "S-1-5-21" + BUILT_IN_DOMAIN = "S-1-5-32" + + +class SecurityPrincipalRid(IntEnum): + ADMINISTRATOR = 500 + GUESTS = 501 + DOMAIN_ADMINS = 512 + DOMAIN_USERS = 513 + DOMAIN_COMPUTERS = 515 + DOMAIN_READ_ONLY = 521 diff --git a/app/extra/scripts/add_domain_controller.py b/app/extra/scripts/add_domain_controller.py index 21d3bbaed..a04c9be9c 100644 --- a/app/extra/scripts/add_domain_controller.py +++ b/app/extra/scripts/add_domain_controller.py @@ -16,9 +16,8 @@ EntityTypeUseCase, ) from ldap_protocol.objects import UserAccountControlFlag +from ldap_protocol.rid_manager import ObjectSIDUseCase, RIDSetUseCase from ldap_protocol.roles.role_use_case import RoleUseCase -from ldap_protocol.utils.helpers import create_object_sid -from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -27,8 +26,9 @@ async def _add_domain_controller( role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, settings: Settings, - domain: Directory, dc_ou_dir: Directory, + object_sid_use_case: ObjectSIDUseCase, + rid_set_use_case: RIDSetUseCase, ) -> None: dc_directory = Directory( object_class="", @@ -40,9 +40,17 @@ async def _add_domain_controller( await session.flush() dc_directory.parent_id = dc_ou_dir.id - dc_directory.object_sid = create_object_sid(domain, dc_directory.id) await session.flush() + await rid_set_use_case.add( + domain_controller=dc_directory, + allocation_params=await rid_set_use_case.generate_rid_set_attrs(), + ) + await session.flush() + await object_sid_use_case.add( + directory_id=dc_directory.id, + ) + attributes = [ Attribute( name="objectClass", @@ -103,14 +111,11 @@ async def add_domain_controller( settings: Settings, role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, + object_sid_use_case: ObjectSIDUseCase, + rid_set_use_case: RIDSetUseCase, ) -> None: logger.info("Adding domain controller.") - domains = await get_base_directories(session) - if not domains: - logger.debug("Cannot get base directory") - return - domain_controllers_ou = await session.scalar( select(Directory).where( qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME, @@ -140,8 +145,9 @@ async def add_domain_controller( role_use_case=role_use_case, entity_type_use_case=entity_type_use_case, settings=settings, - domain=domains[0], dc_ou_dir=domain_controllers_ou, + object_sid_use_case=object_sid_use_case, + rid_set_use_case=rid_set_use_case, ) logger.debug("Domain controller added.") diff --git a/app/ioc.py b/app/ioc.py index 019a2f6f3..aec45105d 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -178,6 +178,16 @@ PasswordBanWordUseCases, UserPasswordHistoryUseCases, ) +from ldap_protocol.rid_manager import ( + ObjectSIDGateway, + ObjectSIDUseCase, + RIDManagerGateway, + RIDManagerSetupGateway, + RIDManagerSetupUseCase, + RIDManagerUseCase, + RIDSetGateway, + RIDSetUseCase, +) from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.migrations_ace_dao import ( @@ -641,6 +651,21 @@ def get_object_class_use_case_legacy( rootdse_reader = provide(RootDSEReader, scope=Scope.REQUEST) dcinfo_reader = provide(DCInfoReader, scope=Scope.REQUEST) + rid_manager_gateway = provide(RIDManagerGateway, scope=Scope.REQUEST) + rid_manager_setup_gateway = provide( + RIDManagerSetupGateway, + scope=Scope.REQUEST, + ) + rid_manager_use_case = provide(RIDManagerUseCase, scope=Scope.REQUEST) + rid_manager_setup_use_case = provide( + RIDManagerSetupUseCase, + scope=Scope.REQUEST, + ) + object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST) + object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST) + rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST) + rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST) + class LDAPContextProvider(Provider): """Context provider.""" diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index 9d79c80f8..8f4dbf009 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -12,7 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import Attribute, Directory, Group, NetworkPolicy, User -from enums import EntityTypeNames +from enums import EntityTypeNames, SidPrefix from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, ) @@ -20,8 +20,8 @@ from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) +from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.utils.async_cache import base_directories_cache -from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid from ldap_protocol.utils.queries import get_domain_object_class from password_utils import PasswordUtils from repo.pg.tables import queryable_attr as qa @@ -37,6 +37,7 @@ def __init__( entity_type_use_case: EntityTypeUseCase, attribute_value_validator: AttributeValueValidator, directory_dao: DirectoryDAO, + object_sid_use_case: ObjectSIDUseCase, ) -> None: """Initialize Setup use case. @@ -49,6 +50,7 @@ def __init__( self._entity_type_use_case = entity_type_use_case self._attribute_value_validator = attribute_value_validator self._directory_dao = directory_dao + self._object_sid_use_case = object_sid_use_case async def is_setup(self) -> bool: """Check if setup is performed. @@ -67,23 +69,10 @@ async def setup_enviroment( *, data: list, is_system: bool = True, - dn: str = "multifactor.dev", + domain: Directory, ) -> None: """Create directories and users for enviroment.""" - cat_result = await self._session.execute(select(Directory)) - if cat_result.scalar_one_or_none(): - logger.warning("dev data already set up") - return - - domain = Directory(name=dn, object_class="domain") - domain.is_system = True - domain.object_sid = generate_domain_sid() - domain.path = [f"dc={path}" for path in reversed(dn.split("."))] - domain.depth = len(domain.path) - domain.rdname = "" - async with self._session.begin_nested(): - self._session.add(domain) self._session.add( NetworkPolicy( name="Default open policy", @@ -132,6 +121,20 @@ async def setup_enviroment( logger.error(traceback.format_exc()) raise + async def create_base_domain( + self, + dn: str = "multifactor.dev", + ) -> Directory: + """Create base domain.""" + domain = Directory(name=dn, object_class="domain") + domain.is_system = True + domain.path = [f"dc={path}" for path in reversed(dn.split("."))] + domain.depth = len(domain.path) + domain.rdname = "" + self._session.add(domain) + await self._session.flush() + return domain + async def create_dir( self, data: dict, @@ -161,11 +164,12 @@ async def create_dir( ), ) - dir_.object_sid = create_object_sid( - domain, - rid=data.get("objectSid", dir_.id), - reserved="objectSid" in data, - ) + if "objectSid" in data: + await self._object_sid_use_case.add( + directory_id=dir_.id, + rid=int(data["objectSid"]), + sid_prefix=SidPrefix.BUILT_IN_DOMAIN, + ) if dir_.object_class == "group": group = Group(directory_id=dir_.id) diff --git a/app/ldap_protocol/auth/use_cases.py b/app/ldap_protocol/auth/use_cases.py index b6691df00..39abddd9d 100644 --- a/app/ldap_protocol/auth/use_cases.py +++ b/app/ldap_protocol/auth/use_cases.py @@ -19,7 +19,7 @@ FIRST_SETUP_DATA, USERS_CONTAINER_NAME, ) -from enums import EntityTypeNames, SamAccountTypeCodes +from enums import EntityTypeNames, SamAccountTypeCodes, SecurityPrincipalRid from ldap_protocol.auth.dto import SetupDTO from ldap_protocol.auth.setup_gateway import SetupGateway from ldap_protocol.identity.exceptions import ( @@ -44,6 +44,11 @@ from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases +from ldap_protocol.rid_manager import ( + ObjectSIDUseCase, + RIDManagerSetupUseCase, + RIDManagerUseCase, +) from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.helpers import create_integer_hash, ft_now @@ -64,6 +69,9 @@ def __init__( audit_use_case: AuditUseCase, session: AsyncSession, settings: Settings, + rid_manager_setup_use_case: RIDManagerSetupUseCase, + rid_manager_use_case: RIDManagerUseCase, + object_sid_use_case: ObjectSIDUseCase, ) -> None: """Initialize Setup manager. @@ -82,6 +90,9 @@ def __init__( self._object_class_use_case_legacy = object_class_use_case_legacy self._object_class_use_case = object_class_use_case self._settings = settings + self._rid_manager_setup_use_case = rid_manager_setup_use_case + self._rid_manager_use_case = rid_manager_use_case + self._object_sid_use_case = object_sid_use_case async def setup(self, dto: SetupDTO) -> None: """Perform the initial setup of structure and policies. @@ -186,7 +197,7 @@ def _create_user_data(self, dto: SetupDTO) -> dict: str(SamAccountTypeCodes.SAM_USER_OBJECT), ], }, - "objectSid": 500, + "objectSid": SecurityPrincipalRid.ADMINISTRATOR, }, ], } @@ -199,10 +210,14 @@ async def _create(self, dto: SetupDTO, data: list) -> None: :return: None. """ try: + domain = await self._setup_gateway.create_base_domain(dto.domain) + await self._rid_manager_setup_use_case.create_domain_identifier( + domain.id, + ) await self._setup_gateway.setup_enviroment( data=data, - dn=dto.domain, is_system=True, + domain=domain, ) attrs = await self._attribute_type_use_case_legacy.get_all() @@ -237,6 +252,12 @@ async def _create(self, dto: SetupDTO, data: list) -> None: await self._role_use_case.create_domain_admins_role() await self._role_use_case.create_read_only_role() await self._audit_use_case.create_policies() + await self._rid_manager_setup_use_case.setup() + dc = await self._rid_manager_use_case.get_domain_controller() + await self._object_sid_use_case.add( + directory_id=dc.id, + ) + await self._session.commit() except IntegrityError: await self._session.rollback() diff --git a/app/ldap_protocol/kerberos/dtos.py b/app/ldap_protocol/kerberos/dtos.py index d01775aee..ce11b6e2f 100644 --- a/app/ldap_protocol/kerberos/dtos.py +++ b/app/ldap_protocol/kerberos/dtos.py @@ -24,7 +24,6 @@ class AddRequestsDTO: """AddRequestsDTO for Kerberos admin structure.""" group: AddRequest - services: AddRequest krb_user: AddRequest diff --git a/app/ldap_protocol/kerberos/ldap_structure.py b/app/ldap_protocol/kerberos/ldap_structure.py index fec8741c0..d501fe858 100644 --- a/app/ldap_protocol/kerberos/ldap_structure.py +++ b/app/ldap_protocol/kerberos/ldap_structure.py @@ -39,28 +39,17 @@ def __init__( async def create_kerberos_structure( self, group: AddRequest, - services: AddRequest, krb_user: AddRequest, ctx: LDAPAddRequestContext, ) -> None: """Create Kerberos structure in the LDAP directory. :param AddRequest group: AddRequest for Kerberos group. - :param AddRequest services: AddRequest for services container. :param AddRequest krb_user: AddRequest for Kerberos admin user. - :param LDAPSession ldap_session: LDAP session. - :param AbstractKadmin kadmin: Kerberos admin interface. - :param EntityTypeDAO entity_type_dao: DAO for entity types. - :param str services_container: DN for services container. - :param str krbgroup: DN for Kerberos group. + :param LDAPAddRequestContext ctx: LDAP request context. :raises Exception: On structure creation error. :return None. """ - async with self._session.begin_nested(): - service_result = await anext(services.handle(ctx)) - if service_result.result_code != 0: - raise KerberosConflictError("Service error") - async with self._session.begin_nested(): group_result = await anext(group.handle(ctx)) if group_result.result_code != 0: @@ -76,20 +65,17 @@ async def create_kerberos_structure( async def rollback_kerberos_structure( self, krbadmin: str, - services_container: str, krbgroup: str, ) -> None: """Rollback Kerberos structure in the LDAP directory. :param str krbadmin: DN for Kerberos admin user. - :param str services_container: DN for services container. :param str krbgroup: DN for Kerberos group. :return None. """ directories_query = select(Directory).where( or_( get_filter_from_path(krbadmin), - get_filter_from_path(services_container), get_filter_from_path(krbgroup), ), ) diff --git a/app/ldap_protocol/kerberos/service.py b/app/ldap_protocol/kerberos/service.py index fa838abb9..9a6d331a9 100644 --- a/app/ldap_protocol/kerberos/service.py +++ b/app/ldap_protocol/kerberos/service.py @@ -121,14 +121,12 @@ async def setup_krb_catalogue( try: await self._ldap_manager.create_kerberos_structure( add_requests.group, - add_requests.services, add_requests.krb_user, ctx, ) except Exception: await self._ldap_manager.rollback_kerberos_structure( dns.krbadmin_dn, - dns.services_container_dn, dns.krbadmin_group_dn, ) await self._session.commit() @@ -188,11 +186,6 @@ def _build_add_requests( }, is_system=True, ) - services = AddRequest.from_dict( - dns.services_container_dn, - {"objectClass": ["organizationalUnit", "top", "container"]}, - is_system=True, - ) krb_user = AddRequest.from_dict( dns.krbadmin_dn, password=krbadmin_password.get_secret_value(), @@ -229,7 +222,6 @@ def _build_add_requests( ) return AddRequestsDTO( group=group, - services=services, krb_user=krb_user, ) @@ -283,7 +275,6 @@ async def setup_kdc( ) as err: await self._ldap_manager.rollback_kerberos_structure( context.krbadmin, - context.services_container, context.krbgroup, ) await self._kadmin.reset_setup() diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 9ea2eccaf..ff625d2a5 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -35,7 +35,6 @@ is_dn_in_base_directory, ) from ldap_protocol.utils.queries import ( - create_object_sid, get_base_directories, get_group, get_groups, @@ -214,8 +213,9 @@ async def handle( # noqa: C901 ctx.session.add(new_dir) await ctx.session.flush() - - new_dir.object_sid = create_object_sid(base_dn, new_dir.id) + await ctx.object_sid_use_case.add( + directory_id=new_dir.id, + ) await ctx.session.flush() except IntegrityError: await ctx.session.rollback() diff --git a/app/ldap_protocol/ldap_requests/contexts.py b/app/ldap_protocol/ldap_requests/contexts.py index d94b92af8..f81b3f113 100644 --- a/app/ldap_protocol/ldap_requests/contexts.py +++ b/app/ldap_protocol/ldap_requests/contexts.py @@ -27,6 +27,7 @@ from ldap_protocol.multifactor import LDAPMultiFactorAPI from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases +from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.rootdse.reader import RootDSEReader @@ -47,6 +48,7 @@ class LDAPAddRequestContext: access_manager: AccessManager role_use_case: RoleUseCase attribute_value_validator: AttributeValueValidator + object_sid_use_case: ObjectSIDUseCase @dataclass @@ -63,6 +65,7 @@ class LDAPModifyRequestContext: password_use_cases: PasswordPolicyUseCases password_utils: PasswordUtils attribute_value_validator: AttributeValueValidator + object_sid_use_case: ObjectSIDUseCase @dataclass diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index f25ab9996..2c74e680a 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -53,6 +53,7 @@ get_directory_by_rid, get_filter_from_path, get_groups, + groups_include_primary_rid, remove_disallowed_group_members, remove_from_group_membership, ) @@ -386,13 +387,17 @@ def _get_primary_group_id(self, directory: Directory) -> str | None: None, ) - def _contain_primary_group( + async def _contain_primary_group( self, groups: list[Group], primary_group_id: str, + session: AsyncSession, ) -> bool: - return any( - group.directory.relative_id == primary_group_id for group in groups + """Check whether membership includes the group for this RID.""" + return await groups_include_primary_rid( + session, + groups, + primary_group_id, ) async def _get_directories_with_primary_group_id( @@ -433,16 +438,25 @@ async def _get_members_with_primary_group_id( ) return list(await session.scalars(query)) - def _is_primary_group_deleted( + async def _is_primary_group_deleted( self, groups: list[Group], primary_group_id: str, operation: Operation, + session: AsyncSession, ) -> bool: if operation == Operation.REPLACE: - return not self._contain_primary_group(groups, primary_group_id) + return not await self._contain_primary_group( + groups, + primary_group_id, + session, + ) elif operation == Operation.DELETE: - return self._contain_primary_group(groups, primary_group_id) + return await self._contain_primary_group( + groups, + primary_group_id, + session, + ) return False async def _can_delete_group_from_directory( @@ -451,6 +465,7 @@ async def _can_delete_group_from_directory( user: UserSchema, groups: list[Group], operation: Operation, + session: AsyncSession, ) -> None: """Check if the request can delete group from directory.""" if operation == Operation.REPLACE: @@ -478,7 +493,12 @@ async def _can_delete_group_from_directory( if not primary_group_id: return - if self._is_primary_group_deleted(groups, primary_group_id, operation): + if await self._is_primary_group_deleted( + groups, + primary_group_id, + operation, + session, + ): raise ModifyForbiddenError( "Can't delete primary group from user.", ) @@ -553,6 +573,7 @@ async def _delete_memberof( user=user, groups=groups, operation=change.operation, + session=session, ) if not change.modification.vals: @@ -716,7 +737,7 @@ async def _add_primary_group_attribute( rid = str(change.modification.vals[0]) - if self._contain_primary_group(directory.groups, rid): + if await self._contain_primary_group(directory.groups, rid, session): session.add( Attribute( name="primaryGroupID", diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index bc5d67ca9..7c46f6948 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -489,7 +489,7 @@ async def paginate_query( return query, int(ceil(count / float(self.size_limit))), count - async def _fill_attrs( + async def _fill_attrs( # noqa: C901 self, directory: Directory, obj_classes: list[str], @@ -541,17 +541,24 @@ async def _fill_attrs( if group_directories is not None: async for directory_ in group_directories: - attrs["tokenGroups"].append( - string_to_sid(directory_.object_sid), # type: ignore - ) + sid_bytes = self.get_directory_sid(directory_) + if sid_bytes is not None: + attrs["tokenGroups"].append( + sid_bytes, # type: ignore + ) if self.member and "group" in obj_classes and directory.group: for member in directory.group.members: attrs["member"].append(member.path_dn) @staticmethod - def get_directory_sid(directory: Directory) -> bytes: - return string_to_sid(directory.object_sid) + def get_directory_sid(directory: Directory) -> bytes | None: + """Get objectSid as bytes from directory attributes.""" + return ( + string_to_sid(directory.object_sid) + if directory.object_sid + else None + ) @staticmethod def get_directory_guid(directory: Directory) -> bytes: @@ -600,6 +607,13 @@ async def tree_view( # noqa: C901 attrs[attr.name].append(value) continue + if ( + attr.name + and attr.name.lower() == "objectsid" + and self.is_sid_requested + ): + continue + attrs[attr.name].append(value) distinguished_name = directory.path_dn @@ -670,8 +684,11 @@ async def tree_view( # noqa: C901 attrs[directory.search_fields["objectguid"]].append(guid) # type: ignore if self.is_sid_requested: - guid = self.get_directory_sid(directory) - attrs[directory.search_fields["objectsid"]].append(guid) # type: ignore + sid_bytes = self.get_directory_sid(directory) + if sid_bytes is not None: + attrs["objectSid"].append( + sid_bytes, # type: ignore + ) if self.entity_type_name: attrs["entityTypeName"].append(directory.entity_type.name) diff --git a/app/ldap_protocol/ldap_schema/directory_create_use_case.py b/app/ldap_protocol/ldap_schema/directory_create_use_case.py index 9e2e4a033..5477c530e 100644 --- a/app/ldap_protocol/ldap_schema/directory_create_use_case.py +++ b/app/ldap_protocol/ldap_schema/directory_create_use_case.py @@ -14,7 +14,6 @@ from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) -from ldap_protocol.ldap_schema.exceptions import CantCreateDirectoryError from ldap_protocol.roles.role_use_case import RoleUseCase if TYPE_CHECKING: @@ -68,25 +67,12 @@ async def create_dir( parent_dir: "Directory", ) -> None: """Create.""" - base_directory_paths_and_sids = ( - await self.__directory_dao.get_base_directory_paths_with_sid() - ) - dir_ = await self.__directory_dao.create_directory( name=dto.name, is_system=dto.is_system, parent_dir=parent_dir, ) - for _path, _sid in base_directory_paths_and_sids: - if _is_dn_in_base_directory(_path, dir_.path_dn): - base_dn_sid = _sid - break - else: - raise CantCreateDirectoryError("Cannot create a directory.") - - dir_.object_sid = _get_object_sid(base_dn_sid, dir_.id) - attr_dto = AttributeDTO(name=dir_.rdname, values=[dir_.name]) await self.__attribute_dao.add_directory_name_attribute( dir_.id, diff --git a/app/ldap_protocol/ldap_schema/directory_dao.py b/app/ldap_protocol/ldap_schema/directory_dao.py index 095da77ae..ec0f72e13 100644 --- a/app/ldap_protocol/ldap_schema/directory_dao.py +++ b/app/ldap_protocol/ldap_schema/directory_dao.py @@ -10,7 +10,6 @@ from constants import CONFIGURATION_DIR_NAME from entities import Directory, EntityType -from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -64,13 +63,6 @@ async def get_all_without_entity_type(self) -> list[Directory]: ) return list(result.all()) - async def get_base_directory_paths_with_sid(self) -> list[tuple[str, str]]: - """Get all base directory paths.""" - base_dirs = await get_base_directories(self.__session) - return [ - (base_dir.path_dn, base_dir.object_sid) for base_dir in base_dirs - ] - async def get_configuration_dir(self) -> Directory: """Get configuration directory.""" result = await self.__session.execute( diff --git a/app/ldap_protocol/rid_manager/__init__.py b/app/ldap_protocol/rid_manager/__init__.py new file mode 100644 index 000000000..204bbef53 --- /dev/null +++ b/app/ldap_protocol/rid_manager/__init__.py @@ -0,0 +1,25 @@ +"""RID Manager module. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from .object_sid_gateway import ObjectSIDGateway +from .object_sid_use_case import ObjectSIDUseCase +from .rid_manager_gateway import RIDManagerGateway +from .rid_manager_use_case import RIDManagerUseCase +from .rid_set_gateway import RIDSetGateway +from .rid_set_use_case import RIDSetUseCase +from .setup_gateway import RIDManagerSetupGateway +from .setup_use_case import RIDManagerSetupUseCase + +__all__ = [ + "ObjectSIDGateway", + "ObjectSIDUseCase", + "RIDManagerGateway", + "RIDManagerSetupGateway", + "RIDManagerSetupUseCase", + "RIDManagerUseCase", + "RIDSetGateway", + "RIDSetUseCase", +] diff --git a/app/ldap_protocol/rid_manager/dtos.py b/app/ldap_protocol/rid_manager/dtos.py new file mode 100644 index 000000000..12e324cd0 --- /dev/null +++ b/app/ldap_protocol/rid_manager/dtos.py @@ -0,0 +1,16 @@ +"""RID Manager DTOs. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from dataclasses import dataclass + + +@dataclass +class RIDSetAllocationParamsDTO: + """RID Set DTO.""" + + next_rid: int + previous_allocation_pool: int + allocation_pool: int diff --git a/app/ldap_protocol/rid_manager/exceptions.py b/app/ldap_protocol/rid_manager/exceptions.py new file mode 100644 index 000000000..40b75bbb9 --- /dev/null +++ b/app/ldap_protocol/rid_manager/exceptions.py @@ -0,0 +1,101 @@ +"""RID Manager exceptions.""" + +from enum import IntEnum + +from errors import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + RID_MANAGER_NOT_FOUND_ERROR = 1 + RID_AVAILABLE_POOL_NOT_FOUND_ERROR = 2 + RID_NEXT_RID_NOT_FOUND_ERROR = 3 + RID_SET_NOT_FOUND_ERROR = 4 + RID_DOMAIN_IDENTIFIER_NOT_FOUND_ERROR = 5 + RID_DOMAIN_CONTROLLER_NOT_FOUND_ERROR = 6 + RID_BASE_DOMAIN_NOT_FOUND_ERROR = 7 + RID_SYSTEM_CONTAINER_NOT_FOUND_ERROR = 8 + RID_ALLOCATION_POOL_NOT_FOUND_ERROR = 9 + RID_PREVIOUS_ALLOCATION_POOL_NOT_FOUND_ERROR = 10 + RID_POOL_EXCEEDED_ERROR = 11 + RID_OBJECT_SID_NOT_FOUND_ERROR = 12 + + +class RIDManagerError(BaseDomainException): + """RID Manager error.""" + + code: ErrorCodes = ErrorCodes.BASE_ERROR + + +class RIDManagerNotFoundError(RIDManagerError): + """RID Manager not found error.""" + + code = ErrorCodes.RID_MANAGER_NOT_FOUND_ERROR + + +class RIDManagerAvailablePoolNotFoundError(RIDManagerError): + """RID Manager available pool not found error.""" + + code = ErrorCodes.RID_AVAILABLE_POOL_NOT_FOUND_ERROR + + +class RIDManagerRidNextRIDNotFoundError(RIDManagerError): + """RID Manager next RID not found error.""" + + code = ErrorCodes.RID_NEXT_RID_NOT_FOUND_ERROR + + +class RIDManagerRidSetNotFoundError(RIDManagerError): + """RID Manager RID Set not found error.""" + + code = ErrorCodes.RID_SET_NOT_FOUND_ERROR + + +class RIDManagerDomainIdentifierNotFoundError(RIDManagerError): + """RID Manager domain identifier not found error.""" + + code = ErrorCodes.RID_DOMAIN_IDENTIFIER_NOT_FOUND_ERROR + + +class RIDManagerDomainControllerNotFoundError(RIDManagerError): + """RID Manager domain controller not found error.""" + + code = ErrorCodes.RID_DOMAIN_CONTROLLER_NOT_FOUND_ERROR + + +class RIDManagerSystemContainerNotFoundError(RIDManagerError): + """RID Manager system container not found error.""" + + code = ErrorCodes.RID_SYSTEM_CONTAINER_NOT_FOUND_ERROR + + +class RIDManagerRidAllocationPoolNotFoundError(RIDManagerError): + """RID Manager RID allocation pool not found error.""" + + code = ErrorCodes.RID_ALLOCATION_POOL_NOT_FOUND_ERROR + + +class RIDManagerRidPreviousAllocationPoolNotFoundError(RIDManagerError): + """RID Manager RID previous allocation pool not found error.""" + + code = ErrorCodes.RID_PREVIOUS_ALLOCATION_POOL_NOT_FOUND_ERROR + + +class RIDManagerPoolExceededError(RIDManagerError): + """RID Manager pool exceeded error.""" + + code = ErrorCodes.RID_POOL_EXCEEDED_ERROR + + +class RIDManagerBaseDomainNotFoundError(RIDManagerError): + """RID Manager base domain not found error.""" + + code = ErrorCodes.RID_BASE_DOMAIN_NOT_FOUND_ERROR + + +class RIDManagerObjectSIDNotFoundError(RIDManagerError): + """RID Manager object SID not found error.""" + + code = ErrorCodes.RID_OBJECT_SID_NOT_FOUND_ERROR diff --git a/app/ldap_protocol/rid_manager/object_sid_gateway.py b/app/ldap_protocol/rid_manager/object_sid_gateway.py new file mode 100644 index 000000000..f63cd6cf9 --- /dev/null +++ b/app/ldap_protocol/rid_manager/object_sid_gateway.py @@ -0,0 +1,72 @@ +"""Object SID gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Attribute, Directory +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerDomainIdentifierNotFoundError, + RIDManagerObjectSIDNotFoundError, +) +from ldap_protocol.utils.async_cache import domain_identifier_cache +from repo.pg.tables import queryable_attr as qa + + +class ObjectSIDGateway: + """Object SID gateway.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize Object SID gateway.""" + self._session = session + + async def get(self, directory_id: int) -> str: + """Get object SID.""" + query = await self._session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == directory_id, + qa(Attribute.name) == "objectSid", + ), + ) + if not (query and query.value): + raise RIDManagerObjectSIDNotFoundError("object SID not found") + + return query.value + + async def add(self, directory_id: int, object_sid: str) -> None: + """Add object SID.""" + self._session.add( + Attribute( + name="objectSid", + value=object_sid, + directory_id=directory_id, + ), + ) + + @domain_identifier_cache + async def get_domain_identifier(self) -> str: + """Get domain identifier (cached ``Attribute.value`` string).""" + return await self._load_domain_identifier_value() + + async def _load_domain_identifier_value(self) -> str: + query = await self._session.scalar( + select(Attribute).where( + qa(Attribute.name) == "DomainIdentifier", + select(Directory) + .where( + qa(Directory.id) == qa(Attribute.directory_id), + qa(Directory.parent_id).is_(None), + ) + .exists(), + ), + ) + + if not query or not query.value: + raise RIDManagerDomainIdentifierNotFoundError( + "domain identifier not found", + ) + + return query.value diff --git a/app/ldap_protocol/rid_manager/object_sid_use_case.py b/app/ldap_protocol/rid_manager/object_sid_use_case.py new file mode 100644 index 000000000..d46a4b9cb --- /dev/null +++ b/app/ldap_protocol/rid_manager/object_sid_use_case.py @@ -0,0 +1,52 @@ +"""Object SID use case. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy.ext.asyncio import AsyncSession + +from enums import SidPrefix +from ldap_protocol.rid_manager.object_sid_gateway import ObjectSIDGateway +from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase +from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase + + +class ObjectSIDUseCase: + """Object SID use case.""" + + def __init__( + self, + gateway: ObjectSIDGateway, + rid_set_use_case: RIDSetUseCase, + session: AsyncSession, + rid_manager_use_case: RIDManagerUseCase, + ) -> None: + """Initialize Object SID use case.""" + self._gateway = gateway + self._rid_set_use_case = rid_set_use_case + self._session = session + self._rid_manager_use_case = rid_manager_use_case + + async def get(self, directory_id: int) -> str: + """Get object SID.""" + return await self._gateway.get(directory_id) + + async def add( + self, + directory_id: int, + rid: int | None = None, + sid_prefix: SidPrefix = SidPrefix.DOMAIN_IDENTIFIER, + ) -> None: + """Add object SID.""" + if rid is None: + rid_set_id = await self._rid_set_use_case.get_rid_set_id() + rid = await self._rid_set_use_case.allocate_next_rid(rid_set_id) + + if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: + object_sid = f"{sid_prefix}-{rid}" + elif sid_prefix == SidPrefix.DOMAIN_IDENTIFIER: + domain_identifier = await self._gateway.get_domain_identifier() + object_sid = f"{sid_prefix}-{domain_identifier}-{rid}" + + await self._gateway.add(directory_id, object_sid) diff --git a/app/ldap_protocol/rid_manager/rid_manager_gateway.py b/app/ldap_protocol/rid_manager/rid_manager_gateway.py new file mode 100644 index 000000000..059312069 --- /dev/null +++ b/app/ldap_protocol/rid_manager/rid_manager_gateway.py @@ -0,0 +1,95 @@ +"""RID Manager gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from config import Settings +from constants import DOMAIN_CONTROLLERS_OU_NAME +from entities import Attribute, Directory +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerAvailablePoolNotFoundError, + RIDManagerDomainControllerNotFoundError, + RIDManagerNotFoundError, +) +from repo.pg.tables import queryable_attr as qa + + +class RIDManagerGateway: + """RID Manager gateway.""" + + def __init__(self, session: AsyncSession, settings: Settings) -> None: + """Initialize RID Manager gateway.""" + self._session = session + self._settings = settings + + async def get_rid_manager(self) -> Directory: + """Get RID Manager directory.""" + rid_manager = await self._session.scalar( + select(Directory).where(qa(Directory.name) == "RID Manager$"), + ) + if not rid_manager: + raise RIDManagerNotFoundError("RID Manager directory not found") + return rid_manager + + async def get_rid_available_pool(self) -> int: + """Get RID available pool.""" + rid_available_pool = await self._session.scalar( + select(Attribute) + .where(qa(Attribute.name) == "rIDAvailablePool") + .with_for_update(), + ) + if not (rid_available_pool and rid_available_pool.value): + raise RIDManagerAvailablePoolNotFoundError( + "RID available pool not found", + ) + return int(rid_available_pool.value) + + async def update_rid_available_pool(self, available_pool: int) -> None: + """Update RID available pool.""" + await self._session.execute( + update(Attribute) + .where(qa(Attribute.name) == "rIDAvailablePool") + .values(value=str(available_pool)), + ) + + async def get_domain_controller( + self, + ) -> Directory: + """Get domain controller.""" + domain = await self._session.scalar( + select(Directory).where( + qa(Directory.object_class) == "domain", + qa(Directory.parent_id).is_(None), + ), + ) + if not domain: + raise RIDManagerDomainControllerNotFoundError( + "Domain controller not found", + ) + + domain_controllers_ou = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME, + qa(Directory.parent_id) == domain.id, + ), + ) + if not domain_controllers_ou: + raise RIDManagerDomainControllerNotFoundError( + "Domain controllers OU not found", + ) + + domain_controller = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == self._settings.HOST_MACHINE_SHORT_NAME, + qa(Directory.parent_id) == domain_controllers_ou.id, + ), + ) + if not domain_controller: + raise RIDManagerDomainControllerNotFoundError( + "Domain controller not found", + ) + return domain_controller diff --git a/app/ldap_protocol/rid_manager/rid_manager_use_case.py b/app/ldap_protocol/rid_manager/rid_manager_use_case.py new file mode 100644 index 000000000..62fc74e58 --- /dev/null +++ b/app/ldap_protocol/rid_manager/rid_manager_use_case.py @@ -0,0 +1,47 @@ +"""RID Manager use case. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Directory +from ldap_protocol.rid_manager.exceptions import RIDManagerPoolExceededError +from ldap_protocol.rid_manager.rid_manager_gateway import RIDManagerGateway +from ldap_protocol.rid_manager.utils import from_qword, to_qword + + +class RIDManagerUseCase: + """RID Manager use case.""" + + RID_BLOCK_SIZE = 500 + # NOTE Domain Controller(with role Rid Master) attr + # replace and change logic, when super DC is introduced + + def __init__( + self, + gateway: RIDManagerGateway, + session: AsyncSession, + ) -> None: + """Initialize RID Manager use case.""" + self._gateway = gateway + self._session = session + + async def allocate_pool(self) -> int: + """Allocate pool.""" + async with self._session.begin_nested(): + available_pool = await self._gateway.get_rid_available_pool() + lower, upper = from_qword(available_pool) + + if lower + self.RID_BLOCK_SIZE > upper: + raise RIDManagerPoolExceededError("Available pool exceeded") + + new_available_pool = to_qword(lower + self.RID_BLOCK_SIZE, upper) + await self._gateway.update_rid_available_pool(new_available_pool) + + return to_qword(lower, lower + self.RID_BLOCK_SIZE) + + async def get_domain_controller(self) -> Directory: + """Locate best Domain Controller via DNS SRV records.""" + return await self._gateway.get_domain_controller() diff --git a/app/ldap_protocol/rid_manager/rid_set_gateway.py b/app/ldap_protocol/rid_manager/rid_set_gateway.py new file mode 100644 index 000000000..29b70985c --- /dev/null +++ b/app/ldap_protocol/rid_manager/rid_set_gateway.py @@ -0,0 +1,252 @@ +"""RID Set gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import aliased + +from config import Settings +from constants import DOMAIN_CONTROLLERS_OU_NAME +from entities import Attribute, Directory +from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerRidAllocationPoolNotFoundError, + RIDManagerRidNextRIDNotFoundError, + RIDManagerRidPreviousAllocationPoolNotFoundError, + RIDManagerRidSetNotFoundError, +) +from ldap_protocol.utils.async_cache import rid_set_id_cache +from repo.pg.tables import queryable_attr as qa + + +class RIDSetGateway: + """RID Set gateway.""" + + def __init__(self, session: AsyncSession, settings: Settings) -> None: + """Initialize RID Set gateway.""" + self._session = session + self._settings = settings + + @rid_set_id_cache + async def get_rid_set_id(self) -> int: + """Get RID Set ID.""" + return await self.get_rid_set_id_value() + + async def get_rid_set_id_value(self) -> int: + """Get RID Set ID.""" + domain = aliased(Directory) + domain_controllers_ou = aliased(Directory) + domain_controller = aliased(Directory) + rid_set = aliased(Directory) + + rid_set_id = await self._session.scalar( + select(qa(rid_set.id)) + .select_from(domain) + .join( + domain_controllers_ou, + qa(domain_controllers_ou.parent_id) == qa(domain.id), + ) + .join( + domain_controller, + qa(domain_controller.parent_id) + == qa(domain_controllers_ou.id), + ) + .join( + rid_set, + qa(rid_set.parent_id) == qa(domain_controller.id), + ) + .where( + qa(domain.object_class) == "domain", + qa(domain.parent_id).is_(None), + qa(domain_controllers_ou.name) == DOMAIN_CONTROLLERS_OU_NAME, + qa(domain_controller.name) + == self._settings.HOST_MACHINE_SHORT_NAME, + qa(rid_set.name) == "RID Set", + ), + ) + if rid_set_id is None: + raise RIDManagerRidSetNotFoundError("RID Set directory not found") + return int(rid_set_id) + + async def get(self, domain_controller: Directory) -> Directory: + """Get RID Set directory.""" + rid_set = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == "RID Set", + qa(Directory.parent_id) == domain_controller.id, + ), + ) + if not rid_set: + raise RIDManagerRidSetNotFoundError("RID Set directory not found") + + return rid_set + + async def add(self, domain_controller: Directory) -> Directory: + """Add RID Set directory.""" + rid_set_dir = Directory( + is_system=True, + name="RID Set", + ) + rid_set_dir.create_path(domain_controller, "cn") + + self._session.add(rid_set_dir) + await self._session.flush() + + rid_set_dir.parent_id = domain_controller.id + await self._session.refresh(rid_set_dir, ["id"]) + + self._session.add( + Attribute( + name="cn", + value="RID Set", + directory_id=rid_set_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="top", + directory_id=rid_set_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="rIDSet", + directory_id=rid_set_dir.id, + ), + ) + + await self._session.flush() + + await self._session.refresh( + instance=rid_set_dir, + attribute_names=["attributes"], + with_for_update=None, + ) + + return rid_set_dir + + async def set_allocation_attrs( + self, + rid_set_id: int, + allocation_params: RIDSetAllocationParamsDTO, + ) -> None: + """Set next RID attribute in RID Set directory.""" + self._session.add( + Attribute( + name="rIDNextRID", + value=str(allocation_params.next_rid), + directory_id=rid_set_id, + ), + ) + self._session.add( + Attribute( + name="rIDPreviousAllocationPool", + value=str(allocation_params.previous_allocation_pool), + directory_id=rid_set_id, + ), + ) + self._session.add( + Attribute( + name="rIDAllocationPool", + value=str(allocation_params.allocation_pool), + directory_id=rid_set_id, + ), + ) + + async def get_rid_allocation_pool(self, rid_set_id: int) -> int: + """Get RID allocation pool from RID Set directory.""" + allocation_pool = await self._session.scalar( + select(Attribute).where( + qa(Attribute.name) == "rIDAllocationPool", + qa(Attribute.directory_id) == rid_set_id, + ), + ) + if not (allocation_pool and allocation_pool.value): + raise RIDManagerRidAllocationPoolNotFoundError( + "RID allocation pool not found", + ) + return int(allocation_pool.value) + + async def get_rid_previous_allocation_pool( + self, + rid_set_id: int, + ) -> int: + """Get previous RID allocation pool from RID Set directory.""" + previous_allocation_pool = await self._session.scalar( + select(Attribute) + .where( + qa(Attribute.name) == "rIDPreviousAllocationPool", + qa(Attribute.directory_id) == rid_set_id, + ) + .with_for_update(), + ) + if not (previous_allocation_pool and previous_allocation_pool.value): + raise RIDManagerRidPreviousAllocationPoolNotFoundError( + "previous RID allocation pool not found", + ) + return int(previous_allocation_pool.value) + + async def get_rid_next_rid(self, rid_set_id: int) -> int: + """Get next RID from RID Set directory.""" + next_rid = await self._session.scalar( + select(Attribute) + .where( + qa(Attribute.name) == "rIDNextRID", + qa(Attribute.directory_id) == rid_set_id, + ) + .with_for_update(), + ) + if not (next_rid and next_rid.value): + raise RIDManagerRidNextRIDNotFoundError("next RID not found") + return int(next_rid.value) + + async def update_next_rid( + self, + rid_set_id: int, + next_rid: int, + ) -> None: + """Update next RID.""" + await self._session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "rIDNextRID", + qa(Attribute.directory_id) == rid_set_id, + ) + .values(value=str(next_rid)), + ) + + async def reset_attrs_when_pool_exceeded( + self, + rid_set_id: int, + allocation_pool: int, + previous_allocation_pool: int, + next_rid: int, + ) -> None: + """Reset RID pools when pool exceeded.""" + await self._session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "rIDAllocationPool", + qa(Attribute.directory_id) == rid_set_id, + ) + .values(value=str(allocation_pool)), + ) + await self._session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "rIDPreviousAllocationPool", + qa(Attribute.directory_id) == rid_set_id, + ) + .values(value=str(previous_allocation_pool)), + ) + await self.update_next_rid( + rid_set_id, + next_rid, + ) diff --git a/app/ldap_protocol/rid_manager/rid_set_use_case.py b/app/ldap_protocol/rid_manager/rid_set_use_case.py new file mode 100644 index 000000000..26d6767cb --- /dev/null +++ b/app/ldap_protocol/rid_manager/rid_set_use_case.py @@ -0,0 +1,148 @@ +"""RID Set use case. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Directory +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO +from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase +from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway +from ldap_protocol.rid_manager.utils import from_qword +from ldap_protocol.roles.role_use_case import RoleUseCase + + +class RIDSetUseCase: + """RID Set use case.""" + + def __init__( + self, + gateway: RIDSetGateway, + entity_type_use_case: EntityTypeUseCase, + session: AsyncSession, + rid_manager_use_case: RIDManagerUseCase, + role_use_case: RoleUseCase, + ) -> None: + """Initialize RID Set use case.""" + self._gateway = gateway + self._entity_type_use_case = entity_type_use_case + self._session = session + self._rid_manager_use_case = rid_manager_use_case + self._role_use_case = role_use_case + + async def get(self, domain_controller: Directory) -> Directory: + """Get RID Set directory.""" + return await self._gateway.get(domain_controller) + + async def get_rid_set_id(self) -> int: + """Get RID Set ID.""" + return await self._gateway.get_rid_set_id() + + async def add( + self, + domain_controller: Directory, + allocation_params: RIDSetAllocationParamsDTO, + ) -> Directory: + """Create RID Set directory.""" + rid_set = await self._gateway.add(domain_controller) + await self._entity_type_use_case.attach_entity_type_to_directory( + directory=rid_set, + is_system_entity_type=True, + object_class_names={"top", "rIDSet"}, + ) + + await self._gateway.set_allocation_attrs( + rid_set.id, + allocation_params, + ) + await self.inherit_parent_aces( + domain_controller=domain_controller, + rid_set=rid_set, + ) + await self._session.flush() + return rid_set + + def is_pool_exceeded( + self, + current_next_rid: int, + previous_allocation_pool: int, + ) -> bool: + """Check if RID pool is exceeded.""" + _, upper = from_qword(previous_allocation_pool) + + return current_next_rid + 1 > upper + + async def allocate_next_rid(self, rid_set_id: int) -> int: + """Allocate next RID.""" + async with self._session.begin_nested(): + current_next_rid = await self._gateway.get_rid_next_rid(rid_set_id) + previous_allocation_pool = ( + await self._gateway.get_rid_previous_allocation_pool( + rid_set_id, + ) + ) + + if not self.is_pool_exceeded( + current_next_rid, + previous_allocation_pool, + ): + new_next_rid = current_next_rid + 1 + await self._gateway.update_next_rid( + rid_set_id, + new_next_rid, + ) + else: + new_next_rid = await self.rebind_next_rid_from_new_pool( + rid_set_id, + ) + + return new_next_rid + + async def rebind_next_rid_from_new_pool( + self, + rid_set_id: int, + ) -> int: + """Rebind next RID from new pool.""" + new_allocation_pool = await self._rid_manager_use_case.allocate_pool() + + current_allocation_pool = await self._gateway.get_rid_allocation_pool( + rid_set_id, + ) + lower, _ = from_qword(current_allocation_pool) + await self._gateway.reset_attrs_when_pool_exceeded( + rid_set_id=rid_set_id, + next_rid=lower, + allocation_pool=new_allocation_pool, + previous_allocation_pool=current_allocation_pool, + ) + return lower + + async def generate_rid_set_attrs(self) -> RIDSetAllocationParamsDTO: + """Generate RID Set attributes.""" + previous_allocation_pool = ( + await self._rid_manager_use_case.allocate_pool() + ) + allocation_pool = await self._rid_manager_use_case.allocate_pool() + lower, _ = from_qword(previous_allocation_pool) + + return RIDSetAllocationParamsDTO( + next_rid=lower, + allocation_pool=allocation_pool, + previous_allocation_pool=previous_allocation_pool, + ) + + async def inherit_parent_aces( + self, + domain_controller: Directory, + rid_set: Directory, + ) -> None: + """Inherit parent ACEs to RID Set directory.""" + await self._role_use_case.inherit_parent_aces( + parent_directory=domain_controller, + directory=rid_set, + ) diff --git a/app/ldap_protocol/rid_manager/setup_gateway.py b/app/ldap_protocol/rid_manager/setup_gateway.py new file mode 100644 index 000000000..e8d4117bd --- /dev/null +++ b/app/ldap_protocol/rid_manager/setup_gateway.py @@ -0,0 +1,166 @@ +"""RID Manager Gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import secrets + +from sqlalchemy import exists, select +from sqlalchemy.ext.asyncio import AsyncSession + +from constants import SYSTEM_CONTAINER_NAME +from entities import Attribute, Directory +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerSystemContainerNotFoundError, +) +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa + + +class RIDManagerSetupGateway: + """Gateway for RID Manager setup database operations.""" + + def __init__( + self, + session: AsyncSession, + entity_type_use_case: EntityTypeUseCase, + ) -> None: + """Initialize RID Manager setup gateway.""" + self._session = session + self._entity_type_use_case = entity_type_use_case + + async def get_system_container(self) -> Directory: + """Get System container directory. + + :return: System container directory + """ + base_dn_list = await get_base_directories(self._session) + + domain = base_dn_list[0] + + query = select(Directory).where( + qa(Directory.name) == SYSTEM_CONTAINER_NAME, + qa(Directory.parent_id) == domain.id, + ) + + system_container = await self._session.scalar(query) + + if not system_container: + raise RIDManagerSystemContainerNotFoundError( + "System container not found", + ) + + return system_container + + async def set_rid_manager(self) -> Directory: + """Create RID Manager directory.""" + system_container = await self.get_system_container() + + rid_manager_dir = Directory( + is_system=True, + name="RID Manager$", + ) + rid_manager_dir.create_path(system_container, "cn") + + self._session.add(rid_manager_dir) + await self._session.flush() + + rid_manager_dir.parent_id = system_container.id + await self._session.refresh(rid_manager_dir, ["id"]) + + self._session.add( + Attribute( + name="cn", + value="RID Manager$", + directory_id=rid_manager_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="top", + directory_id=rid_manager_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="rIDManager", + directory_id=rid_manager_dir.id, + ), + ) + + await self._session.flush() + + await self._session.refresh( + instance=rid_manager_dir, + attribute_names=["attributes"], + with_for_update=None, + ) + + await self._entity_type_use_case.attach_entity_type_to_directory( + directory=rid_manager_dir, + is_system_entity_type=True, + object_class_names={"top", "rIDManager"}, + ) + + await self._session.flush() + + return rid_manager_dir + + async def set_rid_available_pool( + self, + rid_manager_dir: Directory, + qword_value: int, + ) -> None: + """Set rIDAvailablePool attribute in domain. + + Updates the global RID pool counter. + + :param rid_manager_dir: RID Manager directory object + :param qword_value: New QWORD value (64-bit) + """ + self._session.add( + Attribute( + directory_id=rid_manager_dir.id, + name="rIDAvailablePool", + value=str(qword_value), + ), + ) + + await self._session.flush() + + def _generate_domain_sid_identifier(self) -> str: + """Generate Domain Identifier for Active Directory domain.""" + return ( + f"{secrets.randbits(32)}" + f"-{secrets.randbits(32)}-{secrets.randbits(32)}" + ) + + async def create_domain_identifier(self, domain_id: int) -> None: + """Add domain identifier to domain.""" + domain_identifer = await self._session.scalar( + select( + exists(Attribute), + ).where( + qa(Attribute.name) == "DomainIdentifier", + ), + ) + + if domain_identifer: + return + + self._session.add( + Attribute( + name="DomainIdentifier", + value=f"{self._generate_domain_sid_identifier()}", + directory_id=domain_id, + ), + ) + await self._session.flush() diff --git a/app/ldap_protocol/rid_manager/setup_use_case.py b/app/ldap_protocol/rid_manager/setup_use_case.py new file mode 100644 index 000000000..34bb18074 --- /dev/null +++ b/app/ldap_protocol/rid_manager/setup_use_case.py @@ -0,0 +1,82 @@ +"""RID Manager for issuing RID from pools. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE + +""" + +from config import Settings +from entities import Directory +from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase +from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase +from ldap_protocol.rid_manager.setup_gateway import RIDManagerSetupGateway +from ldap_protocol.rid_manager.utils import to_qword +from ldap_protocol.roles.ace_dao import AccessControlEntryDAO +from ldap_protocol.roles.role_use_case import RoleUseCase + + +class RIDManagerSetupUseCase: + """RID Manager setup use case.""" + + RID_MIN = 1100 + RID_AVAILABLE_MAX = 1073741822 # 30-bit max (2^30 - 2) + + def __init__( + self, + rid_manager_setup_gateway: RIDManagerSetupGateway, + role_use_case: RoleUseCase, + access_control_entry_dao: AccessControlEntryDAO, + rid_set_use_case: RIDSetUseCase, + rid_manager_use_case: RIDManagerUseCase, + settings: Settings, + ) -> None: + """Initialize RID Manager setup use case. + + :param rid_manager_setup_gateway: Gateway for setup operations + :param role_use_case: Role use case + """ + self._gateway = rid_manager_setup_gateway + self._role_use_case = role_use_case + self._access_control_entry_dao = access_control_entry_dao + self._settings = settings + self._rid_set_use_case = rid_set_use_case + self._rid_manager_use_case = rid_manager_use_case + + async def setup(self) -> None: + """Create RID Manager.""" + rid_manager_dir = await self._gateway.set_rid_manager() + qword = to_qword(self.RID_MIN, self.RID_AVAILABLE_MAX) + await self._gateway.set_rid_available_pool( + rid_manager_dir, + qword, + ) + dc = await self._rid_manager_use_case.get_domain_controller() + await self._rid_set_use_case.add( + dc, + await self._rid_set_use_case.generate_rid_set_attrs(), + ) + + await self.inherit_aces( + rid_manager_dir, + ) + + async def inherit_aces( + self, + rid_manager_dir: Directory, + ) -> None: + """Inherit ACEs from domain root to RID Manager directory. + + Instead of creating a special ACE or role for RID Manager, + we reuse the existing ACL model: all ACEs that apply to the + domain root (including Domain Admins) are inherited by the + `CN=RID Manager$` directory, similar to how it is done in + migration `ebf19750805e_add_domain_controllers_ou`. + """ + await self._role_use_case.inherit_parent_aces( + parent_directory=await self._gateway.get_system_container(), + directory=rid_manager_dir, + ) + + async def create_domain_identifier(self, domain_id: int) -> None: + """Create domain identifier.""" + await self._gateway.create_domain_identifier(domain_id) diff --git a/app/ldap_protocol/rid_manager/utils.py b/app/ldap_protocol/rid_manager/utils.py new file mode 100644 index 000000000..eb6f3835b --- /dev/null +++ b/app/ldap_protocol/rid_manager/utils.py @@ -0,0 +1,23 @@ +"""RID Manager utils.""" + + +def to_qword(lower: int, upper: int) -> int: + """Create QWORD (64-bit) from two DWORDs (32-bit each).""" + if lower < 0 or lower > 0xFFFFFFFF: + raise ValueError(f"Lower boundary out of range: {lower}") + if upper < 0 or upper > 0xFFFFFFFF: + raise ValueError(f"Upper boundary out of range: {upper}") + + qword = (upper << 32) | lower + + return qword + + +def from_qword(qword: int) -> tuple[int, int]: + """Split QWORD (64-bit) into two DWORDs (lower, upper).""" + if qword < 0 or qword > 0xFFFFFFFFFFFFFFFF: + raise ValueError(f"QWORD out of range: {qword}") + + lower = qword & 0xFFFFFFFF + upper = (qword >> 32) & 0xFFFFFFFF + return lower, upper diff --git a/app/ldap_protocol/rootdse/reader.py b/app/ldap_protocol/rootdse/reader.py index 065be0a54..6e4464416 100644 --- a/app/ldap_protocol/rootdse/reader.py +++ b/app/ldap_protocol/rootdse/reader.py @@ -8,6 +8,7 @@ from config import Settings from constants import DEFAULT_DC_POSTFIX, UNC_PREFIX +from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.utils.helpers import get_generalized_now from .dto import DomainControllerInfo @@ -87,14 +88,21 @@ async def get( class DCInfoReader: - def __init__(self, settings: Settings, gw: DomainReadProtocol) -> None: + def __init__( + self, + settings: Settings, + gw: DomainReadProtocol, + object_sid_use_case: ObjectSIDUseCase, + ) -> None: self._settings = settings self._gw = gw + self._object_sid_use_case = object_sid_use_case async def get(self) -> DomainControllerInfo: domain = await self._gw.get_domain() dns = domain.name.lower() nb_domain = dns.split(".")[0].upper() + object_sid = await self._object_sid_use_case.get(domain.id) return DomainControllerInfo( net_bios_domain=nb_domain, @@ -102,6 +110,6 @@ async def get(self) -> DomainControllerInfo: unc=UNC_PREFIX + dns, dns=dns, dns_forest=dns, - object_sid=domain.object_sid, + object_sid=object_sid, object_guid=str(domain.object_guid), ) diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py index f723440a6..9974effad 100644 --- a/app/ldap_protocol/utils/async_cache.py +++ b/app/ldap_protocol/utils/async_cache.py @@ -44,3 +44,5 @@ async def wrapper(*args: tuple, **kwargs: dict) -> T: base_directories_cache = AsyncTTLCache[list[Directory]]() +domain_identifier_cache = AsyncTTLCache[str]() +rid_set_id_cache = AsyncTTLCache[int]() diff --git a/app/ldap_protocol/utils/cte.py b/app/ldap_protocol/utils/cte.py index 7b4628254..6b9c513af 100644 --- a/app/ldap_protocol/utils/cte.py +++ b/app/ldap_protocol/utils/cte.py @@ -6,6 +6,7 @@ from sqlalchemy import exists, or_ from sqlalchemy.ext.asyncio import AsyncScalarResult, AsyncSession +from sqlalchemy.orm import selectinload from sqlalchemy.sql.expression import select from sqlalchemy.sql.selectable import CTE @@ -237,6 +238,10 @@ async def get_all_parent_group_directories( if not directories_ids: return None - query = select(Directory).where(directory_table.c.id.in_(directories_ids)) + query = ( + select(Directory) + .where(directory_table.c.id.in_(directories_ids)) + .options(selectinload(qa(Directory.attributes))) + ) return await session.stream_scalars(query) diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 296ff5978..e5db1444a 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -132,7 +132,6 @@ import functools import hashlib -import random import re import struct import time @@ -301,34 +300,6 @@ def string_to_sid(sid_string: str) -> bytes: return sid -def create_object_sid( - domain: Directory, - rid: int, - reserved: bool = False, -) -> str: - """Generate the objectSid attribute for an object. - - :param domain: domain directory - :param int rid: relative identifier - :param bool reserved: A flag indicating whether the RID is reserved. - If `True`, the given RID is used directly. If - `False`, 1000 is added to the given RID to generate - the final RID - :return str: the complete objectSid as a string - """ - return domain.object_sid + f"-{rid if reserved else 1000 + rid}" - - -def generate_domain_sid() -> str: - """Generate domain objectSid attr.""" - sub_authorities = [ - random.randint(1000000000, (1 << 32) - 1), - random.randint(1000000000, (1 << 32) - 1), - random.randint(100000000, 999999999), - ] - return "S-1-5-21-" + "-".join(str(part) for part in sub_authorities) - - def create_user_name(directory_id: int) -> str: """Create username by directory id. diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 2e078b840..d3380776e 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -21,7 +21,7 @@ from sqlalchemy.sql.expression import ColumnElement from entities import Attribute, Directory, Group, User -from enums import SamAccountTypeCodes +from enums import SamAccountTypeCodes, SidPrefix from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, AttributeValueValidatorError, @@ -36,7 +36,6 @@ from .const import EMAIL_RE, GRANT_DN_STRING from .helpers import ( create_integer_hash, - create_object_sid, dn_is_base_directory, ft_now, validate_entry, @@ -190,20 +189,43 @@ async def get_directory_by_rid( rid: str, session: AsyncSession, ) -> Directory | None: - """Get directory by relative ID (rid). - - :param str rid: relative ID - :param AsyncSession session: SA session - :return Directory | None: directory or None - """ query = ( select(Directory) - .options(joinedload(qa(Directory.group))) - .filter(qa(Directory.object_sid).endswith(f"-{rid}")) + .join( + Attribute, + qa(Attribute.directory_id) == qa(Directory.id), + ) + .options( + selectinload(qa(Directory.attributes)), + joinedload(qa(Directory.group)), + ) + .where( + qa(Attribute.name) == "objectSid", + qa(Attribute.value).endswith(f"-{rid}"), + ) ) return await session.scalar(query) +async def groups_include_primary_rid( + session: AsyncSession, + groups: list[Group], + primary_group_id: str, +) -> bool: + directory_ids = {g.directory_id for g in groups} + + stmt = ( + select(qa(Attribute.id)) + .where( + qa(Attribute.directory_id).in_(directory_ids), + qa(Attribute.name) == "objectSid", + qa(Attribute.value).endswith(f"-{primary_group_id}"), + ) + .limit(1) + ) + return await session.scalar(stmt) is not None + + async def get_groups(dn_list: list[str], session: AsyncSession) -> list[Group]: """Get dirs with groups by dn list.""" paths = [] @@ -225,6 +247,9 @@ async def get_groups(dn_list: list[str], session: AsyncSession) -> list[Group]: .options(selectinload(qa(Group.members))) .options( joinedload(qa(Group.directory)).selectinload(qa(Directory.groups)), + joinedload(qa(Group.directory)).selectinload( + qa(Directory.attributes), + ), ) ) @@ -247,7 +272,11 @@ async def get_group( query = ( select(Group) .join(qa(Group.directory), isouter=True) - .options(joinedload(qa(Group.directory))) + .options( + joinedload(qa(Group.directory)).selectinload( + qa(Directory.attributes), + ), + ) ) if validate_entry(dn): @@ -386,10 +415,12 @@ async def create_group( dir_.create_path(parent) session.add(group) - dir_.object_sid = create_object_sid( - base_dn_list[0], - rid=sid or dir_.id, - reserved=bool(sid), + session.add( + Attribute( + name="objectSid", + value=f"{SidPrefix.BUILT_IN_DOMAIN}-{sid or dir_.id}", + directory_id=dir_.id, + ), ) await session.flush() @@ -559,9 +590,16 @@ async def get_group_path_dn_by_primary_group_id( """ query = ( select(Directory) + .join( + Attribute, + qa(Attribute.directory_id) == qa(Directory.id), + ) .join(qa(Directory.group)) .options(contains_eager(qa(Directory.group))) - .filter(qa(Directory.object_sid).endswith(f"-{primary_group_id}")) + .where( + qa(Attribute.name) == "objectSid", + qa(Attribute.value).endswith(f"-{primary_group_id}"), + ) ) directory = await session.scalar(query) diff --git a/app/repo/pg/tables.py b/app/repo/pg/tables.py index be17cef6f..4c7b641fb 100644 --- a/app/repo/pg/tables.py +++ b/app/repo/pg/tables.py @@ -145,7 +145,6 @@ def _compile_create_uc( key="updated_at", ), Column("depth", Integer, nullable=True), - Column("objectSid", String, nullable=True, key="object_sid"), Column( "objectGUID", PG_UUID(as_uuid=True), @@ -781,7 +780,6 @@ def _compile_create_uc( ), "objectclass": synonym("object_class"), "objectguid": synonym("object_guid"), - "objectsid": synonym("object_sid"), "whencreated": synonym("created_at"), "whenchanged": synonym("updated_at"), }, diff --git a/interface b/interface index 046449cdd..61e15e236 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit 046449cdd568919cca12a7939366dcee7a54fdfa +Subproject commit 61e15e2367182a3e706c94cf8e1895d742840aa7 diff --git a/tests/conftest.py b/tests/conftest.py index 2d59cec3b..0e9afe287 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,7 +63,12 @@ from api.shadow.adapter import ShadowAdapter from authorization_provider_protocol import AuthorizationProviderProtocol from config import Settings -from constants import ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2 +from constants import ( + DOMAIN_CONTROLLERS_OU_NAME, + ENTITY_TYPE_DTOS_V1, + ENTITY_TYPE_DTOS_V2, +) +from entities import Directory from enums import AuthorizationRules from ioc import AuditRedisClient, MFACredsProvider, SessionStorageClient from ldap_protocol.auth import AuthManager, MFAManager @@ -174,6 +179,16 @@ PasswordBanWordUseCases, UserPasswordHistoryUseCases, ) +from ldap_protocol.rid_manager import ( + ObjectSIDGateway, + ObjectSIDUseCase, + RIDManagerGateway, + RIDManagerSetupGateway, + RIDManagerSetupUseCase, + RIDManagerUseCase, + RIDSetGateway, + RIDSetUseCase, +) from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import RoleDTO @@ -189,6 +204,10 @@ from ldap_protocol.server import PoolClientHandler from ldap_protocol.session_storage import RedisSessionStorage, SessionStorage from ldap_protocol.session_storage.repository import SessionRepository +from ldap_protocol.utils.async_cache import ( + domain_identifier_cache, + rid_set_id_cache, +) from ldap_protocol.utils.queries import get_user from password_utils import PasswordUtils from repo.pg.master_gateway import PGMasterGateway @@ -819,6 +838,20 @@ def authorization_provider_protocol( ) rootdse_reader = provide(RootDSEReader, scope=Scope.REQUEST) dcinfo_reader = provide(DCInfoReader, scope=Scope.REQUEST) + rid_manager_gateway = provide(RIDManagerGateway, scope=Scope.REQUEST) + rid_manager_use_case = provide(RIDManagerUseCase, scope=Scope.REQUEST) + rid_manager_setup_gateway = provide( + RIDManagerSetupGateway, + scope=Scope.REQUEST, + ) + rid_manager_setup_use_case = provide( + RIDManagerSetupUseCase, + scope=Scope.REQUEST, + ) + object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST) + object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST) + rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST) + rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST) @dataclass @@ -1025,8 +1058,11 @@ async def setup_session( session: AsyncSession, raw_audit_manager: RawAuditManager, password_utils: PasswordUtils, + settings: Settings, ) -> None: """Get session and acquire after completion.""" + domain_identifier_cache.clear() + rid_set_id_cache.clear() role_dao = RoleDAO(session) ace_dao = AccessControlEntryDAO(session) role_use_case = RoleUseCase(role_dao, ace_dao) @@ -1056,6 +1092,28 @@ async def setup_session( object_class_dao=object_class_dao, directory_dao=directory_dao, ) + rid_manager_gateway = RIDManagerGateway(session, settings) + + rid_manager_use_case = RIDManagerUseCase( + rid_manager_gateway, + session, + ) + rid_set_gateway = RIDSetGateway(session, settings) + + rid_set_use_case = RIDSetUseCase( + rid_set_gateway, + entity_type_use_case, + session, + rid_manager_use_case, + role_use_case, + ) + object_sid_gateway = ObjectSIDGateway(session) + object_sid_use_case = ObjectSIDUseCase( + object_sid_gateway, + rid_set_use_case, + session, + rid_manager_use_case, + ) directory_create_use_case = DirectoryCreateUseCase( session=session, entity_type_use_case=entity_type_use_case, @@ -1098,50 +1156,63 @@ async def setup_session( password_policy_validator, password_ban_word_repository, ) + + rid_manager_setup_gateway = RIDManagerSetupGateway( + session=session, + entity_type_use_case=entity_type_use_case, + ) + role_dao = RoleDAO(session) + ace_dao = AccessControlEntryDAO(session) + role_use_case = RoleUseCase(role_dao, ace_dao) + + rid_manager_setup_use_case = RIDManagerSetupUseCase( + rid_manager_setup_gateway=rid_manager_setup_gateway, + role_use_case=role_use_case, + rid_set_use_case=rid_set_use_case, + access_control_entry_dao=AccessControlEntryDAO(session), + settings=settings, + rid_manager_use_case=rid_manager_use_case, + ) setup_gateway = SetupGateway( session, password_utils, entity_type_use_case=entity_type_use_case, attribute_value_validator=attribute_value_validator, directory_dao=directory_dao, + object_sid_use_case=object_sid_use_case, ) - for entity_type_dto in chain(ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2): await entity_type_use_case.create_not_safe(entity_type_dto) - await session.flush() - await audit_use_case.create_policies() + domain = await setup_gateway.create_base_domain("md.test") + await rid_manager_setup_use_case.create_domain_identifier(domain.id) + await setup_gateway.setup_enviroment( - dn="md.test", + domain=domain, data=TEST_DATA, is_system=False, ) + dc_directory = Directory( + name=DOMAIN_CONTROLLERS_OU_NAME, + object_class="computer", + is_system=True, + ) + dc_directory.create_path(domain, "cn") + session.add(dc_directory) + await session.flush() + dc_directory.parent_id = domain.id + await session.refresh(dc_directory, ["id"]) + await session.flush() + dc = Directory( + name=settings.HOST_MACHINE_SHORT_NAME, + is_system=True, + ) - for _at_dto in ( - AttributeTypeDTO[None]( - oid="1.2.3.4.5.6.7.8", - name="attr_with_bvalue", - ldap_display_name="attrWithBvalue", - syntax="1.3.6.1.4.1.1466.115.121.1.40", # Octet String - single_value=True, - no_user_modification=False, - is_system=True, - system_flags=0, - is_included_anr=False, - ), - AttributeTypeDTO[None]( - oid="1.2.3.4.5.6.7.8.9", - name="testing_attr", - ldap_display_name="testingAttr", - syntax="1.3.6.1.4.1.1466.115.121.1.15", - single_value=True, - no_user_modification=False, - is_system=True, - system_flags=0, - is_included_anr=False, - ), - ): - await attribute_type_use_case.create(_at_dto) + dc.create_path(dc_directory, "cn") + session.add(dc) + await session.flush() + dc.parent_id = dc_directory.id + await session.refresh(dc, ["id"]) for attr_type_name in ( "description", @@ -1166,12 +1237,15 @@ async def setup_session( "organizationalPerson", "user", "domain", + "computer", "container", "organization", "domainDNS", "group", "inetOrgPerson", "posixAccount", + "rIDManager", + "rIDSet", ): _oc_dto = await object_class_use_case_legacy.get(_obj_class_name) _oc_dto.attribute_types_may = [ @@ -1184,11 +1258,43 @@ async def setup_session( ] await object_class_use_case.create(_oc_dto) # type: ignore + await session.flush() + + for _at_dto in ( + AttributeTypeDTO[None]( + oid="1.2.3.4.5.6.7.8", + name="attr_with_bvalue", + ldap_display_name="attrWithBvalue", + syntax="1.3.6.1.4.1.1466.115.121.1.40", # Octet String + single_value=True, + no_user_modification=False, + is_system=True, + system_flags=0, + is_included_anr=False, + ), + AttributeTypeDTO[None]( + oid="1.2.3.4.5.6.7.8.9", + name="testing_attr", + ldap_display_name="testingAttr", + syntax="1.3.6.1.4.1.1466.115.121.1.15", + single_value=True, + no_user_modification=False, + is_system=True, + system_flags=0, + is_included_anr=False, + ), + ): + await attribute_type_use_case.create(_at_dto) + + await audit_use_case.create_policies() + # NOTE: after setup environment we need base DN to be created await password_use_cases.create_default_domain_policy() await role_use_case.create_domain_admins_role() + await rid_manager_setup_use_case.setup() + await role_use_case._role_dao.create( # noqa: SLF001 dto=RoleDTO( name="TEST ONLY LOGIN ROLE", @@ -1709,6 +1815,87 @@ async def ctx_search( yield await c.get(LDAPSearchRequestContext) +@pytest_asyncio.fixture(scope="function") +async def rid_manager_gateway( + container: AsyncContainer, + settings: Settings, +) -> AsyncIterator[RIDManagerGateway]: + """Get RID Manager gateway.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDManagerGateway(session, settings) + + +@pytest_asyncio.fixture(scope="function") +async def rid_manager_use_case( + container: AsyncContainer, + rid_manager_gateway: RIDManagerGateway, +) -> AsyncIterator[RIDManagerUseCase]: + """Provide RIDManagerUseCase for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDManagerUseCase(rid_manager_gateway, session) + + +@pytest_asyncio.fixture(scope="function") +async def rid_set_gateway( + container: AsyncContainer, + settings: Settings, +) -> AsyncIterator[RIDSetGateway]: + """Provide RIDSetGateway for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDSetGateway(session, settings) + + +@pytest_asyncio.fixture(scope="function") +async def rid_set_use_case( + container: AsyncContainer, + rid_manager_use_case: RIDManagerUseCase, + entity_type_use_case: EntityTypeUseCase, + rid_set_gateway: RIDSetGateway, + role_use_case: RoleUseCase, +) -> AsyncIterator[RIDSetUseCase]: + """Provide RIDManagerUseCase for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDSetUseCase( + rid_set_gateway, + entity_type_use_case, + session, + rid_manager_use_case, + role_use_case, + ) + + +@pytest_asyncio.fixture(scope="function") +async def object_sid_gateway( + container: AsyncContainer, +) -> AsyncIterator[ObjectSIDGateway]: + """Provide ObjectSIDGateway for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield ObjectSIDGateway(session) + + +@pytest_asyncio.fixture(scope="function") +async def object_sid_use_case( + container: AsyncContainer, + rid_manager_use_case: RIDManagerUseCase, + rid_set_use_case: RIDSetUseCase, + object_sid_gateway: ObjectSIDGateway, +) -> AsyncIterator[ObjectSIDUseCase]: + """Provide RIDManagerUseCase for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield ObjectSIDUseCase( + object_sid_gateway, + rid_set_use_case, + session, + rid_manager_use_case, + ) + + def pytest_configure(config: pytest.Config) -> None: """Pytest hook to limit xdist workers based on Dragonfly DBs. diff --git a/tests/constants.py b/tests/constants.py index 68e980383..50345f139 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -10,9 +10,10 @@ DOMAIN_COMPUTERS_GROUP_NAME, DOMAIN_USERS_GROUP_NAME, GROUPS_CONTAINER_NAME, + SYSTEM_CONTAINER_NAME, USERS_CONTAINER_NAME, ) -from enums import EntityTypeNames, SamAccountTypeCodes +from enums import EntityTypeNames, SamAccountTypeCodes, SecurityPrincipalRid from ldap_protocol.objects import UserAccountControlFlag user_data_dict = { @@ -66,7 +67,7 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, - "objectSid": 512, + "objectSid": SecurityPrincipalRid.DOMAIN_ADMINS, }, { "name": "developers", @@ -82,6 +83,7 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, + "objectSid": 999, }, { "name": "admin login only", @@ -464,6 +466,14 @@ "entity_type_name": EntityTypeNames.CONFIGURATION, "object_class": "container", "attributes": {"objectClass": ["top", "configuration"]}, + }, + { + "name": SYSTEM_CONTAINER_NAME, + "entity_type_name": EntityTypeNames.ORGANIZATIONAL_UNIT, + "object_class": "organizationalUnit", + "attributes": { + "objectClass": ["top", "container"], + }, "children": [], }, ] diff --git a/tests/test_api/test_main/test_router/conftest.py b/tests/test_api/test_main/test_router/conftest.py index dc26b0577..c7f2427cd 100644 --- a/tests/test_api/test_main/test_router/conftest.py +++ b/tests/test_api/test_main/test_router/conftest.py @@ -19,6 +19,7 @@ from ldap_protocol.ldap_schema.object_class.object_class_dao import ( ObjectClassDAO, ) +from ldap_protocol.rid_manager.object_sid_use_case import ObjectSIDUseCase from ldap_protocol.utils.queries import get_base_directories from password_utils import PasswordUtils from tests.constants import TEST_SYSTEM_ADMIN_DATA @@ -29,6 +30,7 @@ async def add_system_administrator( session: AsyncSession, password_utils: PasswordUtils, setup_session: None, # noqa: ARG001 + object_sid_use_case: ObjectSIDUseCase, ) -> None: """Create system administrator user for tests that require it.""" attribute_value_validator = AttributeValueValidator() @@ -51,6 +53,7 @@ async def add_system_administrator( entity_type_use_case, attribute_value_validator=attribute_value_validator, directory_dao=directory_dao, + object_sid_use_case=object_sid_use_case, ) domain = (await get_base_directories(session))[0] diff --git a/tests/test_api/test_main/test_router/test_modify_dn.py b/tests/test_api/test_main/test_router/test_modify_dn.py index 8313049f5..af0bb83ef 100644 --- a/tests/test_api/test_main/test_router/test_modify_dn.py +++ b/tests/test_api/test_main/test_router/test_modify_dn.py @@ -6,6 +6,7 @@ import pytest from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession from ldap_protocol.ldap_codes import LDAPCodes @@ -15,6 +16,7 @@ @pytest.mark.usefixtures("session") async def test_api_modify_dn_without_level_change( http_client: AsyncClient, + session: AsyncSession, ) -> None: """Test API for updating DN. @@ -40,7 +42,7 @@ async def test_api_modify_dn_without_level_change( data["search_result"][0]["object_name"] == "ou=testModifyDn1,dc=md,dc=test" ) - + session.expire_all() response = await http_client.put( "/entry/update/dn", json={ @@ -83,6 +85,7 @@ async def test_api_modify_dn_without_level_change( @pytest.mark.usefixtures("session") async def test_api_modify_dn_with_level_down( http_client: AsyncClient, + session: AsyncSession, ) -> None: """Test API for updating DN. @@ -109,6 +112,8 @@ async def test_api_modify_dn_with_level_down( == "cn=testGroup1,ou=testModifyDn2,ou=testModifyDn1,dc=md,dc=test" ) + session.expire_all() + response = await http_client.put( "/entry/update/dn", json={ @@ -151,6 +156,7 @@ async def test_api_modify_dn_with_level_down( @pytest.mark.usefixtures("session") async def test_api_modify_dn_with_level_up( http_client: AsyncClient, + session: AsyncSession, ) -> None: """Test API for updating DN. @@ -177,6 +183,8 @@ async def test_api_modify_dn_with_level_up( == "cn=testGroup2,ou=testModifyDn1,dc=md,dc=test" ) + session.expire_all() + response = await http_client.put( "/entry/update/dn", json={ @@ -217,7 +225,10 @@ async def test_api_modify_dn_with_level_up( @pytest.mark.asyncio @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") -async def test_api_correct_update_dn(http_client: AsyncClient) -> None: +async def test_api_correct_update_dn( + http_client: AsyncClient, + session: AsyncSession, +) -> None: """Test API for update DN.""" old_user_dn = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" newrdn_user = "cn=new_test2" @@ -254,6 +265,8 @@ async def test_api_correct_update_dn(http_client: AsyncClient) -> None: if attr["type"] == "cn": assert attr["vals"] == ["user1"] + session.expire_all() + response = await http_client.put( "/entry/update/dn", json={ @@ -336,7 +349,10 @@ async def test_api_correct_update_dn(http_client: AsyncClient) -> None: @pytest.mark.asyncio @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") -async def test_api_update_dn_with_parent(http_client: AsyncClient) -> None: +async def test_api_update_dn_with_parent( + http_client: AsyncClient, + session: AsyncSession, +) -> None: """Test API for update DN.""" old_user_dn = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" new_user_dn = "cn=new_test2,cn=Users,dc=md,dc=test" @@ -368,6 +384,8 @@ async def test_api_update_dn_with_parent(http_client: AsyncClient) -> None: assert groups_user + session.expire_all() + response = await http_client.put( "/entry/update/dn", json={ diff --git a/tests/test_api/test_main/test_router/test_search.py b/tests/test_api/test_main/test_router/test_search.py index c4c604ed1..18aa019a8 100644 --- a/tests/test_api/test_main/test_router/test_search.py +++ b/tests/test_api/test_main/test_router/test_search.py @@ -132,6 +132,7 @@ async def test_api_search(http_client: AsyncClient) -> None: sub_dirs = { "cn=Groups,dc=md,dc=test", "cn=Configuration,dc=md,dc=test", + "ou=System,dc=md,dc=test", "cn=Users,dc=md,dc=test", "ou=testModifyDn1,dc=md,dc=test", "ou=testModifyDn3,dc=md,dc=test", @@ -662,7 +663,7 @@ async def test_api_get_group_path_dn_by_primary_group_id_not_found( http_client: AsyncClient, ) -> None: """Test api get group path DN by primary group id not found.""" - primary_group_id = 513 + primary_group_id = 5135 response = await http_client.get( f"entry/group/primary/{primary_group_id}", ) diff --git a/tests/test_ldap/test_object_sid.py b/tests/test_ldap/test_object_sid.py new file mode 100644 index 000000000..9432933bc --- /dev/null +++ b/tests/test_ldap/test_object_sid.py @@ -0,0 +1,147 @@ +"""Tests for RID Manager.""" + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from enums import SidPrefix +from ldap_protocol.rid_manager import RIDManagerUseCase +from ldap_protocol.rid_manager.object_sid_gateway import ObjectSIDGateway +from ldap_protocol.rid_manager.object_sid_use_case import ObjectSIDUseCase +from ldap_protocol.rid_manager.rid_manager_gateway import RIDManagerGateway +from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway +from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase +from ldap_protocol.rid_manager.utils import from_qword, to_qword + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_rid_manager_allocate_pool( + rid_manager_use_case: RIDManagerUseCase, + rid_manager_gateway: RIDManagerGateway, +) -> None: + """Test RID Manager get domain controller.""" + available_pool = await rid_manager_gateway.get_rid_available_pool() + + await rid_manager_use_case.allocate_pool() + new_available_pool = await rid_manager_gateway.get_rid_available_pool() + lower, _ = from_qword(available_pool) + new_lower, _ = from_qword(new_available_pool) + + assert new_lower == lower + RIDManagerUseCase.RID_BLOCK_SIZE + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_next_rid( + rid_set_use_case: RIDSetUseCase, +) -> None: + """Test RID Manager get domain controller.""" + rid_set_id = await rid_set_use_case.get_rid_set_id() + next_rid = await rid_set_use_case.allocate_next_rid(rid_set_id) + new_next_rid = await rid_set_use_case.allocate_next_rid(rid_set_id) + assert new_next_rid == next_rid + 1 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_rid_set_reset_pool( + rid_set_use_case: RIDSetUseCase, + rid_manager_gateway: RIDManagerGateway, + rid_set_gateway: RIDSetGateway, +) -> None: + """Test RID Set pool reset.""" + rid_set_id = await rid_set_use_case.get_rid_set_id() + + available_pool_before = await rid_manager_gateway.get_rid_available_pool() + lower_before, _ = from_qword(available_pool_before) + allocation_pool_before = await rid_set_gateway.get_rid_allocation_pool( + rid_set_id, + ) + previous_pool_before = ( + await rid_set_gateway.get_rid_previous_allocation_pool(rid_set_id) + ) + + _, upper = from_qword(previous_pool_before) + await rid_set_gateway.update_next_rid(rid_set_id, upper) + + current_next_rid = await rid_set_gateway.get_rid_next_rid(rid_set_id) + assert ( + rid_set_use_case.is_pool_exceeded( + current_next_rid, + previous_pool_before, + ) + is True + ) + + await rid_set_use_case.allocate_next_rid(rid_set_id) + current_next_rid = await rid_set_gateway.get_rid_next_rid(rid_set_id) + previous_pool_mid = await rid_set_gateway.get_rid_previous_allocation_pool( + rid_set_id, + ) + assert ( + rid_set_use_case.is_pool_exceeded( + current_next_rid, + previous_pool_mid, + ) + is False + ) + + available_pool_after = await rid_manager_gateway.get_rid_available_pool() + lower_after, _ = from_qword(available_pool_after) + allocation_pool_after = await rid_set_gateway.get_rid_allocation_pool( + rid_set_id, + ) + previous_pool_after = ( + await rid_set_gateway.get_rid_previous_allocation_pool( + rid_set_id, + ) + ) + + assert lower_after == lower_before + RIDManagerUseCase.RID_BLOCK_SIZE + assert allocation_pool_after == to_qword( + lower_before, + lower_before + RIDManagerUseCase.RID_BLOCK_SIZE, + ) + assert previous_pool_after == allocation_pool_before + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_object_sid_add_updates_next_rid_and_prefix( + session: AsyncSession, + object_sid_use_case: ObjectSIDUseCase, + object_sid_gateway: ObjectSIDGateway, + rid_set_use_case: RIDSetUseCase, + rid_set_gateway: RIDSetGateway, + rid_manager_use_case: RIDManagerUseCase, +) -> None: + dc = await rid_manager_use_case.get_domain_controller() + rid_set_id = await rid_set_use_case.get_rid_set_id() + dc_id = dc.id + + next_before = await rid_set_gateway.get_rid_next_rid(rid_set_id) + + await object_sid_use_case.add(directory_id=dc_id) + await session.flush() + next_after = await rid_set_gateway.get_rid_next_rid(rid_set_id) + assert next_after == next_before + 1 + + sid_domain_attr = await object_sid_gateway.get(dc_id) + assert sid_domain_attr.startswith("S-1-5-21-") + + await object_sid_use_case.add( + directory_id=rid_set_id, + rid=512, + sid_prefix=SidPrefix.BUILT_IN_DOMAIN, + ) + await session.flush() + next_after_builtin = await rid_set_gateway.get_rid_next_rid(rid_set_id) + assert next_after_builtin == next_after + + sid_builtin_attr = await object_sid_gateway.get(rid_set_id) + assert sid_builtin_attr.startswith("S-1-5-32-") + assert sid_builtin_attr != sid_domain_attr diff --git a/tests/test_ldap/test_roles/test_search.py b/tests/test_ldap/test_roles/test_search.py index 54e9a0641..55a4d8532 100644 --- a/tests/test_ldap/test_roles/test_search.py +++ b/tests/test_ldap/test_roles/test_search.py @@ -108,9 +108,10 @@ async def test_role_search_3( "dn: cn=Groups,dc=md,dc=test", "dn: cn=Users,dc=md,dc=test", "dn: cn=user_non_admin,cn=Users,dc=md,dc=test", - "dn: ou=test_bit_rules,dc=md,dc=test", + "dn: ou=System,dc=md,dc=test", "dn: ou=testModifyDn1,dc=md,dc=test", "dn: ou=testModifyDn3,dc=md,dc=test", + "dn: ou=test_bit_rules,dc=md,dc=test", ], expected_attrs_present=[], expected_attrs_absent=[], diff --git a/tests/test_ldap/test_util/test_modify.py b/tests/test_ldap/test_util/test_modify.py index adf89c1b6..d3db50388 100644 --- a/tests/test_ldap/test_util/test_modify.py +++ b/tests/test_ldap/test_util/test_modify.py @@ -968,7 +968,9 @@ async def fetch_directory_by_dn(session: AsyncSession, dn: str) -> Directory: query = ( select(Directory) .options( - selectinload(qa(Directory.groups)).joinedload(qa(Group.directory)), + selectinload(qa(Directory.groups)) + .joinedload(qa(Group.directory)) + .selectinload(qa(Directory.attributes)), selectinload(qa(Directory.attributes)), joinedload(qa(Directory.group)), ) @@ -1062,7 +1064,7 @@ async def test_ldap_modify_primary_group_id_scenarios( attributes[attr.name].append(attr.value) if expected_primary_group: - assert attributes["primaryGroupID"] == [group_dir.relative_id] + assert attributes["primaryGroupID"] == [rid] else: assert "primaryGroupID" not in attributes diff --git a/tests/test_shedule.py b/tests/test_shedule.py index a952b94e1..3a7cc5dcb 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -17,6 +17,9 @@ from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) +from ldap_protocol.rid_manager import ObjectSIDUseCase +from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase +from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase from ldap_protocol.roles.role_use_case import RoleUseCase @@ -88,11 +91,42 @@ async def test_add_domain_controller( settings: Settings, role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, + object_sid_use_case: ObjectSIDUseCase, + rid_manager_use_case: RIDManagerUseCase, + rid_set_use_case: RIDSetUseCase, + monkeypatch: pytest.MonkeyPatch, ) -> None: """Test add domain controller.""" + existing_dc = await rid_manager_use_case.get_domain_controller() + existing_rid_set_id = await rid_set_use_case.get_rid_set_id() + monkeypatch.setattr( + settings, + "HOST_MACHINE_SHORT_NAME", + f"{settings.HOST_MACHINE_SHORT_NAME}-test", + ) + + async def _get_existing_dc() -> object: + return existing_dc + + monkeypatch.setattr( + object_sid_use_case._rid_manager_use_case, # noqa: SLF001 + "get_domain_controller", + _get_existing_dc, + ) + + async def _get_existing_rid_set_id() -> int: + return existing_rid_set_id + + monkeypatch.setattr( + rid_set_use_case, + "get_rid_set_id", + _get_existing_rid_set_id, + ) await add_domain_controller( settings=settings, session=session, role_use_case=role_use_case, entity_type_use_case=entity_type_use_case, + object_sid_use_case=object_sid_use_case, + rid_set_use_case=rid_set_use_case, )