Files
visualizador_instanciados/backend_go/sparql_decode.go

342 lines
8.9 KiB
Go

package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
)
type sparqlQueryMetrics struct {
ResponseBytes int64
RoundTripTime time.Duration
BodyDecodeTime time.Duration
BindingCount int
}
type countingReadCloser struct {
io.ReadCloser
bytesRead int64
}
func (c *countingReadCloser) Read(p []byte) (int, error) {
n, err := c.ReadCloser.Read(p)
c.bytesRead += int64(n)
return n, err
}
type cancelOnCloseReadCloser struct {
io.ReadCloser
cancel context.CancelFunc
}
func (c *cancelOnCloseReadCloser) Close() error {
err := c.ReadCloser.Close()
c.cancel()
return err
}
type sparqlHTTPStatusError struct {
StatusCode int
Status string
Body string
}
func (e *sparqlHTTPStatusError) Error() string {
return fmt.Sprintf("sparql query failed: %s: %s", e.Status, e.Body)
}
func (c *AnzoGraphClient) queryRequestWithTimeout(ctx context.Context, query string, timeout time.Duration) (*http.Response, time.Duration, error) {
ctx2, cancel := context.WithTimeout(ctx, timeout)
resp, roundTripTime, err := c.queryRequest(ctx2, query, true)
if err != nil {
cancel()
return nil, roundTripTime, err
}
resp.Body = &cancelOnCloseReadCloser{ReadCloser: resp.Body, cancel: cancel}
return resp, roundTripTime, nil
}
func (c *AnzoGraphClient) queryRequest(ctx context.Context, query string, allowRefresh bool) (*http.Response, time.Duration, error) {
resp, roundTripTime, err := c.queryRequestAttempt(ctx, query)
if err == nil {
return resp, roundTripTime, nil
}
var statusErr *sparqlHTTPStatusError
if allowRefresh && errors.As(err, &statusErr) && c.shouldRefreshExpiredJWT(statusErr) {
log.Printf("[auth] sparql_token_retry endpoint=%s reason=jwt_expired", c.endpoint)
if _, refreshErr := c.refreshExternalToken(ctx, "sparql_jwt_expired"); refreshErr != nil {
return nil, roundTripTime, fmt.Errorf("%w (token refresh failed: %v)", statusErr, refreshErr)
}
retryResp, retryRoundTripTime, retryErr := c.queryRequest(ctx, query, false)
return retryResp, roundTripTime + retryRoundTripTime, retryErr
}
return nil, roundTripTime, err
}
func (c *AnzoGraphClient) queryRequestAttempt(ctx context.Context, query string) (*http.Response, time.Duration, error) {
form := url.Values{}
form.Set("query", query)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, 0, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/sparql-results+json")
authHeader, err := c.authorizationHeader(ctx, "sparql_query")
if err != nil {
return nil, 0, err
}
if authHeader != "" {
req.Header.Set("Authorization", authHeader)
}
start := time.Now()
resp, err := c.client.Do(req)
if err != nil {
return nil, 0, err
}
roundTripTime := time.Since(start)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer resp.Body.Close()
body, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, roundTripTime, readErr
}
return nil, roundTripTime, &sparqlHTTPStatusError{
StatusCode: resp.StatusCode,
Status: resp.Status,
Body: strings.TrimSpace(string(body)),
}
}
return resp, roundTripTime, nil
}
func (c *AnzoGraphClient) shouldRefreshExpiredJWT(err *sparqlHTTPStatusError) bool {
if err == nil || !c.cfg.UsesExternalSparql() {
return false
}
return err.StatusCode == http.StatusUnauthorized && strings.Contains(err.Body, "Jwt is expired")
}
func (c *AnzoGraphClient) QueryJSON(ctx context.Context, query string, target any) (sparqlQueryMetrics, error) {
return c.queryJSONWithTimeout(ctx, query, c.cfg.SparqlTimeout, target)
}
func (c *AnzoGraphClient) queryJSONWithTimeout(ctx context.Context, query string, timeout time.Duration, target any) (sparqlQueryMetrics, error) {
resp, roundTripTime, err := c.queryRequestWithTimeout(ctx, query, timeout)
if err != nil {
return sparqlQueryMetrics{}, err
}
counter := &countingReadCloser{ReadCloser: resp.Body}
defer counter.Close()
decodeStart := time.Now()
if err := json.NewDecoder(counter).Decode(target); err != nil {
return sparqlQueryMetrics{
ResponseBytes: counter.bytesRead,
RoundTripTime: roundTripTime,
BodyDecodeTime: time.Since(decodeStart),
}, wrapSparqlJSONDecodeError(err)
}
return sparqlQueryMetrics{
ResponseBytes: counter.bytesRead,
RoundTripTime: roundTripTime,
BodyDecodeTime: time.Since(decodeStart),
}, nil
}
func (c *AnzoGraphClient) QueryTripleBindingsStream(
ctx context.Context,
query string,
visit func(binding sparqlTripleBinding) error,
) (sparqlQueryMetrics, error) {
return c.queryTripleBindingsStreamWithTimeout(ctx, query, c.cfg.SparqlTimeout, visit)
}
func (c *AnzoGraphClient) queryTripleBindingsStreamWithTimeout(
ctx context.Context,
query string,
timeout time.Duration,
visit func(binding sparqlTripleBinding) error,
) (sparqlQueryMetrics, error) {
resp, roundTripTime, err := c.queryRequestWithTimeout(ctx, query, timeout)
if err != nil {
return sparqlQueryMetrics{}, err
}
counter := &countingReadCloser{ReadCloser: resp.Body}
defer counter.Close()
decodeStart := time.Now()
bindingCount, err := decodeBindingsStream(json.NewDecoder(counter), visit)
if err != nil {
return sparqlQueryMetrics{
ResponseBytes: counter.bytesRead,
RoundTripTime: roundTripTime,
BodyDecodeTime: time.Since(decodeStart),
BindingCount: bindingCount,
}, wrapSparqlJSONDecodeError(err)
}
return sparqlQueryMetrics{
ResponseBytes: counter.bytesRead,
RoundTripTime: roundTripTime,
BodyDecodeTime: time.Since(decodeStart),
BindingCount: bindingCount,
}, nil
}
func decodeBindingsStream(dec *json.Decoder, visit func(binding sparqlTripleBinding) error) (int, error) {
tok, err := dec.Token()
if err != nil {
return 0, err
}
if delim, ok := tok.(json.Delim); !ok || delim != '{' {
return 0, fmt.Errorf("invalid SPARQL JSON: expected top-level object")
}
foundResults := false
bindingCount := 0
for dec.More() {
keyToken, err := dec.Token()
if err != nil {
return bindingCount, err
}
key, ok := keyToken.(string)
if !ok {
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected top-level field name")
}
switch key {
case "results":
foundResults = true
n, err := decodeTripleBindingsObject(dec, visit)
bindingCount += n
if err != nil {
return bindingCount, err
}
default:
if err := discardJSONValue(dec); err != nil {
return bindingCount, err
}
}
}
tok, err = dec.Token()
if err != nil {
return bindingCount, err
}
if delim, ok := tok.(json.Delim); !ok || delim != '}' {
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected top-level object terminator")
}
if !foundResults {
return 0, fmt.Errorf("invalid SPARQL JSON: missing results field")
}
return bindingCount, nil
}
func decodeTripleBindingsObject(dec *json.Decoder, visit func(binding sparqlTripleBinding) error) (int, error) {
tok, err := dec.Token()
if err != nil {
return 0, err
}
if delim, ok := tok.(json.Delim); !ok || delim != '{' {
return 0, fmt.Errorf("invalid SPARQL JSON: expected results object")
}
bindingCount := 0
for dec.More() {
keyToken, err := dec.Token()
if err != nil {
return bindingCount, err
}
key, ok := keyToken.(string)
if !ok {
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected results field name")
}
if key != "bindings" {
if err := discardJSONValue(dec); err != nil {
return bindingCount, err
}
continue
}
tok, err := dec.Token()
if err != nil {
return bindingCount, err
}
if delim, ok := tok.(json.Delim); !ok || delim != '[' {
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected bindings array")
}
for dec.More() {
var binding sparqlTripleBinding
if err := dec.Decode(&binding); err != nil {
return bindingCount, err
}
bindingCount++
if err := visit(binding); err != nil {
return bindingCount, err
}
}
tok, err = dec.Token()
if err != nil {
return bindingCount, err
}
if delim, ok := tok.(json.Delim); !ok || delim != ']' {
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected bindings array terminator")
}
}
tok, err = dec.Token()
if err != nil {
return bindingCount, err
}
if delim, ok := tok.(json.Delim); !ok || delim != '}' {
return bindingCount, fmt.Errorf("invalid SPARQL JSON: expected results object terminator")
}
return bindingCount, nil
}
func discardJSONValue(dec *json.Decoder) error {
var discard json.RawMessage
return dec.Decode(&discard)
}
func wrapSparqlJSONDecodeError(err error) error {
if err == nil {
return nil
}
if isTruncatedJSONError(err) {
return fmt.Errorf("truncated SPARQL JSON: %w", err)
}
return fmt.Errorf("failed to decode SPARQL JSON: %w", err)
}
func isTruncatedJSONError(err error) bool {
return errors.Is(err, io.ErrUnexpectedEOF) ||
errors.Is(err, io.EOF) ||
strings.Contains(err.Error(), "unexpected end of JSON input") ||
strings.Contains(err.Error(), "unexpected EOF")
}