274 lines
7.8 KiB
Go
274 lines
7.8 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestKeycloakTokenManagerFetchTokenParsesAccessToken(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/token" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
if got := r.Header.Get("Content-Type"); got != "application/x-www-form-urlencoded" {
|
|
t.Errorf("unexpected content-type: %s", got)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
_, _ = io.WriteString(w, `{"access_token":"fresh-token"}`)
|
|
}))
|
|
defer server.Close()
|
|
|
|
cfg := Config{
|
|
KeycloakTokenEndpoint: server.URL + "/token",
|
|
KeycloakClientID: "anzograph",
|
|
KeycloakUsername: "user",
|
|
KeycloakPassword: "pass",
|
|
KeycloakScope: "openid",
|
|
}
|
|
manager := newKeycloakTokenManager(cfg, server.Client())
|
|
|
|
token, err := manager.fetchToken(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("fetchToken returned error: %v", err)
|
|
}
|
|
if token != "fresh-token" {
|
|
t.Fatalf("expected fresh-token, got %q", token)
|
|
}
|
|
}
|
|
|
|
func TestAnzoGraphClientStartupFetchesFreshToken(t *testing.T) {
|
|
var tokenCalls atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/token":
|
|
tokenCalls.Add(1)
|
|
_, _ = io.WriteString(w, `{"access_token":"startup-token"}`)
|
|
case "/sparql":
|
|
if got := r.Header.Get("Authorization"); got != "Bearer startup-token" {
|
|
t.Errorf("expected startup bearer token, got %q", got)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
_, _ = io.WriteString(w, `{"head":{},"boolean":true}`)
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
cfg := Config{
|
|
SparqlSourceMode: "external",
|
|
ExternalSparqlEndpoint: server.URL + "/sparql",
|
|
KeycloakTokenEndpoint: server.URL + "/token",
|
|
KeycloakClientID: "anzograph",
|
|
KeycloakUsername: "user",
|
|
KeycloakPassword: "pass",
|
|
KeycloakScope: "openid",
|
|
SparqlReadyTimeout: 2 * time.Second,
|
|
SparqlReadyRetries: 1,
|
|
SparqlReadyDelay: 1 * time.Millisecond,
|
|
SparqlTimeout: 2 * time.Second,
|
|
}
|
|
|
|
client := NewAnzoGraphClient(cfg)
|
|
client.client = server.Client()
|
|
client.tokenManager.client = server.Client()
|
|
|
|
if err := client.Startup(context.Background()); err != nil {
|
|
t.Fatalf("Startup returned error: %v", err)
|
|
}
|
|
if tokenCalls.Load() != 1 {
|
|
t.Fatalf("expected 1 startup token request, got %d", tokenCalls.Load())
|
|
}
|
|
}
|
|
|
|
func TestQueryRetriesOnceWhenJWTExpires(t *testing.T) {
|
|
var tokenCalls atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/token":
|
|
call := tokenCalls.Add(1)
|
|
if call != 1 {
|
|
t.Errorf("expected exactly 1 refresh call, got %d", call)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
_, _ = io.WriteString(w, `{"access_token":"fresh-token"}`)
|
|
case "/sparql":
|
|
switch r.Header.Get("Authorization") {
|
|
case "Bearer expired-token":
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
_, _ = io.WriteString(w, "Jwt is expired")
|
|
case "Bearer fresh-token":
|
|
_, _ = io.WriteString(w, `{"results":{"bindings":[{"s":{"type":"uri","value":"http://example.com/s"},"p":{"type":"uri","value":"http://example.com/p"},"o":{"type":"uri","value":"http://example.com/o"}}]}}`)
|
|
default:
|
|
t.Errorf("unexpected authorization header %q", r.Header.Get("Authorization"))
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
cfg := Config{
|
|
SparqlSourceMode: "external",
|
|
ExternalSparqlEndpoint: server.URL + "/sparql",
|
|
KeycloakTokenEndpoint: server.URL + "/token",
|
|
KeycloakClientID: "anzograph",
|
|
KeycloakUsername: "user",
|
|
KeycloakPassword: "pass",
|
|
KeycloakScope: "openid",
|
|
SparqlTimeout: 2 * time.Second,
|
|
}
|
|
|
|
client := NewAnzoGraphClient(cfg)
|
|
client.client = server.Client()
|
|
client.tokenManager.client = server.Client()
|
|
client.tokenManager.token = "expired-token"
|
|
|
|
raw, err := client.Query(context.Background(), "SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
|
|
if err != nil {
|
|
t.Fatalf("Query returned error: %v", err)
|
|
}
|
|
if string(raw) == "" {
|
|
t.Fatalf("expected successful response body after refresh")
|
|
}
|
|
if tokenCalls.Load() != 1 {
|
|
t.Fatalf("expected 1 refresh call, got %d", tokenCalls.Load())
|
|
}
|
|
}
|
|
|
|
func TestQueryDoesNotRefreshForNonExpiry401(t *testing.T) {
|
|
var tokenCalls atomic.Int32
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/token":
|
|
tokenCalls.Add(1)
|
|
_, _ = io.WriteString(w, `{"access_token":"fresh-token"}`)
|
|
case "/sparql":
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
_, _ = io.WriteString(w, "RBAC: access denied")
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
cfg := Config{
|
|
SparqlSourceMode: "external",
|
|
ExternalSparqlEndpoint: server.URL + "/sparql",
|
|
KeycloakTokenEndpoint: server.URL + "/token",
|
|
KeycloakClientID: "anzograph",
|
|
KeycloakUsername: "user",
|
|
KeycloakPassword: "pass",
|
|
KeycloakScope: "openid",
|
|
SparqlTimeout: 2 * time.Second,
|
|
}
|
|
|
|
client := NewAnzoGraphClient(cfg)
|
|
client.client = server.Client()
|
|
client.tokenManager.client = server.Client()
|
|
client.tokenManager.token = "still-bad-token"
|
|
|
|
_, err := client.Query(context.Background(), "SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
|
|
if err == nil {
|
|
t.Fatalf("expected non-expiry 401 to fail")
|
|
}
|
|
if tokenCalls.Load() != 0 {
|
|
t.Fatalf("expected no token refresh for non-expiry 401, got %d", tokenCalls.Load())
|
|
}
|
|
}
|
|
|
|
func TestConcurrentExpiredQueriesShareOneRefresh(t *testing.T) {
|
|
var tokenCalls atomic.Int32
|
|
var sparqlCalls atomic.Int32
|
|
var mu sync.Mutex
|
|
seenFresh := 0
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
switch r.URL.Path {
|
|
case "/token":
|
|
tokenCalls.Add(1)
|
|
time.Sleep(50 * time.Millisecond)
|
|
_, _ = io.WriteString(w, `{"access_token":"fresh-token"}`)
|
|
case "/sparql":
|
|
sparqlCalls.Add(1)
|
|
switch r.Header.Get("Authorization") {
|
|
case "Bearer expired-token":
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
_, _ = io.WriteString(w, "Jwt is expired")
|
|
case "Bearer fresh-token":
|
|
mu.Lock()
|
|
seenFresh++
|
|
mu.Unlock()
|
|
_, _ = io.WriteString(w, `{"head":{},"boolean":true}`)
|
|
default:
|
|
t.Errorf("unexpected authorization header %q", r.Header.Get("Authorization"))
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
default:
|
|
http.NotFound(w, r)
|
|
}
|
|
}))
|
|
defer server.Close()
|
|
|
|
cfg := Config{
|
|
SparqlSourceMode: "external",
|
|
ExternalSparqlEndpoint: server.URL + "/sparql",
|
|
KeycloakTokenEndpoint: server.URL + "/token",
|
|
KeycloakClientID: "anzograph",
|
|
KeycloakUsername: "user",
|
|
KeycloakPassword: "pass",
|
|
KeycloakScope: "openid",
|
|
SparqlTimeout: 2 * time.Second,
|
|
}
|
|
|
|
client := NewAnzoGraphClient(cfg)
|
|
client.client = server.Client()
|
|
client.tokenManager.client = server.Client()
|
|
client.tokenManager.token = "expired-token"
|
|
|
|
const workers = 5
|
|
var wg sync.WaitGroup
|
|
errs := make(chan error, workers)
|
|
for i := 0; i < workers; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
_, err := client.Query(context.Background(), "ASK WHERE { ?s ?p ?o }")
|
|
errs <- err
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
close(errs)
|
|
|
|
for err := range errs {
|
|
if err != nil {
|
|
t.Fatalf("concurrent query returned error: %v", err)
|
|
}
|
|
}
|
|
if tokenCalls.Load() != 1 {
|
|
t.Fatalf("expected exactly 1 shared refresh, got %d", tokenCalls.Load())
|
|
}
|
|
if seenFresh != workers {
|
|
t.Fatalf("expected %d successful retried queries, got %d", workers, seenFresh)
|
|
}
|
|
if sparqlCalls.Load() < workers*2 {
|
|
t.Fatalf("expected each worker to hit sparql before and after refresh, got %d calls", sparqlCalls.Load())
|
|
}
|
|
}
|