backend: support external SPARQL and named-graph snapshots
This commit is contained in:
341
backend_go/sparql_decode.go
Normal file
341
backend_go/sparql_decode.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type sparqlQueryMetrics struct {
|
||||
ResponseBytes int64
|
||||
RoundTripTime time.Duration
|
||||
BodyDecodeTime time.Duration
|
||||
BindingCount int
|
||||
}
|
||||
|
||||
type countingReadCloser struct {
|
||||
io.ReadCloser
|
||||
bytesRead int64
|
||||
}
|
||||
|
||||
func (c *countingReadCloser) Read(p []byte) (int, error) {
|
||||
n, err := c.ReadCloser.Read(p)
|
||||
c.bytesRead += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
type cancelOnCloseReadCloser struct {
|
||||
io.ReadCloser
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *cancelOnCloseReadCloser) Close() error {
|
||||
err := c.ReadCloser.Close()
|
||||
c.cancel()
|
||||
return err
|
||||
}
|
||||
|
||||
type sparqlHTTPStatusError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
Body string
|
||||
}
|
||||
|
||||
func (e *sparqlHTTPStatusError) Error() string {
|
||||
return fmt.Sprintf("sparql query failed: %s: %s", e.Status, e.Body)
|
||||
}
|
||||
|
||||
func (c *AnzoGraphClient) queryRequestWithTimeout(ctx context.Context, query string, timeout time.Duration) (*http.Response, time.Duration, error) {
|
||||
ctx2, cancel := context.WithTimeout(ctx, timeout)
|
||||
resp, roundTripTime, err := c.queryRequest(ctx2, query, true)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, roundTripTime, err
|
||||
}
|
||||
resp.Body = &cancelOnCloseReadCloser{ReadCloser: resp.Body, cancel: cancel}
|
||||
return resp, roundTripTime, nil
|
||||
}
|
||||
|
||||
func (c *AnzoGraphClient) queryRequest(ctx context.Context, query string, allowRefresh bool) (*http.Response, time.Duration, error) {
|
||||
resp, roundTripTime, err := c.queryRequestAttempt(ctx, query)
|
||||
if err == nil {
|
||||
return resp, roundTripTime, nil
|
||||
}
|
||||
|
||||
var statusErr *sparqlHTTPStatusError
|
||||
if allowRefresh && errors.As(err, &statusErr) && c.shouldRefreshExpiredJWT(statusErr) {
|
||||
log.Printf("[auth] sparql_token_retry endpoint=%s reason=jwt_expired", c.endpoint)
|
||||
if _, refreshErr := c.refreshExternalToken(ctx, "sparql_jwt_expired"); refreshErr != nil {
|
||||
return nil, roundTripTime, fmt.Errorf("%w (token refresh failed: %v)", statusErr, refreshErr)
|
||||
}
|
||||
|
||||
retryResp, retryRoundTripTime, retryErr := c.queryRequest(ctx, query, false)
|
||||
return retryResp, roundTripTime + retryRoundTripTime, retryErr
|
||||
}
|
||||
|
||||
return nil, roundTripTime, err
|
||||
}
|
||||
|
||||
func (c *AnzoGraphClient) queryRequestAttempt(ctx context.Context, query string) (*http.Response, time.Duration, error) {
|
||||
form := url.Values{}
|
||||
form.Set("query", query)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/sparql-results+json")
|
||||
authHeader, err := c.authorizationHeader(ctx, "sparql_query")
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if authHeader != "" {
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
roundTripTime := time.Since(start)
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
defer resp.Body.Close()
|
||||
body, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return nil, roundTripTime, readErr
|
||||
}
|
||||
return nil, roundTripTime, &sparqlHTTPStatusError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
Body: strings.TrimSpace(string(body)),
|
||||
}
|
||||
}
|
||||
|
||||
return resp, roundTripTime, nil
|
||||
}
|
||||
|
||||
func (c *AnzoGraphClient) shouldRefreshExpiredJWT(err *sparqlHTTPStatusError) bool {
|
||||
if err == nil || !c.cfg.UsesExternalSparql() {
|
||||
return false
|
||||
}
|
||||
return err.StatusCode == http.StatusUnauthorized && strings.Contains(err.Body, "Jwt is expired")
|
||||
}
|
||||
|
||||
func (c *AnzoGraphClient) QueryJSON(ctx context.Context, query string, target any) (sparqlQueryMetrics, error) {
|
||||
return c.queryJSONWithTimeout(ctx, query, c.cfg.SparqlTimeout, target)
|
||||
}
|
||||
|
||||
func (c *AnzoGraphClient) queryJSONWithTimeout(ctx context.Context, query string, timeout time.Duration, target any) (sparqlQueryMetrics, error) {
|
||||
resp, roundTripTime, err := c.queryRequestWithTimeout(ctx, query, timeout)
|
||||
if err != nil {
|
||||
return sparqlQueryMetrics{}, err
|
||||
}
|
||||
counter := &countingReadCloser{ReadCloser: resp.Body}
|
||||
defer counter.Close()
|
||||
|
||||
decodeStart := time.Now()
|
||||
if err := json.NewDecoder(counter).Decode(target); err != nil {
|
||||
return sparqlQueryMetrics{
|
||||
ResponseBytes: counter.bytesRead,
|
||||
RoundTripTime: roundTripTime,
|
||||
BodyDecodeTime: time.Since(decodeStart),
|
||||
}, wrapSparqlJSONDecodeError(err)
|
||||
}
|
||||
|
||||
return sparqlQueryMetrics{
|
||||
ResponseBytes: counter.bytesRead,
|
||||
RoundTripTime: roundTripTime,
|
||||
BodyDecodeTime: time.Since(decodeStart),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *AnzoGraphClient) QueryTripleBindingsStream(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
visit func(binding sparqlTripleBinding) error,
|
||||
) (sparqlQueryMetrics, error) {
|
||||
return c.queryTripleBindingsStreamWithTimeout(ctx, query, c.cfg.SparqlTimeout, visit)
|
||||
}
|
||||
|
||||
func (c *AnzoGraphClient) queryTripleBindingsStreamWithTimeout(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
timeout time.Duration,
|
||||
visit func(binding sparqlTripleBinding) error,
|
||||
) (sparqlQueryMetrics, error) {
|
||||
resp, roundTripTime, err := c.queryRequestWithTimeout(ctx, query, timeout)
|
||||
if err != nil {
|
||||
return sparqlQueryMetrics{}, err
|
||||
}
|
||||
counter := &countingReadCloser{ReadCloser: resp.Body}
|
||||
defer counter.Close()
|
||||
|
||||
decodeStart := time.Now()
|
||||
bindingCount, err := decodeBindingsStream(json.NewDecoder(counter), visit)
|
||||
if err != nil {
|
||||
return sparqlQueryMetrics{
|
||||
ResponseBytes: counter.bytesRead,
|
||||
RoundTripTime: roundTripTime,
|
||||
BodyDecodeTime: time.Since(decodeStart),
|
||||
BindingCount: bindingCount,
|
||||
}, wrapSparqlJSONDecodeError(err)
|
||||
}
|
||||
|
||||
return sparqlQueryMetrics{
|
||||
ResponseBytes: counter.bytesRead,
|
||||
RoundTripTime: roundTripTime,
|
||||
BodyDecodeTime: time.Since(decodeStart),
|
||||
BindingCount: bindingCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeBindingsStream(dec *json.Decoder, visit func(binding sparqlTripleBinding) error) (int, error) {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if delim, ok := tok.(json.Delim); !ok || delim != '{' {
|
||||
return 0, fmt.Errorf("invalid SPARQL JSON: expected top-level object")
|
||||
}
|
||||
|
||||
foundResults := false
|
||||
bindingCount := 0
|
||||
for dec.More() {
|
||||
keyToken, err := dec.Token()
|
||||
if err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
|
||||
key, ok := keyToken.(string)
|
||||
if !ok {
|
||||
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected top-level field name")
|
||||
}
|
||||
|
||||
switch key {
|
||||
case "results":
|
||||
foundResults = true
|
||||
n, err := decodeTripleBindingsObject(dec, visit)
|
||||
bindingCount += n
|
||||
if err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
default:
|
||||
if err := discardJSONValue(dec); err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tok, err = dec.Token()
|
||||
if err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
if delim, ok := tok.(json.Delim); !ok || delim != '}' {
|
||||
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected top-level object terminator")
|
||||
}
|
||||
|
||||
if !foundResults {
|
||||
return 0, fmt.Errorf("invalid SPARQL JSON: missing results field")
|
||||
}
|
||||
|
||||
return bindingCount, nil
|
||||
}
|
||||
|
||||
func decodeTripleBindingsObject(dec *json.Decoder, visit func(binding sparqlTripleBinding) error) (int, error) {
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if delim, ok := tok.(json.Delim); !ok || delim != '{' {
|
||||
return 0, fmt.Errorf("invalid SPARQL JSON: expected results object")
|
||||
}
|
||||
|
||||
bindingCount := 0
|
||||
for dec.More() {
|
||||
keyToken, err := dec.Token()
|
||||
if err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
|
||||
key, ok := keyToken.(string)
|
||||
if !ok {
|
||||
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected results field name")
|
||||
}
|
||||
|
||||
if key != "bindings" {
|
||||
if err := discardJSONValue(dec); err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
tok, err := dec.Token()
|
||||
if err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
if delim, ok := tok.(json.Delim); !ok || delim != '[' {
|
||||
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected bindings array")
|
||||
}
|
||||
|
||||
for dec.More() {
|
||||
var binding sparqlTripleBinding
|
||||
if err := dec.Decode(&binding); err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
bindingCount++
|
||||
if err := visit(binding); err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
}
|
||||
|
||||
tok, err = dec.Token()
|
||||
if err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
if delim, ok := tok.(json.Delim); !ok || delim != ']' {
|
||||
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected bindings array terminator")
|
||||
}
|
||||
}
|
||||
|
||||
tok, err = dec.Token()
|
||||
if err != nil {
|
||||
return bindingCount, err
|
||||
}
|
||||
if delim, ok := tok.(json.Delim); !ok || delim != '}' {
|
||||
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected results object terminator")
|
||||
}
|
||||
|
||||
return bindingCount, nil
|
||||
}
|
||||
|
||||
func discardJSONValue(dec *json.Decoder) error {
|
||||
var discard json.RawMessage
|
||||
return dec.Decode(&discard)
|
||||
}
|
||||
|
||||
func wrapSparqlJSONDecodeError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if isTruncatedJSONError(err) {
|
||||
return fmt.Errorf("truncated SPARQL JSON: %w", err)
|
||||
}
|
||||
return fmt.Errorf("failed to decode SPARQL JSON: %w", err)
|
||||
}
|
||||
|
||||
func isTruncatedJSONError(err error) bool {
|
||||
return errors.Is(err, io.ErrUnexpectedEOF) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
strings.Contains(err.Error(), "unexpected end of JSON input") ||
|
||||
strings.Contains(err.Error(), "unexpected EOF")
|
||||
}
|
||||
Reference in New Issue
Block a user