154 lines
4.6 KiB
Python
154 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from ..sparql_engine import SparqlEngine
|
|
|
|
RDFS_SUBCLASS_OF = "http://www.w3.org/2000/01/rdf-schema#subClassOf"
|
|
RDFS_LABEL = "http://www.w3.org/2000/01/rdf-schema#label"
|
|
|
|
|
|
def _bindings(res: dict[str, Any]) -> list[dict[str, Any]]:
|
|
return (((res.get("results") or {}).get("bindings")) or [])
|
|
|
|
|
|
def _term_key(term: dict[str, Any]) -> tuple[str, str] | None:
|
|
t = term.get("type")
|
|
v = term.get("value")
|
|
if not t or v is None:
|
|
return None
|
|
if t == "literal":
|
|
return None
|
|
if t == "bnode":
|
|
return ("bnode", str(v))
|
|
return ("uri", str(v))
|
|
|
|
|
|
def _key_to_entity_string(key: tuple[str, str]) -> str:
|
|
t, v = key
|
|
if t == "bnode":
|
|
return f"_:{v}"
|
|
return v
|
|
|
|
|
|
def _label_score(binding: dict[str, Any]) -> int:
|
|
"""
|
|
Higher is better.
|
|
Prefer English, then no-language, then anything else.
|
|
"""
|
|
lang = (binding.get("xml:lang") or "").lower()
|
|
if lang == "en":
|
|
return 3
|
|
if lang == "":
|
|
return 2
|
|
return 1
|
|
|
|
|
|
async def extract_subclass_entities_and_labels(
|
|
sparql: SparqlEngine,
|
|
*,
|
|
include_bnodes: bool,
|
|
label_batch_size: int = 500,
|
|
) -> tuple[list[str], list[str | None]]:
|
|
"""
|
|
Pipeline:
|
|
1) Query all rdfs:subClassOf triples.
|
|
2) Build a unique set of entity terms from subjects+objects, convert to list.
|
|
3) Fetch rdfs:label for those entities and return an aligned labels list.
|
|
|
|
Returns:
|
|
entities: list[str] (IRI or "_:bnodeId")
|
|
labels: list[str|None], aligned with entities
|
|
"""
|
|
|
|
subclass_q = f"""
|
|
SELECT ?s ?o
|
|
WHERE {{
|
|
?s <{RDFS_SUBCLASS_OF}> ?o .
|
|
FILTER(!isLiteral(?o))
|
|
{"FILTER(!isBlank(?s) && !isBlank(?o))" if not include_bnodes else ""}
|
|
}}
|
|
"""
|
|
res = await sparql.query_json(subclass_q)
|
|
|
|
entity_keys: set[tuple[str, str]] = set()
|
|
for b in _bindings(res):
|
|
sk = _term_key(b.get("s") or {})
|
|
ok = _term_key(b.get("o") or {})
|
|
if sk is not None and (include_bnodes or sk[0] != "bnode"):
|
|
entity_keys.add(sk)
|
|
if ok is not None and (include_bnodes or ok[0] != "bnode"):
|
|
entity_keys.add(ok)
|
|
|
|
# Deterministic ordering.
|
|
entity_key_list = sorted(entity_keys, key=lambda k: (k[0], k[1]))
|
|
entities = [_key_to_entity_string(k) for k in entity_key_list]
|
|
|
|
# Build label map keyed by term key.
|
|
best_label_by_key: dict[tuple[str, str], tuple[int, str]] = {}
|
|
|
|
# URIs can be batch-queried via VALUES.
|
|
uri_values = [v for (t, v) in entity_key_list if t == "uri"]
|
|
for i in range(0, len(uri_values), label_batch_size):
|
|
batch = uri_values[i : i + label_batch_size]
|
|
values = " ".join(f"<{u}>" for u in batch)
|
|
labels_q = f"""
|
|
SELECT ?s ?label
|
|
WHERE {{
|
|
VALUES ?s {{ {values} }}
|
|
?s <{RDFS_LABEL}> ?label .
|
|
}}
|
|
"""
|
|
lres = await sparql.query_json(labels_q)
|
|
for b in _bindings(lres):
|
|
sk = _term_key(b.get("s") or {})
|
|
if sk is None or sk[0] != "uri":
|
|
continue
|
|
label_term = b.get("label") or {}
|
|
if label_term.get("type") != "literal":
|
|
continue
|
|
label_value = label_term.get("value")
|
|
if label_value is None:
|
|
continue
|
|
|
|
score = _label_score(label_term)
|
|
prev = best_label_by_key.get(sk)
|
|
if prev is None or score > prev[0]:
|
|
best_label_by_key[sk] = (score, str(label_value))
|
|
|
|
# Blank nodes can't reliably be addressed by ID across queries, but if enabled we can still
|
|
# fetch all bnode labels and filter locally.
|
|
if include_bnodes:
|
|
bnode_keys = {k for k in entity_key_list if k[0] == "bnode"}
|
|
if bnode_keys:
|
|
bnode_labels_q = f"""
|
|
SELECT ?s ?label
|
|
WHERE {{
|
|
?s <{RDFS_LABEL}> ?label .
|
|
FILTER(isBlank(?s))
|
|
}}
|
|
"""
|
|
blres = await sparql.query_json(bnode_labels_q)
|
|
for b in _bindings(blres):
|
|
sk = _term_key(b.get("s") or {})
|
|
if sk is None or sk not in bnode_keys:
|
|
continue
|
|
label_term = b.get("label") or {}
|
|
if label_term.get("type") != "literal":
|
|
continue
|
|
label_value = label_term.get("value")
|
|
if label_value is None:
|
|
continue
|
|
score = _label_score(label_term)
|
|
prev = best_label_by_key.get(sk)
|
|
if prev is None or score > prev[0]:
|
|
best_label_by_key[sk] = (score, str(label_value))
|
|
|
|
labels: list[str | None] = []
|
|
for k in entity_key_list:
|
|
item = best_label_by_key.get(k)
|
|
labels.append(item[1] if item else None)
|
|
|
|
return entities, labels
|
|
|