222 lines
7.1 KiB
Python
222 lines
7.1 KiB
Python
"""
|
|
Repository layer for User database operations.
|
|
Handles CRUD operations and user provisioning on first login.
|
|
"""
|
|
from typing import Optional, List
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
from src.db_models import User, Role
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default role for new users (editor = role_id 1)
|
|
DEFAULT_ROLE_NAME = "editor"
|
|
|
|
|
|
class RoleRepository:
|
|
"""Repository for Role-related database operations."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def get_by_name(self, role_name: str) -> Optional[Role]:
|
|
"""Get a role by name."""
|
|
result = await self.session.execute(
|
|
select(Role).where(Role.role_name == role_name)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_id(self, role_id: int) -> Optional[Role]:
|
|
"""Get a role by ID."""
|
|
result = await self.session.execute(
|
|
select(Role).where(Role.id == role_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def create(self, role_name: str) -> Role:
|
|
"""Create a new role."""
|
|
role = Role(role_name=role_name)
|
|
self.session.add(role)
|
|
await self.session.flush()
|
|
await self.session.refresh(role)
|
|
return role
|
|
|
|
async def get_all(self) -> List[Role]:
|
|
"""
|
|
Get all roles.
|
|
|
|
Returns:
|
|
List of all roles
|
|
"""
|
|
result = await self.session.execute(
|
|
select(Role).order_by(Role.id)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def ensure_default_roles_exist(self) -> None:
|
|
"""
|
|
Ensure default roles exist in the database.
|
|
Called during application startup.
|
|
Creates roles in order: editor (1), auditor (2), admin (3)
|
|
"""
|
|
# Order matters for role IDs: editor=1, auditor=2, admin=3
|
|
default_roles = ["editor", "auditor", "admin"]
|
|
|
|
for role_name in default_roles:
|
|
existing = await self.get_by_name(role_name)
|
|
if existing is None:
|
|
logger.info(f"Creating default role: {role_name}")
|
|
await self.create(role_name)
|
|
|
|
|
|
class UserRepository:
|
|
"""Repository for User-related database operations."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def get_by_sub(self, sub: str) -> Optional[User]:
|
|
"""
|
|
Get a user by their Keycloak subject ID.
|
|
|
|
Args:
|
|
sub: The Keycloak subject ID (unique identifier)
|
|
|
|
Returns:
|
|
User if found, None otherwise
|
|
"""
|
|
result = await self.session.execute(
|
|
select(User)
|
|
.options(selectinload(User.role))
|
|
.where(User.sub == sub)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_id(self, user_id: int) -> Optional[User]:
|
|
"""
|
|
Get a user by their database ID.
|
|
|
|
Args:
|
|
user_id: The database user ID
|
|
|
|
Returns:
|
|
User if found, None otherwise
|
|
"""
|
|
result = await self.session.execute(
|
|
select(User)
|
|
.options(selectinload(User.role))
|
|
.where(User.id == user_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def create(self, sub: str, role_id: int, username: str, full_name: str | None = None) -> User:
|
|
"""
|
|
Create a new user.
|
|
|
|
Args:
|
|
sub: The Keycloak subject ID
|
|
role_id: The role ID to assign
|
|
username: The Keycloak preferred_username
|
|
full_name: The Keycloak name claim (optional)
|
|
|
|
Returns:
|
|
The created User
|
|
"""
|
|
user = User(sub=sub, role_id=role_id, username=username, full_name=full_name)
|
|
self.session.add(user)
|
|
await self.session.flush()
|
|
await self.session.refresh(user)
|
|
return user
|
|
|
|
async def update_role(self, user_id: int, role_id: int) -> Optional[User]:
|
|
"""
|
|
Update a user's role.
|
|
|
|
Args:
|
|
user_id: The user ID to update
|
|
role_id: The new role ID
|
|
|
|
Returns:
|
|
The updated User, or None if not found
|
|
"""
|
|
user = await self.get_by_id(user_id)
|
|
if not user:
|
|
return None
|
|
|
|
user.role_id = role_id
|
|
await self.session.flush()
|
|
await self.session.refresh(user)
|
|
return user
|
|
|
|
async def update_profile(self, user: User, username: str, full_name: str | None = None) -> User:
|
|
"""
|
|
Update a user's profile info (username and full_name) from Keycloak.
|
|
Called on subsequent logins to sync changes from Keycloak.
|
|
|
|
Args:
|
|
user: The user to update
|
|
username: The Keycloak preferred_username
|
|
full_name: The Keycloak name claim (optional)
|
|
|
|
Returns:
|
|
The updated User
|
|
"""
|
|
user.username = username
|
|
user.full_name = full_name
|
|
await self.session.flush()
|
|
await self.session.refresh(user)
|
|
return user
|
|
|
|
async def get_or_create_default_role(self) -> Role:
|
|
"""
|
|
Get the default user role, creating it if it doesn't exist.
|
|
|
|
Returns:
|
|
The default Role
|
|
"""
|
|
role_repo = RoleRepository(self.session)
|
|
role = await role_repo.get_by_name(DEFAULT_ROLE_NAME)
|
|
|
|
if role is None:
|
|
logger.info(f"Creating default role: {DEFAULT_ROLE_NAME}")
|
|
role = await role_repo.create(DEFAULT_ROLE_NAME)
|
|
|
|
return role
|
|
|
|
async def get_or_create_user(self, sub: str, username: str, full_name: str | None = None) -> tuple[User, bool]:
|
|
"""
|
|
Get an existing user or create a new one (Just-in-Time Provisioning).
|
|
This is the main method called during login.
|
|
Also updates username/full_name on subsequent logins to sync with Keycloak.
|
|
|
|
Args:
|
|
sub: The Keycloak subject ID
|
|
username: The Keycloak preferred_username
|
|
full_name: The Keycloak name claim (optional)
|
|
|
|
Returns:
|
|
Tuple of (User, created) where created is True if a new user was created
|
|
"""
|
|
# Check if user already exists
|
|
user = await self.get_by_sub(sub)
|
|
|
|
if user is not None:
|
|
logger.debug(f"Found existing user with sub: {sub}")
|
|
# Update profile info on subsequent logins to sync with Keycloak
|
|
user = await self.update_profile(user, username, full_name)
|
|
return user, False
|
|
|
|
# User doesn't exist, create them with default role
|
|
logger.info(f"Creating new user with sub: {sub}")
|
|
|
|
# Get or create the default role
|
|
default_role = await self.get_or_create_default_role()
|
|
|
|
# Create the user
|
|
user = await self.create(sub=sub, role_id=default_role.id, username=username, full_name=full_name)
|
|
|
|
logger.info(f"Created new user with id: {user.id}, sub: {sub}, username: {username}")
|
|
return user, True
|