from __future__ import annotations import asyncio from dataclasses import dataclass from ..models import GraphResponse from ..sparql_engine import SparqlEngine from ..settings import Settings from .graph_snapshot import fetch_graph_snapshot @dataclass(frozen=True) class SnapshotKey: node_limit: int edge_limit: int include_bnodes: bool class GraphSnapshotService: """ Caches graph snapshots so the backend doesn't re-run expensive SPARQL for stats/graph. """ def __init__(self, *, sparql: SparqlEngine, settings: Settings): self._sparql = sparql self._settings = settings self._cache: dict[SnapshotKey, GraphResponse] = {} self._locks: dict[SnapshotKey, asyncio.Lock] = {} self._global_lock = asyncio.Lock() async def get(self, *, node_limit: int, edge_limit: int) -> GraphResponse: key = SnapshotKey( node_limit=node_limit, edge_limit=edge_limit, include_bnodes=self._settings.include_bnodes, ) cached = self._cache.get(key) if cached is not None: return cached # Create/get a per-key lock under a global lock to avoid races. async with self._global_lock: lock = self._locks.get(key) if lock is None: lock = asyncio.Lock() self._locks[key] = lock async with lock: cached2 = self._cache.get(key) if cached2 is not None: return cached2 snapshot = await fetch_graph_snapshot( self._sparql, settings=self._settings, node_limit=node_limit, edge_limit=edge_limit, ) self._cache[key] = snapshot return snapshot