Added DB connection and started creating api calls for the pages
This commit is contained in:
166
backend/src/repositories/user_repository.py
Normal file
166
backend/src/repositories/user_repository.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Repository layer for User database operations.
|
||||
Handles CRUD operations and user provisioning on first login.
|
||||
"""
|
||||
from typing import Optional
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from src.db_models import User, Role
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default role for new users
|
||||
DEFAULT_ROLE_NAME = "user"
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Repository for User-related database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_sub(self, sub: str) -> Optional[User]:
|
||||
"""
|
||||
Get a user by their Keycloak subject ID.
|
||||
|
||||
Args:
|
||||
sub: The Keycloak subject ID (unique identifier)
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(User)
|
||||
.options(selectinload(User.role))
|
||||
.where(User.sub == sub)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_id(self, user_id: int) -> Optional[User]:
|
||||
"""
|
||||
Get a user by their database ID.
|
||||
|
||||
Args:
|
||||
user_id: The database user ID
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(User)
|
||||
.options(selectinload(User.role))
|
||||
.where(User.id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, sub: str, role_id: int) -> User:
|
||||
"""
|
||||
Create a new user.
|
||||
|
||||
Args:
|
||||
sub: The Keycloak subject ID
|
||||
role_id: The role ID to assign
|
||||
|
||||
Returns:
|
||||
The created User
|
||||
"""
|
||||
user = User(sub=sub, role_id=role_id)
|
||||
self.session.add(user)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_or_create_default_role(self) -> Role:
|
||||
"""
|
||||
Get the default user role, creating it if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
The default Role
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Role).where(Role.role_name == DEFAULT_ROLE_NAME)
|
||||
)
|
||||
role = result.scalar_one_or_none()
|
||||
|
||||
if role is None:
|
||||
logger.info(f"Creating default role: {DEFAULT_ROLE_NAME}")
|
||||
role = Role(role_name=DEFAULT_ROLE_NAME)
|
||||
self.session.add(role)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(role)
|
||||
|
||||
return role
|
||||
|
||||
async def get_or_create_user(self, sub: str) -> tuple[User, bool]:
|
||||
"""
|
||||
Get an existing user or create a new one (Just-in-Time Provisioning).
|
||||
This is the main method called during login.
|
||||
|
||||
Args:
|
||||
sub: The Keycloak subject ID
|
||||
|
||||
Returns:
|
||||
Tuple of (User, created) where created is True if a new user was created
|
||||
"""
|
||||
# Check if user already exists
|
||||
user = await self.get_by_sub(sub)
|
||||
|
||||
if user is not None:
|
||||
logger.debug(f"Found existing user with sub: {sub}")
|
||||
return user, False
|
||||
|
||||
# User doesn't exist, create them with default role
|
||||
logger.info(f"Creating new user with sub: {sub}")
|
||||
|
||||
# Get or create the default role
|
||||
default_role = await self.get_or_create_default_role()
|
||||
|
||||
# Create the user
|
||||
user = await self.create(sub=sub, role_id=default_role.id)
|
||||
|
||||
logger.info(f"Created new user with id: {user.id}, sub: {sub}")
|
||||
return user, True
|
||||
|
||||
|
||||
class RoleRepository:
|
||||
"""Repository for Role-related database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_name(self, role_name: str) -> Optional[Role]:
|
||||
"""Get a role by name."""
|
||||
result = await self.session.execute(
|
||||
select(Role).where(Role.role_name == role_name)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_id(self, role_id: int) -> Optional[Role]:
|
||||
"""Get a role by ID."""
|
||||
result = await self.session.execute(
|
||||
select(Role).where(Role.id == role_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, role_name: str) -> Role:
|
||||
"""Create a new role."""
|
||||
role = Role(role_name=role_name)
|
||||
self.session.add(role)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(role)
|
||||
return role
|
||||
|
||||
async def ensure_default_roles_exist(self) -> None:
|
||||
"""
|
||||
Ensure default roles exist in the database.
|
||||
Called during application startup.
|
||||
"""
|
||||
default_roles = ["admin", "user", "viewer"]
|
||||
|
||||
for role_name in default_roles:
|
||||
existing = await self.get_by_name(role_name)
|
||||
if existing is None:
|
||||
logger.info(f"Creating default role: {role_name}")
|
||||
await self.create(role_name)
|
||||
Reference in New Issue
Block a user