Import Solver + neighbors via sparql query
This commit is contained in:
@@ -1,11 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .models import EdgesResponse, GraphResponse, NodesResponse, SparqlQueryRequest, StatsResponse
|
||||
from .models import (
|
||||
EdgesResponse,
|
||||
GraphResponse,
|
||||
NeighborsRequest,
|
||||
NeighborsResponse,
|
||||
NodesResponse,
|
||||
SparqlQueryRequest,
|
||||
StatsResponse,
|
||||
)
|
||||
from .pipelines.layout_dag_radial import CycleError
|
||||
from .pipelines.owl_imports_combiner import (
|
||||
build_combined_graph,
|
||||
output_location_to_path,
|
||||
resolve_output_location,
|
||||
serialize_graph_to_ttl,
|
||||
)
|
||||
from .pipelines.selection_neighbors import fetch_neighbor_ids_for_selection
|
||||
from .pipelines.snapshot_service import GraphSnapshotService
|
||||
from .rdf_store import RDFStore
|
||||
from .sparql_engine import RdflibEngine, SparqlEngine, create_sparql_engine
|
||||
@@ -13,11 +31,33 @@ from .settings import Settings
|
||||
|
||||
|
||||
settings = Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
sparql: SparqlEngine = create_sparql_engine(settings)
|
||||
rdflib_preloaded_graph = None
|
||||
|
||||
if settings.combine_owl_imports_on_start:
|
||||
entry_location = settings.combine_entry_location or settings.ttl_path
|
||||
output_location = resolve_output_location(
|
||||
entry_location,
|
||||
output_location=settings.combine_output_location,
|
||||
output_name=settings.combine_output_name,
|
||||
)
|
||||
|
||||
output_path = output_location_to_path(output_location)
|
||||
if output_path.exists() and not settings.combine_force:
|
||||
logger.info("Skipping combine step (output exists): %s", output_location)
|
||||
else:
|
||||
rdflib_preloaded_graph = await asyncio.to_thread(build_combined_graph, entry_location)
|
||||
logger.info("Finished combining imports; serializing to: %s", output_location)
|
||||
await asyncio.to_thread(serialize_graph_to_ttl, rdflib_preloaded_graph, output_location)
|
||||
|
||||
if settings.graph_backend == "rdflib":
|
||||
settings.ttl_path = str(output_path)
|
||||
|
||||
sparql: SparqlEngine = create_sparql_engine(settings, rdflib_graph=rdflib_preloaded_graph)
|
||||
await sparql.startup()
|
||||
app.state.sparql = sparql
|
||||
app.state.snapshot_service = GraphSnapshotService(sparql=sparql, settings=settings)
|
||||
@@ -62,7 +102,10 @@ def health() -> dict[str, str]:
|
||||
async def stats() -> StatsResponse:
|
||||
# Stats reflect exactly what we send to the frontend (/api/graph), not global graph size.
|
||||
svc: GraphSnapshotService = app.state.snapshot_service
|
||||
snap = await svc.get(node_limit=50_000, edge_limit=100_000)
|
||||
try:
|
||||
snap = await svc.get(node_limit=50_000, edge_limit=100_000)
|
||||
except CycleError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e)) from None
|
||||
meta = snap.meta
|
||||
return StatsResponse(
|
||||
backend=meta.backend if meta else app.state.sparql.name,
|
||||
@@ -81,6 +124,20 @@ async def sparql_query(req: SparqlQueryRequest) -> dict:
|
||||
return data
|
||||
|
||||
|
||||
@app.post("/api/neighbors", response_model=NeighborsResponse)
|
||||
async def neighbors(req: NeighborsRequest) -> NeighborsResponse:
|
||||
svc: GraphSnapshotService = app.state.snapshot_service
|
||||
snap = await svc.get(node_limit=req.node_limit, edge_limit=req.edge_limit)
|
||||
sparql: SparqlEngine = app.state.sparql
|
||||
neighbor_ids = await fetch_neighbor_ids_for_selection(
|
||||
sparql,
|
||||
snapshot=snap,
|
||||
selected_ids=req.selected_ids,
|
||||
include_bnodes=settings.include_bnodes,
|
||||
)
|
||||
return NeighborsResponse(selected_ids=req.selected_ids, neighbor_ids=neighbor_ids)
|
||||
|
||||
|
||||
@app.get("/api/nodes", response_model=NodesResponse)
|
||||
def nodes(
|
||||
limit: int = Query(default=10_000, ge=1, le=200_000),
|
||||
@@ -109,4 +166,7 @@ async def graph(
|
||||
edge_limit: int = Query(default=100_000, ge=1, le=500_000),
|
||||
) -> GraphResponse:
|
||||
svc: GraphSnapshotService = app.state.snapshot_service
|
||||
return await svc.get(node_limit=node_limit, edge_limit=edge_limit)
|
||||
try:
|
||||
return await svc.get(node_limit=node_limit, edge_limit=edge_limit)
|
||||
except CycleError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e)) from None
|
||||
|
||||
Reference in New Issue
Block a user