package main import ( "context" "io" "net/http" "net/http/httptest" "sync" "sync/atomic" "testing" "time" ) func TestKeycloakTokenManagerFetchTokenParsesAccessToken(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/token" { http.NotFound(w, r) return } if got := r.Header.Get("Content-Type"); got != "application/x-www-form-urlencoded" { t.Errorf("unexpected content-type: %s", got) w.WriteHeader(http.StatusInternalServerError) return } _, _ = io.WriteString(w, `{"access_token":"fresh-token"}`) })) defer server.Close() cfg := Config{ KeycloakTokenEndpoint: server.URL + "/token", KeycloakClientID: "anzograph", KeycloakUsername: "user", KeycloakPassword: "pass", KeycloakScope: "openid", } manager := newKeycloakTokenManager(cfg, server.Client()) token, err := manager.fetchToken(context.Background()) if err != nil { t.Fatalf("fetchToken returned error: %v", err) } if token != "fresh-token" { t.Fatalf("expected fresh-token, got %q", token) } } func TestAnzoGraphClientStartupFetchesFreshToken(t *testing.T) { var tokenCalls atomic.Int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/token": tokenCalls.Add(1) _, _ = io.WriteString(w, `{"access_token":"startup-token"}`) case "/sparql": if got := r.Header.Get("Authorization"); got != "Bearer startup-token" { t.Errorf("expected startup bearer token, got %q", got) w.WriteHeader(http.StatusInternalServerError) return } _, _ = io.WriteString(w, `{"head":{},"boolean":true}`) default: http.NotFound(w, r) } })) defer server.Close() cfg := Config{ SparqlSourceMode: "external", ExternalSparqlEndpoint: server.URL + "/sparql", KeycloakTokenEndpoint: server.URL + "/token", KeycloakClientID: "anzograph", KeycloakUsername: "user", KeycloakPassword: "pass", KeycloakScope: "openid", SparqlReadyTimeout: 2 * time.Second, SparqlReadyRetries: 1, SparqlReadyDelay: 1 * time.Millisecond, SparqlTimeout: 2 * time.Second, } client := NewAnzoGraphClient(cfg) client.client = server.Client() client.tokenManager.client = server.Client() if err := client.Startup(context.Background()); err != nil { t.Fatalf("Startup returned error: %v", err) } if tokenCalls.Load() != 1 { t.Fatalf("expected 1 startup token request, got %d", tokenCalls.Load()) } } func TestQueryRetriesOnceWhenJWTExpires(t *testing.T) { var tokenCalls atomic.Int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/token": call := tokenCalls.Add(1) if call != 1 { t.Errorf("expected exactly 1 refresh call, got %d", call) w.WriteHeader(http.StatusInternalServerError) return } _, _ = io.WriteString(w, `{"access_token":"fresh-token"}`) case "/sparql": switch r.Header.Get("Authorization") { case "Bearer expired-token": w.WriteHeader(http.StatusUnauthorized) _, _ = io.WriteString(w, "Jwt is expired") case "Bearer fresh-token": _, _ = io.WriteString(w, `{"results":{"bindings":[{"s":{"type":"uri","value":"http://example.com/s"},"p":{"type":"uri","value":"http://example.com/p"},"o":{"type":"uri","value":"http://example.com/o"}}]}}`) default: t.Errorf("unexpected authorization header %q", r.Header.Get("Authorization")) w.WriteHeader(http.StatusInternalServerError) return } default: http.NotFound(w, r) } })) defer server.Close() cfg := Config{ SparqlSourceMode: "external", ExternalSparqlEndpoint: server.URL + "/sparql", KeycloakTokenEndpoint: server.URL + "/token", KeycloakClientID: "anzograph", KeycloakUsername: "user", KeycloakPassword: "pass", KeycloakScope: "openid", SparqlTimeout: 2 * time.Second, } client := NewAnzoGraphClient(cfg) client.client = server.Client() client.tokenManager.client = server.Client() client.tokenManager.token = "expired-token" raw, err := client.Query(context.Background(), "SELECT ?s ?p ?o WHERE { ?s ?p ?o }") if err != nil { t.Fatalf("Query returned error: %v", err) } if string(raw) == "" { t.Fatalf("expected successful response body after refresh") } if tokenCalls.Load() != 1 { t.Fatalf("expected 1 refresh call, got %d", tokenCalls.Load()) } } func TestQueryDoesNotRefreshForNonExpiry401(t *testing.T) { var tokenCalls atomic.Int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/token": tokenCalls.Add(1) _, _ = io.WriteString(w, `{"access_token":"fresh-token"}`) case "/sparql": w.WriteHeader(http.StatusUnauthorized) _, _ = io.WriteString(w, "RBAC: access denied") default: http.NotFound(w, r) } })) defer server.Close() cfg := Config{ SparqlSourceMode: "external", ExternalSparqlEndpoint: server.URL + "/sparql", KeycloakTokenEndpoint: server.URL + "/token", KeycloakClientID: "anzograph", KeycloakUsername: "user", KeycloakPassword: "pass", KeycloakScope: "openid", SparqlTimeout: 2 * time.Second, } client := NewAnzoGraphClient(cfg) client.client = server.Client() client.tokenManager.client = server.Client() client.tokenManager.token = "still-bad-token" _, err := client.Query(context.Background(), "SELECT ?s ?p ?o WHERE { ?s ?p ?o }") if err == nil { t.Fatalf("expected non-expiry 401 to fail") } if tokenCalls.Load() != 0 { t.Fatalf("expected no token refresh for non-expiry 401, got %d", tokenCalls.Load()) } } func TestConcurrentExpiredQueriesShareOneRefresh(t *testing.T) { var tokenCalls atomic.Int32 var sparqlCalls atomic.Int32 var mu sync.Mutex seenFresh := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/token": tokenCalls.Add(1) time.Sleep(50 * time.Millisecond) _, _ = io.WriteString(w, `{"access_token":"fresh-token"}`) case "/sparql": sparqlCalls.Add(1) switch r.Header.Get("Authorization") { case "Bearer expired-token": w.WriteHeader(http.StatusUnauthorized) _, _ = io.WriteString(w, "Jwt is expired") case "Bearer fresh-token": mu.Lock() seenFresh++ mu.Unlock() _, _ = io.WriteString(w, `{"head":{},"boolean":true}`) default: t.Errorf("unexpected authorization header %q", r.Header.Get("Authorization")) w.WriteHeader(http.StatusInternalServerError) return } default: http.NotFound(w, r) } })) defer server.Close() cfg := Config{ SparqlSourceMode: "external", ExternalSparqlEndpoint: server.URL + "/sparql", KeycloakTokenEndpoint: server.URL + "/token", KeycloakClientID: "anzograph", KeycloakUsername: "user", KeycloakPassword: "pass", KeycloakScope: "openid", SparqlTimeout: 2 * time.Second, } client := NewAnzoGraphClient(cfg) client.client = server.Client() client.tokenManager.client = server.Client() client.tokenManager.token = "expired-token" const workers = 5 var wg sync.WaitGroup errs := make(chan error, workers) for i := 0; i < workers; i++ { wg.Add(1) go func() { defer wg.Done() _, err := client.Query(context.Background(), "ASK WHERE { ?s ?p ?o }") errs <- err }() } wg.Wait() close(errs) for err := range errs { if err != nil { t.Fatalf("concurrent query returned error: %v", err) } } if tokenCalls.Load() != 1 { t.Fatalf("expected exactly 1 shared refresh, got %d", tokenCalls.Load()) } if seenFresh != workers { t.Fatalf("expected %d successful retried queries, got %d", workers, seenFresh) } if sparqlCalls.Load() < workers*2 { t.Fatalf("expected each worker to hit sparql before and after refresh, got %d calls", sparqlCalls.Load()) } }