ai-servers/llm-gateway/internal/proxy/concurrency_test.go
Ray Andrew 90adf6f3a8
feat(gateway): add circuit breaker, retry, and concurrency limit support
feat(gateway): add debug logging with file storage and retention

feat(gateway): add audit logging for user actions

feat(gateway): add request ID tracking and rate limit headers

feat(gateway): add model aliases and load balancing strategies

feat(gateway): add config hot-reload via SIGHUP

feat(gateway): add CORS support

feat(gateway): add data export API and dashboard endpoints

feat(gateway): add dashboard pages for audit and debug logs

feat(gateway): add concurrent request limiting per token

feat(gateway): add streaming timeout support

feat(gateway): add migration support for new schema fields
2026-02-15 04:21:40 -06:00

317 lines
7.6 KiB
Go

package proxy
import (
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"llm-gateway/internal/auth"
)
func TestConcurrencyLimiter_AllowsWithinLimit(t *testing.T) {
tests := []struct {
name string
maxConcurrent int
numRequests int
wantAllowed int
}{
{
name: "single request within limit",
maxConcurrent: 5,
numRequests: 1,
wantAllowed: 1,
},
{
name: "all requests within limit",
maxConcurrent: 5,
numRequests: 5,
wantAllowed: 5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cl := NewConcurrencyLimiter()
token := &auth.APIToken{
Name: "conc-token",
MaxConcurrent: tt.maxConcurrent,
}
var allowed atomic.Int64
var wg sync.WaitGroup
// Use a channel to hold all goroutines inside the handler simultaneously.
gate := make(chan struct{})
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
allowed.Add(1)
<-gate // Block until released.
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < tt.numRequests; i++ {
wg.Add(1)
go func() {
defer wg.Done()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
}()
}
// Wait for goroutines to enter the handler.
time.Sleep(50 * time.Millisecond)
close(gate)
wg.Wait()
if int(allowed.Load()) != tt.wantAllowed {
t.Errorf("allowed = %d, want %d", allowed.Load(), tt.wantAllowed)
}
})
}
}
func TestConcurrencyLimiter_DeniesOverLimit(t *testing.T) {
tests := []struct {
name string
maxConcurrent int
numRequests int
wantDenied int
}{
{
name: "one over limit",
maxConcurrent: 2,
numRequests: 3,
wantDenied: 1,
},
{
name: "many over limit",
maxConcurrent: 1,
numRequests: 5,
wantDenied: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cl := NewConcurrencyLimiter()
token := &auth.APIToken{
Name: "conc-token",
MaxConcurrent: tt.maxConcurrent,
}
var denied atomic.Int64
var wg sync.WaitGroup
gate := make(chan struct{})
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-gate
w.WriteHeader(http.StatusOK)
}))
results := make([]int, tt.numRequests)
for i := 0; i < tt.numRequests; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
results[idx] = rec.Code
if rec.Code == http.StatusTooManyRequests {
denied.Add(1)
}
}(i)
}
// Wait for goroutines to reach the handler or be rejected.
time.Sleep(50 * time.Millisecond)
close(gate)
wg.Wait()
if int(denied.Load()) != tt.wantDenied {
t.Errorf("denied = %d, want %d", denied.Load(), tt.wantDenied)
}
})
}
}
func TestConcurrencyLimiter_CounterDecrementsAfterCompletion(t *testing.T) {
cl := NewConcurrencyLimiter()
token := &auth.APIToken{
Name: "decrement-token",
MaxConcurrent: 1,
}
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request should succeed and complete, decrementing the counter.
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("first request: status = %d, want %d", rec.Code, http.StatusOK)
}
// Counter should have decremented. A second request should also succeed.
rec2 := httptest.NewRecorder()
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
ctx2 := withAPIToken(req2.Context(), token)
req2 = req2.WithContext(ctx2)
handler.ServeHTTP(rec2, req2)
if rec2.Code != http.StatusOK {
t.Errorf("second request after first completed: status = %d, want %d", rec2.Code, http.StatusOK)
}
// Verify the internal counter is back to 0.
counter := cl.getCounter(token.Name)
val := counter.Load()
if val != 0 {
t.Errorf("counter = %d, want 0 after all requests completed", val)
}
}
func TestConcurrencyLimiter_ZeroMaxConcurrentMeansUnlimited(t *testing.T) {
tests := []struct {
name string
maxConcurrent int
numRequests int
}{
{
name: "zero allows unlimited concurrent requests",
maxConcurrent: 0,
numRequests: 50,
},
{
name: "negative allows unlimited concurrent requests",
maxConcurrent: -1,
numRequests: 50,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cl := NewConcurrencyLimiter()
token := &auth.APIToken{
Name: "unlimited-token",
MaxConcurrent: tt.maxConcurrent,
}
var allowed atomic.Int64
var wg sync.WaitGroup
gate := make(chan struct{})
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
allowed.Add(1)
<-gate
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < tt.numRequests; i++ {
wg.Add(1)
go func() {
defer wg.Done()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
}()
}
// Give goroutines time to enter the handler.
time.Sleep(100 * time.Millisecond)
close(gate)
wg.Wait()
if int(allowed.Load()) != tt.numRequests {
t.Errorf("allowed = %d, want %d (zero/negative maxConcurrent should be unlimited)", allowed.Load(), tt.numRequests)
}
})
}
}
func TestConcurrencyLimiter_NoToken(t *testing.T) {
cl := NewConcurrencyLimiter()
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
// No API token in context.
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
}
}
func TestConcurrencyLimiter_PerTokenIsolation(t *testing.T) {
cl := NewConcurrencyLimiter()
tokenA := &auth.APIToken{
Name: "token-a",
MaxConcurrent: 1,
}
tokenB := &auth.APIToken{
Name: "token-b",
MaxConcurrent: 1,
}
gateA := make(chan struct{})
var wg sync.WaitGroup
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tok := getAPIToken(r.Context())
if tok.Name == "token-a" {
<-gateA // Block token A's request.
}
w.WriteHeader(http.StatusOK)
}))
// Start a request for token A that blocks.
wg.Add(1)
go func() {
defer wg.Done()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), tokenA)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
}()
// Give token A's goroutine time to enter handler.
time.Sleep(50 * time.Millisecond)
// Token B should not be affected by token A's in-flight request.
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), tokenB)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("token-b status = %d, want %d (should not be affected by token-a)", rec.Code, http.StatusOK)
}
close(gateA)
wg.Wait()
}