Files
visualizador_instanciados/backend_go/keycloak_token_test.go

274 lines
7.8 KiB
Go

package main
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestKeycloakTokenManagerFetchTokenParsesAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/token" {
http.NotFound(w, r)
return
}
if got := r.Header.Get("Content-Type"); got != "application/x-www-form-urlencoded" {
t.Errorf("unexpected content-type: %s", got)
w.WriteHeader(http.StatusInternalServerError)
return
}
_, _ = io.WriteString(w, `{"access_token":"fresh-token"}`)
}))
defer server.Close()
cfg := Config{
KeycloakTokenEndpoint: server.URL + "/token",
KeycloakClientID: "anzograph",
KeycloakUsername: "user",
KeycloakPassword: "pass",
KeycloakScope: "openid",
}
manager := newKeycloakTokenManager(cfg, server.Client())
token, err := manager.fetchToken(context.Background())
if err != nil {
t.Fatalf("fetchToken returned error: %v", err)
}
if token != "fresh-token" {
t.Fatalf("expected fresh-token, got %q", token)
}
}
func TestAnzoGraphClientStartupFetchesFreshToken(t *testing.T) {
var tokenCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
tokenCalls.Add(1)
_, _ = io.WriteString(w, `{"access_token":"startup-token"}`)
case "/sparql":
if got := r.Header.Get("Authorization"); got != "Bearer startup-token" {
t.Errorf("expected startup bearer token, got %q", got)
w.WriteHeader(http.StatusInternalServerError)
return
}
_, _ = io.WriteString(w, `{"head":{},"boolean":true}`)
default:
http.NotFound(w, r)
}
}))
defer server.Close()
cfg := Config{
SparqlSourceMode: "external",
ExternalSparqlEndpoint: server.URL + "/sparql",
KeycloakTokenEndpoint: server.URL + "/token",
KeycloakClientID: "anzograph",
KeycloakUsername: "user",
KeycloakPassword: "pass",
KeycloakScope: "openid",
SparqlReadyTimeout: 2 * time.Second,
SparqlReadyRetries: 1,
SparqlReadyDelay: 1 * time.Millisecond,
SparqlTimeout: 2 * time.Second,
}
client := NewAnzoGraphClient(cfg)
client.client = server.Client()
client.tokenManager.client = server.Client()
if err := client.Startup(context.Background()); err != nil {
t.Fatalf("Startup returned error: %v", err)
}
if tokenCalls.Load() != 1 {
t.Fatalf("expected 1 startup token request, got %d", tokenCalls.Load())
}
}
func TestQueryRetriesOnceWhenJWTExpires(t *testing.T) {
var tokenCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
call := tokenCalls.Add(1)
if call != 1 {
t.Errorf("expected exactly 1 refresh call, got %d", call)
w.WriteHeader(http.StatusInternalServerError)
return
}
_, _ = io.WriteString(w, `{"access_token":"fresh-token"}`)
case "/sparql":
switch r.Header.Get("Authorization") {
case "Bearer expired-token":
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "Jwt is expired")
case "Bearer fresh-token":
_, _ = io.WriteString(w, `{"results":{"bindings":[{"s":{"type":"uri","value":"http://example.com/s"},"p":{"type":"uri","value":"http://example.com/p"},"o":{"type":"uri","value":"http://example.com/o"}}]}}`)
default:
t.Errorf("unexpected authorization header %q", r.Header.Get("Authorization"))
w.WriteHeader(http.StatusInternalServerError)
return
}
default:
http.NotFound(w, r)
}
}))
defer server.Close()
cfg := Config{
SparqlSourceMode: "external",
ExternalSparqlEndpoint: server.URL + "/sparql",
KeycloakTokenEndpoint: server.URL + "/token",
KeycloakClientID: "anzograph",
KeycloakUsername: "user",
KeycloakPassword: "pass",
KeycloakScope: "openid",
SparqlTimeout: 2 * time.Second,
}
client := NewAnzoGraphClient(cfg)
client.client = server.Client()
client.tokenManager.client = server.Client()
client.tokenManager.token = "expired-token"
raw, err := client.Query(context.Background(), "SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
if err != nil {
t.Fatalf("Query returned error: %v", err)
}
if string(raw) == "" {
t.Fatalf("expected successful response body after refresh")
}
if tokenCalls.Load() != 1 {
t.Fatalf("expected 1 refresh call, got %d", tokenCalls.Load())
}
}
func TestQueryDoesNotRefreshForNonExpiry401(t *testing.T) {
var tokenCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
tokenCalls.Add(1)
_, _ = io.WriteString(w, `{"access_token":"fresh-token"}`)
case "/sparql":
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "RBAC: access denied")
default:
http.NotFound(w, r)
}
}))
defer server.Close()
cfg := Config{
SparqlSourceMode: "external",
ExternalSparqlEndpoint: server.URL + "/sparql",
KeycloakTokenEndpoint: server.URL + "/token",
KeycloakClientID: "anzograph",
KeycloakUsername: "user",
KeycloakPassword: "pass",
KeycloakScope: "openid",
SparqlTimeout: 2 * time.Second,
}
client := NewAnzoGraphClient(cfg)
client.client = server.Client()
client.tokenManager.client = server.Client()
client.tokenManager.token = "still-bad-token"
_, err := client.Query(context.Background(), "SELECT ?s ?p ?o WHERE { ?s ?p ?o }")
if err == nil {
t.Fatalf("expected non-expiry 401 to fail")
}
if tokenCalls.Load() != 0 {
t.Fatalf("expected no token refresh for non-expiry 401, got %d", tokenCalls.Load())
}
}
func TestConcurrentExpiredQueriesShareOneRefresh(t *testing.T) {
var tokenCalls atomic.Int32
var sparqlCalls atomic.Int32
var mu sync.Mutex
seenFresh := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/token":
tokenCalls.Add(1)
time.Sleep(50 * time.Millisecond)
_, _ = io.WriteString(w, `{"access_token":"fresh-token"}`)
case "/sparql":
sparqlCalls.Add(1)
switch r.Header.Get("Authorization") {
case "Bearer expired-token":
w.WriteHeader(http.StatusUnauthorized)
_, _ = io.WriteString(w, "Jwt is expired")
case "Bearer fresh-token":
mu.Lock()
seenFresh++
mu.Unlock()
_, _ = io.WriteString(w, `{"head":{},"boolean":true}`)
default:
t.Errorf("unexpected authorization header %q", r.Header.Get("Authorization"))
w.WriteHeader(http.StatusInternalServerError)
return
}
default:
http.NotFound(w, r)
}
}))
defer server.Close()
cfg := Config{
SparqlSourceMode: "external",
ExternalSparqlEndpoint: server.URL + "/sparql",
KeycloakTokenEndpoint: server.URL + "/token",
KeycloakClientID: "anzograph",
KeycloakUsername: "user",
KeycloakPassword: "pass",
KeycloakScope: "openid",
SparqlTimeout: 2 * time.Second,
}
client := NewAnzoGraphClient(cfg)
client.client = server.Client()
client.tokenManager.client = server.Client()
client.tokenManager.token = "expired-token"
const workers = 5
var wg sync.WaitGroup
errs := make(chan error, workers)
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := client.Query(context.Background(), "ASK WHERE { ?s ?p ?o }")
errs <- err
}()
}
wg.Wait()
close(errs)
for err := range errs {
if err != nil {
t.Fatalf("concurrent query returned error: %v", err)
}
}
if tokenCalls.Load() != 1 {
t.Fatalf("expected exactly 1 shared refresh, got %d", tokenCalls.Load())
}
if seenFresh != workers {
t.Fatalf("expected %d successful retried queries, got %d", workers, seenFresh)
}
if sparqlCalls.Load() < workers*2 {
t.Fatalf("expected each worker to hit sparql before and after refresh, got %d calls", sparqlCalls.Load())
}
}