Added auditor logic
This commit is contained in:
@@ -9,12 +9,16 @@ from src.models import (
|
||||
TokenResponse, UserInfo, GroupResponse,
|
||||
TagResponse, RequirementResponse, PriorityResponse,
|
||||
RequirementCreateRequest, RequirementUpdateRequest,
|
||||
ProjectResponse, ProjectCreateRequest, ProjectUpdateRequest, ProjectMemberRequest
|
||||
ProjectResponse, ProjectCreateRequest, ProjectUpdateRequest, ProjectMemberRequest,
|
||||
ValidationStatusResponse, ValidationHistoryResponse, ValidationCreateRequest
|
||||
)
|
||||
from src.controller import AuthController
|
||||
from src.config import get_openid, get_settings
|
||||
from src.database import init_db, close_db, get_db
|
||||
from src.repositories import RoleRepository, GroupRepository, TagRepository, RequirementRepository, PriorityRepository, ProjectRepository
|
||||
from src.repositories import (
|
||||
RoleRepository, GroupRepository, TagRepository, RequirementRepository,
|
||||
PriorityRepository, ProjectRepository, ValidationStatusRepository, ValidationRepository
|
||||
)
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
@@ -45,6 +49,23 @@ async def lifespan(app: FastAPI):
|
||||
await session.commit()
|
||||
logger.info("Default roles ensured")
|
||||
|
||||
# Ensure default validation statuses exist
|
||||
async with AsyncSessionLocal() as session:
|
||||
await session.execute(
|
||||
__import__('sqlalchemy').text(
|
||||
"""
|
||||
INSERT INTO validation_statuses (id, status_name) VALUES
|
||||
(1, 'Approved'),
|
||||
(2, 'Denied'),
|
||||
(3, 'Partial'),
|
||||
(4, 'Not Validated')
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
logger.info("Default validation statuses ensured")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
@@ -126,14 +147,27 @@ async def callback(request: Request, db: AsyncSession = Depends(get_db)):
|
||||
|
||||
# Define the auth/me endpoint to get current user from cookie
|
||||
@app.get("/api/auth/me", response_model=UserInfo)
|
||||
async def get_current_user(request: Request):
|
||||
async def get_current_user(request: Request, db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
Get the current authenticated user from the session cookie.
|
||||
Includes role information from the database.
|
||||
|
||||
Returns:
|
||||
UserInfo: Information about the authenticated user.
|
||||
UserInfo: Information about the authenticated user including role.
|
||||
"""
|
||||
return AuthController.get_current_user(request)
|
||||
user_info = AuthController.get_current_user(request)
|
||||
|
||||
# Fetch role information from database
|
||||
from src.repositories import UserRepository
|
||||
user_repo = UserRepository(db)
|
||||
db_user = await user_repo.get_by_sub(user_info.sub)
|
||||
|
||||
if db_user:
|
||||
user_info.db_user_id = db_user.id
|
||||
user_info.role_id = db_user.role_id
|
||||
user_info.role = db_user.role.role_name if db_user.role else None
|
||||
|
||||
return user_info
|
||||
|
||||
|
||||
# Define the logout endpoint
|
||||
@@ -257,6 +291,25 @@ async def _get_current_user_db(request: Request, db: AsyncSession):
|
||||
return user
|
||||
|
||||
|
||||
def _require_role(user, allowed_role_ids: List[int], action: str = "perform this action"):
|
||||
"""
|
||||
Helper to check if user has one of the allowed roles.
|
||||
|
||||
Args:
|
||||
user: The database user object
|
||||
allowed_role_ids: List of role IDs that are permitted (e.g., [1, 3] for admin and user)
|
||||
action: Description of the action for error message
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 Forbidden if user's role is not in allowed list
|
||||
"""
|
||||
if user.role_id not in allowed_role_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Your role does not have permission to {action}"
|
||||
)
|
||||
|
||||
|
||||
async def _verify_project_membership(project_id: int, user_id: int, db: AsyncSession):
|
||||
"""Helper to verify user is a member of a project."""
|
||||
project_repo = ProjectRepository(db)
|
||||
@@ -473,10 +526,19 @@ def _build_requirement_response(req) -> RequirementResponse:
|
||||
"""Helper function to build RequirementResponse from a Requirement model."""
|
||||
# Determine validation status from latest validation
|
||||
validation_status = "Not Validated"
|
||||
validated_by = None
|
||||
validated_at = None
|
||||
validation_version = None
|
||||
|
||||
if req.validations:
|
||||
# Get the latest validation
|
||||
latest_validation = max(req.validations, key=lambda v: v.created_at or req.created_at)
|
||||
validation_status = latest_validation.status.status_name if latest_validation.status else "Not Validated"
|
||||
# Try to get username from user relationship
|
||||
if latest_validation.user:
|
||||
validated_by = latest_validation.user.sub
|
||||
validated_at = latest_validation.created_at
|
||||
validation_version = latest_validation.req_version_snapshot
|
||||
|
||||
return RequirementResponse(
|
||||
id=req.id,
|
||||
@@ -490,6 +552,9 @@ def _build_requirement_response(req) -> RequirementResponse:
|
||||
priority=req.priority if req.priority else None,
|
||||
groups=[GroupResponse.model_validate(g) for g in req.groups],
|
||||
validation_status=validation_status,
|
||||
validated_by=validated_by,
|
||||
validated_at=validated_at,
|
||||
validation_version=validation_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -611,6 +676,7 @@ async def create_requirement(
|
||||
"""
|
||||
Create a new requirement.
|
||||
User must be a member of the project.
|
||||
Auditors (role_id=2) cannot create requirements.
|
||||
|
||||
Args:
|
||||
req_data: The requirement data (must include project_id)
|
||||
@@ -620,6 +686,9 @@ async def create_requirement(
|
||||
"""
|
||||
user = await _get_current_user_db(request, db)
|
||||
|
||||
# Auditors (role_id=2) cannot create requirements
|
||||
_require_role(user, [1, 3], "create requirements")
|
||||
|
||||
# Verify user is a member of the project
|
||||
await _verify_project_membership(req_data.project_id, user.id, db)
|
||||
|
||||
@@ -648,6 +717,7 @@ async def update_requirement(
|
||||
"""
|
||||
Update an existing requirement.
|
||||
User must be a member of the requirement's project.
|
||||
Auditors (role_id=2) cannot edit requirements.
|
||||
|
||||
Args:
|
||||
requirement_id: The requirement ID to update
|
||||
@@ -658,6 +728,9 @@ async def update_requirement(
|
||||
"""
|
||||
user = await _get_current_user_db(request, db)
|
||||
|
||||
# Auditors (role_id=2) cannot edit requirements
|
||||
_require_role(user, [1, 3], "edit requirements")
|
||||
|
||||
req_repo = RequirementRepository(db)
|
||||
|
||||
# First check if requirement exists
|
||||
@@ -694,12 +767,16 @@ async def delete_requirement(
|
||||
"""
|
||||
Delete a requirement.
|
||||
User must be a member of the requirement's project.
|
||||
Auditors (role_id=2) cannot delete requirements.
|
||||
|
||||
Args:
|
||||
requirement_id: The requirement ID to delete
|
||||
"""
|
||||
user = await _get_current_user_db(request, db)
|
||||
|
||||
# Auditors (role_id=2) cannot delete requirements
|
||||
_require_role(user, [1, 3], "delete requirements")
|
||||
|
||||
req_repo = RequirementRepository(db)
|
||||
|
||||
# First check if requirement exists
|
||||
@@ -715,3 +792,137 @@ async def delete_requirement(
|
||||
|
||||
await req_repo.delete(requirement_id)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ===========================================
|
||||
# Validation Endpoints
|
||||
# ===========================================
|
||||
|
||||
@app.get("/api/validation-statuses", response_model=List[ValidationStatusResponse])
|
||||
async def get_validation_statuses(db: AsyncSession = Depends(get_db)):
|
||||
"""
|
||||
Get all validation statuses.
|
||||
|
||||
Returns:
|
||||
List of validation statuses (Approved, Denied, Partial, Not Validated).
|
||||
"""
|
||||
status_repo = ValidationStatusRepository(db)
|
||||
statuses = await status_repo.get_all()
|
||||
return [ValidationStatusResponse.model_validate(s) for s in statuses]
|
||||
|
||||
|
||||
@app.post("/api/requirements/{requirement_id}/validations", response_model=ValidationHistoryResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_validation(
|
||||
requirement_id: int,
|
||||
request: Request,
|
||||
validation_data: ValidationCreateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Create a new validation for a requirement.
|
||||
Only auditors (role_id=2) and admins (role_id=1) can validate requirements.
|
||||
|
||||
Args:
|
||||
requirement_id: The requirement to validate
|
||||
validation_data: The validation status and optional comment
|
||||
|
||||
Returns:
|
||||
The created validation record.
|
||||
"""
|
||||
user = await _get_current_user_db(request, db)
|
||||
|
||||
# Only auditors (role_id=2) and admins (role_id=1) can validate
|
||||
_require_role(user, [1, 2], "validate requirements")
|
||||
|
||||
# Check if requirement exists and user has access
|
||||
req_repo = RequirementRepository(db)
|
||||
requirement = await req_repo.get_by_id(requirement_id)
|
||||
if not requirement:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Requirement with id {requirement_id} not found"
|
||||
)
|
||||
|
||||
# Verify user is a member of the requirement's project
|
||||
await _verify_project_membership(requirement.project_id, user.id, db)
|
||||
|
||||
# Verify status exists
|
||||
status_repo = ValidationStatusRepository(db)
|
||||
validation_status = await status_repo.get_by_id(validation_data.status_id)
|
||||
if not validation_status:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid validation status id {validation_data.status_id}"
|
||||
)
|
||||
|
||||
# Create the validation
|
||||
validation_repo = ValidationRepository(db)
|
||||
validation = await validation_repo.create(
|
||||
requirement_id=requirement_id,
|
||||
user_id=user.id,
|
||||
status_id=validation_data.status_id,
|
||||
req_version_snapshot=requirement.version,
|
||||
comment=validation_data.comment
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
return ValidationHistoryResponse(
|
||||
id=validation.id,
|
||||
status_name=validation_status.status_name,
|
||||
status_id=validation.status_id,
|
||||
req_version_snapshot=validation.req_version_snapshot,
|
||||
comment=validation.comment,
|
||||
created_at=validation.created_at,
|
||||
validator_username=user.sub,
|
||||
validator_id=user.id
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/requirements/{requirement_id}/validations", response_model=List[ValidationHistoryResponse])
|
||||
async def get_validation_history(
|
||||
requirement_id: int,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get the validation history for a requirement.
|
||||
Returns all validations ordered by date (newest first).
|
||||
|
||||
Args:
|
||||
requirement_id: The requirement to get validation history for
|
||||
|
||||
Returns:
|
||||
List of validation records with validator info.
|
||||
"""
|
||||
user = await _get_current_user_db(request, db)
|
||||
|
||||
# Check if requirement exists
|
||||
req_repo = RequirementRepository(db)
|
||||
requirement = await req_repo.get_by_id(requirement_id)
|
||||
if not requirement:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Requirement with id {requirement_id} not found"
|
||||
)
|
||||
|
||||
# Verify user is a member of the requirement's project
|
||||
await _verify_project_membership(requirement.project_id, user.id, db)
|
||||
|
||||
# Get validation history
|
||||
validation_repo = ValidationRepository(db)
|
||||
validations = await validation_repo.get_by_requirement_id(requirement_id)
|
||||
|
||||
return [
|
||||
ValidationHistoryResponse(
|
||||
id=v.id,
|
||||
status_name=v.status.status_name,
|
||||
status_id=v.status_id,
|
||||
req_version_snapshot=v.req_version_snapshot,
|
||||
comment=v.comment,
|
||||
created_at=v.created_at,
|
||||
validator_username=v.user.sub,
|
||||
validator_id=v.user_id
|
||||
)
|
||||
for v in validations
|
||||
]
|
||||
|
||||
@@ -20,6 +20,7 @@ class UserInfo(BaseModel):
|
||||
full_name: Optional[str] = None
|
||||
db_user_id: Optional[int] = None # Database user ID (populated after login)
|
||||
role: Optional[str] = None # User role name
|
||||
role_id: Optional[int] = None # User role ID (1=admin, 2=auditor, 3=user, etc.)
|
||||
|
||||
|
||||
# Project schemas
|
||||
@@ -106,6 +107,15 @@ class PriorityResponse(BaseModel):
|
||||
|
||||
|
||||
# Validation schemas
|
||||
class ValidationStatusResponse(BaseModel):
|
||||
"""Response schema for a validation status."""
|
||||
id: int
|
||||
status_name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ValidationResponse(BaseModel):
|
||||
"""Response schema for a validation."""
|
||||
id: int
|
||||
@@ -118,6 +128,27 @@ class ValidationResponse(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ValidationHistoryResponse(BaseModel):
|
||||
"""Response schema for validation history with validator info."""
|
||||
id: int
|
||||
status_name: str
|
||||
status_id: int
|
||||
req_version_snapshot: int
|
||||
comment: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
validator_username: str
|
||||
validator_id: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ValidationCreateRequest(BaseModel):
|
||||
"""Request schema for creating a validation."""
|
||||
status_id: int
|
||||
comment: Optional[str] = None
|
||||
|
||||
|
||||
# Requirement schemas
|
||||
class RequirementResponse(BaseModel):
|
||||
"""Response schema for a single requirement."""
|
||||
@@ -132,6 +163,9 @@ class RequirementResponse(BaseModel):
|
||||
priority: Optional[PriorityResponse] = None
|
||||
groups: List[GroupResponse] = []
|
||||
validation_status: Optional[str] = None # Computed from latest validation
|
||||
validated_by: Optional[str] = None # Username of the validator
|
||||
validated_at: Optional[datetime] = None # When the latest validation was made
|
||||
validation_version: Optional[int] = None # Version at which requirement was validated
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@@ -7,6 +7,8 @@ from src.repositories.tag_repository import TagRepository
|
||||
from src.repositories.requirement_repository import RequirementRepository
|
||||
from src.repositories.priority_repository import PriorityRepository
|
||||
from src.repositories.project_repository import ProjectRepository
|
||||
from src.repositories.validation_status_repository import ValidationStatusRepository
|
||||
from src.repositories.validation_repository import ValidationRepository
|
||||
|
||||
__all__ = [
|
||||
"UserRepository",
|
||||
@@ -16,4 +18,6 @@ __all__ = [
|
||||
"RequirementRepository",
|
||||
"PriorityRepository",
|
||||
"ProjectRepository",
|
||||
"ValidationStatusRepository",
|
||||
"ValidationRepository",
|
||||
]
|
||||
|
||||
@@ -31,6 +31,7 @@ class RequirementRepository:
|
||||
selectinload(Requirement.priority),
|
||||
selectinload(Requirement.groups),
|
||||
selectinload(Requirement.validations).selectinload(Validation.status),
|
||||
selectinload(Requirement.validations).selectinload(Validation.user),
|
||||
)
|
||||
.order_by(Requirement.created_at.desc())
|
||||
)
|
||||
@@ -53,6 +54,7 @@ class RequirementRepository:
|
||||
selectinload(Requirement.priority),
|
||||
selectinload(Requirement.groups),
|
||||
selectinload(Requirement.validations).selectinload(Validation.status),
|
||||
selectinload(Requirement.validations).selectinload(Validation.user),
|
||||
)
|
||||
.where(Requirement.project_id == project_id)
|
||||
.order_by(Requirement.created_at.desc())
|
||||
@@ -76,6 +78,7 @@ class RequirementRepository:
|
||||
selectinload(Requirement.priority),
|
||||
selectinload(Requirement.groups),
|
||||
selectinload(Requirement.validations).selectinload(Validation.status),
|
||||
selectinload(Requirement.validations).selectinload(Validation.user),
|
||||
)
|
||||
.where(Requirement.user_id == user_id)
|
||||
.order_by(Requirement.created_at.desc())
|
||||
@@ -99,6 +102,7 @@ class RequirementRepository:
|
||||
selectinload(Requirement.priority),
|
||||
selectinload(Requirement.groups),
|
||||
selectinload(Requirement.validations).selectinload(Validation.status),
|
||||
selectinload(Requirement.validations).selectinload(Validation.user),
|
||||
selectinload(Requirement.user),
|
||||
selectinload(Requirement.last_editor),
|
||||
)
|
||||
@@ -125,6 +129,7 @@ class RequirementRepository:
|
||||
selectinload(Requirement.priority),
|
||||
selectinload(Requirement.groups),
|
||||
selectinload(Requirement.validations).selectinload(Validation.status),
|
||||
selectinload(Requirement.validations).selectinload(Validation.user),
|
||||
)
|
||||
.join(Requirement.groups)
|
||||
.where(Group.id == group_id)
|
||||
@@ -159,6 +164,7 @@ class RequirementRepository:
|
||||
selectinload(Requirement.priority),
|
||||
selectinload(Requirement.groups),
|
||||
selectinload(Requirement.validations).selectinload(Validation.status),
|
||||
selectinload(Requirement.validations).selectinload(Validation.user),
|
||||
)
|
||||
.where(Requirement.tag_id == tag_id)
|
||||
)
|
||||
|
||||
102
backend/src/repositories/validation_repository.py
Normal file
102
backend/src/repositories/validation_repository.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Repository for Validation database operations.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.db_models import Validation, ValidationStatus, User
|
||||
|
||||
|
||||
class ValidationRepository:
|
||||
"""Repository for validation CRUD operations."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def create(
|
||||
self,
|
||||
requirement_id: int,
|
||||
user_id: int,
|
||||
status_id: int,
|
||||
req_version_snapshot: int,
|
||||
comment: Optional[str] = None
|
||||
) -> Validation:
|
||||
"""
|
||||
Create a new validation record.
|
||||
|
||||
Args:
|
||||
requirement_id: The requirement being validated
|
||||
user_id: The auditor performing the validation
|
||||
status_id: The validation status (1=Approved, 2=Denied, 3=Partial, 4=Not Validated)
|
||||
req_version_snapshot: The version of the requirement at validation time
|
||||
comment: Optional comment explaining the validation decision
|
||||
"""
|
||||
validation = Validation(
|
||||
requirement_id=requirement_id,
|
||||
user_id=user_id,
|
||||
status_id=status_id,
|
||||
req_version_snapshot=req_version_snapshot,
|
||||
comment=comment
|
||||
)
|
||||
self.db.add(validation)
|
||||
await self.db.flush()
|
||||
await self.db.refresh(validation)
|
||||
return validation
|
||||
|
||||
async def get_by_id(self, validation_id: int) -> Optional[Validation]:
|
||||
"""Get a validation by ID with related data."""
|
||||
result = await self.db.execute(
|
||||
select(Validation)
|
||||
.options(
|
||||
selectinload(Validation.status),
|
||||
selectinload(Validation.user)
|
||||
)
|
||||
.where(Validation.id == validation_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_requirement_id(self, requirement_id: int) -> List[Validation]:
|
||||
"""
|
||||
Get all validations for a requirement, ordered by creation date (newest first).
|
||||
Includes related status and user data.
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(Validation)
|
||||
.options(
|
||||
selectinload(Validation.status),
|
||||
selectinload(Validation.user)
|
||||
)
|
||||
.where(Validation.requirement_id == requirement_id)
|
||||
.order_by(Validation.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_latest_by_requirement_id(self, requirement_id: int) -> Optional[Validation]:
|
||||
"""
|
||||
Get the most recent validation for a requirement.
|
||||
Returns None if no validations exist (requirement is "Not Validated").
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(Validation)
|
||||
.options(
|
||||
selectinload(Validation.status),
|
||||
selectinload(Validation.user)
|
||||
)
|
||||
.where(Validation.requirement_id == requirement_id)
|
||||
.order_by(Validation.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def delete(self, validation_id: int) -> bool:
|
||||
"""Delete a validation by ID. Returns True if deleted, False if not found."""
|
||||
result = await self.db.execute(
|
||||
select(Validation).where(Validation.id == validation_id)
|
||||
)
|
||||
validation = result.scalar_one_or_none()
|
||||
if validation:
|
||||
await self.db.delete(validation)
|
||||
await self.db.flush()
|
||||
return True
|
||||
return False
|
||||
80
backend/src/repositories/validation_status_repository.py
Normal file
80
backend/src/repositories/validation_status_repository.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Repository for ValidationStatus database operations.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from src.db_models import ValidationStatus
|
||||
|
||||
|
||||
class ValidationStatusRepository:
|
||||
"""Repository for validation status CRUD operations."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_all(self) -> List[ValidationStatus]:
|
||||
"""Get all validation statuses."""
|
||||
result = await self.db.execute(
|
||||
select(ValidationStatus).order_by(ValidationStatus.id)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_id(self, status_id: int) -> Optional[ValidationStatus]:
|
||||
"""Get a validation status by ID."""
|
||||
result = await self.db.execute(
|
||||
select(ValidationStatus).where(ValidationStatus.id == status_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_name(self, status_name: str) -> Optional[ValidationStatus]:
|
||||
"""Get a validation status by name."""
|
||||
result = await self.db.execute(
|
||||
select(ValidationStatus).where(ValidationStatus.status_name == status_name)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, status_name: str) -> ValidationStatus:
|
||||
"""Create a new validation status."""
|
||||
status = ValidationStatus(status_name=status_name)
|
||||
self.db.add(status)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(status)
|
||||
return status
|
||||
|
||||
@staticmethod
|
||||
def seed_default_statuses(db) -> None:
|
||||
"""
|
||||
Seed default validation statuses if they don't exist.
|
||||
Statuses:
|
||||
1 - Approved: Requirement fully validated
|
||||
2 - Denied: Requirement rejected, needs rework
|
||||
3 - Partial: Part of requirement approved, needs more work
|
||||
4 - Not Validated: Default status, awaiting validation
|
||||
|
||||
Note: This is a synchronous method that uses the sync engine on startup.
|
||||
"""
|
||||
from sqlalchemy import text as sync_text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
default_statuses = [
|
||||
(1, "Approved"),
|
||||
(2, "Denied"),
|
||||
(3, "Partial"),
|
||||
(4, "Not Validated"),
|
||||
]
|
||||
|
||||
# Check if db is async or sync session
|
||||
if hasattr(db, 'execute'):
|
||||
for status_id, status_name in default_statuses:
|
||||
# Use raw SQL to insert with specific ID to maintain consistency
|
||||
db.execute(
|
||||
sync_text(
|
||||
"INSERT INTO validation_statuses (id, status_name) "
|
||||
"VALUES (:id, :name) "
|
||||
"ON CONFLICT (id) DO NOTHING"
|
||||
),
|
||||
{"id": status_id, "name": status_name}
|
||||
)
|
||||
|
||||
db.commit()
|
||||
Reference in New Issue
Block a user