feat(gateway): add debug logging with file storage and retention feat(gateway): add audit logging for user actions feat(gateway): add request ID tracking and rate limit headers feat(gateway): add model aliases and load balancing strategies feat(gateway): add config hot-reload via SIGHUP feat(gateway): add CORS support feat(gateway): add data export API and dashboard endpoints feat(gateway): add dashboard pages for audit and debug logs feat(gateway): add concurrent request limiting per token feat(gateway): add streaming timeout support feat(gateway): add migration support for new schema fields
374 lines
9.3 KiB
Go
374 lines
9.3 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
|
|
"llm-gateway/internal/auth"
|
|
"llm-gateway/internal/storage"
|
|
)
|
|
|
|
// newTestDB creates an in-memory SQLite database wrapped in storage.DB.
|
|
// It creates the request_logs table needed by TodaySpend.
|
|
func newTestDB(t *testing.T) *storage.DB {
|
|
t.Helper()
|
|
sqlDB, err := sql.Open("sqlite", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("opening in-memory sqlite: %v", err)
|
|
}
|
|
t.Cleanup(func() { sqlDB.Close() })
|
|
|
|
// Create the minimal table needed for TodaySpend queries.
|
|
_, err = sqlDB.Exec(`CREATE TABLE IF NOT EXISTS request_logs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
token_name TEXT,
|
|
cost_usd REAL,
|
|
timestamp INTEGER
|
|
)`)
|
|
if err != nil {
|
|
t.Fatalf("creating request_logs table: %v", err)
|
|
}
|
|
return &storage.DB{DB: sqlDB}
|
|
}
|
|
|
|
// okHandler is a simple handler that writes 200 OK.
|
|
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
func TestRateLimiter_Allow(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
rateLimitRPM int
|
|
numRequests int
|
|
wantAllowed int
|
|
wantDenied int
|
|
}{
|
|
{
|
|
name: "allows requests within limit",
|
|
rateLimitRPM: 10,
|
|
numRequests: 5,
|
|
wantAllowed: 5,
|
|
wantDenied: 0,
|
|
},
|
|
{
|
|
name: "denies requests over limit",
|
|
rateLimitRPM: 3,
|
|
numRequests: 6,
|
|
wantAllowed: 3,
|
|
wantDenied: 3,
|
|
},
|
|
{
|
|
name: "allows exactly up to limit",
|
|
rateLimitRPM: 5,
|
|
numRequests: 5,
|
|
wantAllowed: 5,
|
|
wantDenied: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := newTestDB(t)
|
|
rl := NewRateLimiter(db)
|
|
|
|
allowed := 0
|
|
denied := 0
|
|
for i := 0; i < tt.numRequests; i++ {
|
|
ok, _, _ := rl.allow("test-token", tt.rateLimitRPM)
|
|
if ok {
|
|
allowed++
|
|
} else {
|
|
denied++
|
|
}
|
|
}
|
|
|
|
if allowed != tt.wantAllowed {
|
|
t.Errorf("allowed = %d, want %d", allowed, tt.wantAllowed)
|
|
}
|
|
if denied != tt.wantDenied {
|
|
t.Errorf("denied = %d, want %d", denied, tt.wantDenied)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_TokenRefillsOverTime(t *testing.T) {
|
|
db := newTestDB(t)
|
|
rl := NewRateLimiter(db)
|
|
|
|
rpm := 60 // 1 token per second refill rate
|
|
|
|
// Exhaust all tokens.
|
|
for i := 0; i < rpm; i++ {
|
|
ok, _, _ := rl.allow("refill-token", rpm)
|
|
if !ok {
|
|
t.Fatalf("request %d should have been allowed", i)
|
|
}
|
|
}
|
|
|
|
// Next request should be denied.
|
|
ok, _, _ := rl.allow("refill-token", rpm)
|
|
if ok {
|
|
t.Fatal("request should have been denied after exhausting tokens")
|
|
}
|
|
|
|
// Manually advance the bucket's lastRefill to simulate time passing.
|
|
rl.mu.Lock()
|
|
bucket := rl.buckets["refill-token"]
|
|
bucket.lastRefill = bucket.lastRefill.Add(-2 * time.Second)
|
|
rl.mu.Unlock()
|
|
|
|
// After 2 seconds at 1 token/sec, we should have ~2 tokens refilled.
|
|
ok, remaining, _ := rl.allow("refill-token", rpm)
|
|
if !ok {
|
|
t.Fatal("request should have been allowed after token refill")
|
|
}
|
|
// We consumed 1 of the ~2 refilled tokens, so remaining should be >= 0.
|
|
if remaining < 0 {
|
|
t.Errorf("remaining = %d, want >= 0", remaining)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_AllowReturnValues(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
rateLimitRPM int
|
|
numRequests int
|
|
wantLastAllowed bool
|
|
wantLastRemaining int
|
|
}{
|
|
{
|
|
name: "remaining decrements correctly",
|
|
rateLimitRPM: 5,
|
|
numRequests: 1,
|
|
wantLastAllowed: true,
|
|
wantLastRemaining: 4,
|
|
},
|
|
{
|
|
name: "remaining is zero at limit",
|
|
rateLimitRPM: 3,
|
|
numRequests: 3,
|
|
wantLastAllowed: true,
|
|
wantLastRemaining: 0,
|
|
},
|
|
{
|
|
name: "denied returns zero remaining",
|
|
rateLimitRPM: 2,
|
|
numRequests: 3,
|
|
wantLastAllowed: false,
|
|
wantLastRemaining: 0,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := newTestDB(t)
|
|
rl := NewRateLimiter(db)
|
|
|
|
var allowed bool
|
|
var remaining int
|
|
for i := 0; i < tt.numRequests; i++ {
|
|
allowed, remaining, _ = rl.allow("test-token", tt.rateLimitRPM)
|
|
}
|
|
|
|
if allowed != tt.wantLastAllowed {
|
|
t.Errorf("allowed = %v, want %v", allowed, tt.wantLastAllowed)
|
|
}
|
|
if remaining != tt.wantLastRemaining {
|
|
t.Errorf("remaining = %d, want %d", remaining, tt.wantLastRemaining)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
rateLimitRPM int
|
|
numRequests int
|
|
wantStatusCode int
|
|
wantLimitHeader string
|
|
wantRetryAfter bool
|
|
}{
|
|
{
|
|
name: "sets rate limit headers on allowed request",
|
|
rateLimitRPM: 10,
|
|
numRequests: 1,
|
|
wantStatusCode: http.StatusOK,
|
|
wantLimitHeader: "10",
|
|
wantRetryAfter: false,
|
|
},
|
|
{
|
|
name: "sets Retry-After header on 429",
|
|
rateLimitRPM: 2,
|
|
numRequests: 3,
|
|
wantStatusCode: http.StatusTooManyRequests,
|
|
wantLimitHeader: "2",
|
|
wantRetryAfter: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
db := newTestDB(t)
|
|
rl := NewRateLimiter(db)
|
|
|
|
token := &auth.APIToken{
|
|
Name: "header-test-token",
|
|
RateLimitRPM: tt.rateLimitRPM,
|
|
}
|
|
|
|
handler := rl.Check(okHandler)
|
|
|
|
var rec *httptest.ResponseRecorder
|
|
for i := 0; i < tt.numRequests; i++ {
|
|
rec = httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
ctx := withAPIToken(req.Context(), token)
|
|
req = req.WithContext(ctx)
|
|
handler.ServeHTTP(rec, req)
|
|
}
|
|
|
|
// Check the last response.
|
|
if rec.Code != tt.wantStatusCode {
|
|
t.Errorf("status code = %d, want %d", rec.Code, tt.wantStatusCode)
|
|
}
|
|
|
|
// X-RateLimit-Limit header.
|
|
limitHeader := rec.Header().Get("X-RateLimit-Limit")
|
|
if limitHeader != tt.wantLimitHeader {
|
|
t.Errorf("X-RateLimit-Limit = %q, want %q", limitHeader, tt.wantLimitHeader)
|
|
}
|
|
|
|
// X-RateLimit-Remaining header must be present and numeric.
|
|
remainingHeader := rec.Header().Get("X-RateLimit-Remaining")
|
|
if remainingHeader == "" {
|
|
t.Error("X-RateLimit-Remaining header is missing")
|
|
} else if _, err := strconv.Atoi(remainingHeader); err != nil {
|
|
t.Errorf("X-RateLimit-Remaining = %q, not a valid integer", remainingHeader)
|
|
}
|
|
|
|
// X-RateLimit-Reset header must be present and numeric.
|
|
resetHeader := rec.Header().Get("X-RateLimit-Reset")
|
|
if resetHeader == "" {
|
|
t.Error("X-RateLimit-Reset header is missing")
|
|
} else if _, err := strconv.ParseInt(resetHeader, 10, 64); err != nil {
|
|
t.Errorf("X-RateLimit-Reset = %q, not a valid integer", resetHeader)
|
|
}
|
|
|
|
// Retry-After header.
|
|
retryAfter := rec.Header().Get("Retry-After")
|
|
if tt.wantRetryAfter && retryAfter == "" {
|
|
t.Error("Retry-After header is missing on 429 response")
|
|
}
|
|
if !tt.wantRetryAfter && retryAfter != "" {
|
|
t.Errorf("Retry-After header should not be present, got %q", retryAfter)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_CheckMiddleware_NoToken(t *testing.T) {
|
|
db := newTestDB(t)
|
|
rl := NewRateLimiter(db)
|
|
|
|
handler := rl.Check(okHandler)
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
// No API token in context.
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("status code = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_CheckMiddleware_ZeroRPM(t *testing.T) {
|
|
db := newTestDB(t)
|
|
rl := NewRateLimiter(db)
|
|
|
|
token := &auth.APIToken{
|
|
Name: "unlimited-token",
|
|
RateLimitRPM: 0, // zero means unlimited
|
|
}
|
|
|
|
handler := rl.Check(okHandler)
|
|
|
|
for i := 0; i < 100; i++ {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
ctx := withAPIToken(req.Context(), token)
|
|
req = req.WithContext(ctx)
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("request %d: status code = %d, want %d (zero RPM should be unlimited)", i, rec.Code, http.StatusOK)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_PerTokenIsolation(t *testing.T) {
|
|
db := newTestDB(t)
|
|
rl := NewRateLimiter(db)
|
|
|
|
rpm := 2
|
|
|
|
// Exhaust token A.
|
|
for i := 0; i < rpm; i++ {
|
|
rl.allow("token-a", rpm)
|
|
}
|
|
ok, _, _ := rl.allow("token-a", rpm)
|
|
if ok {
|
|
t.Fatal("token-a should be rate limited")
|
|
}
|
|
|
|
// Token B should still have its own bucket.
|
|
ok, _, _ = rl.allow("token-b", rpm)
|
|
if !ok {
|
|
t.Fatal("token-b should not be affected by token-a's rate limit")
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_ResetAtIsFuture(t *testing.T) {
|
|
db := newTestDB(t)
|
|
rl := NewRateLimiter(db)
|
|
|
|
// Consume one token so there's a deficit.
|
|
_, _, resetAt := rl.allow("reset-token", 10)
|
|
now := time.Now().Unix()
|
|
|
|
if resetAt < now {
|
|
t.Errorf("resetAt = %d, want >= %d (should be now or in the future)", resetAt, now)
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter_CheckMiddleware_ContextCancelled(t *testing.T) {
|
|
db := newTestDB(t)
|
|
rl := NewRateLimiter(db)
|
|
|
|
token := &auth.APIToken{
|
|
Name: "ctx-token",
|
|
RateLimitRPM: 10,
|
|
}
|
|
|
|
handler := rl.Check(okHandler)
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
ctx, cancel := context.WithCancel(req.Context())
|
|
ctx = withAPIToken(ctx, token)
|
|
cancel() // Cancel immediately.
|
|
req = req.WithContext(ctx)
|
|
|
|
// Should still process (rate limiter does not check context cancellation).
|
|
handler.ServeHTTP(rec, req)
|
|
// The handler itself may or may not respect cancelled context;
|
|
// the key point is no panic occurs.
|
|
}
|