Added DB connection and started creating api calls for the pages
This commit is contained in:
@@ -21,6 +21,22 @@ class Settings(BaseSettings):
|
||||
cookie_max_age: int = Field(default=3600, env="COOKIE_MAX_AGE")
|
||||
cookie_name: str = Field(default="access_token", env="COOKIE_NAME")
|
||||
|
||||
# Database settings
|
||||
database_host: str = Field(default="postgres", env="DATABASE_HOST")
|
||||
database_port: int = Field(default=5432, env="DATABASE_PORT")
|
||||
database_name: str = Field(default="periodic_table", env="DATABASE_NAME")
|
||||
database_user: str = Field(default="postgres", env="DATABASE_USER")
|
||||
database_password: str = Field(default="postgres", env="DATABASE_PASSWORD")
|
||||
database_echo: bool = Field(default=False, env="DATABASE_ECHO")
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""Construct the async database URL for SQLAlchemy."""
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.database_user}:{self.database_password}"
|
||||
f"@{self.database_host}:{self.database_port}/{self.database_name}"
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.responses import RedirectResponse, JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.models import TokenResponse, UserInfo
|
||||
from src.service import AuthService
|
||||
from src.service import AuthService, UserService
|
||||
from src.config import get_settings
|
||||
from src.database import get_db
|
||||
|
||||
# Initialize HTTPBearer security dependency
|
||||
bearer_scheme = HTTPBearer()
|
||||
@@ -35,13 +37,15 @@ class AuthController:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def login(keycode: str, request: Request) -> RedirectResponse:
|
||||
async def login(keycode: str, request: Request, db: AsyncSession) -> RedirectResponse:
|
||||
"""
|
||||
Authenticate user, set HTTP-only cookie, and redirect to frontend.
|
||||
Authenticate user, provision in database if needed, set HTTP-only cookie,
|
||||
and redirect to frontend.
|
||||
|
||||
Args:
|
||||
keycode (str): The authorization code from Keycloak.
|
||||
request (Request): The FastAPI request object.
|
||||
db (AsyncSession): Database session for user provisioning.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the authentication fails.
|
||||
@@ -50,7 +54,8 @@ class AuthController:
|
||||
RedirectResponse: Redirects to frontend with cookie set.
|
||||
"""
|
||||
# Authenticate the user using the AuthService
|
||||
access_token = AuthService.authenticate_user(keycode, request)
|
||||
token_response = AuthService.authenticate_user(keycode, request)
|
||||
access_token = token_response.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
@@ -58,6 +63,10 @@ class AuthController:
|
||||
detail="Authentication failed",
|
||||
)
|
||||
|
||||
# Provision user in database (JIT provisioning)
|
||||
# This creates the user if they don't exist
|
||||
user_id, is_new_user = await UserService.provision_user_on_login(access_token, db)
|
||||
|
||||
# Create redirect response to frontend
|
||||
response = RedirectResponse(
|
||||
url=f"{settings.frontend_url}/dashboard",
|
||||
|
||||
61
backend/src/database.py
Normal file
61
backend/src/database.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from src.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Create async engine for PostgreSQL
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.database_echo,
|
||||
pool_pre_ping=True,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
)
|
||||
|
||||
# Create async session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
# Base class for SQLAlchemy models
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
"""
|
||||
Dependency that provides a database session.
|
||||
Use with FastAPI's Depends() for automatic session management.
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""
|
||||
Initialize the database by creating all tables.
|
||||
This should be called on application startup.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
# Import all models here to ensure they are registered
|
||||
from src import db_models # noqa: F401
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""
|
||||
Close the database connection pool.
|
||||
This should be called on application shutdown.
|
||||
"""
|
||||
await engine.dispose()
|
||||
230
backend/src/db_models.py
Normal file
230
backend/src/db_models.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
SQLAlchemy ORM models for the Periodic Table Requirements application.
|
||||
Based on the database schema defined in periodic-table.sql
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from sqlalchemy import (
|
||||
Integer, String, Text, ForeignKey, DateTime, Boolean,
|
||||
UniqueConstraint, Index
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from src.database import Base
|
||||
|
||||
|
||||
class Role(Base):
|
||||
"""User roles for access control."""
|
||||
__tablename__ = "roles"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
role_name: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
|
||||
|
||||
# Relationships
|
||||
users: Mapped[List["User"]] = relationship("User", back_populates="role")
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""
|
||||
Users table - populated on first login via Keycloak.
|
||||
The 'sub' field is the Keycloak subject ID (unique identifier).
|
||||
"""
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
sub: Mapped[str] = mapped_column(Text, nullable=False, unique=True) # Keycloak subject ID
|
||||
role_id: Mapped[int] = mapped_column(Integer, ForeignKey("roles.id"), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=datetime.utcnow,
|
||||
nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
role: Mapped["Role"] = relationship("Role", back_populates="users")
|
||||
requirements: Mapped[List["Requirement"]] = relationship(
|
||||
"Requirement",
|
||||
back_populates="user",
|
||||
foreign_keys="Requirement.user_id"
|
||||
)
|
||||
edited_requirements: Mapped[List["Requirement"]] = relationship(
|
||||
"Requirement",
|
||||
back_populates="last_editor",
|
||||
foreign_keys="Requirement.last_editor_id"
|
||||
)
|
||||
validations: Mapped[List["Validation"]] = relationship("Validation", back_populates="user")
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
"""Requirement tags (e.g., GSR, SFR)."""
|
||||
__tablename__ = "tags"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
tag_code: Mapped[str] = mapped_column(String(10), nullable=False, unique=True)
|
||||
tag_description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
# Relationships
|
||||
requirements: Mapped[List["Requirement"]] = relationship("Requirement", back_populates="tag")
|
||||
|
||||
|
||||
class Group(Base):
|
||||
"""Requirement groups for categorization."""
|
||||
__tablename__ = "groups"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
group_name: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
|
||||
hex_color: Mapped[str] = mapped_column(String(7), nullable=False) # e.g., #FF5733
|
||||
|
||||
# Relationships
|
||||
requirements: Mapped[List["Requirement"]] = relationship(
|
||||
"Requirement",
|
||||
secondary="requirements_groups",
|
||||
back_populates="groups"
|
||||
)
|
||||
|
||||
|
||||
class Priority(Base):
|
||||
"""Priority levels for requirements."""
|
||||
__tablename__ = "priorities"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
priority_name: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
|
||||
priority_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
# Relationships
|
||||
requirements: Mapped[List["Requirement"]] = relationship("Requirement", back_populates="priority")
|
||||
|
||||
|
||||
class ValidationStatus(Base):
|
||||
"""Validation status options."""
|
||||
__tablename__ = "validation_statuses"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
status_name: Mapped[str] = mapped_column(Text, nullable=False, unique=True)
|
||||
|
||||
# Relationships
|
||||
validations: Mapped[List["Validation"]] = relationship("Validation", back_populates="status")
|
||||
|
||||
|
||||
class Requirement(Base):
|
||||
"""Main requirements table."""
|
||||
__tablename__ = "requirements"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
tag_id: Mapped[int] = mapped_column(Integer, ForeignKey("tags.id"), nullable=False)
|
||||
last_editor_id: Mapped[Optional[int]] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("users.id"),
|
||||
nullable=True
|
||||
)
|
||||
req_name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
req_desc: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
priority_id: Mapped[Optional[int]] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("priorities.id"),
|
||||
nullable=True
|
||||
)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=datetime.utcnow,
|
||||
nullable=True
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=datetime.utcnow,
|
||||
onupdate=datetime.utcnow,
|
||||
nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped["User"] = relationship(
|
||||
"User",
|
||||
back_populates="requirements",
|
||||
foreign_keys=[user_id]
|
||||
)
|
||||
last_editor: Mapped[Optional["User"]] = relationship(
|
||||
"User",
|
||||
back_populates="edited_requirements",
|
||||
foreign_keys=[last_editor_id]
|
||||
)
|
||||
tag: Mapped["Tag"] = relationship("Tag", back_populates="requirements")
|
||||
priority: Mapped[Optional["Priority"]] = relationship("Priority", back_populates="requirements")
|
||||
groups: Mapped[List["Group"]] = relationship(
|
||||
"Group",
|
||||
secondary="requirements_groups",
|
||||
back_populates="requirements"
|
||||
)
|
||||
validations: Mapped[List["Validation"]] = relationship("Validation", back_populates="requirement")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index("idx_req_tag", "tag_id"),
|
||||
Index("idx_req_priority", "priority_id"),
|
||||
Index("idx_req_user", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
class RequirementGroup(Base):
|
||||
"""Join table for many-to-many relationship between requirements and groups."""
|
||||
__tablename__ = "requirements_groups"
|
||||
|
||||
requirement_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("requirements.id", ondelete="CASCADE"),
|
||||
primary_key=True
|
||||
)
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("groups.id", ondelete="CASCADE"),
|
||||
primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class Validation(Base):
|
||||
"""Validation records for requirements."""
|
||||
__tablename__ = "validations"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
requirement_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("requirements.id", ondelete="CASCADE"),
|
||||
nullable=False
|
||||
)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
status_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("validation_statuses.id"),
|
||||
nullable=False
|
||||
)
|
||||
req_version_snapshot: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
comment: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=datetime.utcnow,
|
||||
nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
requirement: Mapped["Requirement"] = relationship("Requirement", back_populates="validations")
|
||||
user: Mapped["User"] = relationship("User", back_populates="validations")
|
||||
status: Mapped["ValidationStatus"] = relationship("ValidationStatus", back_populates="validations")
|
||||
|
||||
|
||||
class RequirementHistory(Base):
|
||||
"""
|
||||
Historical records of requirement changes.
|
||||
Note: This is populated by a database trigger, not by the application.
|
||||
"""
|
||||
__tablename__ = "requirements_history"
|
||||
|
||||
history_id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
original_req_id: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
req_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
req_desc: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
priority_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
tag_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
version: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
valid_from: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
valid_to: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
edited_by: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
@@ -1,17 +1,60 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List
|
||||
from fastapi import FastAPI, Depends, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.responses import RedirectResponse
|
||||
from src.models import TokenResponse, UserInfo
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.models import TokenResponse, UserInfo, GroupResponse
|
||||
from src.controller import AuthController
|
||||
from src.config import get_openid, get_settings
|
||||
from src.database import init_db, close_db, get_db
|
||||
from src.repositories import RoleRepository, GroupRepository
|
||||
import logging
|
||||
|
||||
# Initialize the FastAPI app
|
||||
app = FastAPI(title="Keycloak Auth API", version="1.0.0")
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get settings
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Application lifespan manager.
|
||||
Handles startup and shutdown events.
|
||||
"""
|
||||
# Startup
|
||||
logger.info("Starting up application...")
|
||||
logger.info("Initializing database...")
|
||||
await init_db()
|
||||
logger.info("Database initialized successfully")
|
||||
|
||||
# Ensure default roles exist
|
||||
from src.database import AsyncSessionLocal
|
||||
async with AsyncSessionLocal() as session:
|
||||
role_repo = RoleRepository(session)
|
||||
await role_repo.ensure_default_roles_exist()
|
||||
await session.commit()
|
||||
logger.info("Default roles ensured")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down application...")
|
||||
await close_db()
|
||||
logger.info("Database connection closed")
|
||||
|
||||
|
||||
# Initialize the FastAPI app
|
||||
app = FastAPI(
|
||||
title="Keycloak Auth API",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -65,15 +108,15 @@ async def login(request: Request):
|
||||
|
||||
# Define the callback endpoint
|
||||
@app.get("/api/callback", include_in_schema=False)
|
||||
async def callback(request: Request):
|
||||
async def callback(request: Request, db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
OAuth callback endpoint that exchanges the authorization code for a token
|
||||
and sets it as an HTTP-only cookie.
|
||||
OAuth callback endpoint that exchanges the authorization code for a token,
|
||||
provisions the user in the database if needed, and sets it as an HTTP-only cookie.
|
||||
"""
|
||||
# Extract the code from the URL
|
||||
keycode = request.query_params.get('code')
|
||||
|
||||
return AuthController.login(str(keycode), request)
|
||||
return await AuthController.login(str(keycode), request, db)
|
||||
|
||||
|
||||
# Define the auth/me endpoint to get current user from cookie
|
||||
@@ -116,3 +159,20 @@ async def protected_endpoint(
|
||||
UserInfo: Information about the authenticated user.
|
||||
"""
|
||||
return AuthController.protected_endpoint(credentials)
|
||||
|
||||
|
||||
# ===========================================
|
||||
# Groups Endpoints
|
||||
# ===========================================
|
||||
|
||||
@app.get("/api/groups", response_model=List[GroupResponse])
|
||||
async def get_groups(db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
Get all groups.
|
||||
|
||||
Returns:
|
||||
List of all groups with their names and colors.
|
||||
"""
|
||||
group_repo = GroupRepository(db)
|
||||
groups = await group_repo.get_all()
|
||||
return [GroupResponse.model_validate(g) for g in groups]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
|
||||
@@ -13,6 +13,25 @@ class TokenResponse(BaseModel):
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
sub: Optional[str] = None # Keycloak subject ID
|
||||
preferred_username: str
|
||||
email: Optional[str] = None
|
||||
full_name: Optional[str] = None
|
||||
db_user_id: Optional[int] = None # Database user ID (populated after login)
|
||||
role: Optional[str] = None # User role name
|
||||
|
||||
|
||||
# Group schemas
|
||||
class GroupResponse(BaseModel):
|
||||
"""Response schema for a single group."""
|
||||
id: int
|
||||
group_name: str
|
||||
hex_color: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class GroupListResponse(BaseModel):
|
||||
"""Response schema for list of groups."""
|
||||
groups: List[GroupResponse]
|
||||
|
||||
11
backend/src/repositories/__init__.py
Normal file
11
backend/src/repositories/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Repository layer for database operations.
|
||||
"""
|
||||
from src.repositories.user_repository import UserRepository, RoleRepository
|
||||
from src.repositories.group_repository import GroupRepository
|
||||
|
||||
__all__ = [
|
||||
"UserRepository",
|
||||
"RoleRepository",
|
||||
"GroupRepository",
|
||||
]
|
||||
76
backend/src/repositories/group_repository.py
Normal file
76
backend/src/repositories/group_repository.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Repository layer for Group database operations.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.db_models import Group
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroupRepository:
|
||||
"""Repository for Group-related database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_all(self) -> List[Group]:
|
||||
"""
|
||||
Get all groups.
|
||||
|
||||
Returns:
|
||||
List of all groups
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Group).order_by(Group.group_name)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_id(self, group_id: int) -> Optional[Group]:
|
||||
"""
|
||||
Get a group by ID.
|
||||
|
||||
Args:
|
||||
group_id: The group ID
|
||||
|
||||
Returns:
|
||||
Group if found, None otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Group).where(Group.id == group_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_name(self, group_name: str) -> Optional[Group]:
|
||||
"""
|
||||
Get a group by name.
|
||||
|
||||
Args:
|
||||
group_name: The group name
|
||||
|
||||
Returns:
|
||||
Group if found, None otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Group).where(Group.group_name == group_name)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, group_name: str, hex_color: str) -> Group:
|
||||
"""
|
||||
Create a new group.
|
||||
|
||||
Args:
|
||||
group_name: The group name
|
||||
hex_color: The hex color code (e.g., #FF5733)
|
||||
|
||||
Returns:
|
||||
The created Group
|
||||
"""
|
||||
group = Group(group_name=group_name, hex_color=hex_color)
|
||||
self.session.add(group)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(group)
|
||||
return group
|
||||
166
backend/src/repositories/user_repository.py
Normal file
166
backend/src/repositories/user_repository.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Repository layer for User database operations.
|
||||
Handles CRUD operations and user provisioning on first login.
|
||||
"""
|
||||
from typing import Optional
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from src.db_models import User, Role
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default role for new users
|
||||
DEFAULT_ROLE_NAME = "user"
|
||||
|
||||
|
||||
class UserRepository:
|
||||
"""Repository for User-related database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_sub(self, sub: str) -> Optional[User]:
|
||||
"""
|
||||
Get a user by their Keycloak subject ID.
|
||||
|
||||
Args:
|
||||
sub: The Keycloak subject ID (unique identifier)
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(User)
|
||||
.options(selectinload(User.role))
|
||||
.where(User.sub == sub)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_id(self, user_id: int) -> Optional[User]:
|
||||
"""
|
||||
Get a user by their database ID.
|
||||
|
||||
Args:
|
||||
user_id: The database user ID
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(User)
|
||||
.options(selectinload(User.role))
|
||||
.where(User.id == user_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, sub: str, role_id: int) -> User:
|
||||
"""
|
||||
Create a new user.
|
||||
|
||||
Args:
|
||||
sub: The Keycloak subject ID
|
||||
role_id: The role ID to assign
|
||||
|
||||
Returns:
|
||||
The created User
|
||||
"""
|
||||
user = User(sub=sub, role_id=role_id)
|
||||
self.session.add(user)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(user)
|
||||
return user
|
||||
|
||||
async def get_or_create_default_role(self) -> Role:
|
||||
"""
|
||||
Get the default user role, creating it if it doesn't exist.
|
||||
|
||||
Returns:
|
||||
The default Role
|
||||
"""
|
||||
result = await self.session.execute(
|
||||
select(Role).where(Role.role_name == DEFAULT_ROLE_NAME)
|
||||
)
|
||||
role = result.scalar_one_or_none()
|
||||
|
||||
if role is None:
|
||||
logger.info(f"Creating default role: {DEFAULT_ROLE_NAME}")
|
||||
role = Role(role_name=DEFAULT_ROLE_NAME)
|
||||
self.session.add(role)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(role)
|
||||
|
||||
return role
|
||||
|
||||
async def get_or_create_user(self, sub: str) -> tuple[User, bool]:
|
||||
"""
|
||||
Get an existing user or create a new one (Just-in-Time Provisioning).
|
||||
This is the main method called during login.
|
||||
|
||||
Args:
|
||||
sub: The Keycloak subject ID
|
||||
|
||||
Returns:
|
||||
Tuple of (User, created) where created is True if a new user was created
|
||||
"""
|
||||
# Check if user already exists
|
||||
user = await self.get_by_sub(sub)
|
||||
|
||||
if user is not None:
|
||||
logger.debug(f"Found existing user with sub: {sub}")
|
||||
return user, False
|
||||
|
||||
# User doesn't exist, create them with default role
|
||||
logger.info(f"Creating new user with sub: {sub}")
|
||||
|
||||
# Get or create the default role
|
||||
default_role = await self.get_or_create_default_role()
|
||||
|
||||
# Create the user
|
||||
user = await self.create(sub=sub, role_id=default_role.id)
|
||||
|
||||
logger.info(f"Created new user with id: {user.id}, sub: {sub}")
|
||||
return user, True
|
||||
|
||||
|
||||
class RoleRepository:
|
||||
"""Repository for Role-related database operations."""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_name(self, role_name: str) -> Optional[Role]:
|
||||
"""Get a role by name."""
|
||||
result = await self.session.execute(
|
||||
select(Role).where(Role.role_name == role_name)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_id(self, role_id: int) -> Optional[Role]:
|
||||
"""Get a role by ID."""
|
||||
result = await self.session.execute(
|
||||
select(Role).where(Role.id == role_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, role_name: str) -> Role:
|
||||
"""Create a new role."""
|
||||
role = Role(role_name=role_name)
|
||||
self.session.add(role)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(role)
|
||||
return role
|
||||
|
||||
async def ensure_default_roles_exist(self) -> None:
|
||||
"""
|
||||
Ensure default roles exist in the database.
|
||||
Called during application startup.
|
||||
"""
|
||||
default_roles = ["admin", "user", "viewer"]
|
||||
|
||||
for role_name in default_roles:
|
||||
existing = await self.get_by_name(role_name)
|
||||
if existing is None:
|
||||
logger.info(f"Creating default role: {role_name}")
|
||||
await self.create(role_name)
|
||||
@@ -1,8 +1,10 @@
|
||||
from fastapi import HTTPException, status, Request
|
||||
from keycloak.exceptions import KeycloakAuthenticationError, KeycloakPostError
|
||||
from keycloak import KeycloakOpenID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.config import get_settings
|
||||
from src.models import UserInfo
|
||||
from src.repositories import UserRepository
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -21,9 +23,10 @@ def get_keycloak_openid():
|
||||
|
||||
class AuthService:
|
||||
@staticmethod
|
||||
def authenticate_user(keycode: str, request: Request) -> str:
|
||||
def authenticate_user(keycode: str, request: Request) -> dict:
|
||||
"""
|
||||
Authenticate the user using Keycloak and return an access token.
|
||||
Authenticate the user using Keycloak and return the full token response.
|
||||
Returns the full token dict to allow access to the access_token.
|
||||
"""
|
||||
try:
|
||||
# Use the same redirect_uri that was used in the login endpoint
|
||||
@@ -46,7 +49,7 @@ class AuthService:
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
logger.info("Token exchange successful")
|
||||
return token["access_token"]
|
||||
return token
|
||||
except KeycloakAuthenticationError as exc:
|
||||
logger.error(f"KeycloakAuthenticationError: {exc}")
|
||||
raise HTTPException(
|
||||
@@ -80,6 +83,7 @@ class AuthService:
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
return UserInfo(
|
||||
sub=user_info.get("sub"),
|
||||
preferred_username=user_info["preferred_username"],
|
||||
email=user_info.get("email"),
|
||||
full_name=user_info.get("name"),
|
||||
@@ -89,3 +93,64 @@ class AuthService:
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def decode_token(token: str) -> dict:
|
||||
"""
|
||||
Decode the access token to extract claims without full verification.
|
||||
Used to get the 'sub' claim for user provisioning.
|
||||
"""
|
||||
try:
|
||||
keycloak_openid = get_keycloak_openid()
|
||||
# Decode token - this validates the signature
|
||||
token_info = keycloak_openid.decode_token(
|
||||
token,
|
||||
validate=True
|
||||
)
|
||||
return token_info
|
||||
except Exception as exc:
|
||||
logger.error(f"Error decoding token: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not decode token",
|
||||
) from exc
|
||||
|
||||
|
||||
class UserService:
|
||||
"""Service for user-related operations."""
|
||||
|
||||
@staticmethod
|
||||
async def provision_user_on_login(
|
||||
token: str,
|
||||
db: AsyncSession
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Provision a user in the database on first login (JIT provisioning).
|
||||
|
||||
Args:
|
||||
token: The access token from Keycloak
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Tuple of (user_id, is_new_user)
|
||||
"""
|
||||
# Decode the token to get the 'sub' claim
|
||||
token_info = AuthService.decode_token(token)
|
||||
sub = token_info.get("sub")
|
||||
|
||||
if not sub:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Token does not contain 'sub' claim"
|
||||
)
|
||||
|
||||
# Get or create the user
|
||||
user_repo = UserRepository(db)
|
||||
user, created = await user_repo.get_or_create_user(sub)
|
||||
|
||||
if created:
|
||||
logger.info(f"New user provisioned: {sub} -> user_id: {user.id}")
|
||||
else:
|
||||
logger.debug(f"Existing user logged in: {sub} -> user_id: {user.id}")
|
||||
|
||||
return user.id, created
|
||||
|
||||
Reference in New Issue
Block a user