from __future__ import annotations from typing import Any, Iterable from ..models import GraphResponse, Node from ..sparql_engine import SparqlEngine def _values_term(node: Node) -> str | None: iri = node.iri if node.termType == "uri": return f"<{iri}>" if node.termType == "bnode": if iri.startswith("_:"): return iri return f"_:{iri}" return None def selection_neighbors_query(*, selected_nodes: Iterable[Node], include_bnodes: bool) -> str: values_terms: list[str] = [] for n in selected_nodes: t = _values_term(n) if t is None: continue values_terms.append(t) if not values_terms: # Caller should avoid running this query when selection is empty, but keep this safe. return "SELECT ?nbr WHERE { FILTER(false) }" bnode_filter = "" if include_bnodes else "FILTER(!isBlank(?nbr))" values = " ".join(values_terms) # Neighbors are defined as any node directly connected by rdf:type (to owl:Class) # or rdfs:subClassOf, in either direction (treating edges as undirected). return f""" PREFIX rdf: PREFIX rdfs: PREFIX owl: SELECT DISTINCT ?nbr WHERE {{ VALUES ?sel {{ {values} }} {{ ?sel rdf:type ?o . ?o rdf:type owl:Class . BIND(?o AS ?nbr) }} UNION {{ ?s rdf:type ?sel . ?sel rdf:type owl:Class . BIND(?s AS ?nbr) }} UNION {{ ?sel rdfs:subClassOf ?o . BIND(?o AS ?nbr) }} UNION {{ ?s rdfs:subClassOf ?sel . BIND(?s AS ?nbr) }} FILTER(!isLiteral(?nbr)) FILTER(?nbr != ?sel) {bnode_filter} }} """ def _bindings(res: dict[str, Any]) -> list[dict[str, Any]]: return (((res.get("results") or {}).get("bindings")) or []) def _term_key(term: dict[str, Any], *, include_bnodes: bool) -> tuple[str, str] | None: t = term.get("type") v = term.get("value") if not t or v is None: return None if t == "literal": return None if t == "bnode": if not include_bnodes: return None return ("bnode", f"_:{v}") return ("uri", str(v)) async def fetch_neighbor_ids_for_selection( sparql: SparqlEngine, *, snapshot: GraphResponse, selected_ids: list[int], include_bnodes: bool, ) -> list[int]: id_to_node: dict[int, Node] = {n.id: n for n in snapshot.nodes} selected_nodes: list[Node] = [] selected_id_set: set[int] = set() for nid in selected_ids: if not isinstance(nid, int): continue n = id_to_node.get(nid) if n is None: continue if n.termType == "bnode" and not include_bnodes: continue selected_nodes.append(n) selected_id_set.add(nid) if not selected_nodes: return [] key_to_id: dict[tuple[str, str], int] = {} for n in snapshot.nodes: key_to_id[(n.termType, n.iri)] = n.id q = selection_neighbors_query(selected_nodes=selected_nodes, include_bnodes=include_bnodes) res = await sparql.query_json(q) neighbor_ids: set[int] = set() for b in _bindings(res): nbr_term = b.get("nbr") or {} key = _term_key(nbr_term, include_bnodes=include_bnodes) if key is None: continue nid = key_to_id.get(key) if nid is None: continue if nid in selected_id_set: continue neighbor_ids.add(nid) # Stable ordering for consistent frontend behavior. return sorted(neighbor_ids)