package main import ( "context" "encoding/json" "errors" "fmt" "io" "log" "net/http" "net/url" "strings" "time" ) type sparqlQueryMetrics struct { ResponseBytes int64 RoundTripTime time.Duration BodyDecodeTime time.Duration BindingCount int } type countingReadCloser struct { io.ReadCloser bytesRead int64 } func (c *countingReadCloser) Read(p []byte) (int, error) { n, err := c.ReadCloser.Read(p) c.bytesRead += int64(n) return n, err } type cancelOnCloseReadCloser struct { io.ReadCloser cancel context.CancelFunc } func (c *cancelOnCloseReadCloser) Close() error { err := c.ReadCloser.Close() c.cancel() return err } type sparqlHTTPStatusError struct { StatusCode int Status string Body string } func (e *sparqlHTTPStatusError) Error() string { return fmt.Sprintf("sparql query failed: %s: %s", e.Status, e.Body) } func (c *AnzoGraphClient) queryRequestWithTimeout(ctx context.Context, query string, timeout time.Duration) (*http.Response, time.Duration, error) { ctx2, cancel := context.WithTimeout(ctx, timeout) resp, roundTripTime, err := c.queryRequest(ctx2, query, true) if err != nil { cancel() return nil, roundTripTime, err } resp.Body = &cancelOnCloseReadCloser{ReadCloser: resp.Body, cancel: cancel} return resp, roundTripTime, nil } func (c *AnzoGraphClient) queryRequest(ctx context.Context, query string, allowRefresh bool) (*http.Response, time.Duration, error) { resp, roundTripTime, err := c.queryRequestAttempt(ctx, query) if err == nil { return resp, roundTripTime, nil } var statusErr *sparqlHTTPStatusError if allowRefresh && errors.As(err, &statusErr) && c.shouldRefreshExpiredJWT(statusErr) { log.Printf("[auth] sparql_token_retry endpoint=%s reason=jwt_expired", c.endpoint) if _, refreshErr := c.refreshExternalToken(ctx, "sparql_jwt_expired"); refreshErr != nil { return nil, roundTripTime, fmt.Errorf("%w (token refresh failed: %v)", statusErr, refreshErr) } retryResp, retryRoundTripTime, retryErr := c.queryRequest(ctx, query, false) return retryResp, roundTripTime + retryRoundTripTime, retryErr } return nil, roundTripTime, err } func (c *AnzoGraphClient) queryRequestAttempt(ctx context.Context, query string) (*http.Response, time.Duration, error) { form := url.Values{} form.Set("query", query) req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(form.Encode())) if err != nil { return nil, 0, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/sparql-results+json") authHeader, err := c.authorizationHeader(ctx, "sparql_query") if err != nil { return nil, 0, err } if authHeader != "" { req.Header.Set("Authorization", authHeader) } start := time.Now() resp, err := c.client.Do(req) if err != nil { return nil, 0, err } roundTripTime := time.Since(start) if resp.StatusCode < 200 || resp.StatusCode >= 300 { defer resp.Body.Close() body, readErr := io.ReadAll(resp.Body) if readErr != nil { return nil, roundTripTime, readErr } return nil, roundTripTime, &sparqlHTTPStatusError{ StatusCode: resp.StatusCode, Status: resp.Status, Body: strings.TrimSpace(string(body)), } } return resp, roundTripTime, nil } func (c *AnzoGraphClient) shouldRefreshExpiredJWT(err *sparqlHTTPStatusError) bool { if err == nil || !c.cfg.UsesExternalSparql() { return false } return err.StatusCode == http.StatusUnauthorized && strings.Contains(err.Body, "Jwt is expired") } func (c *AnzoGraphClient) QueryJSON(ctx context.Context, query string, target any) (sparqlQueryMetrics, error) { return c.queryJSONWithTimeout(ctx, query, c.cfg.SparqlTimeout, target) } func (c *AnzoGraphClient) queryJSONWithTimeout(ctx context.Context, query string, timeout time.Duration, target any) (sparqlQueryMetrics, error) { resp, roundTripTime, err := c.queryRequestWithTimeout(ctx, query, timeout) if err != nil { return sparqlQueryMetrics{}, err } counter := &countingReadCloser{ReadCloser: resp.Body} defer counter.Close() decodeStart := time.Now() if err := json.NewDecoder(counter).Decode(target); err != nil { return sparqlQueryMetrics{ ResponseBytes: counter.bytesRead, RoundTripTime: roundTripTime, BodyDecodeTime: time.Since(decodeStart), }, wrapSparqlJSONDecodeError(err) } return sparqlQueryMetrics{ ResponseBytes: counter.bytesRead, RoundTripTime: roundTripTime, BodyDecodeTime: time.Since(decodeStart), }, nil } func (c *AnzoGraphClient) QueryTripleBindingsStream( ctx context.Context, query string, visit func(binding sparqlTripleBinding) error, ) (sparqlQueryMetrics, error) { return c.queryTripleBindingsStreamWithTimeout(ctx, query, c.cfg.SparqlTimeout, visit) } func (c *AnzoGraphClient) queryTripleBindingsStreamWithTimeout( ctx context.Context, query string, timeout time.Duration, visit func(binding sparqlTripleBinding) error, ) (sparqlQueryMetrics, error) { resp, roundTripTime, err := c.queryRequestWithTimeout(ctx, query, timeout) if err != nil { return sparqlQueryMetrics{}, err } counter := &countingReadCloser{ReadCloser: resp.Body} defer counter.Close() decodeStart := time.Now() bindingCount, err := decodeBindingsStream(json.NewDecoder(counter), visit) if err != nil { return sparqlQueryMetrics{ ResponseBytes: counter.bytesRead, RoundTripTime: roundTripTime, BodyDecodeTime: time.Since(decodeStart), BindingCount: bindingCount, }, wrapSparqlJSONDecodeError(err) } return sparqlQueryMetrics{ ResponseBytes: counter.bytesRead, RoundTripTime: roundTripTime, BodyDecodeTime: time.Since(decodeStart), BindingCount: bindingCount, }, nil } func decodeBindingsStream(dec *json.Decoder, visit func(binding sparqlTripleBinding) error) (int, error) { tok, err := dec.Token() if err != nil { return 0, err } if delim, ok := tok.(json.Delim); !ok || delim != '{' { return 0, fmt.Errorf("invalid SPARQL JSON: expected top-level object") } foundResults := false bindingCount := 0 for dec.More() { keyToken, err := dec.Token() if err != nil { return bindingCount, err } key, ok := keyToken.(string) if !ok { return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected top-level field name") } switch key { case "results": foundResults = true n, err := decodeTripleBindingsObject(dec, visit) bindingCount += n if err != nil { return bindingCount, err } default: if err := discardJSONValue(dec); err != nil { return bindingCount, err } } } tok, err = dec.Token() if err != nil { return bindingCount, err } if delim, ok := tok.(json.Delim); !ok || delim != '}' { return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected top-level object terminator") } if !foundResults { return 0, fmt.Errorf("invalid SPARQL JSON: missing results field") } return bindingCount, nil } func decodeTripleBindingsObject(dec *json.Decoder, visit func(binding sparqlTripleBinding) error) (int, error) { tok, err := dec.Token() if err != nil { return 0, err } if delim, ok := tok.(json.Delim); !ok || delim != '{' { return 0, fmt.Errorf("invalid SPARQL JSON: expected results object") } bindingCount := 0 for dec.More() { keyToken, err := dec.Token() if err != nil { return bindingCount, err } key, ok := keyToken.(string) if !ok { return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected results field name") } if key != "bindings" { if err := discardJSONValue(dec); err != nil { return bindingCount, err } continue } tok, err := dec.Token() if err != nil { return bindingCount, err } if delim, ok := tok.(json.Delim); !ok || delim != '[' { return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected bindings array") } for dec.More() { var binding sparqlTripleBinding if err := dec.Decode(&binding); err != nil { return bindingCount, err } bindingCount++ if err := visit(binding); err != nil { return bindingCount, err } } tok, err = dec.Token() if err != nil { return bindingCount, err } if delim, ok := tok.(json.Delim); !ok || delim != ']' { return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected bindings array terminator") } } tok, err = dec.Token() if err != nil { return bindingCount, err } if delim, ok := tok.(json.Delim); !ok || delim != '}' { return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected results object terminator") } return bindingCount, nil } func discardJSONValue(dec *json.Decoder) error { var discard json.RawMessage return dec.Decode(&discard) } func wrapSparqlJSONDecodeError(err error) error { if err == nil { return nil } if isTruncatedJSONError(err) { return fmt.Errorf("truncated SPARQL JSON: %w", err) } return fmt.Errorf("failed to decode SPARQL JSON: %w", err) } func isTruncatedJSONError(err error) bool { return errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) || strings.Contains(err.Error(), "unexpected end of JSON input") || strings.Contains(err.Error(), "unexpected EOF") }