package main import ( "context" "errors" "sync/atomic" "testing" "time" ) func TestSnapshotBuildContinuesAfterRequesterCancellation(t *testing.T) { svc := NewGraphSnapshotService(nil, Config{}) var fetchCalls atomic.Int32 started := make(chan struct{}) release := make(chan struct{}) expected := GraphResponse{ Nodes: []Node{{ID: 1}}, Edges: []Edge{{Source: 1, Target: 1, PredicateID: 0}}, Meta: &GraphMeta{GraphQueryID: "default", Nodes: 1, Edges: 1}, } svc.fetchSnapshot = func(ctx context.Context, _ *AnzoGraphClient, _ Config, nodeLimit int, edgeLimit int, graphQueryID string) (GraphResponse, error) { fetchCalls.Add(1) if nodeLimit != 10 || edgeLimit != 20 || graphQueryID != "default" { t.Fatalf("unexpected fetch args nodeLimit=%d edgeLimit=%d graphQueryID=%s", nodeLimit, edgeLimit, graphQueryID) } close(started) <-release return expected, nil } ctx1, cancel1 := context.WithCancel(context.Background()) defer cancel1() firstErrCh := make(chan error, 1) go func() { _, err := svc.Get(ctx1, 10, 20, "default") firstErrCh <- err }() <-started cancel1() select { case err := <-firstErrCh: if !errors.Is(err, context.Canceled) { t.Fatalf("first Get error = %v, want context.Canceled", err) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for first Get to return after cancellation") } secondSnapCh := make(chan GraphResponse, 1) secondErrCh := make(chan error, 1) go func() { snap, err := svc.Get(context.Background(), 10, 20, "default") if err != nil { secondErrCh <- err return } secondSnapCh <- snap }() time.Sleep(50 * time.Millisecond) if got := fetchCalls.Load(); got != 1 { t.Fatalf("fetchCalls after second waiter start = %d, want 1", got) } close(release) select { case err := <-secondErrCh: t.Fatalf("second Get error = %v", err) case snap := <-secondSnapCh: if snap.Meta == nil || snap.Meta.Nodes != expected.Meta.Nodes || snap.Meta.Edges != expected.Meta.Edges { t.Fatalf("second Get snapshot meta = %#v, want %#v", snap.Meta, expected.Meta) } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for second Get to return") } if got := fetchCalls.Load(); got != 1 { t.Fatalf("fetchCalls after background completion = %d, want 1", got) } snap, err := svc.Get(context.Background(), 10, 20, "default") if err != nil { t.Fatalf("cached Get error = %v", err) } if snap.Meta == nil || snap.Meta.Nodes != expected.Meta.Nodes || snap.Meta.Edges != expected.Meta.Edges { t.Fatalf("cached Get snapshot meta = %#v, want %#v", snap.Meta, expected.Meta) } if got := fetchCalls.Load(); got != 1 { t.Fatalf("fetchCalls after cached Get = %d, want 1", got) } }