138 lines
3.5 KiB
Python
138 lines
3.5 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Iterable
|
|
|
|
from ..models import GraphResponse, Node
|
|
from ..sparql_engine import SparqlEngine
|
|
|
|
|
|
def _values_term(node: Node) -> str | None:
|
|
iri = node.iri
|
|
if node.termType == "uri":
|
|
return f"<{iri}>"
|
|
if node.termType == "bnode":
|
|
if iri.startswith("_:"):
|
|
return iri
|
|
return f"_:{iri}"
|
|
return None
|
|
|
|
|
|
def selection_neighbors_query(*, selected_nodes: Iterable[Node], include_bnodes: bool) -> str:
|
|
values_terms: list[str] = []
|
|
for n in selected_nodes:
|
|
t = _values_term(n)
|
|
if t is None:
|
|
continue
|
|
values_terms.append(t)
|
|
|
|
if not values_terms:
|
|
# Caller should avoid running this query when selection is empty, but keep this safe.
|
|
return "SELECT ?nbr WHERE { FILTER(false) }"
|
|
|
|
bnode_filter = "" if include_bnodes else "FILTER(!isBlank(?nbr))"
|
|
values = " ".join(values_terms)
|
|
|
|
# Neighbors are defined as any node directly connected by rdf:type (to owl:Class)
|
|
# or rdfs:subClassOf, in either direction (treating edges as undirected).
|
|
return f"""
|
|
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
|
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
|
PREFIX owl: <http://www.w3.org/2002/07/owl#>
|
|
|
|
SELECT DISTINCT ?nbr
|
|
WHERE {{
|
|
VALUES ?sel {{ {values} }}
|
|
{{
|
|
?sel rdf:type ?o .
|
|
?o rdf:type owl:Class .
|
|
BIND(?o AS ?nbr)
|
|
}}
|
|
UNION
|
|
{{
|
|
?s rdf:type ?sel .
|
|
?sel rdf:type owl:Class .
|
|
BIND(?s AS ?nbr)
|
|
}}
|
|
UNION
|
|
{{
|
|
?sel rdfs:subClassOf ?o .
|
|
BIND(?o AS ?nbr)
|
|
}}
|
|
UNION
|
|
{{
|
|
?s rdfs:subClassOf ?sel .
|
|
BIND(?s AS ?nbr)
|
|
}}
|
|
FILTER(!isLiteral(?nbr))
|
|
FILTER(?nbr != ?sel)
|
|
{bnode_filter}
|
|
}}
|
|
"""
|
|
|
|
|
|
def _bindings(res: dict[str, Any]) -> list[dict[str, Any]]:
|
|
return (((res.get("results") or {}).get("bindings")) or [])
|
|
|
|
|
|
def _term_key(term: dict[str, Any], *, include_bnodes: bool) -> tuple[str, str] | None:
|
|
t = term.get("type")
|
|
v = term.get("value")
|
|
if not t or v is None:
|
|
return None
|
|
if t == "literal":
|
|
return None
|
|
if t == "bnode":
|
|
if not include_bnodes:
|
|
return None
|
|
return ("bnode", f"_:{v}")
|
|
return ("uri", str(v))
|
|
|
|
|
|
async def fetch_neighbor_ids_for_selection(
|
|
sparql: SparqlEngine,
|
|
*,
|
|
snapshot: GraphResponse,
|
|
selected_ids: list[int],
|
|
include_bnodes: bool,
|
|
) -> list[int]:
|
|
id_to_node: dict[int, Node] = {n.id: n for n in snapshot.nodes}
|
|
|
|
selected_nodes: list[Node] = []
|
|
selected_id_set: set[int] = set()
|
|
for nid in selected_ids:
|
|
if not isinstance(nid, int):
|
|
continue
|
|
n = id_to_node.get(nid)
|
|
if n is None:
|
|
continue
|
|
if n.termType == "bnode" and not include_bnodes:
|
|
continue
|
|
selected_nodes.append(n)
|
|
selected_id_set.add(nid)
|
|
|
|
if not selected_nodes:
|
|
return []
|
|
|
|
key_to_id: dict[tuple[str, str], int] = {}
|
|
for n in snapshot.nodes:
|
|
key_to_id[(n.termType, n.iri)] = n.id
|
|
|
|
q = selection_neighbors_query(selected_nodes=selected_nodes, include_bnodes=include_bnodes)
|
|
res = await sparql.query_json(q)
|
|
|
|
neighbor_ids: set[int] = set()
|
|
for b in _bindings(res):
|
|
nbr_term = b.get("nbr") or {}
|
|
key = _term_key(nbr_term, include_bnodes=include_bnodes)
|
|
if key is None:
|
|
continue
|
|
nid = key_to_id.get(key)
|
|
if nid is None:
|
|
continue
|
|
if nid in selected_id_set:
|
|
continue
|
|
neighbor_ids.add(nid)
|
|
|
|
# Stable ordering for consistent frontend behavior.
|
|
return sorted(neighbor_ids)
|