342 lines
8.9 KiB
Go
342 lines
8.9 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type sparqlQueryMetrics struct {
|
|
ResponseBytes int64
|
|
RoundTripTime time.Duration
|
|
BodyDecodeTime time.Duration
|
|
BindingCount int
|
|
}
|
|
|
|
type countingReadCloser struct {
|
|
io.ReadCloser
|
|
bytesRead int64
|
|
}
|
|
|
|
func (c *countingReadCloser) Read(p []byte) (int, error) {
|
|
n, err := c.ReadCloser.Read(p)
|
|
c.bytesRead += int64(n)
|
|
return n, err
|
|
}
|
|
|
|
type cancelOnCloseReadCloser struct {
|
|
io.ReadCloser
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
func (c *cancelOnCloseReadCloser) Close() error {
|
|
err := c.ReadCloser.Close()
|
|
c.cancel()
|
|
return err
|
|
}
|
|
|
|
type sparqlHTTPStatusError struct {
|
|
StatusCode int
|
|
Status string
|
|
Body string
|
|
}
|
|
|
|
func (e *sparqlHTTPStatusError) Error() string {
|
|
return fmt.Sprintf("sparql query failed: %s: %s", e.Status, e.Body)
|
|
}
|
|
|
|
func (c *AnzoGraphClient) queryRequestWithTimeout(ctx context.Context, query string, timeout time.Duration) (*http.Response, time.Duration, error) {
|
|
ctx2, cancel := context.WithTimeout(ctx, timeout)
|
|
resp, roundTripTime, err := c.queryRequest(ctx2, query, true)
|
|
if err != nil {
|
|
cancel()
|
|
return nil, roundTripTime, err
|
|
}
|
|
resp.Body = &cancelOnCloseReadCloser{ReadCloser: resp.Body, cancel: cancel}
|
|
return resp, roundTripTime, nil
|
|
}
|
|
|
|
func (c *AnzoGraphClient) queryRequest(ctx context.Context, query string, allowRefresh bool) (*http.Response, time.Duration, error) {
|
|
resp, roundTripTime, err := c.queryRequestAttempt(ctx, query)
|
|
if err == nil {
|
|
return resp, roundTripTime, nil
|
|
}
|
|
|
|
var statusErr *sparqlHTTPStatusError
|
|
if allowRefresh && errors.As(err, &statusErr) && c.shouldRefreshExpiredJWT(statusErr) {
|
|
log.Printf("[auth] sparql_token_retry endpoint=%s reason=jwt_expired", c.endpoint)
|
|
if _, refreshErr := c.refreshExternalToken(ctx, "sparql_jwt_expired"); refreshErr != nil {
|
|
return nil, roundTripTime, fmt.Errorf("%w (token refresh failed: %v)", statusErr, refreshErr)
|
|
}
|
|
|
|
retryResp, retryRoundTripTime, retryErr := c.queryRequest(ctx, query, false)
|
|
return retryResp, roundTripTime + retryRoundTripTime, retryErr
|
|
}
|
|
|
|
return nil, roundTripTime, err
|
|
}
|
|
|
|
func (c *AnzoGraphClient) queryRequestAttempt(ctx context.Context, query string) (*http.Response, time.Duration, error) {
|
|
form := url.Values{}
|
|
form.Set("query", query)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/sparql-results+json")
|
|
authHeader, err := c.authorizationHeader(ctx, "sparql_query")
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
if authHeader != "" {
|
|
req.Header.Set("Authorization", authHeader)
|
|
}
|
|
|
|
start := time.Now()
|
|
resp, err := c.client.Do(req)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
roundTripTime := time.Since(start)
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
defer resp.Body.Close()
|
|
body, readErr := io.ReadAll(resp.Body)
|
|
if readErr != nil {
|
|
return nil, roundTripTime, readErr
|
|
}
|
|
return nil, roundTripTime, &sparqlHTTPStatusError{
|
|
StatusCode: resp.StatusCode,
|
|
Status: resp.Status,
|
|
Body: strings.TrimSpace(string(body)),
|
|
}
|
|
}
|
|
|
|
return resp, roundTripTime, nil
|
|
}
|
|
|
|
func (c *AnzoGraphClient) shouldRefreshExpiredJWT(err *sparqlHTTPStatusError) bool {
|
|
if err == nil || !c.cfg.UsesExternalSparql() {
|
|
return false
|
|
}
|
|
return err.StatusCode == http.StatusUnauthorized && strings.Contains(err.Body, "Jwt is expired")
|
|
}
|
|
|
|
func (c *AnzoGraphClient) QueryJSON(ctx context.Context, query string, target any) (sparqlQueryMetrics, error) {
|
|
return c.queryJSONWithTimeout(ctx, query, c.cfg.SparqlTimeout, target)
|
|
}
|
|
|
|
func (c *AnzoGraphClient) queryJSONWithTimeout(ctx context.Context, query string, timeout time.Duration, target any) (sparqlQueryMetrics, error) {
|
|
resp, roundTripTime, err := c.queryRequestWithTimeout(ctx, query, timeout)
|
|
if err != nil {
|
|
return sparqlQueryMetrics{}, err
|
|
}
|
|
counter := &countingReadCloser{ReadCloser: resp.Body}
|
|
defer counter.Close()
|
|
|
|
decodeStart := time.Now()
|
|
if err := json.NewDecoder(counter).Decode(target); err != nil {
|
|
return sparqlQueryMetrics{
|
|
ResponseBytes: counter.bytesRead,
|
|
RoundTripTime: roundTripTime,
|
|
BodyDecodeTime: time.Since(decodeStart),
|
|
}, wrapSparqlJSONDecodeError(err)
|
|
}
|
|
|
|
return sparqlQueryMetrics{
|
|
ResponseBytes: counter.bytesRead,
|
|
RoundTripTime: roundTripTime,
|
|
BodyDecodeTime: time.Since(decodeStart),
|
|
}, nil
|
|
}
|
|
|
|
func (c *AnzoGraphClient) QueryTripleBindingsStream(
|
|
ctx context.Context,
|
|
query string,
|
|
visit func(binding sparqlTripleBinding) error,
|
|
) (sparqlQueryMetrics, error) {
|
|
return c.queryTripleBindingsStreamWithTimeout(ctx, query, c.cfg.SparqlTimeout, visit)
|
|
}
|
|
|
|
func (c *AnzoGraphClient) queryTripleBindingsStreamWithTimeout(
|
|
ctx context.Context,
|
|
query string,
|
|
timeout time.Duration,
|
|
visit func(binding sparqlTripleBinding) error,
|
|
) (sparqlQueryMetrics, error) {
|
|
resp, roundTripTime, err := c.queryRequestWithTimeout(ctx, query, timeout)
|
|
if err != nil {
|
|
return sparqlQueryMetrics{}, err
|
|
}
|
|
counter := &countingReadCloser{ReadCloser: resp.Body}
|
|
defer counter.Close()
|
|
|
|
decodeStart := time.Now()
|
|
bindingCount, err := decodeBindingsStream(json.NewDecoder(counter), visit)
|
|
if err != nil {
|
|
return sparqlQueryMetrics{
|
|
ResponseBytes: counter.bytesRead,
|
|
RoundTripTime: roundTripTime,
|
|
BodyDecodeTime: time.Since(decodeStart),
|
|
BindingCount: bindingCount,
|
|
}, wrapSparqlJSONDecodeError(err)
|
|
}
|
|
|
|
return sparqlQueryMetrics{
|
|
ResponseBytes: counter.bytesRead,
|
|
RoundTripTime: roundTripTime,
|
|
BodyDecodeTime: time.Since(decodeStart),
|
|
BindingCount: bindingCount,
|
|
}, nil
|
|
}
|
|
|
|
func decodeBindingsStream(dec *json.Decoder, visit func(binding sparqlTripleBinding) error) (int, error) {
|
|
tok, err := dec.Token()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if delim, ok := tok.(json.Delim); !ok || delim != '{' {
|
|
return 0, fmt.Errorf("invalid SPARQL JSON: expected top-level object")
|
|
}
|
|
|
|
foundResults := false
|
|
bindingCount := 0
|
|
for dec.More() {
|
|
keyToken, err := dec.Token()
|
|
if err != nil {
|
|
return bindingCount, err
|
|
}
|
|
|
|
key, ok := keyToken.(string)
|
|
if !ok {
|
|
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected top-level field name")
|
|
}
|
|
|
|
switch key {
|
|
case "results":
|
|
foundResults = true
|
|
n, err := decodeTripleBindingsObject(dec, visit)
|
|
bindingCount += n
|
|
if err != nil {
|
|
return bindingCount, err
|
|
}
|
|
default:
|
|
if err := discardJSONValue(dec); err != nil {
|
|
return bindingCount, err
|
|
}
|
|
}
|
|
}
|
|
|
|
tok, err = dec.Token()
|
|
if err != nil {
|
|
return bindingCount, err
|
|
}
|
|
if delim, ok := tok.(json.Delim); !ok || delim != '}' {
|
|
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected top-level object terminator")
|
|
}
|
|
|
|
if !foundResults {
|
|
return 0, fmt.Errorf("invalid SPARQL JSON: missing results field")
|
|
}
|
|
|
|
return bindingCount, nil
|
|
}
|
|
|
|
func decodeTripleBindingsObject(dec *json.Decoder, visit func(binding sparqlTripleBinding) error) (int, error) {
|
|
tok, err := dec.Token()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if delim, ok := tok.(json.Delim); !ok || delim != '{' {
|
|
return 0, fmt.Errorf("invalid SPARQL JSON: expected results object")
|
|
}
|
|
|
|
bindingCount := 0
|
|
for dec.More() {
|
|
keyToken, err := dec.Token()
|
|
if err != nil {
|
|
return bindingCount, err
|
|
}
|
|
|
|
key, ok := keyToken.(string)
|
|
if !ok {
|
|
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected results field name")
|
|
}
|
|
|
|
if key != "bindings" {
|
|
if err := discardJSONValue(dec); err != nil {
|
|
return bindingCount, err
|
|
}
|
|
continue
|
|
}
|
|
|
|
tok, err := dec.Token()
|
|
if err != nil {
|
|
return bindingCount, err
|
|
}
|
|
if delim, ok := tok.(json.Delim); !ok || delim != '[' {
|
|
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected bindings array")
|
|
}
|
|
|
|
for dec.More() {
|
|
var binding sparqlTripleBinding
|
|
if err := dec.Decode(&binding); err != nil {
|
|
return bindingCount, err
|
|
}
|
|
bindingCount++
|
|
if err := visit(binding); err != nil {
|
|
return bindingCount, err
|
|
}
|
|
}
|
|
|
|
tok, err = dec.Token()
|
|
if err != nil {
|
|
return bindingCount, err
|
|
}
|
|
if delim, ok := tok.(json.Delim); !ok || delim != ']' {
|
|
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected bindings array terminator")
|
|
}
|
|
}
|
|
|
|
tok, err = dec.Token()
|
|
if err != nil {
|
|
return bindingCount, err
|
|
}
|
|
if delim, ok := tok.(json.Delim); !ok || delim != '}' {
|
|
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected results object terminator")
|
|
}
|
|
|
|
return bindingCount, nil
|
|
}
|
|
|
|
func discardJSONValue(dec *json.Decoder) error {
|
|
var discard json.RawMessage
|
|
return dec.Decode(&discard)
|
|
}
|
|
|
|
func wrapSparqlJSONDecodeError(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
if isTruncatedJSONError(err) {
|
|
return fmt.Errorf("truncated SPARQL JSON: %w", err)
|
|
}
|
|
return fmt.Errorf("failed to decode SPARQL JSON: %w", err)
|
|
}
|
|
|
|
func isTruncatedJSONError(err error) bool {
|
|
return errors.Is(err, io.ErrUnexpectedEOF) ||
|
|
errors.Is(err, io.EOF) ||
|
|
strings.Contains(err.Error(), "unexpected end of JSON input") ||
|
|
strings.Contains(err.Error(), "unexpected EOF")
|
|
}
|