Import Solver + neighbors via sparql query

This commit is contained in:
Oxy8
2026-03-04 13:49:14 -03:00
parent d4bfa5f064
commit a75b5b93da
15 changed files with 747 additions and 463 deletions

View File

@@ -32,6 +32,11 @@ Callers (frontend or other clients) interact with a single API surface (`/api/*`
- Used by `/api/nodes`, `/api/edges`, and `rdflib`-mode `/api/stats`.
- `pipelines/graph_snapshot.py`
- Pipeline used by `/api/graph` to return a `{nodes, edges}` snapshot via SPARQL (works for both RDFLib and AnzoGraph).
- `pipelines/layout_dag_radial.py`
- DAG layout helpers used by `pipelines/graph_snapshot.py`:
- cycle detection
- level-synchronous Kahn layering
- radial (ring-per-layer) positioning.
- `pipelines/snapshot_service.py`
- Snapshot cache layer used by `/api/graph` and `/api/stats` so the backend doesn't run expensive SPARQL twice.
- `pipelines/subclass_labels.py`
@@ -64,6 +69,14 @@ RDFLib mode:
- `TTL_PATH`: path inside the backend container to a `.ttl` file (example: `/data/o3po.ttl`)
- `MAX_TRIPLES`: optional int; if set, stops parsing after this many triples
Optional import-combining step (runs before the SPARQL engine starts):
- `COMBINE_OWL_IMPORTS_ON_START`: `true` to recursively load `TTL_PATH` (or `COMBINE_ENTRY_LOCATION`) plus `owl:imports` and write a combined TTL file.
- `COMBINE_ENTRY_LOCATION`: optional override for the entry file/URL to load (defaults to `TTL_PATH`)
- `COMBINE_OUTPUT_LOCATION`: optional explicit output path (defaults to `${dirname(entry)}/${COMBINE_OUTPUT_NAME}`)
- `COMBINE_OUTPUT_NAME`: output filename when `COMBINE_OUTPUT_LOCATION` is not set (default: `combined_ontology.ttl`)
- `COMBINE_FORCE`: `true` to rebuild even if the output file already exists
AnzoGraph mode:
- `SPARQL_HOST`: base host (example: `http://anzograph:8080`)
@@ -129,8 +142,8 @@ Returned in `nodes[]` (dense IDs; suitable for indexing in typed arrays):
- `id`: integer dense node ID used in edges
- `termType`: `"uri"` or `"bnode"`
- `iri`: URI string; blank nodes are normalized to `_:<id>`
- `label`: currently `null` in `/api/graph` snapshots (pipelines can be used to populate later)
- `x`/`y`: world-space coordinates for rendering (currently a deterministic spiral layout)
- `label`: `rdfs:label` when available (best-effort; prefers English)
- `x`/`y`: world-space coordinates for rendering (currently a radial layered layout derived from `rdfs:subClassOf`)
### Edge
@@ -149,11 +162,10 @@ Returned in `edges[]`:
## Snapshot Query (`/api/graph`)
`/api/graph` uses a SPARQL query that:
`/api/graph` currently uses a SPARQL query that returns only `rdfs:subClassOf` edges:
- selects triples `?s ?p ?o`
- excludes literal objects (`FILTER(!isLiteral(?o))`)
- excludes `rdfs:label`, `skos:prefLabel`, and `skos:altLabel` predicates
- selects bindings as `?s ?p ?o` (with `?p` bound to `rdfs:subClassOf`)
- excludes literal objects (`FILTER(!isLiteral(?o))`) for safety
- optionally excludes blank nodes (unless `INCLUDE_BNODES=true`)
- applies `LIMIT edge_limit`
@@ -161,6 +173,8 @@ The result bindings are mapped to dense node IDs (first-seen order) and returned
`/api/graph` also returns `meta` with snapshot counts and engine info so the frontend doesn't need to call `/api/stats`.
If a cycle is detected in the returned `rdfs:subClassOf` snapshot, `/api/graph` returns HTTP 422 (layout requires a DAG).
## Pipelines
### `pipelines/graph_snapshot.py`

View File

@@ -5,16 +5,25 @@ from typing import Any
def edge_retrieval_query(*, edge_limit: int, include_bnodes: bool) -> str:
bnode_filter = "" if include_bnodes else "FILTER(!isBlank(?s) && !isBlank(?o))"
return f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX owl: <http://www.w3.org/2002/07/owl#>
SELECT ?s ?p ?o
WHERE {{
?s ?p ?o .
{{
VALUES ?p {{ rdf:type }}
?s ?p ?o .
?o rdf:type owl:Class .
}}
UNION
{{
VALUES ?p {{ rdfs:subClassOf }}
?s ?p ?o .
}}
FILTER(!isLiteral(?o))
FILTER(?p NOT IN (
<http://www.w3.org/2000/01/rdf-schema#label>,
<http://www.w3.org/2004/02/skos/core#prefLabel>,
<http://www.w3.org/2004/02/skos/core#altLabel>
))
{bnode_filter}
}}
LIMIT {edge_limit}
@@ -91,4 +100,3 @@ def graph_from_sparql_bindings(
]
return out_nodes, out_edges

View File

@@ -1,11 +1,29 @@
from __future__ import annotations
from contextlib import asynccontextmanager
import logging
import asyncio
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from .models import EdgesResponse, GraphResponse, NodesResponse, SparqlQueryRequest, StatsResponse
from .models import (
EdgesResponse,
GraphResponse,
NeighborsRequest,
NeighborsResponse,
NodesResponse,
SparqlQueryRequest,
StatsResponse,
)
from .pipelines.layout_dag_radial import CycleError
from .pipelines.owl_imports_combiner import (
build_combined_graph,
output_location_to_path,
resolve_output_location,
serialize_graph_to_ttl,
)
from .pipelines.selection_neighbors import fetch_neighbor_ids_for_selection
from .pipelines.snapshot_service import GraphSnapshotService
from .rdf_store import RDFStore
from .sparql_engine import RdflibEngine, SparqlEngine, create_sparql_engine
@@ -13,11 +31,33 @@ from .settings import Settings
settings = Settings()
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
sparql: SparqlEngine = create_sparql_engine(settings)
rdflib_preloaded_graph = None
if settings.combine_owl_imports_on_start:
entry_location = settings.combine_entry_location or settings.ttl_path
output_location = resolve_output_location(
entry_location,
output_location=settings.combine_output_location,
output_name=settings.combine_output_name,
)
output_path = output_location_to_path(output_location)
if output_path.exists() and not settings.combine_force:
logger.info("Skipping combine step (output exists): %s", output_location)
else:
rdflib_preloaded_graph = await asyncio.to_thread(build_combined_graph, entry_location)
logger.info("Finished combining imports; serializing to: %s", output_location)
await asyncio.to_thread(serialize_graph_to_ttl, rdflib_preloaded_graph, output_location)
if settings.graph_backend == "rdflib":
settings.ttl_path = str(output_path)
sparql: SparqlEngine = create_sparql_engine(settings, rdflib_graph=rdflib_preloaded_graph)
await sparql.startup()
app.state.sparql = sparql
app.state.snapshot_service = GraphSnapshotService(sparql=sparql, settings=settings)
@@ -62,7 +102,10 @@ def health() -> dict[str, str]:
async def stats() -> StatsResponse:
# Stats reflect exactly what we send to the frontend (/api/graph), not global graph size.
svc: GraphSnapshotService = app.state.snapshot_service
snap = await svc.get(node_limit=50_000, edge_limit=100_000)
try:
snap = await svc.get(node_limit=50_000, edge_limit=100_000)
except CycleError as e:
raise HTTPException(status_code=422, detail=str(e)) from None
meta = snap.meta
return StatsResponse(
backend=meta.backend if meta else app.state.sparql.name,
@@ -81,6 +124,20 @@ async def sparql_query(req: SparqlQueryRequest) -> dict:
return data
@app.post("/api/neighbors", response_model=NeighborsResponse)
async def neighbors(req: NeighborsRequest) -> NeighborsResponse:
svc: GraphSnapshotService = app.state.snapshot_service
snap = await svc.get(node_limit=req.node_limit, edge_limit=req.edge_limit)
sparql: SparqlEngine = app.state.sparql
neighbor_ids = await fetch_neighbor_ids_for_selection(
sparql,
snapshot=snap,
selected_ids=req.selected_ids,
include_bnodes=settings.include_bnodes,
)
return NeighborsResponse(selected_ids=req.selected_ids, neighbor_ids=neighbor_ids)
@app.get("/api/nodes", response_model=NodesResponse)
def nodes(
limit: int = Query(default=10_000, ge=1, le=200_000),
@@ -109,4 +166,7 @@ async def graph(
edge_limit: int = Query(default=100_000, ge=1, le=500_000),
) -> GraphResponse:
svc: GraphSnapshotService = app.state.snapshot_service
return await svc.get(node_limit=node_limit, edge_limit=edge_limit)
try:
return await svc.get(node_limit=node_limit, edge_limit=edge_limit)
except CycleError as e:
raise HTTPException(status_code=422, detail=str(e)) from None

View File

@@ -56,3 +56,14 @@ class GraphResponse(BaseModel):
class SparqlQueryRequest(BaseModel):
query: str
class NeighborsRequest(BaseModel):
selected_ids: list[int]
node_limit: int = 50_000
edge_limit: int = 100_000
class NeighborsResponse(BaseModel):
selected_ids: list[int]
neighbor_ids: list[int]

View File

@@ -1,10 +1,64 @@
from __future__ import annotations
from typing import Any
from ..graph_export import edge_retrieval_query, graph_from_sparql_bindings
from ..models import GraphResponse
from ..sparql_engine import SparqlEngine
from ..settings import Settings
from .layout_spiral import spiral_positions
from .layout_dag_radial import CycleError, level_synchronous_kahn_layers, radial_positions_from_layers
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
def _bindings(res: dict[str, Any]) -> list[dict[str, Any]]:
return (((res.get("results") or {}).get("bindings")) or [])
def _label_score(label_binding: dict[str, Any]) -> int:
# Prefer English, then no-language, then anything else.
lang = (label_binding.get("xml:lang") or "").lower()
if lang == "en":
return 3
if lang == "":
return 2
return 1
async def _fetch_rdfs_labels_for_iris(
sparql: SparqlEngine,
iris: list[str],
*,
batch_size: int = 500,
) -> dict[str, str]:
best: dict[str, tuple[int, str]] = {}
for i in range(0, len(iris), batch_size):
batch = iris[i : i + batch_size]
values = " ".join(f"<{u}>" for u in batch)
q = f"""
SELECT ?s ?label
WHERE {{
VALUES ?s {{ {values} }}
?s <{RDFS_LABEL}> ?label .
}}
"""
res = await sparql.query_json(q)
for b in _bindings(res):
s = (b.get("s") or {}).get("value")
label_term = b.get("label") or {}
if not s or label_term.get("type") != "literal":
continue
label_value = label_term.get("value")
if label_value is None:
continue
score = _label_score(label_term)
prev = best.get(s)
if prev is None or score > prev[0]:
best[s] = (score, str(label_value))
return {iri: lbl for iri, (_, lbl) in best.items()}
async def fetch_graph_snapshot(
@@ -28,11 +82,59 @@ async def fetch_graph_snapshot(
)
# Add positions so the frontend doesn't need to run a layout.
xs, ys = spiral_positions(len(nodes))
#
# We are exporting only rdfs:subClassOf triples. In the exported edges:
# source = subclass, target = superclass
# For hierarchical layout we invert edges to:
# superclass -> subclass
hier_edges: list[tuple[int, int]] = []
for e in edges:
s = e.get("source")
t = e.get("target")
try:
sid = int(s) # subclass
tid = int(t) # superclass
except Exception:
continue
hier_edges.append((tid, sid))
try:
layers = level_synchronous_kahn_layers(node_count=len(nodes), edges=hier_edges)
except CycleError as e:
# Add a small URI sample to aid debugging.
sample: list[str] = []
for nid in e.remaining_node_ids[:20]:
try:
sample.append(str(nodes[nid].get("iri")))
except Exception:
continue
raise CycleError(
processed=e.processed,
total=e.total,
remaining_node_ids=e.remaining_node_ids,
remaining_iri_sample=sample or None,
) from None
# Deterministic order within each ring/layer for stable layouts.
id_to_iri = [str(n.get("iri", "")) for n in nodes]
for layer in layers:
layer.sort(key=lambda nid: id_to_iri[nid])
xs, ys = radial_positions_from_layers(node_count=len(nodes), layers=layers)
for i, node in enumerate(nodes):
node["x"] = float(xs[i])
node["y"] = float(ys[i])
# Attach labels for URI nodes (blank nodes remain label-less).
uri_nodes = [n for n in nodes if n.get("termType") == "uri"]
if uri_nodes:
iris = [str(n["iri"]) for n in uri_nodes if isinstance(n.get("iri"), str)]
label_by_iri = await _fetch_rdfs_labels_for_iris(sparql, iris)
for n in uri_nodes:
iri = n.get("iri")
if isinstance(iri, str) and iri in label_by_iri:
n["label"] = label_by_iri[iri]
meta = GraphResponse.Meta(
backend=sparql.name,
ttl_path=settings.ttl_path if settings.graph_backend == "rdflib" else None,

View File

@@ -0,0 +1,141 @@
from __future__ import annotations
import math
from collections import deque
from typing import Iterable, Sequence
class CycleError(RuntimeError):
"""
Raised when the requested layout requires a DAG, but a cycle is detected.
`remaining_node_ids` are the node ids that still had indegree > 0 after Kahn.
"""
def __init__(
self,
*,
processed: int,
total: int,
remaining_node_ids: list[int],
remaining_iri_sample: list[str] | None = None,
) -> None:
self.processed = int(processed)
self.total = int(total)
self.remaining_node_ids = remaining_node_ids
self.remaining_iri_sample = remaining_iri_sample
msg = f"Cycle detected in subClassOf graph (processed {self.processed}/{self.total} nodes)."
if remaining_iri_sample:
msg += f" Example nodes: {', '.join(remaining_iri_sample)}"
super().__init__(msg)
def level_synchronous_kahn_layers(
*,
node_count: int,
edges: Iterable[tuple[int, int]],
) -> list[list[int]]:
"""
Level-synchronous Kahn's algorithm:
- process the entire current queue as one batch (one layer)
- only then enqueue newly-unlocked nodes for the next batch
`edges` are directed (u -> v).
"""
n = int(node_count)
if n <= 0:
return []
adj: list[list[int]] = [[] for _ in range(n)]
indeg = [0] * n
for u, v in edges:
if u == v:
# Self-loops don't help layout and would trivially violate DAG-ness.
continue
if not (0 <= u < n and 0 <= v < n):
continue
adj[u].append(v)
indeg[v] += 1
q: deque[int] = deque(i for i, d in enumerate(indeg) if d == 0)
layers: list[list[int]] = []
processed = 0
while q:
# Consume the full current queue as a single layer.
layer = list(q)
q.clear()
layers.append(layer)
for u in layer:
processed += 1
for v in adj[u]:
indeg[v] -= 1
if indeg[v] == 0:
q.append(v)
if processed != n:
remaining = [i for i, d in enumerate(indeg) if d > 0]
raise CycleError(processed=processed, total=n, remaining_node_ids=remaining)
return layers
def radial_positions_from_layers(
*,
node_count: int,
layers: Sequence[Sequence[int]],
max_r: float = 5000.0,
) -> tuple[list[float], list[float]]:
"""
Assign node positions in concentric rings (one ring per layer).
- radius increases with layer index
- nodes within a layer are placed evenly by angle
- each ring gets a "golden-angle" rotation to reduce spoke artifacts
"""
n = int(node_count)
if n <= 0:
return ([], [])
xs = [0.0] * n
ys = [0.0] * n
if not layers:
return (xs, ys)
two_pi = 2.0 * math.pi
golden = math.pi * (3.0 - math.sqrt(5.0))
layer_count = len(layers)
denom = float(layer_count + 1)
for li, layer in enumerate(layers):
m = len(layer)
if m <= 0:
continue
# Keep everything within ~[-max_r, max_r] like the previous spiral layout.
r = ((li + 1) / denom) * max_r
# Rotate each layer deterministically to avoid radial spokes aligning.
offset = (li * golden) % two_pi
if m == 1:
nid = int(layer[0])
if 0 <= nid < n:
xs[nid] = r * math.cos(offset)
ys[nid] = r * math.sin(offset)
continue
step = two_pi / float(m)
for j, raw_id in enumerate(layer):
nid = int(raw_id)
if not (0 <= nid < n):
continue
t = offset + step * float(j)
xs[nid] = r * math.cos(t)
ys[nid] = r * math.sin(t)
return (xs, ys)

View File

@@ -0,0 +1,96 @@
from __future__ import annotations
import logging
import os
from pathlib import Path
from urllib.parse import unquote, urlparse
from rdflib import Graph
from rdflib.namespace import OWL
logger = logging.getLogger(__name__)
def _is_http_url(location: str) -> bool:
scheme = urlparse(location).scheme.lower()
return scheme in {"http", "https"}
def _is_file_uri(location: str) -> bool:
return urlparse(location).scheme.lower() == "file"
def _file_uri_to_path(location: str) -> Path:
u = urlparse(location)
if u.scheme.lower() != "file":
raise ValueError(f"Not a file:// URI: {location!r}")
return Path(unquote(u.path))
def resolve_output_location(
entry_location: str,
*,
output_location: str | None,
output_name: str,
) -> str:
if output_location:
return output_location
if _is_http_url(entry_location):
raise ValueError(
"COMBINE_ENTRY_LOCATION points to an http(s) URL; set COMBINE_OUTPUT_LOCATION to a writable file path."
)
entry_path = _file_uri_to_path(entry_location) if _is_file_uri(entry_location) else Path(entry_location)
return str(entry_path.parent / output_name)
def _output_destination_to_path(output_location: str) -> Path:
if _is_file_uri(output_location):
return _file_uri_to_path(output_location)
if _is_http_url(output_location):
raise ValueError("Output location must be a local file path (or file:// URI), not http(s).")
return Path(output_location)
def output_location_to_path(output_location: str) -> Path:
return _output_destination_to_path(output_location)
def build_combined_graph(entry_location: str) -> Graph:
"""
Recursively loads an RDF document (file path, file:// URI, or http(s) URL) and its
owl:imports into a single in-memory graph.
"""
combined_graph = Graph()
visited_locations: set[str] = set()
def resolve_imports(location: str) -> None:
if location in visited_locations:
return
visited_locations.add(location)
logger.info("Loading ontology: %s", location)
try:
combined_graph.parse(location=location)
except Exception as e:
logger.warning("Failed to load %s (%s)", location, e)
return
imports = [str(o) for _, _, o in combined_graph.triples((None, OWL.imports, None))]
for imported_location in imports:
if imported_location not in visited_locations:
resolve_imports(imported_location)
resolve_imports(entry_location)
return combined_graph
def serialize_graph_to_ttl(graph: Graph, output_location: str) -> None:
output_path = _output_destination_to_path(output_location)
output_path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = output_path.with_suffix(output_path.suffix + ".tmp")
graph.serialize(destination=str(tmp_path), format="turtle")
os.replace(str(tmp_path), str(output_path))

View File

@@ -0,0 +1,137 @@
from __future__ import annotations
from typing import Any, Iterable
from ..models import GraphResponse, Node
from ..sparql_engine import SparqlEngine
def _values_term(node: Node) -> str | None:
iri = node.iri
if node.termType == "uri":
return f"<{iri}>"
if node.termType == "bnode":
if iri.startswith("_:"):
return iri
return f"_:{iri}"
return None
def selection_neighbors_query(*, selected_nodes: Iterable[Node], include_bnodes: bool) -> str:
values_terms: list[str] = []
for n in selected_nodes:
t = _values_term(n)
if t is None:
continue
values_terms.append(t)
if not values_terms:
# Caller should avoid running this query when selection is empty, but keep this safe.
return "SELECT ?nbr WHERE { FILTER(false) }"
bnode_filter = "" if include_bnodes else "FILTER(!isBlank(?nbr))"
values = " ".join(values_terms)
# Neighbors are defined as any node directly connected by rdf:type (to owl:Class)
# or rdfs:subClassOf, in either direction (treating edges as undirected).
return f"""
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX owl: <http://www.w3.org/2002/07/owl#>
SELECT DISTINCT ?nbr
WHERE {{
VALUES ?sel {{ {values} }}
{{
?sel rdf:type ?o .
?o rdf:type owl:Class .
BIND(?o AS ?nbr)
}}
UNION
{{
?s rdf:type ?sel .
?sel rdf:type owl:Class .
BIND(?s AS ?nbr)
}}
UNION
{{
?sel rdfs:subClassOf ?o .
BIND(?o AS ?nbr)
}}
UNION
{{
?s rdfs:subClassOf ?sel .
BIND(?s AS ?nbr)
}}
FILTER(!isLiteral(?nbr))
FILTER(?nbr != ?sel)
{bnode_filter}
}}
"""
def _bindings(res: dict[str, Any]) -> list[dict[str, Any]]:
return (((res.get("results") or {}).get("bindings")) or [])
def _term_key(term: dict[str, Any], *, include_bnodes: bool) -> tuple[str, str] | None:
t = term.get("type")
v = term.get("value")
if not t or v is None:
return None
if t == "literal":
return None
if t == "bnode":
if not include_bnodes:
return None
return ("bnode", f"_:{v}")
return ("uri", str(v))
async def fetch_neighbor_ids_for_selection(
sparql: SparqlEngine,
*,
snapshot: GraphResponse,
selected_ids: list[int],
include_bnodes: bool,
) -> list[int]:
id_to_node: dict[int, Node] = {n.id: n for n in snapshot.nodes}
selected_nodes: list[Node] = []
selected_id_set: set[int] = set()
for nid in selected_ids:
if not isinstance(nid, int):
continue
n = id_to_node.get(nid)
if n is None:
continue
if n.termType == "bnode" and not include_bnodes:
continue
selected_nodes.append(n)
selected_id_set.add(nid)
if not selected_nodes:
return []
key_to_id: dict[tuple[str, str], int] = {}
for n in snapshot.nodes:
key_to_id[(n.termType, n.iri)] = n.id
q = selection_neighbors_query(selected_nodes=selected_nodes, include_bnodes=include_bnodes)
res = await sparql.query_json(q)
neighbor_ids: set[int] = set()
for b in _bindings(res):
nbr_term = b.get("nbr") or {}
key = _term_key(nbr_term, include_bnodes=include_bnodes)
if key is None:
continue
nid = key_to_id.get(key)
if nid is None:
continue
if nid in selected_id_set:
continue
neighbor_ids.add(nid)
# Stable ordering for consistent frontend behavior.
return sorted(neighbor_ids)

View File

@@ -16,6 +16,13 @@ class Settings(BaseSettings):
include_bnodes: bool = Field(default=False, alias="INCLUDE_BNODES")
max_triples: int | None = Field(default=None, alias="MAX_TRIPLES")
# Optional: Combine owl:imports into a single TTL file on backend startup.
combine_owl_imports_on_start: bool = Field(default=False, alias="COMBINE_OWL_IMPORTS_ON_START")
combine_entry_location: str | None = Field(default=None, alias="COMBINE_ENTRY_LOCATION")
combine_output_location: str | None = Field(default=None, alias="COMBINE_OUTPUT_LOCATION")
combine_output_name: str = Field(default="combined_ontology.ttl", alias="COMBINE_OUTPUT_NAME")
combine_force: bool = Field(default=False, alias="COMBINE_FORCE")
# AnzoGraph / SPARQL endpoint configuration
sparql_host: str = Field(default="http://anzograph:8080", alias="SPARQL_HOST")
# If not set, the backend uses `${SPARQL_HOST}/sparql`.

View File

@@ -24,11 +24,13 @@ class SparqlEngine(Protocol):
class RdflibEngine:
name = "rdflib"
def __init__(self, *, ttl_path: str):
def __init__(self, *, ttl_path: str, graph: Graph | None = None):
self.ttl_path = ttl_path
self.graph: Graph | None = None
self.graph: Graph | None = graph
async def startup(self) -> None:
if self.graph is not None:
return
g = Graph()
g.parse(self.ttl_path, format="turtle")
self.graph = g
@@ -167,9 +169,9 @@ class AnzoGraphEngine:
raise RuntimeError(f"AnzoGraph not ready at {self.endpoint}") from last_err
def create_sparql_engine(settings: Settings) -> SparqlEngine:
def create_sparql_engine(settings: Settings, *, rdflib_graph: Graph | None = None) -> SparqlEngine:
if settings.graph_backend == "rdflib":
return RdflibEngine(ttl_path=settings.ttl_path)
return RdflibEngine(ttl_path=settings.ttl_path, graph=rdflib_graph)
if settings.graph_backend == "anzograph":
return AnzoGraphEngine(settings=settings)
raise RuntimeError(f"Unsupported GRAPH_BACKEND={settings.graph_backend!r}")