""" Repository layer for User database operations. Handles CRUD operations and user provisioning on first login. """ from typing import Optional, List from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from src.db_models import User, Role import logging logger = logging.getLogger(__name__) # Default role for new users (editor = role_id 1) DEFAULT_ROLE_NAME = "editor" class RoleRepository: """Repository for Role-related database operations.""" def __init__(self, session: AsyncSession): self.session = session async def get_by_name(self, role_name: str) -> Optional[Role]: """Get a role by name.""" result = await self.session.execute( select(Role).where(Role.role_name == role_name) ) return result.scalar_one_or_none() async def get_by_id(self, role_id: int) -> Optional[Role]: """Get a role by ID.""" result = await self.session.execute( select(Role).where(Role.id == role_id) ) return result.scalar_one_or_none() async def create(self, role_name: str) -> Role: """Create a new role.""" role = Role(role_name=role_name) self.session.add(role) await self.session.flush() await self.session.refresh(role) return role async def get_all(self) -> List[Role]: """ Get all roles. Returns: List of all roles """ result = await self.session.execute( select(Role).order_by(Role.id) ) return list(result.scalars().all()) async def ensure_default_roles_exist(self) -> None: """ Ensure default roles exist in the database. Called during application startup. Creates roles in order: editor (1), auditor (2), admin (3) """ # Order matters for role IDs: editor=1, auditor=2, admin=3 default_roles = ["editor", "auditor", "admin"] for role_name in default_roles: existing = await self.get_by_name(role_name) if existing is None: logger.info(f"Creating default role: {role_name}") await self.create(role_name) class UserRepository: """Repository for User-related database operations.""" def __init__(self, session: AsyncSession): self.session = session async def get_by_sub(self, sub: str) -> Optional[User]: """ Get a user by their Keycloak subject ID. Args: sub: The Keycloak subject ID (unique identifier) Returns: User if found, None otherwise """ result = await self.session.execute( select(User) .options(selectinload(User.role)) .where(User.sub == sub) ) return result.scalar_one_or_none() async def get_by_id(self, user_id: int) -> Optional[User]: """ Get a user by their database ID. Args: user_id: The database user ID Returns: User if found, None otherwise """ result = await self.session.execute( select(User) .options(selectinload(User.role)) .where(User.id == user_id) ) return result.scalar_one_or_none() async def create(self, sub: str, role_id: int, username: str, full_name: str | None = None) -> User: """ Create a new user. Args: sub: The Keycloak subject ID role_id: The role ID to assign username: The Keycloak preferred_username full_name: The Keycloak name claim (optional) Returns: The created User """ user = User(sub=sub, role_id=role_id, username=username, full_name=full_name) self.session.add(user) await self.session.flush() await self.session.refresh(user) return user async def update_role(self, user_id: int, role_id: int) -> Optional[User]: """ Update a user's role. Args: user_id: The user ID to update role_id: The new role ID Returns: The updated User, or None if not found """ user = await self.get_by_id(user_id) if not user: return None user.role_id = role_id await self.session.flush() await self.session.refresh(user) return user async def update_profile(self, user: User, username: str, full_name: str | None = None) -> User: """ Update a user's profile info (username and full_name) from Keycloak. Called on subsequent logins to sync changes from Keycloak. Args: user: The user to update username: The Keycloak preferred_username full_name: The Keycloak name claim (optional) Returns: The updated User """ user.username = username user.full_name = full_name await self.session.flush() await self.session.refresh(user) return user async def get_or_create_default_role(self) -> Role: """ Get the default user role, creating it if it doesn't exist. Returns: The default Role """ role_repo = RoleRepository(self.session) role = await role_repo.get_by_name(DEFAULT_ROLE_NAME) if role is None: logger.info(f"Creating default role: {DEFAULT_ROLE_NAME}") role = await role_repo.create(DEFAULT_ROLE_NAME) return role async def get_or_create_user(self, sub: str, username: str, full_name: str | None = None) -> tuple[User, bool]: """ Get an existing user or create a new one (Just-in-Time Provisioning). This is the main method called during login. Also updates username/full_name on subsequent logins to sync with Keycloak. Args: sub: The Keycloak subject ID username: The Keycloak preferred_username full_name: The Keycloak name claim (optional) Returns: Tuple of (User, created) where created is True if a new user was created """ # Check if user already exists user = await self.get_by_sub(sub) if user is not None: logger.debug(f"Found existing user with sub: {sub}") # Update profile info on subsequent logins to sync with Keycloak user = await self.update_profile(user, username, full_name) return user, False # User doesn't exist, create them with default role logger.info(f"Creating new user with sub: {sub}") # Get or create the default role default_role = await self.get_or_create_default_role() # Create the user user = await self.create(sub=sub, role_id=default_role.id, username=username, full_name=full_name) logger.info(f"Created new user with id: {user.id}, sub: {sub}, username: {username}") return user, True