Files
periodic-table/backend/src/repositories/user_repository.py

222 lines
7.1 KiB
Python

"""
Repository layer for User database operations.
Handles CRUD operations and user provisioning on first login.
"""
from typing import Optional, List
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from src.db_models import User, Role
import logging
logger = logging.getLogger(__name__)
# Default role for new users (editor = role_id 1)
DEFAULT_ROLE_NAME = "editor"
class RoleRepository:
"""Repository for Role-related database operations."""
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_name(self, role_name: str) -> Optional[Role]:
"""Get a role by name."""
result = await self.session.execute(
select(Role).where(Role.role_name == role_name)
)
return result.scalar_one_or_none()
async def get_by_id(self, role_id: int) -> Optional[Role]:
"""Get a role by ID."""
result = await self.session.execute(
select(Role).where(Role.id == role_id)
)
return result.scalar_one_or_none()
async def create(self, role_name: str) -> Role:
"""Create a new role."""
role = Role(role_name=role_name)
self.session.add(role)
await self.session.flush()
await self.session.refresh(role)
return role
async def get_all(self) -> List[Role]:
"""
Get all roles.
Returns:
List of all roles
"""
result = await self.session.execute(
select(Role).order_by(Role.id)
)
return list(result.scalars().all())
async def ensure_default_roles_exist(self) -> None:
"""
Ensure default roles exist in the database.
Called during application startup.
Creates roles in order: editor (1), auditor (2), admin (3)
"""
# Order matters for role IDs: editor=1, auditor=2, admin=3
default_roles = ["editor", "auditor", "admin"]
for role_name in default_roles:
existing = await self.get_by_name(role_name)
if existing is None:
logger.info(f"Creating default role: {role_name}")
await self.create(role_name)
class UserRepository:
"""Repository for User-related database operations."""
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_sub(self, sub: str) -> Optional[User]:
"""
Get a user by their Keycloak subject ID.
Args:
sub: The Keycloak subject ID (unique identifier)
Returns:
User if found, None otherwise
"""
result = await self.session.execute(
select(User)
.options(selectinload(User.role))
.where(User.sub == sub)
)
return result.scalar_one_or_none()
async def get_by_id(self, user_id: int) -> Optional[User]:
"""
Get a user by their database ID.
Args:
user_id: The database user ID
Returns:
User if found, None otherwise
"""
result = await self.session.execute(
select(User)
.options(selectinload(User.role))
.where(User.id == user_id)
)
return result.scalar_one_or_none()
async def create(self, sub: str, role_id: int, username: str, full_name: str | None = None) -> User:
"""
Create a new user.
Args:
sub: The Keycloak subject ID
role_id: The role ID to assign
username: The Keycloak preferred_username
full_name: The Keycloak name claim (optional)
Returns:
The created User
"""
user = User(sub=sub, role_id=role_id, username=username, full_name=full_name)
self.session.add(user)
await self.session.flush()
await self.session.refresh(user)
return user
async def update_role(self, user_id: int, role_id: int) -> Optional[User]:
"""
Update a user's role.
Args:
user_id: The user ID to update
role_id: The new role ID
Returns:
The updated User, or None if not found
"""
user = await self.get_by_id(user_id)
if not user:
return None
user.role_id = role_id
await self.session.flush()
await self.session.refresh(user)
return user
async def update_profile(self, user: User, username: str, full_name: str | None = None) -> User:
"""
Update a user's profile info (username and full_name) from Keycloak.
Called on subsequent logins to sync changes from Keycloak.
Args:
user: The user to update
username: The Keycloak preferred_username
full_name: The Keycloak name claim (optional)
Returns:
The updated User
"""
user.username = username
user.full_name = full_name
await self.session.flush()
await self.session.refresh(user)
return user
async def get_or_create_default_role(self) -> Role:
"""
Get the default user role, creating it if it doesn't exist.
Returns:
The default Role
"""
role_repo = RoleRepository(self.session)
role = await role_repo.get_by_name(DEFAULT_ROLE_NAME)
if role is None:
logger.info(f"Creating default role: {DEFAULT_ROLE_NAME}")
role = await role_repo.create(DEFAULT_ROLE_NAME)
return role
async def get_or_create_user(self, sub: str, username: str, full_name: str | None = None) -> tuple[User, bool]:
"""
Get an existing user or create a new one (Just-in-Time Provisioning).
This is the main method called during login.
Also updates username/full_name on subsequent logins to sync with Keycloak.
Args:
sub: The Keycloak subject ID
username: The Keycloak preferred_username
full_name: The Keycloak name claim (optional)
Returns:
Tuple of (User, created) where created is True if a new user was created
"""
# Check if user already exists
user = await self.get_by_sub(sub)
if user is not None:
logger.debug(f"Found existing user with sub: {sub}")
# Update profile info on subsequent logins to sync with Keycloak
user = await self.update_profile(user, username, full_name)
return user, False
# User doesn't exist, create them with default role
logger.info(f"Creating new user with sub: {sub}")
# Get or create the default role
default_role = await self.get_or_create_default_role()
# Create the user
user = await self.create(sub=sub, role_id=default_role.id, username=username, full_name=full_name)
logger.info(f"Created new user with id: {user.id}, sub: {sub}, username: {username}")
return user, True