From 90adf6f3a84f31d35953ca3da41f2b4e47471491 Mon Sep 17 00:00:00 2001 From: Ray Andrew Date: Sun, 15 Feb 2026 04:21:40 -0600 Subject: [PATCH] feat(gateway): add circuit breaker, retry, and concurrency limit support feat(gateway): add debug logging with file storage and retention feat(gateway): add audit logging for user actions feat(gateway): add request ID tracking and rate limit headers feat(gateway): add model aliases and load balancing strategies feat(gateway): add config hot-reload via SIGHUP feat(gateway): add CORS support feat(gateway): add data export API and dashboard endpoints feat(gateway): add dashboard pages for audit and debug logs feat(gateway): add concurrent request limiting per token feat(gateway): add streaming timeout support feat(gateway): add migration support for new schema fields --- llm-gateway/.gitignore | 3 + llm-gateway/cmd/gateway/main.go | 99 ++- llm-gateway/go.mod | 1 + llm-gateway/go.sum | 2 + llm-gateway/internal/auth/handlers.go | 47 ++ llm-gateway/internal/auth/store.go | 26 +- llm-gateway/internal/auth/store_test.go | 300 +++++++ llm-gateway/internal/cache/cache_test.go | 112 +++ llm-gateway/internal/config/config.go | 107 ++- llm-gateway/internal/config/config_test.go | 738 ++++++++++++++++++ llm-gateway/internal/config/watcher.go | 27 + llm-gateway/internal/dashboard/api.go | 92 ++- llm-gateway/internal/dashboard/export.go | 297 +++++++ llm-gateway/internal/dashboard/handler.go | 86 +- .../internal/dashboard/templates/layout.html | 20 + .../dashboard/templates/partials/audit.html | 83 ++ .../templates/partials/dashboard.html | 8 +- .../dashboard/templates/partials/debug.html | 100 +++ .../dashboard/templates/partials/logs.html | 13 + llm-gateway/internal/metrics/prometheus.go | 20 + llm-gateway/internal/provider/balancer.go | 144 ++++ .../internal/provider/balancer_test.go | 294 +++++++ llm-gateway/internal/provider/health.go | 161 +++- llm-gateway/internal/provider/health_test.go | 345 ++++++++ llm-gateway/internal/provider/openai.go | 6 + llm-gateway/internal/provider/registry.go | 103 ++- .../internal/provider/registry_test.go | 282 +++++++ llm-gateway/internal/proxy/concurrency.go | 51 ++ .../internal/proxy/concurrency_test.go | 317 ++++++++ llm-gateway/internal/proxy/handler.go | 135 +++- llm-gateway/internal/proxy/ratelimit.go | 40 +- llm-gateway/internal/proxy/ratelimit_test.go | 374 +++++++++ llm-gateway/internal/proxy/stream.go | 91 ++- llm-gateway/internal/storage/audit.go | 102 +++ llm-gateway/internal/storage/debuglog.go | 250 ++++++ llm-gateway/internal/storage/logger.go | 7 +- .../migrations/004_token_concurrency.down.sql | 4 + .../migrations/004_token_concurrency.up.sql | 1 + .../migrations/005_request_id.down.sql | 1 + .../storage/migrations/005_request_id.up.sql | 2 + .../storage/migrations/006_audit_log.down.sql | 1 + .../storage/migrations/006_audit_log.up.sql | 14 + .../storage/migrations/007_debug_log.down.sql | 1 + .../storage/migrations/007_debug_log.up.sql | 14 + .../migrations/008_debug_log_files.down.sql | 1 + .../migrations/008_debug_log_files.up.sql | 1 + 46 files changed, 4814 insertions(+), 109 deletions(-) create mode 100644 llm-gateway/internal/auth/store_test.go create mode 100644 llm-gateway/internal/cache/cache_test.go create mode 100644 llm-gateway/internal/config/config_test.go create mode 100644 llm-gateway/internal/config/watcher.go create mode 100644 llm-gateway/internal/dashboard/export.go create mode 100644 llm-gateway/internal/dashboard/templates/partials/audit.html create mode 100644 llm-gateway/internal/dashboard/templates/partials/debug.html create mode 100644 llm-gateway/internal/provider/balancer.go create mode 100644 llm-gateway/internal/provider/balancer_test.go create mode 100644 llm-gateway/internal/provider/health_test.go create mode 100644 llm-gateway/internal/provider/registry_test.go create mode 100644 llm-gateway/internal/proxy/concurrency.go create mode 100644 llm-gateway/internal/proxy/concurrency_test.go create mode 100644 llm-gateway/internal/proxy/ratelimit_test.go create mode 100644 llm-gateway/internal/storage/audit.go create mode 100644 llm-gateway/internal/storage/debuglog.go create mode 100644 llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql create mode 100644 llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql create mode 100644 llm-gateway/internal/storage/migrations/005_request_id.down.sql create mode 100644 llm-gateway/internal/storage/migrations/005_request_id.up.sql create mode 100644 llm-gateway/internal/storage/migrations/006_audit_log.down.sql create mode 100644 llm-gateway/internal/storage/migrations/006_audit_log.up.sql create mode 100644 llm-gateway/internal/storage/migrations/007_debug_log.down.sql create mode 100644 llm-gateway/internal/storage/migrations/007_debug_log.up.sql create mode 100644 llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql create mode 100644 llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql diff --git a/llm-gateway/.gitignore b/llm-gateway/.gitignore index 8bf05e4..759cb20 100644 --- a/llm-gateway/.gitignore +++ b/llm-gateway/.gitignore @@ -8,6 +8,9 @@ llm-gateway *.db-wal *.db-shm +# Debug log files +debug-logs/ + # Local config configs/config.local.yaml diff --git a/llm-gateway/cmd/gateway/main.go b/llm-gateway/cmd/gateway/main.go index b819d9c..8bd10b2 100644 --- a/llm-gateway/cmd/gateway/main.go +++ b/llm-gateway/cmd/gateway/main.go @@ -7,11 +7,13 @@ import ( "net/http" "os" "os/signal" + "path/filepath" "syscall" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + gocors "github.com/go-chi/cors" "github.com/prometheus/client_golang/prometheus/promhttp" "llm-gateway/internal/auth" @@ -91,7 +93,7 @@ func main() { log.Printf("Registered %d models", len(cfg.Models)) // Provider health tracker - healthTracker := provider.NewHealthTracker(5 * time.Minute) + healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker) // Auth store (static tokens checked in-memory, not seeded to DB) var staticTokens []auth.StaticToken @@ -102,6 +104,7 @@ func main() { Key: t.Key, RateLimitRPM: t.RateLimitRPM, DailyBudgetUSD: t.DailyBudgetUSD, + MaxConcurrent: t.MaxConcurrent, }) log.Printf("Loaded static token: %s", t.Name) } @@ -110,6 +113,17 @@ func main() { authMiddleware := auth.NewMiddleware(authStore) authHandlers := auth.NewHandlers(authStore, cfg.Server.SessionSecret) + // Audit logger + auditLogger := storage.NewAuditLogger(db) + authHandlers.SetAuditLogger(auditLogger) + + // Debug logger + debugDataDir := cfg.Debug.DataDir + if debugDataDir == "" { + debugDataDir = filepath.Dir(cfg.Database.Path) + } + debugLogger := storage.NewDebugLogger(db, cfg.Debug.Enabled, debugDataDir) + // Seed default admin seedDefaultAdmin(cfg, authStore) @@ -118,22 +132,43 @@ func main() { // Handlers proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker) + proxyHandler.SetDebugLogger(debugLogger) modelsHandler := proxy.NewModelsHandler(registry) proxyAuth := proxy.NewAuthMiddleware(authStore) rateLimiter := proxy.NewRateLimiter(db) + concurrencyLimiter := proxy.NewConcurrencyLimiter() statsAPI := dashboard.NewStatsAPI(db, authStore) statsAPI.SetHealthTracker(healthTracker) + statsAPI.SetAuditLogger(auditLogger) + statsAPI.SetDebugLogger(debugLogger) if c != nil { statsAPI.SetCache(c) } dash := dashboard.NewDashboard(authStore, statsAPI) dash.SetRegistry(registry) + dash.SetAuditLogger(auditLogger) + dash.SetDebugLogger(debugLogger) if c != nil { dash.SetCache(c) } + // Export handler + exportHandler := dashboard.NewExportHandler(db, authStore) + // Router r := chi.NewRouter() + + // CORS (before other middleware) + if cfg.CORS.Enabled { + r.Use(gocors.Handler(gocors.Options{ + AllowedOrigins: cfg.CORS.AllowedOrigins, + AllowedMethods: cfg.CORS.AllowedMethods, + AllowedHeaders: cfg.CORS.AllowedHeaders, + MaxAge: cfg.CORS.MaxAge, + AllowCredentials: true, + })) + } + r.Use(middleware.RealIP) r.Use(middleware.Recoverer) r.Use(middleware.RequestID) @@ -159,6 +194,7 @@ func main() { r.Group(func(r chi.Router) { r.Use(proxyAuth.Authenticate) r.Use(rateLimiter.Check) + r.Use(concurrencyLimiter.Check) r.Post("/v1/chat/completions", proxyHandler.ChatCompletions) r.Get("/v1/models", modelsHandler.ListModels) }) @@ -192,6 +228,8 @@ func main() { r.Group(func(r chi.Router) { r.Use(authMiddleware.RequireAdmin) r.Get("/users", dash.UsersPage) + r.Get("/audit", dash.AuditPage) + r.Get("/debug", dash.DebugPage) }) // Auth API @@ -224,16 +262,29 @@ func main() { r.Get("/api/stats/provider-health", statsAPI.ProviderHealthHandler) r.Get("/api/stats/cache", statsAPI.CacheStats) - // Admin-only: user management + // Data export + r.Get("/api/export/logs", exportHandler.ExportLogs) + r.Get("/api/export/stats", exportHandler.ExportStats) + + // Admin-only: user management, audit, debug r.Group(func(r chi.Router) { r.Use(authMiddleware.RequireAdmin) r.Get("/api/auth/users", authHandlers.ListUsers) r.Post("/api/auth/users", authHandlers.CreateUser) r.Delete("/api/auth/users/{id}", authHandlers.DeleteUser) + + // Audit log + r.Get("/api/stats/audit", statsAPI.AuditLogs) + + // Debug logging + r.Post("/api/debug/toggle", statsAPI.DebugToggle) + r.Get("/api/debug/status", statsAPI.DebugStatus) + r.Get("/api/debug/logs", statsAPI.DebugLogs) + r.Get("/api/debug/logs/{requestID}", statsAPI.DebugLogByRequestID) }) }) - // Periodic session cleanup + // Periodic session cleanup and debug log cleanup go func() { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() @@ -241,6 +292,9 @@ func main() { if err := authStore.CleanExpiredSessions(); err != nil { log.Printf("WARNING: session cleanup failed: %v", err) } + if err := debugLogger.Cleanup(cfg.Debug.RetentionDays); err != nil { + log.Printf("WARNING: debug log cleanup failed: %v", err) + } } }() @@ -253,6 +307,45 @@ func main() { IdleTimeout: 120 * time.Second, } + // Config hot-reload via SIGHUP + config.WatchReload(*configPath, func(newCfg *config.Config) { + // Reload registry (models, providers, routes) + if err := registry.Reload(newCfg); err != nil { + log.Printf("ERROR: registry reload failed: %v", err) + return + } + log.Printf("Reloaded %d models", len(newCfg.Models)) + + // Reload pricing + for i, m := range newCfg.Models { + for j, rt := range m.Routes { + if rt.Pricing.Input == 0 && rt.Pricing.Output == 0 { + pricingLookup.FillMissing(rt.Provider, rt.Model, + &newCfg.Models[i].Routes[j].Pricing.Input, + &newCfg.Models[i].Routes[j].Pricing.Output) + } + } + } + + // Reload static tokens + var newStaticTokens []auth.StaticToken + for _, t := range newCfg.Tokens { + if t.Key != "" { + newStaticTokens = append(newStaticTokens, auth.StaticToken{ + Name: t.Name, + Key: t.Key, + RateLimitRPM: t.RateLimitRPM, + DailyBudgetUSD: t.DailyBudgetUSD, + MaxConcurrent: t.MaxConcurrent, + }) + } + } + authStore.SetStaticTokens(newStaticTokens) + + // Update config pointer for retry/debug/etc + cfg = newCfg + }) + // Graceful shutdown done := make(chan os.Signal, 1) signal.Notify(done, os.Interrupt, syscall.SIGTERM) diff --git a/llm-gateway/go.mod b/llm-gateway/go.mod index 1d54b54..231f0a7 100644 --- a/llm-gateway/go.mod +++ b/llm-gateway/go.mod @@ -19,6 +19,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-chi/cors v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/llm-gateway/go.sum b/llm-gateway/go.sum index 28784f6..853ad13 100644 --- a/llm-gateway/go.sum +++ b/llm-gateway/go.sum @@ -18,6 +18,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= +github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= diff --git a/llm-gateway/internal/auth/handlers.go b/llm-gateway/internal/auth/handlers.go index afbc867..59ca0c7 100644 --- a/llm-gateway/internal/auth/handlers.go +++ b/llm-gateway/internal/auth/handlers.go @@ -13,12 +13,15 @@ import ( "time" "github.com/go-chi/chi/v5" + + "llm-gateway/internal/storage" ) type Handlers struct { store *Store sessionSecret string loginLimiter *loginRateLimiter + auditLogger *storage.AuditLogger } func NewHandlers(store *Store, sessionSecret string) *Handlers { @@ -29,6 +32,36 @@ func NewHandlers(store *Store, sessionSecret string) *Handlers { } } +func (h *Handlers) SetAuditLogger(al *storage.AuditLogger) { + h.auditLogger = al +} + +func (h *Handlers) audit(r *http.Request, action, targetType, targetID, details string) { + if h.auditLogger == nil { + return + } + user := UserFromContext(r.Context()) + var userID int64 + var username string + if user != nil { + userID = user.ID + username = user.Username + } + ip := r.RemoteAddr + if fwd := r.Header.Get("X-Real-IP"); fwd != "" { + ip = fwd + } + h.auditLogger.Log(storage.AuditEntry{ + UserID: userID, + Username: username, + Action: action, + TargetType: targetType, + TargetID: targetID, + Details: details, + IPAddress: ip, + }) +} + // Login brute-force protection type loginRateLimiter struct { mu sync.Mutex @@ -126,6 +159,8 @@ func (h *Handlers) Setup(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "auth.setup", "user", fmt.Sprintf("%d", user.ID), "initial setup") + h.setSessionCookie(w, sessionID) writeJSON(w, map[string]any{ "user": map[string]any{ @@ -187,6 +222,8 @@ func (h *Handlers) Login(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "auth.login", "user", fmt.Sprintf("%d", user.ID), user.Username) + h.setSessionCookie(w, sessionID) writeJSON(w, map[string]any{ "require_totp": false, @@ -256,6 +293,8 @@ func (h *Handlers) LoginTOTP(w http.ResponseWriter, r *http.Request) { } func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) { + h.audit(r, "auth.logout", "", "", "") + cookie, err := r.Cookie(sessionCookieName) if err == nil { h.store.DeleteSession(cookie.Value) @@ -347,6 +386,7 @@ func (h *Handlers) TOTPVerify(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "totp.enable", "user", fmt.Sprintf("%d", user.ID), "") writeJSON(w, map[string]string{"status": "totp_enabled"}) } @@ -362,6 +402,7 @@ func (h *Handlers) TOTPDisable(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "totp.disable", "user", fmt.Sprintf("%d", user.ID), "") writeJSON(w, map[string]string{"status": "totp_disabled"}) } @@ -420,6 +461,7 @@ func (h *Handlers) CreateUser(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "user.create", "user", fmt.Sprintf("%d", user.ID), user.Username) writeJSON(w, map[string]any{ "id": user.ID, "username": user.Username, @@ -447,6 +489,7 @@ func (h *Handlers) DeleteUser(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "user.delete", "user", idStr, "") writeJSON(w, map[string]string{"status": "deleted"}) } @@ -507,6 +550,7 @@ func (h *Handlers) CreateToken(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "token.create", "token", fmt.Sprintf("%d", token.ID), req.Name) writeJSON(w, map[string]any{ "key": plainKey, "token": token, @@ -545,6 +589,7 @@ func (h *Handlers) DeleteToken(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "token.delete", "token", idStr, "") writeJSON(w, map[string]string{"status": "deleted"}) } @@ -587,6 +632,7 @@ func (h *Handlers) ChangePassword(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "password.change", "user", fmt.Sprintf("%d", user.ID), "") writeJSON(w, map[string]string{"status": "password_updated"}) } @@ -621,6 +667,7 @@ func (h *Handlers) ChangeUsername(w http.ResponseWriter, r *http.Request) { return } + h.audit(r, "username.change", "user", fmt.Sprintf("%d", user.ID), req.NewUsername) writeJSON(w, map[string]string{"status": "username_updated"}) } diff --git a/llm-gateway/internal/auth/store.go b/llm-gateway/internal/auth/store.go index 170af0c..cfeec0c 100644 --- a/llm-gateway/internal/auth/store.go +++ b/llm-gateway/internal/auth/store.go @@ -38,6 +38,7 @@ type APIToken struct { UserID int64 `json:"user_id"` RateLimitRPM int `json:"rate_limit_rpm"` DailyBudgetUSD float64 `json:"daily_budget_usd"` + MaxConcurrent int `json:"max_concurrent"` CreatedAt int64 `json:"created_at"` LastUsedAt int64 `json:"last_used_at"` } @@ -48,6 +49,7 @@ type StaticToken struct { Key string RateLimitRPM int DailyBudgetUSD float64 + MaxConcurrent int } type Store struct { @@ -59,6 +61,11 @@ func NewStore(db *sql.DB, staticTokens []StaticToken) *Store { return &Store{db: db, staticTokens: staticTokens} } +// SetStaticTokens updates the static tokens list (used for config hot-reload). +func (s *Store) SetStaticTokens(tokens []StaticToken) { + s.staticTokens = tokens +} + func (s *Store) HasAnyUser() bool { var count int s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) @@ -287,6 +294,7 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) { KeyPrefix: prefix, RateLimitRPM: st.RateLimitRPM, DailyBudgetUSD: st.DailyBudgetUSD, + MaxConcurrent: st.MaxConcurrent, }, nil } } @@ -297,9 +305,9 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) { var t APIToken err := s.db.QueryRow( - "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE key_hash = ?", + "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?", keyHash, - ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt) + ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) if err != nil { return nil, err } @@ -320,6 +328,7 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) { KeyPrefix: prefix, RateLimitRPM: st.RateLimitRPM, DailyBudgetUSD: st.DailyBudgetUSD, + MaxConcurrent: st.MaxConcurrent, }) } @@ -327,19 +336,18 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) { var rows *sql.Rows var err error if userID == 0 { - // Admin: list all - rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens ORDER BY id") + rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id") } else { - rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID) + rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID) } if err != nil { - return tokens, nil // return static tokens even if DB query fails + return tokens, nil } defer rows.Close() for rows.Next() { var t APIToken - if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt); err != nil { + if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil { return tokens, nil } tokens = append(tokens, t) @@ -355,9 +363,9 @@ func (s *Store) DeleteAPIToken(id int64) error { func (s *Store) GetAPIToken(id int64) (*APIToken, error) { var t APIToken err := s.db.QueryRow( - "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE id = ?", + "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?", id, - ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt) + ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) if err != nil { return nil, err } diff --git a/llm-gateway/internal/auth/store_test.go b/llm-gateway/internal/auth/store_test.go new file mode 100644 index 0000000..fc4678f --- /dev/null +++ b/llm-gateway/internal/auth/store_test.go @@ -0,0 +1,300 @@ +package auth + +import ( + "database/sql" + "testing" + "time" + + _ "modernc.org/sqlite" +) + +func setupTestDB(t *testing.T) *sql.DB { + t.Helper() + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("opening test db: %v", err) + } + + // Create tables + _, err = db.Exec(` + CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + email TEXT DEFAULT '', + password_hash TEXT NOT NULL, + is_admin INTEGER DEFAULT 0, + totp_secret TEXT DEFAULT '', + totp_enabled INTEGER DEFAULT 0, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL + ); + CREATE TABLE sessions ( + id TEXT PRIMARY KEY, + user_id INTEGER NOT NULL, + created_at INTEGER NOT NULL, + expires_at INTEGER NOT NULL + ); + CREATE TABLE api_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + key_hash TEXT NOT NULL, + key_prefix TEXT NOT NULL, + user_id INTEGER NOT NULL, + rate_limit_rpm INTEGER DEFAULT 0, + daily_budget_usd REAL DEFAULT 0, + max_concurrent INTEGER DEFAULT 0, + created_at INTEGER NOT NULL, + last_used_at INTEGER DEFAULT 0 + ); + `) + if err != nil { + t.Fatalf("creating tables: %v", err) + } + return db +} + +func TestCreateUser(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + store := NewStore(db, nil) + + user, err := store.CreateUser("alice", "password123", true) + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + if user.Username != "alice" { + t.Errorf("expected username 'alice', got '%s'", user.Username) + } + if !user.IsAdmin { + t.Error("expected admin user") + } + if user.ID == 0 { + t.Error("expected non-zero ID") + } +} + +func TestGetUserByUsername(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + store := NewStore(db, nil) + + store.CreateUser("bob", "password123", false) + + user, err := store.GetUserByUsername("bob") + if err != nil { + t.Fatalf("GetUserByUsername: %v", err) + } + if user.Username != "bob" { + t.Errorf("expected 'bob', got '%s'", user.Username) + } + + _, err = store.GetUserByUsername("nonexistent") + if err == nil { + t.Error("expected error for nonexistent user") + } +} + +func TestCheckPassword(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + store := NewStore(db, nil) + + store.CreateUser("charlie", "correctpassword", false) + user, _ := store.GetUserByUsername("charlie") + + if !store.CheckPassword(user, "correctpassword") { + t.Error("correct password should match") + } + if store.CheckPassword(user, "wrongpassword") { + t.Error("wrong password should not match") + } +} + +func TestUpdatePassword(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + store := NewStore(db, nil) + + user, _ := store.CreateUser("dave", "oldpass12", false) + + if err := store.UpdatePassword(user.ID, "newpass12"); err != nil { + t.Fatalf("UpdatePassword: %v", err) + } + + user, _ = store.GetUserByUsername("dave") + if store.CheckPassword(user, "oldpass12") { + t.Error("old password should not work") + } + if !store.CheckPassword(user, "newpass12") { + t.Error("new password should work") + } +} + +func TestDeleteUser(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + store := NewStore(db, nil) + + user1, _ := store.CreateUser("admin1", "password1234", true) + user2, _ := store.CreateUser("user2", "password1234", false) + + // Can delete non-admin + if err := store.DeleteUser(user2.ID); err != nil { + t.Fatalf("DeleteUser: %v", err) + } + + // Cannot delete last admin + if err := store.DeleteUser(user1.ID); err == nil { + t.Error("should not be able to delete last admin") + } +} + +func TestHasAnyUser(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + store := NewStore(db, nil) + + if store.HasAnyUser() { + t.Error("should have no users initially") + } + + store.CreateUser("first", "password1234", true) + + if !store.HasAnyUser() { + t.Error("should have users after creation") + } +} + +func TestSessionCRUD(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + store := NewStore(db, nil) + + user, _ := store.CreateUser("sessuser", "password1234", false) + + sessionID, err := store.CreateSession(user.ID, 1*time.Hour) + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + sess, err := store.GetSession(sessionID) + if err != nil { + t.Fatalf("GetSession: %v", err) + } + if sess.UserID != user.ID { + t.Errorf("expected user ID %d, got %d", user.ID, sess.UserID) + } + + if err := store.DeleteSession(sessionID); err != nil { + t.Fatalf("DeleteSession: %v", err) + } + + _, err = store.GetSession(sessionID) + if err == nil { + t.Error("session should be deleted") + } +} + +func TestStaticTokenLookup(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + staticTokens := []StaticToken{ + {Name: "test-token", Key: "sk-test-key-12345678", RateLimitRPM: 60, DailyBudgetUSD: 10.0, MaxConcurrent: 5}, + } + store := NewStore(db, staticTokens) + + token, err := store.LookupAPIToken("sk-test-key-12345678") + if err != nil { + t.Fatalf("LookupAPIToken: %v", err) + } + if token.Name != "test-token" { + t.Errorf("expected 'test-token', got '%s'", token.Name) + } + if token.ID != -1 { + t.Errorf("static token should have ID -1, got %d", token.ID) + } + if token.RateLimitRPM != 60 { + t.Errorf("expected RPM 60, got %d", token.RateLimitRPM) + } + if token.MaxConcurrent != 5 { + t.Errorf("expected max_concurrent 5, got %d", token.MaxConcurrent) + } + + // Non-existent token + _, err = store.LookupAPIToken("nonexistent") + if err == nil { + t.Error("should error on nonexistent token") + } +} + +func TestDBTokenCRUD(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + store := NewStore(db, nil) + + user, _ := store.CreateUser("tokenuser", "password1234", false) + + plainKey, token, err := store.CreateAPIToken(user.ID, "my-token", 100, 5.0) + if err != nil { + t.Fatalf("CreateAPIToken: %v", err) + } + if plainKey == "" { + t.Error("plain key should not be empty") + } + if token.Name != "my-token" { + t.Errorf("expected 'my-token', got '%s'", token.Name) + } + + // Lookup by key + found, err := store.LookupAPIToken(plainKey) + if err != nil { + t.Fatalf("LookupAPIToken: %v", err) + } + if found.Name != "my-token" { + t.Errorf("expected 'my-token', got '%s'", found.Name) + } + + // List tokens + tokens, err := store.ListAPITokens(user.ID) + if err != nil { + t.Fatalf("ListAPITokens: %v", err) + } + if len(tokens) != 1 { + t.Errorf("expected 1 token, got %d", len(tokens)) + } + + // Delete + if err := store.DeleteAPIToken(token.ID); err != nil { + t.Fatalf("DeleteAPIToken: %v", err) + } + + _, err = store.LookupAPIToken(plainKey) + if err == nil { + t.Error("token should be deleted") + } +} + +func TestSetStaticTokens(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + store := NewStore(db, nil) + + _, err := store.LookupAPIToken("key1") + if err == nil { + t.Error("should not find token before setting") + } + + store.SetStaticTokens([]StaticToken{ + {Name: "new-token", Key: "key1"}, + }) + + token, err := store.LookupAPIToken("key1") + if err != nil { + t.Fatalf("after SetStaticTokens: %v", err) + } + if token.Name != "new-token" { + t.Errorf("expected 'new-token', got '%s'", token.Name) + } +} diff --git a/llm-gateway/internal/cache/cache_test.go b/llm-gateway/internal/cache/cache_test.go new file mode 100644 index 0000000..9e72e37 --- /dev/null +++ b/llm-gateway/internal/cache/cache_test.go @@ -0,0 +1,112 @@ +package cache + +import ( + "testing" +) + +func TestCacheKey_Deterministic(t *testing.T) { + c := &Cache{} + + model := "gpt-4" + body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`) + + key1 := c.cacheKey(model, body) + key2 := c.cacheKey(model, body) + + if key1 != key2 { + t.Errorf("cache key not deterministic: %s != %s", key1, key2) + } + + if key1 == "" { + t.Error("cache key is empty") + } +} + +func TestCacheKey_DifferentInputs(t *testing.T) { + c := &Cache{} + + body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`) + + key1 := c.cacheKey("gpt-4", body) + key2 := c.cacheKey("gpt-3.5", body) + + if key1 == key2 { + t.Error("different models should produce different cache keys") + } + + key3 := c.cacheKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"world"}]}`)) + if key1 == key3 { + t.Error("different bodies should produce different cache keys") + } +} + +func TestCacheKey_HasPrefix(t *testing.T) { + c := &Cache{} + key := c.cacheKey("gpt-4", []byte("test")) + + if len(key) < 7 || key[:7] != "llm-gw:" { + t.Errorf("cache key should start with 'llm-gw:', got: %s", key) + } +} + +func TestParseInfoInt(t *testing.T) { + info := "keyspace_hits:42\nkeyspace_misses:10\n" + + hits := parseInfoInt(info, "keyspace_hits") + if hits != 42 { + t.Errorf("expected 42, got %d", hits) + } + + misses := parseInfoInt(info, "keyspace_misses") + if misses != 10 { + t.Errorf("expected 10, got %d", misses) + } + + unknown := parseInfoInt(info, "nonexistent") + if unknown != 0 { + t.Errorf("expected 0 for unknown key, got %d", unknown) + } +} + +func TestParseInfoString(t *testing.T) { + info := "used_memory_human:1.5M\r\nother:value\r\n" + + mem := parseInfoString(info, "used_memory_human") + if mem != "1.5M" { + t.Errorf("expected '1.5M', got '%s'", mem) + } + + unknown := parseInfoString(info, "nonexistent") + if unknown != "" { + t.Errorf("expected empty for unknown key, got '%s'", unknown) + } +} + +func TestParseKeyspaceKeys(t *testing.T) { + info := "# Keyspace\ndb0:keys=123,expires=45,avg_ttl=6789\n" + + keys := parseKeyspaceKeys(info) + if keys != 123 { + t.Errorf("expected 123, got %d", keys) + } + + empty := parseKeyspaceKeys("# Keyspace\n") + if empty != 0 { + t.Errorf("expected 0 for empty keyspace, got %d", empty) + } +} + +func TestSplitLines(t *testing.T) { + lines := splitLines("a\nb\nc") + if len(lines) != 3 { + t.Errorf("expected 3 lines, got %d", len(lines)) + } + if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" { + t.Errorf("unexpected lines: %v", lines) + } + + single := splitLines("hello") + if len(single) != 1 || single[0] != "hello" { + t.Errorf("single line: %v", single) + } +} diff --git a/llm-gateway/internal/config/config.go b/llm-gateway/internal/config/config.go index 2aaec2c..3800eea 100644 --- a/llm-gateway/internal/config/config.go +++ b/llm-gateway/internal/config/config.go @@ -12,13 +12,17 @@ import ( ) type Config struct { - Server ServerConfig `yaml:"server"` - Database DatabaseConfig `yaml:"database"` - Cache CacheConfig `yaml:"cache"` - Pricing PricingLookupConfig `yaml:"pricing_lookup"` - Providers []ProviderConfig `yaml:"providers"` - Models []ModelConfig `yaml:"models"` - Tokens []TokenConfig `yaml:"tokens"` + Server ServerConfig `yaml:"server"` + Database DatabaseConfig `yaml:"database"` + Cache CacheConfig `yaml:"cache"` + Pricing PricingLookupConfig `yaml:"pricing_lookup"` + CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"` + Retry RetryConfig `yaml:"retry"` + Debug DebugConfig `yaml:"debug"` + CORS CORSConfig `yaml:"cors"` + Providers []ProviderConfig `yaml:"providers"` + Models []ModelConfig `yaml:"models"` + Tokens []TokenConfig `yaml:"tokens"` } type PricingLookupConfig struct { @@ -36,14 +40,44 @@ type TokenConfig struct { Key string `yaml:"key"` RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited + MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited } type ServerConfig struct { - Listen string `yaml:"listen"` - RequestTimeout time.Duration `yaml:"request_timeout"` - MaxRequestBodyMB int `yaml:"max_request_body_mb"` - SessionSecret string `yaml:"session_secret"` - DefaultAdmin DefaultAdminConfig `yaml:"default_admin"` + Listen string `yaml:"listen"` + RequestTimeout time.Duration `yaml:"request_timeout"` + StreamingTimeout time.Duration `yaml:"streaming_timeout"` + MaxRequestBodyMB int `yaml:"max_request_body_mb"` + SessionSecret string `yaml:"session_secret"` + DefaultAdmin DefaultAdminConfig `yaml:"default_admin"` +} + +type CircuitBreakerConfig struct { + Enabled bool `yaml:"enabled"` + ErrorThreshold float64 `yaml:"error_threshold"` + MinRequests int `yaml:"min_requests"` + CooldownDuration time.Duration `yaml:"cooldown_duration"` +} + +type RetryConfig struct { + InitialBackoff time.Duration `yaml:"initial_backoff"` + MaxBackoff time.Duration `yaml:"max_backoff"` + Multiplier float64 `yaml:"multiplier"` +} + +type DebugConfig struct { + Enabled bool `yaml:"enabled"` + MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies) + RetentionDays int `yaml:"retention_days"` + DataDir string `yaml:"data_dir"` +} + +type CORSConfig struct { + Enabled bool `yaml:"enabled"` + AllowedOrigins []string `yaml:"allowed_origins"` + AllowedMethods []string `yaml:"allowed_methods"` + AllowedHeaders []string `yaml:"allowed_headers"` + MaxAge int `yaml:"max_age"` } type DatabaseConfig struct { @@ -66,8 +100,10 @@ type ProviderConfig struct { } type ModelConfig struct { - Name string `yaml:"name"` - Routes []RouteConfig `yaml:"routes"` + Name string `yaml:"name"` + Aliases []string `yaml:"aliases"` + Routes []RouteConfig `yaml:"routes"` + LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost } type RouteConfig struct { @@ -128,6 +164,43 @@ func (c *Config) validate() error { c.Pricing.RefreshInterval = 6 * time.Hour } + // Server defaults + if c.Server.StreamingTimeout == 0 { + c.Server.StreamingTimeout = 5 * time.Minute + } + + // Circuit breaker defaults + if c.CircuitBreaker.ErrorThreshold == 0 { + c.CircuitBreaker.ErrorThreshold = 0.5 + } + if c.CircuitBreaker.MinRequests == 0 { + c.CircuitBreaker.MinRequests = 5 + } + if c.CircuitBreaker.CooldownDuration == 0 { + c.CircuitBreaker.CooldownDuration = 30 * time.Second + } + + // Retry defaults + if c.Retry.InitialBackoff == 0 { + c.Retry.InitialBackoff = 100 * time.Millisecond + } + if c.Retry.MaxBackoff == 0 { + c.Retry.MaxBackoff = 5 * time.Second + } + if c.Retry.Multiplier == 0 { + c.Retry.Multiplier = 2.0 + } + + // Debug defaults + if c.Debug.RetentionDays == 0 { + c.Debug.RetentionDays = 90 + } + + // CORS defaults + if c.CORS.MaxAge == 0 { + c.CORS.MaxAge = 300 + } + if len(c.Providers) == 0 { return fmt.Errorf("at least one provider is required") } @@ -160,6 +233,12 @@ func (c *Config) validate() error { return fmt.Errorf("duplicate model name: %s", m.Name) } modelNames[m.Name] = true + for _, alias := range m.Aliases { + if modelNames[alias] { + return fmt.Errorf("model alias %s conflicts with existing model or alias", alias) + } + modelNames[alias] = true + } if len(m.Routes) == 0 { return fmt.Errorf("model %s: at least one route is required", m.Name) } diff --git a/llm-gateway/internal/config/config_test.go b/llm-gateway/internal/config/config_test.go new file mode 100644 index 0000000..d40617a --- /dev/null +++ b/llm-gateway/internal/config/config_test.go @@ -0,0 +1,738 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +// writeConfigFile creates a temporary YAML config file and returns its path. +func writeConfigFile(t *testing.T, content string) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "config-*.yaml") + if err != nil { + t.Fatalf("creating temp file: %v", err) + } + if _, err := f.WriteString(content); err != nil { + f.Close() + t.Fatalf("writing temp file: %v", err) + } + f.Close() + return f.Name() +} + +// minimalValidConfig returns a minimal valid YAML config string. +func minimalValidConfig() string { + return ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-test-key + +models: + - name: gpt-4 + routes: + - provider: openai + model: gpt-4 +` +} + +func TestLoad_ValidConfig(t *testing.T) { + path := writeConfigFile(t, ` +server: + listen: "127.0.0.1:8080" + request_timeout: 60s + streaming_timeout: 120s + max_request_body_mb: 5 + session_secret: "test-secret-1234567890abcdef1234567890abcdef" + +database: + path: "/tmp/test.db" + retention_days: 30 + +pricing_lookup: + url: "https://pricing.example.com" + refresh_interval: 1h + +circuit_breaker: + enabled: true + error_threshold: 0.3 + min_requests: 10 + cooldown_duration: 60s + +retry: + initial_backoff: 200ms + max_backoff: 10s + multiplier: 3.0 + +debug: + enabled: true + max_body_bytes: 65536 + retention_days: 60 + +cors: + enabled: true + allowed_origins: + - "https://example.com" + allowed_methods: + - GET + - POST + allowed_headers: + - Authorization + max_age: 600 + +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-test-key + priority: 2 + timeout: 60s + - name: anthropic + base_url: https://api.anthropic.com/v1 + api_key: sk-ant-test + priority: 1 + timeout: 30s + +models: + - name: gpt-4 + aliases: + - gpt4 + routes: + - provider: openai + model: gpt-4 + pricing: + input: 30.0 + output: 60.0 + load_balancing: first + - name: claude-3 + routes: + - provider: anthropic + model: claude-3-opus-20240229 + +tokens: + - name: test-token + key: tok-abc123 + rate_limit_rpm: 100 + daily_budget_usd: 10.0 + max_concurrent: 5 +`) + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load() returned error: %v", err) + } + + // Server + if cfg.Server.Listen != "127.0.0.1:8080" { + t.Errorf("Listen = %q, want %q", cfg.Server.Listen, "127.0.0.1:8080") + } + if cfg.Server.RequestTimeout != 60*time.Second { + t.Errorf("RequestTimeout = %v, want %v", cfg.Server.RequestTimeout, 60*time.Second) + } + if cfg.Server.StreamingTimeout != 120*time.Second { + t.Errorf("StreamingTimeout = %v, want %v", cfg.Server.StreamingTimeout, 120*time.Second) + } + if cfg.Server.MaxRequestBodyMB != 5 { + t.Errorf("MaxRequestBodyMB = %d, want %d", cfg.Server.MaxRequestBodyMB, 5) + } + if cfg.Server.SessionSecret != "test-secret-1234567890abcdef1234567890abcdef" { + t.Errorf("SessionSecret = %q, want %q", cfg.Server.SessionSecret, "test-secret-1234567890abcdef1234567890abcdef") + } + + // Database + if cfg.Database.Path != "/tmp/test.db" { + t.Errorf("Database.Path = %q, want %q", cfg.Database.Path, "/tmp/test.db") + } + if cfg.Database.RetentionDays != 30 { + t.Errorf("Database.RetentionDays = %d, want %d", cfg.Database.RetentionDays, 30) + } + + // Pricing + if cfg.Pricing.URL != "https://pricing.example.com" { + t.Errorf("Pricing.URL = %q, want %q", cfg.Pricing.URL, "https://pricing.example.com") + } + if cfg.Pricing.RefreshInterval != 1*time.Hour { + t.Errorf("Pricing.RefreshInterval = %v, want %v", cfg.Pricing.RefreshInterval, 1*time.Hour) + } + + // Circuit breaker + if !cfg.CircuitBreaker.Enabled { + t.Error("CircuitBreaker.Enabled = false, want true") + } + if cfg.CircuitBreaker.ErrorThreshold != 0.3 { + t.Errorf("CircuitBreaker.ErrorThreshold = %v, want %v", cfg.CircuitBreaker.ErrorThreshold, 0.3) + } + if cfg.CircuitBreaker.MinRequests != 10 { + t.Errorf("CircuitBreaker.MinRequests = %d, want %d", cfg.CircuitBreaker.MinRequests, 10) + } + if cfg.CircuitBreaker.CooldownDuration != 60*time.Second { + t.Errorf("CircuitBreaker.CooldownDuration = %v, want %v", cfg.CircuitBreaker.CooldownDuration, 60*time.Second) + } + + // Retry + if cfg.Retry.InitialBackoff != 200*time.Millisecond { + t.Errorf("Retry.InitialBackoff = %v, want %v", cfg.Retry.InitialBackoff, 200*time.Millisecond) + } + if cfg.Retry.MaxBackoff != 10*time.Second { + t.Errorf("Retry.MaxBackoff = %v, want %v", cfg.Retry.MaxBackoff, 10*time.Second) + } + if cfg.Retry.Multiplier != 3.0 { + t.Errorf("Retry.Multiplier = %v, want %v", cfg.Retry.Multiplier, 3.0) + } + + // Debug + if !cfg.Debug.Enabled { + t.Error("Debug.Enabled = false, want true") + } + if cfg.Debug.MaxBodyBytes != 65536 { + t.Errorf("Debug.MaxBodyBytes = %d, want %d", cfg.Debug.MaxBodyBytes, 65536) + } + if cfg.Debug.RetentionDays != 60 { + t.Errorf("Debug.RetentionDays = %d, want %d", cfg.Debug.RetentionDays, 60) + } + + // CORS + if !cfg.CORS.Enabled { + t.Error("CORS.Enabled = false, want true") + } + if cfg.CORS.MaxAge != 600 { + t.Errorf("CORS.MaxAge = %d, want %d", cfg.CORS.MaxAge, 600) + } + + // Providers + if len(cfg.Providers) != 2 { + t.Fatalf("len(Providers) = %d, want 2", len(cfg.Providers)) + } + if cfg.Providers[0].Name != "openai" { + t.Errorf("Providers[0].Name = %q, want %q", cfg.Providers[0].Name, "openai") + } + if cfg.Providers[0].Timeout != 60*time.Second { + t.Errorf("Providers[0].Timeout = %v, want %v", cfg.Providers[0].Timeout, 60*time.Second) + } + + // Models + if len(cfg.Models) != 2 { + t.Fatalf("len(Models) = %d, want 2", len(cfg.Models)) + } + if cfg.Models[0].LoadBalancing != "first" { + t.Errorf("Models[0].LoadBalancing = %q, want %q", cfg.Models[0].LoadBalancing, "first") + } + if len(cfg.Models[0].Aliases) != 1 || cfg.Models[0].Aliases[0] != "gpt4" { + t.Errorf("Models[0].Aliases = %v, want [gpt4]", cfg.Models[0].Aliases) + } + if cfg.Models[0].Routes[0].Pricing.Input != 30.0 { + t.Errorf("Models[0].Routes[0].Pricing.Input = %v, want 30.0", cfg.Models[0].Routes[0].Pricing.Input) + } + + // Tokens + if len(cfg.Tokens) != 1 { + t.Fatalf("len(Tokens) = %d, want 1", len(cfg.Tokens)) + } + if cfg.Tokens[0].Name != "test-token" { + t.Errorf("Tokens[0].Name = %q, want %q", cfg.Tokens[0].Name, "test-token") + } + if cfg.Tokens[0].RateLimitRPM != 100 { + t.Errorf("Tokens[0].RateLimitRPM = %d, want 100", cfg.Tokens[0].RateLimitRPM) + } +} + +func TestValidate_Defaults(t *testing.T) { + path := writeConfigFile(t, minimalValidConfig()) + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load() returned error: %v", err) + } + + tests := []struct { + name string + got any + want any + }{ + // Server defaults + {"Server.Listen", cfg.Server.Listen, "0.0.0.0:3000"}, + {"Server.RequestTimeout", cfg.Server.RequestTimeout, 300 * time.Second}, + {"Server.StreamingTimeout", cfg.Server.StreamingTimeout, 5 * time.Minute}, + {"Server.MaxRequestBodyMB", cfg.Server.MaxRequestBodyMB, 10}, + + // Database defaults + {"Database.Path", cfg.Database.Path, "gateway.db"}, + {"Database.RetentionDays", cfg.Database.RetentionDays, 90}, + + // Pricing defaults + {"Pricing.RefreshInterval", cfg.Pricing.RefreshInterval, 6 * time.Hour}, + + // Circuit breaker defaults + {"CircuitBreaker.ErrorThreshold", cfg.CircuitBreaker.ErrorThreshold, 0.5}, + {"CircuitBreaker.MinRequests", cfg.CircuitBreaker.MinRequests, 5}, + {"CircuitBreaker.CooldownDuration", cfg.CircuitBreaker.CooldownDuration, 30 * time.Second}, + + // Retry defaults + {"Retry.InitialBackoff", cfg.Retry.InitialBackoff, 100 * time.Millisecond}, + {"Retry.MaxBackoff", cfg.Retry.MaxBackoff, 5 * time.Second}, + {"Retry.Multiplier", cfg.Retry.Multiplier, 2.0}, + + // Debug defaults + {"Debug.MaxBodyBytes", cfg.Debug.MaxBodyBytes, 0}, + {"Debug.RetentionDays", cfg.Debug.RetentionDays, 90}, + + // CORS defaults + {"CORS.MaxAge", cfg.CORS.MaxAge, 300}, + + // Provider defaults + {"Providers[0].Timeout", cfg.Providers[0].Timeout, 120 * time.Second}, + {"Providers[0].Priority", cfg.Providers[0].Priority, 1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Compare using formatted strings to handle different numeric types + gotStr := formatValue(tt.got) + wantStr := formatValue(tt.want) + if gotStr != wantStr { + t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want) + } + }) + } + + // SessionSecret should be auto-generated (non-empty, 64 hex chars) + if cfg.Server.SessionSecret == "" { + t.Error("SessionSecret should be auto-generated when empty") + } + if len(cfg.Server.SessionSecret) != 64 { + t.Errorf("SessionSecret length = %d, want 64 hex chars", len(cfg.Server.SessionSecret)) + } +} + +func formatValue(v any) string { + switch val := v.(type) { + case time.Duration: + return val.String() + case float64: + return fmt.Sprintf("%g", val) + case int: + return fmt.Sprintf("%d", val) + case string: + return val + default: + return fmt.Sprintf("%v", val) + } +} + +func TestLoad_FileNotFound(t *testing.T) { + _, err := Load(filepath.Join(t.TempDir(), "nonexistent.yaml")) + if err == nil { + t.Fatal("Load() should return error for nonexistent file") + } +} + +func TestLoad_InvalidYAML(t *testing.T) { + path := writeConfigFile(t, `{{{invalid yaml`) + _, err := Load(path) + if err == nil { + t.Fatal("Load() should return error for invalid YAML") + } +} + +func TestValidate_DuplicateProviderNames(t *testing.T) { + path := writeConfigFile(t, ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key1 + - name: openai + base_url: https://api.openai.com/v2 + api_key: sk-key2 + +models: + - name: gpt-4 + routes: + - provider: openai + model: gpt-4 +`) + + _, err := Load(path) + if err == nil { + t.Fatal("Load() should return error for duplicate provider names") + } + wantSubstr := "duplicate provider name: openai" + if !strings.Contains(err.Error(), wantSubstr) { + t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr) + } +} + +func TestValidate_DuplicateModelNames(t *testing.T) { + path := writeConfigFile(t, ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key1 + +models: + - name: gpt-4 + routes: + - provider: openai + model: gpt-4 + - name: gpt-4 + routes: + - provider: openai + model: gpt-4-turbo +`) + + _, err := Load(path) + if err == nil { + t.Fatal("Load() should return error for duplicate model names") + } + wantSubstr := "duplicate model name: gpt-4" + if !strings.Contains(err.Error(), wantSubstr) { + t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr) + } +} + +func TestValidate_AliasConflicts(t *testing.T) { + tests := []struct { + name string + config string + wantErr string + }{ + { + name: "alias conflicts with model name", + config: ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key1 + +models: + - name: gpt-4 + routes: + - provider: openai + model: gpt-4 + - name: claude-3 + aliases: + - gpt-4 + routes: + - provider: openai + model: claude-3 +`, + wantErr: "model alias gpt-4 conflicts with existing model or alias", + }, + { + name: "alias conflicts with another alias", + config: ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key1 + +models: + - name: gpt-4 + aliases: + - fast-model + routes: + - provider: openai + model: gpt-4 + - name: claude-3 + aliases: + - fast-model + routes: + - provider: openai + model: claude-3 +`, + wantErr: "model alias fast-model conflicts with existing model or alias", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := writeConfigFile(t, tt.config) + _, err := Load(path) + if err == nil { + t.Fatal("Load() should return error for alias conflicts") + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr) + } + }) + } +} + +func TestValidate_MissingRequiredFields(t *testing.T) { + tests := []struct { + name string + config string + wantErr string + }{ + { + name: "no providers", + config: `models: [{name: test, routes: [{provider: x, model: y}]}]`, + wantErr: "at least one provider is required", + }, + { + name: "no models", + config: ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key +`, + wantErr: "at least one model is required", + }, + { + name: "provider missing name", + config: ` +providers: + - base_url: https://api.openai.com/v1 + api_key: sk-key + +models: + - name: gpt-4 + routes: + - provider: openai + model: gpt-4 +`, + wantErr: "provider 0: name, base_url, and api_key are required", + }, + { + name: "provider missing base_url", + config: ` +providers: + - name: openai + api_key: sk-key + +models: + - name: gpt-4 + routes: + - provider: openai + model: gpt-4 +`, + wantErr: "provider 0: name, base_url, and api_key are required", + }, + { + name: "provider missing api_key", + config: ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + +models: + - name: gpt-4 + routes: + - provider: openai + model: gpt-4 +`, + wantErr: "provider 0: name, base_url, and api_key are required", + }, + { + name: "model missing name", + config: ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key + +models: + - routes: + - provider: openai + model: gpt-4 +`, + wantErr: "model 0: name is required", + }, + { + name: "model missing routes", + config: ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key + +models: + - name: gpt-4 +`, + wantErr: "model gpt-4: at least one route is required", + }, + { + name: "route missing provider", + config: ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key + +models: + - name: gpt-4 + routes: + - model: gpt-4 +`, + wantErr: "model gpt-4 route 0: provider and model are required", + }, + { + name: "route missing model", + config: ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key + +models: + - name: gpt-4 + routes: + - provider: openai +`, + wantErr: "model gpt-4 route 0: provider and model are required", + }, + { + name: "route references unknown provider", + config: ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-key + +models: + - name: gpt-4 + routes: + - provider: anthropic + model: gpt-4 +`, + wantErr: "model gpt-4 route 0: unknown provider anthropic", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := writeConfigFile(t, tt.config) + _, err := Load(path) + if err == nil { + t.Fatalf("Load() should return error, want %q", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr) + } + }) + } +} + +func TestProviderByName(t *testing.T) { + path := writeConfigFile(t, ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: sk-openai + - name: anthropic + base_url: https://api.anthropic.com/v1 + api_key: sk-anthropic + +models: + - name: gpt-4 + routes: + - provider: openai + model: gpt-4 +`) + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load() returned error: %v", err) + } + + tests := []struct { + name string + lookup string + wantNil bool + wantName string + }{ + {"existing provider openai", "openai", false, "openai"}, + {"existing provider anthropic", "anthropic", false, "anthropic"}, + {"nonexistent provider", "google", true, ""}, + {"empty name", "", true, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := cfg.ProviderByName(tt.lookup) + if tt.wantNil { + if p != nil { + t.Errorf("ProviderByName(%q) = %v, want nil", tt.lookup, p) + } + } else { + if p == nil { + t.Fatalf("ProviderByName(%q) = nil, want provider", tt.lookup) + } + if p.Name != tt.wantName { + t.Errorf("ProviderByName(%q).Name = %q, want %q", tt.lookup, p.Name, tt.wantName) + } + } + }) + } + + // Verify returned pointer refers to the actual config entry + p := cfg.ProviderByName("openai") + if p.APIKey != "sk-openai" { + t.Errorf("ProviderByName(openai).APIKey = %q, want %q", p.APIKey, "sk-openai") + } +} + +func TestLoad_EnvironmentVariableExpansion(t *testing.T) { + t.Setenv("TEST_API_KEY", "sk-from-env") + t.Setenv("TEST_BASE_URL", "https://env.example.com/v1") + t.Setenv("TEST_PROVIDER_NAME", "env-provider") + + path := writeConfigFile(t, ` +providers: + - name: $TEST_PROVIDER_NAME + base_url: ${TEST_BASE_URL} + api_key: ${TEST_API_KEY} + +models: + - name: test-model + routes: + - provider: env-provider + model: gpt-4 +`) + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load() returned error: %v", err) + } + + if cfg.Providers[0].Name != "env-provider" { + t.Errorf("Provider.Name = %q, want %q", cfg.Providers[0].Name, "env-provider") + } + if cfg.Providers[0].BaseURL != "https://env.example.com/v1" { + t.Errorf("Provider.BaseURL = %q, want %q", cfg.Providers[0].BaseURL, "https://env.example.com/v1") + } + if cfg.Providers[0].APIKey != "sk-from-env" { + t.Errorf("Provider.APIKey = %q, want %q", cfg.Providers[0].APIKey, "sk-from-env") + } +} + +func TestLoad_UnsetEnvVarExpandsToEmpty(t *testing.T) { + // Ensure the variable is not set + t.Setenv("TEST_UNSET_VAR", "") + os.Unsetenv("TEST_UNSET_VAR") + + path := writeConfigFile(t, ` +providers: + - name: openai + base_url: https://api.openai.com/v1 + api_key: ${TEST_UNSET_VAR} + +models: + - name: gpt-4 + routes: + - provider: openai + model: gpt-4 +`) + + _, err := Load(path) + if err == nil { + t.Fatal("Load() should return error when env var expands to empty required field") + } + // api_key will be empty, so validation should catch it + if !strings.Contains(err.Error(), "api_key are required") { + t.Errorf("error = %q, want to contain api_key validation message", err.Error()) + } +} + diff --git a/llm-gateway/internal/config/watcher.go b/llm-gateway/internal/config/watcher.go new file mode 100644 index 0000000..f63a478 --- /dev/null +++ b/llm-gateway/internal/config/watcher.go @@ -0,0 +1,27 @@ +package config + +import ( + "log" + "os" + "os/signal" + "syscall" +) + +// WatchReload listens for SIGHUP and calls the callback with the new config. +func WatchReload(configPath string, callback func(*Config)) { + sighup := make(chan os.Signal, 1) + signal.Notify(sighup, syscall.SIGHUP) + + go func() { + for range sighup { + log.Println("SIGHUP received, reloading config...") + newCfg, err := Load(configPath) + if err != nil { + log.Printf("ERROR: config reload failed: %v", err) + continue + } + callback(newCfg) + log.Println("Config reloaded successfully") + } + }() +} diff --git a/llm-gateway/internal/dashboard/api.go b/llm-gateway/internal/dashboard/api.go index 3bb1b7a..346d703 100644 --- a/llm-gateway/internal/dashboard/api.go +++ b/llm-gateway/internal/dashboard/api.go @@ -7,6 +7,8 @@ import ( "strconv" "time" + "github.com/go-chi/chi/v5" + "llm-gateway/internal/auth" "llm-gateway/internal/cache" "llm-gateway/internal/provider" @@ -58,6 +60,7 @@ type TokenUsageStats struct { // RequestLogEntry represents a single request log row. type RequestLogEntry struct { + RequestID string `json:"request_id"` Timestamp int64 `json:"timestamp"` TokenName string `json:"token_name"` Model string `json:"model"` @@ -104,6 +107,8 @@ type StatsAPI struct { authStore *auth.Store healthTracker *provider.HealthTracker cache *cache.Cache + auditLogger *storage.AuditLogger + debugLogger *storage.DebugLogger } func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI { @@ -120,6 +125,16 @@ func (s *StatsAPI) SetCache(c *cache.Cache) { s.cache = c } +// SetAuditLogger sets the audit logger. +func (s *StatsAPI) SetAuditLogger(al *storage.AuditLogger) { + s.auditLogger = al +} + +// SetDebugLogger sets the debug logger. +func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) { + s.debugLogger = dl +} + // TokenNamesForUser returns the token names that belong to the user. // Admins get nil (no filter), non-admins get their token names. func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string { @@ -325,7 +340,7 @@ func (s *StatsAPI) GetLogs(tokenNames []string, page int, model, token, status s } // Get page - query := `SELECT timestamp, token_name, model, provider, provider_model, + query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, COALESCE(error_message, ''), streaming, cached FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?` @@ -341,7 +356,7 @@ func (s *StatsAPI) GetLogs(tokenNames []string, page int, model, token, status s for rows.Next() { var l RequestLogEntry var streaming, cached int - rows.Scan(&l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel, + rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel, &l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status, &l.ErrorMessage, &streaming, &cached) l.Streaming = streaming == 1 @@ -624,6 +639,79 @@ func (s *StatsAPI) CacheStats(w http.ResponseWriter, r *http.Request) { writeJSON(w, stats) } +// AuditLogs serves the audit log API (admin-only). +func (s *StatsAPI) AuditLogs(w http.ResponseWriter, r *http.Request) { + if s.auditLogger == nil { + writeJSON(w, map[string]any{"entries": []any{}, "total": 0}) + return + } + page, _ := strconv.Atoi(r.URL.Query().Get("page")) + action := r.URL.Query().Get("action") + since := time.Now().AddDate(0, 0, -30).Unix() + if sinceStr := r.URL.Query().Get("since"); sinceStr != "" { + if s, err := strconv.ParseInt(sinceStr, 10, 64); err == nil { + since = s + } + } + result := s.auditLogger.Query(since, action, page, 50) + writeJSON(w, result) +} + +// DebugToggle enables/disables debug logging at runtime. +func (s *StatsAPI) DebugToggle(w http.ResponseWriter, r *http.Request) { + if s.debugLogger == nil { + writeJSON(w, map[string]any{"error": "debug logger not configured"}) + return + } + var req struct { + Enabled bool `json:"enabled"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + writeJSON(w, map[string]string{"error": "invalid JSON"}) + return + } + s.debugLogger.SetEnabled(req.Enabled) + writeJSON(w, map[string]any{"enabled": s.debugLogger.IsEnabled()}) +} + +// DebugStatus returns whether debug logging is enabled. +func (s *StatsAPI) DebugStatus(w http.ResponseWriter, r *http.Request) { + enabled := false + if s.debugLogger != nil { + enabled = s.debugLogger.IsEnabled() + } + writeJSON(w, map[string]any{"enabled": enabled}) +} + +// DebugLogs serves paginated debug log entries. +func (s *StatsAPI) DebugLogs(w http.ResponseWriter, r *http.Request) { + if s.debugLogger == nil { + writeJSON(w, map[string]any{"entries": []any{}, "total": 0}) + return + } + page, _ := strconv.Atoi(r.URL.Query().Get("page")) + result := s.debugLogger.Query(page, 50) + writeJSON(w, result) +} + +// DebugLogByRequestID serves a single debug log entry by request ID. +func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) { + if s.debugLogger == nil { + w.WriteHeader(http.StatusNotFound) + writeJSON(w, map[string]string{"error": "debug logger not configured"}) + return + } + requestID := chi.URLParam(r, "requestID") + entry := s.debugLogger.GetByRequestID(requestID) + if entry == nil { + w.WriteHeader(http.StatusNotFound) + writeJSON(w, map[string]string{"error": "not found"}) + return + } + writeJSON(w, entry) +} + func writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(v) diff --git a/llm-gateway/internal/dashboard/export.go b/llm-gateway/internal/dashboard/export.go new file mode 100644 index 0000000..2ac0875 --- /dev/null +++ b/llm-gateway/internal/dashboard/export.go @@ -0,0 +1,297 @@ +package dashboard + +import ( + "encoding/csv" + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "llm-gateway/internal/auth" + "llm-gateway/internal/storage" +) + +type ExportHandler struct { + db *storage.DB + authStore *auth.Store +} + +func NewExportHandler(db *storage.DB, authStore *auth.Store) *ExportHandler { + return &ExportHandler{db: db, authStore: authStore} +} + +// ExportLogs exports request logs as CSV or JSON. +func (e *ExportHandler) ExportLogs(w http.ResponseWriter, r *http.Request) { + format := r.URL.Query().Get("format") + if format == "" { + format = "json" + } + + // Build query + where := "WHERE 1=1" + var args []any + + if from := r.URL.Query().Get("from"); from != "" { + if ts, err := strconv.ParseInt(from, 10, 64); err == nil { + where += " AND timestamp >= ?" + args = append(args, ts) + } + } + if to := r.URL.Query().Get("to"); to != "" { + if ts, err := strconv.ParseInt(to, 10, 64); err == nil { + where += " AND timestamp <= ?" + args = append(args, ts) + } + } + if model := r.URL.Query().Get("model"); model != "" { + where += " AND model = ?" + args = append(args, model) + } + if token := r.URL.Query().Get("token"); token != "" { + where += " AND token_name = ?" + args = append(args, token) + } + if status := r.URL.Query().Get("status"); status != "" { + where += " AND status = ?" + args = append(args, status) + } + + // Token filtering for non-admins + user := auth.UserFromContext(r.Context()) + if user != nil && !user.IsAdmin { + tokens, err := e.authStore.ListAPITokens(user.ID) + if err != nil || len(tokens) == 0 { + where += " AND 1=0" + } else { + where += " AND token_name IN (" + for i, t := range tokens { + if i > 0 { + where += "," + } + where += "?" + args = append(args, t.Name) + } + where += ")" + } + } + + query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model, + input_tokens, output_tokens, cost_usd, latency_ms, status, + COALESCE(error_message, ''), streaming, cached + FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT 100000` + + rows, err := e.db.Query(query, args...) + if err != nil { + http.Error(w, "query failed", http.StatusInternalServerError) + return + } + defer rows.Close() + + type logRow struct { + RequestID string `json:"request_id"` + Timestamp int64 `json:"timestamp"` + TokenName string `json:"token_name"` + Model string `json:"model"` + Provider string `json:"provider"` + ProviderModel string `json:"provider_model"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CostUSD float64 `json:"cost_usd"` + LatencyMS int64 `json:"latency_ms"` + Status string `json:"status"` + ErrorMessage string `json:"error_message"` + Streaming bool `json:"streaming"` + Cached bool `json:"cached"` + } + + var results []logRow + for rows.Next() { + var l logRow + var streaming, cached int + rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel, + &l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status, + &l.ErrorMessage, &streaming, &cached) + l.Streaming = streaming == 1 + l.Cached = cached == 1 + results = append(results, l) + } + + now := time.Now().Format("20060102-150405") + + switch format { + case "csv": + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.csv", now)) + writer := csv.NewWriter(w) + writer.Write([]string{"request_id", "timestamp", "token_name", "model", "provider", "provider_model", + "input_tokens", "output_tokens", "cost_usd", "latency_ms", "status", "error_message", "streaming", "cached"}) + for _, l := range results { + writer.Write([]string{ + l.RequestID, + strconv.FormatInt(l.Timestamp, 10), + l.TokenName, l.Model, l.Provider, l.ProviderModel, + strconv.Itoa(l.InputTokens), strconv.Itoa(l.OutputTokens), + fmt.Sprintf("%.8f", l.CostUSD), + strconv.FormatInt(l.LatencyMS, 10), + l.Status, l.ErrorMessage, + strconv.FormatBool(l.Streaming), strconv.FormatBool(l.Cached), + }) + } + writer.Flush() + default: + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.json", now)) + json.NewEncoder(w).Encode(results) + } +} + +// ExportStats exports aggregated stats as CSV or JSON. +func (e *ExportHandler) ExportStats(w http.ResponseWriter, r *http.Request) { + format := r.URL.Query().Get("format") + if format == "" { + format = "json" + } + statsType := r.URL.Query().Get("type") + if statsType == "" { + statsType = "summary" + } + + now := time.Now().Format("20060102-150405") + since := time.Now().AddDate(0, -1, 0).Unix() + + switch statsType { + case "models": + rows, err := e.db.Query(`SELECT model, COUNT(*) as requests, + COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), + COALESCE(SUM(cost_usd), 0), COALESCE(AVG(latency_ms), 0) + FROM request_logs WHERE timestamp >= ? GROUP BY model ORDER BY requests DESC`, since) + if err != nil { + http.Error(w, "query failed", http.StatusInternalServerError) + return + } + defer rows.Close() + + type modelRow struct { + Model string `json:"model"` + Requests int `json:"requests"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CostUSD float64 `json:"cost_usd"` + AvgLatencyMS float64 `json:"avg_latency_ms"` + } + var results []modelRow + for rows.Next() { + var m modelRow + rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS) + results = append(results, m) + } + + if format == "csv" { + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.csv", now)) + writer := csv.NewWriter(w) + writer.Write([]string{"model", "requests", "input_tokens", "output_tokens", "cost_usd", "avg_latency_ms"}) + for _, m := range results { + writer.Write([]string{m.Model, strconv.Itoa(m.Requests), strconv.Itoa(m.InputTokens), + strconv.Itoa(m.OutputTokens), fmt.Sprintf("%.8f", m.CostUSD), fmt.Sprintf("%.2f", m.AvgLatencyMS)}) + } + writer.Flush() + } else { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.json", now)) + json.NewEncoder(w).Encode(results) + } + + case "providers": + rows, err := e.db.Query(`SELECT provider, COUNT(*) as requests, + COALESCE(SUM(CASE WHEN status='success' THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status='error' THEN 1 ELSE 0 END), 0), + COALESCE(AVG(latency_ms), 0), COALESCE(SUM(cost_usd), 0) + FROM request_logs WHERE timestamp >= ? GROUP BY provider ORDER BY requests DESC`, since) + if err != nil { + http.Error(w, "query failed", http.StatusInternalServerError) + return + } + defer rows.Close() + + type providerRow struct { + Provider string `json:"provider"` + Requests int `json:"requests"` + Successes int `json:"successes"` + Errors int `json:"errors"` + AvgLatencyMS float64 `json:"avg_latency_ms"` + CostUSD float64 `json:"cost_usd"` + } + var results []providerRow + for rows.Next() { + var p providerRow + rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD) + results = append(results, p) + } + + if format == "csv" { + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.csv", now)) + writer := csv.NewWriter(w) + writer.Write([]string{"provider", "requests", "successes", "errors", "avg_latency_ms", "cost_usd"}) + for _, p := range results { + writer.Write([]string{p.Provider, strconv.Itoa(p.Requests), strconv.Itoa(p.Successes), + strconv.Itoa(p.Errors), fmt.Sprintf("%.2f", p.AvgLatencyMS), fmt.Sprintf("%.8f", p.CostUSD)}) + } + writer.Flush() + } else { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.json", now)) + json.NewEncoder(w).Encode(results) + } + + case "tokens": + rows, err := e.db.Query(`SELECT token_name, COUNT(*) as requests, + COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), + COALESCE(SUM(cost_usd), 0) + FROM request_logs WHERE timestamp >= ? GROUP BY token_name ORDER BY requests DESC`, since) + if err != nil { + http.Error(w, "query failed", http.StatusInternalServerError) + return + } + defer rows.Close() + + type tokenRow struct { + TokenName string `json:"token_name"` + Requests int `json:"requests"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CostUSD float64 `json:"cost_usd"` + } + var results []tokenRow + for rows.Next() { + var t tokenRow + rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD) + results = append(results, t) + } + + if format == "csv" { + w.Header().Set("Content-Type", "text/csv") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.csv", now)) + writer := csv.NewWriter(w) + writer.Write([]string{"token_name", "requests", "input_tokens", "output_tokens", "cost_usd"}) + for _, t := range results { + writer.Write([]string{t.TokenName, strconv.Itoa(t.Requests), strconv.Itoa(t.InputTokens), + strconv.Itoa(t.OutputTokens), fmt.Sprintf("%.8f", t.CostUSD)}) + } + writer.Flush() + } else { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.json", now)) + json.NewEncoder(w).Encode(results) + } + + default: // summary + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-summary-%s.json", now)) + statsAPI := NewStatsAPI(e.db, e.authStore) + result := statsAPI.GetSummary(nil) + json.NewEncoder(w).Encode(result) + } +} diff --git a/llm-gateway/internal/dashboard/handler.go b/llm-gateway/internal/dashboard/handler.go index af4e4f1..b637ccd 100644 --- a/llm-gateway/internal/dashboard/handler.go +++ b/llm-gateway/internal/dashboard/handler.go @@ -11,6 +11,7 @@ import ( "llm-gateway/internal/auth" "llm-gateway/internal/cache" "llm-gateway/internal/provider" + "llm-gateway/internal/storage" ) //go:embed templates/*.html templates/partials/*.html @@ -125,15 +126,24 @@ type PageData struct { FilterStatus string // Models routing page data ModelRoutes []provider.ModelRouteInfo + // Audit page data + AuditResult *storage.AuditQueryResult + AuditFilterActions []string + FilterAction string + // Debug page data + DebugResult *storage.DebugLogQueryResult + DebugEnabled bool } // Dashboard serves the HTMX-based dashboard pages. type Dashboard struct { - templates *template.Template - authStore *auth.Store - statsAPI *StatsAPI - registry *provider.Registry - cache *cache.Cache + templates *template.Template + authStore *auth.Store + statsAPI *StatsAPI + registry *provider.Registry + cache *cache.Cache + auditLogger *storage.AuditLogger + debugLogger *storage.DebugLogger } // NewDashboard creates a new Dashboard handler. @@ -162,6 +172,16 @@ func (d *Dashboard) SetCache(c *cache.Cache) { d.cache = c } +// SetAuditLogger sets the audit logger for the audit page. +func (d *Dashboard) SetAuditLogger(al *storage.AuditLogger) { + d.auditLogger = al +} + +// SetDebugLogger sets the debug logger for the debug page. +func (d *Dashboard) SetDebugLogger(dl *storage.DebugLogger) { + d.debugLogger = dl +} + // LoginPage serves the login page. func (d *Dashboard) LoginPage(w http.ResponseWriter, r *http.Request) { if !d.authStore.HasAnyUser() { @@ -298,6 +318,62 @@ func (d *Dashboard) UsersPage(w http.ResponseWriter, r *http.Request) { }) } +// AuditPage serves the audit log view (admin only). +func (d *Dashboard) AuditPage(w http.ResponseWriter, r *http.Request) { + user := auth.UserFromContext(r.Context()) + + page, _ := strconv.Atoi(r.URL.Query().Get("page")) + if page < 1 { + page = 1 + } + action := r.URL.Query().Get("action") + since := time.Now().AddDate(0, 0, -30).Unix() + + var auditResult *storage.AuditQueryResult + if d.auditLogger != nil { + auditResult = d.auditLogger.Query(since, action, page, 50) + } else { + auditResult = &storage.AuditQueryResult{Entries: []storage.AuditEntry{}, Page: 1, TotalPages: 1} + } + + // Common audit action types for the filter dropdown + actions := []string{"login", "logout", "create_user", "delete_user", "create_token", "delete_token", "change_password", "setup_totp", "disable_totp"} + + d.renderDashboardPage(w, r, "partials/audit.html", PageData{ + ActivePage: "audit", + User: user, + AuditResult: auditResult, + AuditFilterActions: actions, + FilterAction: action, + }) +} + +// DebugPage serves the debug logging view (admin only). +func (d *Dashboard) DebugPage(w http.ResponseWriter, r *http.Request) { + user := auth.UserFromContext(r.Context()) + + page, _ := strconv.Atoi(r.URL.Query().Get("page")) + if page < 1 { + page = 1 + } + + var debugResult *storage.DebugLogQueryResult + debugEnabled := false + if d.debugLogger != nil { + debugResult = d.debugLogger.QueryFull(page, 50) + debugEnabled = d.debugLogger.IsEnabled() + } else { + debugResult = &storage.DebugLogQueryResult{Entries: []storage.DebugLogEntry{}, Page: 1, TotalPages: 1} + } + + d.renderDashboardPage(w, r, "partials/debug.html", PageData{ + ActivePage: "debug", + User: user, + DebugResult: debugResult, + DebugEnabled: debugEnabled, + }) +} + // SettingsPage serves the settings view. func (d *Dashboard) SettingsPage(w http.ResponseWriter, r *http.Request) { user := auth.UserFromContext(r.Context()) diff --git a/llm-gateway/internal/dashboard/templates/layout.html b/llm-gateway/internal/dashboard/templates/layout.html index ba10c08..4925906 100644 --- a/llm-gateway/internal/dashboard/templates/layout.html +++ b/llm-gateway/internal/dashboard/templates/layout.html @@ -160,6 +160,24 @@ .badge-error { background: var(--accent-red-bg); color: var(--accent-red); } .badge-cached { background: var(--accent-blue-bg); color: var(--accent-blue); } .badge-priority { background: var(--bg-tertiary); color: var(--text-secondary); } + .badge-open { background: var(--accent-red-bg); color: var(--accent-red); } + .badge-half-open { background: var(--accent-yellow-bg); color: var(--accent-yellow); } + + /* Toggle switch */ + .toggle-switch { position: relative; display: inline-block; width: 44px; height: 24px; } + .toggle-switch input { opacity: 0; width: 0; height: 0; } + .toggle-slider { position: absolute; cursor: pointer; top: 0; left: 0; right: 0; bottom: 0; background: var(--bg-tertiary); border-radius: 24px; transition: 0.2s; } + .toggle-slider:before { content: ""; position: absolute; height: 18px; width: 18px; left: 3px; bottom: 3px; background: var(--text-secondary); border-radius: 50%; transition: 0.2s; } + .toggle-switch input:checked + .toggle-slider { background: var(--accent-blue); } + .toggle-switch input:checked + .toggle-slider:before { transform: translateX(20px); background: #fff; } + + /* Code block for debug bodies */ + .code-block { background: var(--bg-primary); border: 1px solid var(--border-color); border-radius: 6px; padding: 12px; font-family: monospace; font-size: 0.8rem; white-space: pre-wrap; word-break: break-all; max-height: 300px; overflow-y: auto; } + + /* Export button inline */ + .export-links { display: inline-flex; gap: 6px; margin-left: 12px; } + .export-links a { font-size: 0.7rem; color: var(--text-muted); text-decoration: none; padding: 2px 6px; border: 1px solid var(--border-color); border-radius: 4px; } + .export-links a:hover { color: var(--text-primary); border-color: var(--text-muted); } .page-header { display: flex; align-items: center; gap: 12px; margin-bottom: 20px; } .page-header h1 { font-size: 1.3rem; color: var(--text-heading); } @@ -268,6 +286,8 @@ window.matchMedia('(prefers-color-scheme: light)').addEventListener('change', fu API Tokens {{if .User.IsAdmin}} Users + Audit Log + Debug {{end}} Settings diff --git a/llm-gateway/internal/dashboard/templates/partials/audit.html b/llm-gateway/internal/dashboard/templates/partials/audit.html new file mode 100644 index 0000000..7697964 --- /dev/null +++ b/llm-gateway/internal/dashboard/templates/partials/audit.html @@ -0,0 +1,83 @@ +{{define "content"}} + + +
+ + +
+ +
+ + + + + + + + + + + + + {{range .AuditResult.Entries}} + + + + + + + + + {{end}} + {{if not .AuditResult.Entries}} + + {{end}} + +
TimeUserActionTargetDetailsIP
{{formatTimeDetail .Timestamp}}{{.Username}}{{.Action}}{{if .TargetType}}{{.TargetType}}{{if .TargetID}}/{{.TargetID}}{{end}}{{else}}-{{end}}{{if .Details}}{{.Details}}{{else}}-{{end}}{{if .IPAddress}}{{.IPAddress}}{{else}}-{{end}}
No audit log entries
+ + {{if gt .AuditResult.TotalPages 1}} + + {{end}} +
+ + +{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/dashboard.html b/llm-gateway/internal/dashboard/templates/partials/dashboard.html index 2f5063a..1e2d13b 100644 --- a/llm-gateway/internal/dashboard/templates/partials/dashboard.html +++ b/llm-gateway/internal/dashboard/templates/partials/dashboard.html @@ -25,6 +25,8 @@
{{.Provider}} {{.Status}} + {{if eq .CircuitState "open"}}circuit open{{end}} + {{if eq .CircuitState "half-open"}}half-open{{end}} {{printf "%.0f" .AvgLatency}}ms avg | {{formatPct .ErrorRate}} errors
{{end}} @@ -71,7 +73,7 @@ {{if .Models}}
-

Models

+

ModelsCSVJSON

@@ -91,7 +93,7 @@ {{if .Providers}}
-

Providers

+

ProvidersCSVJSON

ModelRequestsTokens (in/out)CostAvg Latency
@@ -112,7 +114,7 @@ {{if .TokenStats}}
-

API Token Usage

+

API Token UsageCSVJSON

ProviderRequestsSuccessErrorsAvg LatencyCost
diff --git a/llm-gateway/internal/dashboard/templates/partials/debug.html b/llm-gateway/internal/dashboard/templates/partials/debug.html new file mode 100644 index 0000000..ffb419b --- /dev/null +++ b/llm-gateway/internal/dashboard/templates/partials/debug.html @@ -0,0 +1,100 @@ +{{define "content"}} + + +
+ Debug Mode + + {{if .DebugEnabled}}Enabled — requests are being logged{{else}}Disabled{{end}} +
+ +
+
TokenRequestsTokens (in/out)Cost
+ + + + + + + + + + + + + {{range $i, $entry := .DebugResult.Entries}} + + + + + + + + + + + + + {{end}} + {{if not .DebugResult.Entries}} + + {{end}} + +
TimeRequest IDTokenModelProviderStatus
+
+
Request Headers:
+
{{if $entry.RequestHeaders}}{{$entry.RequestHeaders}}{{else}}(none){{end}}
+
Request Body:
+
{{if $entry.RequestBody}}{{$entry.RequestBody}}{{else}}(none){{end}}
+
Response Body:
+
{{if $entry.ResponseBody}}{{$entry.ResponseBody}}{{else}}(none){{end}}
+
+
No debug log entries
+ + {{if gt .DebugResult.TotalPages 1}} + + {{end}} +
+ + +{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/logs.html b/llm-gateway/internal/dashboard/templates/partials/logs.html index 79dca12..bf8c640 100644 --- a/llm-gateway/internal/dashboard/templates/partials/logs.html +++ b/llm-gateway/internal/dashboard/templates/partials/logs.html @@ -20,6 +20,9 @@ + + +
@@ -116,5 +119,15 @@ function toggleExpand(id) { var el = document.getElementById(id); if (el) el.classList.toggle('show'); } +function exportLogs(format) { + var params = ['format=' + format]; + var model = document.getElementById('filter-model').value; + var token = document.getElementById('filter-token').value; + var status = document.getElementById('filter-status').value; + if (model) params.push('model=' + encodeURIComponent(model)); + if (token) params.push('token=' + encodeURIComponent(token)); + if (status) params.push('status=' + encodeURIComponent(status)); + window.open('/api/export/logs?' + params.join('&'), '_blank'); +} {{end}} diff --git a/llm-gateway/internal/metrics/prometheus.go b/llm-gateway/internal/metrics/prometheus.go index 7e8b635..620f51e 100644 --- a/llm-gateway/internal/metrics/prometheus.go +++ b/llm-gateway/internal/metrics/prometheus.go @@ -10,6 +10,8 @@ type Metrics struct { requestDuration *prometheus.HistogramVec tokensTotal *prometheus.CounterVec costTotal *prometheus.CounterVec + cacheHits prometheus.Counter + cacheMisses prometheus.Counter } func New() *Metrics { @@ -34,6 +36,16 @@ func New() *Metrics { Name: "llm_gateway_cost_usd_total", Help: "Total cost in USD", }, []string{"model", "provider", "token_name"}), + + cacheHits: promauto.NewCounter(prometheus.CounterOpts{ + Name: "llm_gateway_cache_hits_total", + Help: "Total number of cache hits", + }), + + cacheMisses: promauto.NewCounter(prometheus.CounterOpts{ + Name: "llm_gateway_cache_misses_total", + Help: "Total number of cache misses", + }), } } @@ -51,3 +63,11 @@ func (m *Metrics) RecordRequest(model, providerName, tokenName, status string, l m.costTotal.WithLabelValues(model, providerName, tokenName).Add(cost) } } + +func (m *Metrics) RecordCacheHit() { + m.cacheHits.Inc() +} + +func (m *Metrics) RecordCacheMiss() { + m.cacheMisses.Inc() +} diff --git a/llm-gateway/internal/provider/balancer.go b/llm-gateway/internal/provider/balancer.go new file mode 100644 index 0000000..602d885 --- /dev/null +++ b/llm-gateway/internal/provider/balancer.go @@ -0,0 +1,144 @@ +package provider + +import ( + "math/rand" + "sort" + "sync/atomic" +) + +// LoadBalancer reorders routes for load distribution. +type LoadBalancer interface { + Reorder(routes []Route) []Route +} + +// NewLoadBalancer creates a load balancer by strategy name. +func NewLoadBalancer(strategy string) LoadBalancer { + switch strategy { + case "round-robin": + return &RoundRobinBalancer{} + case "random": + return &RandomBalancer{} + case "least-cost": + return &LeastCostBalancer{} + default: + return &FirstBalancer{} + } +} + +// FirstBalancer is a no-op that preserves original order. +type FirstBalancer struct{} + +func (b *FirstBalancer) Reorder(routes []Route) []Route { + return routes +} + +// RoundRobinBalancer rotates routes within same-priority groups. +type RoundRobinBalancer struct { + counter atomic.Uint64 +} + +func (b *RoundRobinBalancer) Reorder(routes []Route) []Route { + if len(routes) <= 1 { + return routes + } + + result := make([]Route, len(routes)) + copy(result, routes) + + // Group by priority and rotate within each group + groups := groupByPriority(result) + idx := 0 + count := b.counter.Add(1) + for _, group := range groups { + if len(group) > 1 { + offset := int(count) % len(group) + for j := 0; j < len(group); j++ { + result[idx] = group[(j+offset)%len(group)] + idx++ + } + } else { + result[idx] = group[0] + idx++ + } + } + + return result +} + +// RandomBalancer shuffles routes within same-priority groups. +type RandomBalancer struct{} + +func (b *RandomBalancer) Reorder(routes []Route) []Route { + if len(routes) <= 1 { + return routes + } + + result := make([]Route, len(routes)) + copy(result, routes) + + groups := groupByPriority(result) + idx := 0 + for _, group := range groups { + rand.Shuffle(len(group), func(i, j int) { + group[i], group[j] = group[j], group[i] + }) + for _, r := range group { + result[idx] = r + idx++ + } + } + + return result +} + +// LeastCostBalancer sorts by price within same-priority groups. +type LeastCostBalancer struct{} + +func (b *LeastCostBalancer) Reorder(routes []Route) []Route { + if len(routes) <= 1 { + return routes + } + + result := make([]Route, len(routes)) + copy(result, routes) + + groups := groupByPriority(result) + idx := 0 + for _, group := range groups { + sort.Slice(group, func(i, j int) bool { + costI := group[i].InputPrice + group[i].OutputPrice + costJ := group[j].InputPrice + group[j].OutputPrice + return costI < costJ + }) + for _, r := range group { + result[idx] = r + idx++ + } + } + + return result +} + +// groupByPriority splits routes into groups of same priority, preserving order. +func groupByPriority(routes []Route) [][]Route { + if len(routes) == 0 { + return nil + } + + var groups [][]Route + currentPriority := routes[0].Priority + currentGroup := []Route{routes[0]} + + for i := 1; i < len(routes); i++ { + if routes[i].Priority == currentPriority { + currentGroup = append(currentGroup, routes[i]) + } else { + groups = append(groups, currentGroup) + currentPriority = routes[i].Priority + currentGroup = []Route{routes[i]} + } + } + groups = append(groups, currentGroup) + + return groups +} diff --git a/llm-gateway/internal/provider/balancer_test.go b/llm-gateway/internal/provider/balancer_test.go new file mode 100644 index 0000000..cc5378e --- /dev/null +++ b/llm-gateway/internal/provider/balancer_test.go @@ -0,0 +1,294 @@ +package provider + +import ( + "fmt" + "testing" +) + +type routeSpec struct { + name string + priority int + input float64 + output float64 +} + +func makeRoutes(specs ...routeSpec) []Route { + routes := make([]Route, len(specs)) + for i, s := range specs { + routes[i] = Route{ + Provider: &mockProvider{name: s.name}, + ProviderModel: s.name + "-model", + Priority: s.priority, + InputPrice: s.input, + OutputPrice: s.output, + } + } + return routes +} + +func routeNames(routes []Route) []string { + names := make([]string, len(routes)) + for i, r := range routes { + names[i] = r.Provider.Name() + } + return names +} + +func TestFirstBalancer_PreservesOrder(t *testing.T) { + routes := makeRoutes( + routeSpec{"a", 1, 1.0, 1.0}, + routeSpec{"b", 1, 2.0, 2.0}, + routeSpec{"c", 1, 3.0, 3.0}, + ) + + b := &FirstBalancer{} + result := b.Reorder(routes) + + names := routeNames(result) + if names[0] != "a" || names[1] != "b" || names[2] != "c" { + t.Fatalf("expected [a b c], got %v", names) + } +} + +func TestRoundRobinBalancer_RotatesWithinPriorityGroup(t *testing.T) { + routes := makeRoutes( + routeSpec{"a", 1, 1.0, 1.0}, + routeSpec{"b", 1, 1.0, 1.0}, + routeSpec{"c", 1, 1.0, 1.0}, + ) + + b := &RoundRobinBalancer{} + + // Collect the first element from multiple calls + seen := make(map[string]bool) + for i := 0; i < 6; i++ { + result := b.Reorder(routes) + seen[result[0].Provider.Name()] = true + } + + // All routes should have appeared as first at some point + for _, name := range []string{"a", "b", "c"} { + if !seen[name] { + t.Errorf("expected %q to appear as first element in rotation", name) + } + } +} + +func TestRoundRobinBalancer_PreservesPriorityOrder(t *testing.T) { + routes := makeRoutes( + routeSpec{"a", 1, 1.0, 1.0}, + routeSpec{"b", 1, 1.0, 1.0}, + routeSpec{"c", 2, 1.0, 1.0}, + ) + + b := &RoundRobinBalancer{} + + // Priority 2 route should always be last + for i := 0; i < 5; i++ { + result := b.Reorder(routes) + if result[2].Provider.Name() != "c" { + t.Fatalf("expected priority-2 route 'c' at the end, got %q", result[2].Provider.Name()) + } + } +} + +func TestRandomBalancer_AllRoutesPresent(t *testing.T) { + routes := makeRoutes( + routeSpec{"a", 1, 1.0, 1.0}, + routeSpec{"b", 1, 1.0, 1.0}, + routeSpec{"c", 1, 1.0, 1.0}, + ) + + b := &RandomBalancer{} + + for i := 0; i < 10; i++ { + result := b.Reorder(routes) + if len(result) != 3 { + t.Fatalf("expected 3 routes, got %d", len(result)) + } + + names := make(map[string]bool) + for _, r := range result { + names[r.Provider.Name()] = true + } + for _, want := range []string{"a", "b", "c"} { + if !names[want] { + t.Errorf("missing route %q in result", want) + } + } + } +} + +func TestRandomBalancer_PreservesPriorityOrder(t *testing.T) { + routes := makeRoutes( + routeSpec{"a", 1, 1.0, 1.0}, + routeSpec{"b", 1, 1.0, 1.0}, + routeSpec{"c", 2, 1.0, 1.0}, + ) + + b := &RandomBalancer{} + + for i := 0; i < 10; i++ { + result := b.Reorder(routes) + if result[2].Provider.Name() != "c" { + t.Fatalf("expected priority-2 route 'c' last, got %q", result[2].Provider.Name()) + } + } +} + +func TestLeastCostBalancer_SortsByCost(t *testing.T) { + routes := makeRoutes( + routeSpec{"expensive", 1, 10.0, 10.0}, + routeSpec{"cheap", 1, 1.0, 1.0}, + routeSpec{"medium", 1, 5.0, 5.0}, + ) + + b := &LeastCostBalancer{} + result := b.Reorder(routes) + + names := routeNames(result) + expected := []string{"cheap", "medium", "expensive"} + for i, want := range expected { + if names[i] != want { + t.Errorf("position %d: got %q, want %q", i, names[i], want) + } + } +} + +func TestLeastCostBalancer_PreservesPriorityOrder(t *testing.T) { + routes := makeRoutes( + routeSpec{"expensive-p1", 1, 10.0, 10.0}, + routeSpec{"cheap-p1", 1, 1.0, 1.0}, + routeSpec{"cheap-p2", 2, 0.5, 0.5}, + ) + + b := &LeastCostBalancer{} + result := b.Reorder(routes) + + names := routeNames(result) + // Within priority 1, cheap should come first; priority 2 always last + if names[0] != "cheap-p1" { + t.Errorf("expected cheap-p1 first, got %q", names[0]) + } + if names[1] != "expensive-p1" { + t.Errorf("expected expensive-p1 second, got %q", names[1]) + } + if names[2] != "cheap-p2" { + t.Errorf("expected cheap-p2 last, got %q", names[2]) + } +} + +func TestGroupByPriority(t *testing.T) { + tests := []struct { + name string + priorities []int + wantGroups [][]int + }{ + { + name: "empty", + priorities: nil, + wantGroups: nil, + }, + { + name: "single", + priorities: []int{1}, + wantGroups: [][]int{{1}}, + }, + { + name: "all same", + priorities: []int{1, 1, 1}, + wantGroups: [][]int{{1, 1, 1}}, + }, + { + name: "two groups", + priorities: []int{1, 1, 2, 2}, + wantGroups: [][]int{{1, 1}, {2, 2}}, + }, + { + name: "three groups", + priorities: []int{1, 2, 2, 3}, + wantGroups: [][]int{{1}, {2, 2}, {3}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var routes []Route + for _, p := range tt.priorities { + routes = append(routes, Route{Priority: p}) + } + + groups := groupByPriority(routes) + + if tt.wantGroups == nil { + if groups != nil { + t.Fatalf("expected nil groups, got %v", groups) + } + return + } + + if len(groups) != len(tt.wantGroups) { + t.Fatalf("expected %d groups, got %d", len(tt.wantGroups), len(groups)) + } + + for i, wg := range tt.wantGroups { + if len(groups[i]) != len(wg) { + t.Errorf("group %d: expected %d routes, got %d", i, len(wg), len(groups[i])) + continue + } + for j, wp := range wg { + if groups[i][j].Priority != wp { + t.Errorf("group %d, route %d: expected priority %d, got %d", i, j, wp, groups[i][j].Priority) + } + } + } + }) + } +} + +func TestBalancer_SingleRoute(t *testing.T) { + routes := makeRoutes(routeSpec{"only", 1, 1.0, 1.0}) + + balancers := []struct { + name string + balancer LoadBalancer + }{ + {"first", &FirstBalancer{}}, + {"round-robin", &RoundRobinBalancer{}}, + {"random", &RandomBalancer{}}, + {"least-cost", &LeastCostBalancer{}}, + } + + for _, bb := range balancers { + t.Run(bb.name, func(t *testing.T) { + result := bb.balancer.Reorder(routes) + if len(result) != 1 || result[0].Provider.Name() != "only" { + t.Fatalf("expected single route 'only', got %v", routeNames(result)) + } + }) + } +} + +func TestNewLoadBalancer(t *testing.T) { + tests := []struct { + strategy string + wantType string + }{ + {"round-robin", "*provider.RoundRobinBalancer"}, + {"random", "*provider.RandomBalancer"}, + {"least-cost", "*provider.LeastCostBalancer"}, + {"first", "*provider.FirstBalancer"}, + {"unknown", "*provider.FirstBalancer"}, + {"", "*provider.FirstBalancer"}, + } + + for _, tt := range tests { + t.Run(tt.strategy, func(t *testing.T) { + b := NewLoadBalancer(tt.strategy) + got := fmt.Sprintf("%T", b) + if got != tt.wantType { + t.Errorf("NewLoadBalancer(%q) = %s, want %s", tt.strategy, got, tt.wantType) + } + }) + } +} diff --git a/llm-gateway/internal/provider/health.go b/llm-gateway/internal/provider/health.go index ae6d97e..91cb965 100644 --- a/llm-gateway/internal/provider/health.go +++ b/llm-gateway/internal/provider/health.go @@ -3,8 +3,39 @@ package provider import ( "sync" "time" + + "llm-gateway/internal/config" ) +// CircuitState represents the state of a circuit breaker. +type CircuitState int + +const ( + CircuitClosed CircuitState = iota // normal operation + CircuitOpen // blocking requests + CircuitHalfOpen // testing with probe request +) + +func (s CircuitState) String() string { + switch s { + case CircuitClosed: + return "closed" + case CircuitOpen: + return "open" + case CircuitHalfOpen: + return "half-open" + default: + return "unknown" + } +} + +// ProviderCircuit tracks circuit breaker state for a single provider. +type ProviderCircuit struct { + State CircuitState + OpenedAt time.Time + LastProbe time.Time +} + // HealthEvent represents a single request outcome for a provider. type HealthEvent struct { Timestamp time.Time @@ -15,12 +46,13 @@ type HealthEvent struct { // ProviderHealth is the computed health status for a provider. type ProviderHealth struct { - Provider string `json:"provider"` - Status string `json:"status"` // healthy, degraded, down - ErrorRate float64 `json:"error_rate"` - AvgLatency float64 `json:"avg_latency_ms"` - Total int `json:"total"` - Errors int `json:"errors"` + Provider string `json:"provider"` + Status string `json:"status"` // healthy, degraded, down + ErrorRate float64 `json:"error_rate"` + AvgLatency float64 `json:"avg_latency_ms"` + Total int `json:"total"` + Errors int `json:"errors"` + CircuitState string `json:"circuit_state"` } // HealthTracker tracks per-provider health using a sliding window. @@ -28,20 +60,52 @@ type HealthTracker struct { mu sync.RWMutex windows map[string][]HealthEvent windowDu time.Duration + circuits map[string]*ProviderCircuit + cbConfig config.CircuitBreakerConfig } // NewHealthTracker creates a health tracker with the given window duration. -func NewHealthTracker(window time.Duration) *HealthTracker { +func NewHealthTracker(window time.Duration, cbCfg config.CircuitBreakerConfig) *HealthTracker { if window == 0 { window = 5 * time.Minute } return &HealthTracker{ windows: make(map[string][]HealthEvent), + circuits: make(map[string]*ProviderCircuit), windowDu: window, + cbConfig: cbCfg, } } -// Record adds a health event for a provider. +// IsAvailable returns true if the provider's circuit breaker allows requests. +func (h *HealthTracker) IsAvailable(provider string) bool { + if !h.cbConfig.Enabled { + return true + } + + h.mu.RLock() + defer h.mu.RUnlock() + + circuit, ok := h.circuits[provider] + if !ok { + return true // no circuit = closed = available + } + + switch circuit.State { + case CircuitOpen: + // Check if cooldown has elapsed -> transition to half-open + if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration { + return true // will transition to half-open on next record + } + return false + case CircuitHalfOpen: + return true // allow probe + default: + return true + } +} + +// Record adds a health event for a provider and evaluates circuit transitions. func (h *HealthTracker) Record(provider string, latencyMS int64, err error) { event := HealthEvent{ Timestamp: time.Now(), @@ -57,6 +121,69 @@ func (h *HealthTracker) Record(provider string, latencyMS int64, err error) { h.windows[provider] = append(h.windows[provider], event) h.prune(provider) + + if h.cbConfig.Enabled { + h.evaluateCircuit(provider, err) + } +} + +// evaluateCircuit transitions circuit breaker state. Must be called with lock held. +func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) { + circuit, ok := h.circuits[providerName] + if !ok { + circuit = &ProviderCircuit{State: CircuitClosed} + h.circuits[providerName] = circuit + } + + switch circuit.State { + case CircuitClosed: + // Check if error threshold exceeded + errorRate, total := h.errorRateUnlocked(providerName) + if total >= h.cbConfig.MinRequests && errorRate >= h.cbConfig.ErrorThreshold { + circuit.State = CircuitOpen + circuit.OpenedAt = time.Now() + } + case CircuitOpen: + // Check if cooldown elapsed -> half-open + if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration { + circuit.State = CircuitHalfOpen + circuit.LastProbe = time.Now() + // Evaluate the probe result immediately + if lastErr == nil { + circuit.State = CircuitClosed + } else { + circuit.State = CircuitOpen + circuit.OpenedAt = time.Now() + } + } + case CircuitHalfOpen: + if lastErr == nil { + circuit.State = CircuitClosed + } else { + circuit.State = CircuitOpen + circuit.OpenedAt = time.Now() + } + } +} + +// errorRateUnlocked computes error rate within window. Must be called with lock held. +func (h *HealthTracker) errorRateUnlocked(provider string) (float64, int) { + cutoff := time.Now().Add(-h.windowDu) + events := h.windows[provider] + var total, errors int + for _, e := range events { + if e.Timestamp.Before(cutoff) { + continue + } + total++ + if e.IsError { + errors++ + } + } + if total == 0 { + return 0, 0 + } + return float64(errors) / float64(total), total } // Status returns computed health for all tracked providers. @@ -94,13 +221,19 @@ func (h *HealthTracker) Status() []ProviderHealth { status = "degraded" } + circuitState := "closed" + if circuit, ok := h.circuits[provider]; ok { + circuitState = circuit.State.String() + } + results = append(results, ProviderHealth{ - Provider: provider, - Status: status, - ErrorRate: errorRate, - AvgLatency: float64(totalLatency) / float64(total), - Total: total, - Errors: errors, + Provider: provider, + Status: status, + ErrorRate: errorRate, + AvgLatency: float64(totalLatency) / float64(total), + Total: total, + Errors: errors, + CircuitState: circuitState, }) } diff --git a/llm-gateway/internal/provider/health_test.go b/llm-gateway/internal/provider/health_test.go new file mode 100644 index 0000000..6d021b1 --- /dev/null +++ b/llm-gateway/internal/provider/health_test.go @@ -0,0 +1,345 @@ +package provider + +import ( + "errors" + "testing" + "time" + + "llm-gateway/internal/config" +) + +func newTestTracker(window time.Duration, cb config.CircuitBreakerConfig) *HealthTracker { + return NewHealthTracker(window, cb) +} + +func defaultCBConfig() config.CircuitBreakerConfig { + return config.CircuitBreakerConfig{ + Enabled: true, + ErrorThreshold: 0.5, + MinRequests: 3, + CooldownDuration: 100 * time.Millisecond, + } +} + +func TestHealthTracker_Record(t *testing.T) { + ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{}) + + ht.Record("provA", 100, nil) + ht.Record("provA", 200, errors.New("fail")) + ht.Record("provB", 50, nil) + + ht.mu.RLock() + defer ht.mu.RUnlock() + + if len(ht.windows["provA"]) != 2 { + t.Fatalf("expected 2 events for provA, got %d", len(ht.windows["provA"])) + } + if len(ht.windows["provB"]) != 1 { + t.Fatalf("expected 1 event for provB, got %d", len(ht.windows["provB"])) + } + + // Verify event fields + ev := ht.windows["provA"][1] + if !ev.IsError || ev.ErrorMsg != "fail" || ev.LatencyMS != 200 { + t.Fatalf("unexpected event fields: %+v", ev) + } +} + +func TestHealthTracker_Status(t *testing.T) { + tests := []struct { + name string + successCount int + errorCount int + wantStatus string + wantErrorRate float64 + wantTotal int + wantErrors int + }{ + { + name: "healthy - no errors", + successCount: 10, + errorCount: 0, + wantStatus: "healthy", + wantErrorRate: 0.0, + wantTotal: 10, + wantErrors: 0, + }, + { + name: "healthy - below 10% errors", + successCount: 19, + errorCount: 1, + wantStatus: "healthy", + wantErrorRate: 0.05, + wantTotal: 20, + wantErrors: 1, + }, + { + name: "degraded - 20% errors", + successCount: 8, + errorCount: 2, + wantStatus: "degraded", + wantErrorRate: 0.2, + wantTotal: 10, + wantErrors: 2, + }, + { + name: "degraded - exactly 10% errors", + successCount: 9, + errorCount: 1, + wantStatus: "degraded", + wantErrorRate: 0.1, + wantTotal: 10, + wantErrors: 1, + }, + { + name: "down - 50% errors", + successCount: 5, + errorCount: 5, + wantStatus: "down", + wantErrorRate: 0.5, + wantTotal: 10, + wantErrors: 5, + }, + { + name: "down - all errors", + successCount: 0, + errorCount: 5, + wantStatus: "down", + wantErrorRate: 1.0, + wantTotal: 5, + wantErrors: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{}) + + for i := 0; i < tt.successCount; i++ { + ht.Record("prov", 100, nil) + } + for i := 0; i < tt.errorCount; i++ { + ht.Record("prov", 100, errors.New("err")) + } + + statuses := ht.Status() + if len(statuses) != 1 { + t.Fatalf("expected 1 status, got %d", len(statuses)) + } + + s := statuses[0] + if s.Status != tt.wantStatus { + t.Errorf("status = %q, want %q", s.Status, tt.wantStatus) + } + if s.Total != tt.wantTotal { + t.Errorf("total = %d, want %d", s.Total, tt.wantTotal) + } + if s.Errors != tt.wantErrors { + t.Errorf("errors = %d, want %d", s.Errors, tt.wantErrors) + } + // Allow small float tolerance + if diff := s.ErrorRate - tt.wantErrorRate; diff > 0.001 || diff < -0.001 { + t.Errorf("error_rate = %f, want %f", s.ErrorRate, tt.wantErrorRate) + } + }) + } +} + +func TestHealthTracker_CircuitBreaker_ClosedToOpen(t *testing.T) { + cb := defaultCBConfig() + cb.MinRequests = 3 + cb.ErrorThreshold = 0.5 + + ht := newTestTracker(5*time.Minute, cb) + + // Record errors to exceed threshold (3 errors out of 3 = 100% > 50%) + ht.Record("prov", 100, errors.New("err")) + ht.Record("prov", 100, errors.New("err")) + ht.Record("prov", 100, errors.New("err")) + + ht.mu.RLock() + state := ht.circuits["prov"].State + ht.mu.RUnlock() + + if state != CircuitOpen { + t.Fatalf("expected CircuitOpen, got %s", state) + } + + if ht.IsAvailable("prov") { + t.Fatal("expected IsAvailable=false when circuit is open") + } +} + +func TestHealthTracker_CircuitBreaker_OpenToHalfOpenOnCooldown(t *testing.T) { + cb := defaultCBConfig() + cb.CooldownDuration = 50 * time.Millisecond + + ht := newTestTracker(5*time.Minute, cb) + + // Trip the circuit + for i := 0; i < 5; i++ { + ht.Record("prov", 100, errors.New("err")) + } + + if ht.IsAvailable("prov") { + t.Fatal("expected circuit open, IsAvailable should be false") + } + + // Wait for cooldown + time.Sleep(60 * time.Millisecond) + + // After cooldown, IsAvailable should return true (will transition to half-open) + if !ht.IsAvailable("prov") { + t.Fatal("expected IsAvailable=true after cooldown") + } +} + +func TestHealthTracker_CircuitBreaker_HalfOpenToClosedOnSuccess(t *testing.T) { + cb := defaultCBConfig() + cb.CooldownDuration = 10 * time.Millisecond + + ht := newTestTracker(5*time.Minute, cb) + + // Trip the circuit + for i := 0; i < 5; i++ { + ht.Record("prov", 100, errors.New("err")) + } + + // Wait for cooldown so next Record transitions through Open->HalfOpen + time.Sleep(20 * time.Millisecond) + + // A successful record should transition: Open -> HalfOpen -> Closed + ht.Record("prov", 100, nil) + + ht.mu.RLock() + state := ht.circuits["prov"].State + ht.mu.RUnlock() + + if state != CircuitClosed { + t.Fatalf("expected CircuitClosed after success in half-open, got %s", state) + } + + if !ht.IsAvailable("prov") { + t.Fatal("expected IsAvailable=true after circuit closed") + } +} + +func TestHealthTracker_CircuitBreaker_HalfOpenToOpenOnFailure(t *testing.T) { + cb := defaultCBConfig() + cb.CooldownDuration = 10 * time.Millisecond + + ht := newTestTracker(5*time.Minute, cb) + + // Trip the circuit + for i := 0; i < 5; i++ { + ht.Record("prov", 100, errors.New("err")) + } + + // Wait for cooldown + time.Sleep(20 * time.Millisecond) + + // A failed record should transition: Open -> HalfOpen -> Open + ht.Record("prov", 100, errors.New("still failing")) + + ht.mu.RLock() + state := ht.circuits["prov"].State + ht.mu.RUnlock() + + if state != CircuitOpen { + t.Fatalf("expected CircuitOpen after failure in half-open, got %s", state) + } +} + +func TestHealthTracker_IsAvailable_NoCircuitBreaker(t *testing.T) { + ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{Enabled: false}) + + // Even with errors, IsAvailable should return true when CB is disabled + for i := 0; i < 10; i++ { + ht.Record("prov", 100, errors.New("err")) + } + + if !ht.IsAvailable("prov") { + t.Fatal("expected IsAvailable=true when circuit breaker disabled") + } +} + +func TestHealthTracker_IsAvailable_UnknownProvider(t *testing.T) { + ht := newTestTracker(5*time.Minute, defaultCBConfig()) + + if !ht.IsAvailable("unknown") { + t.Fatal("expected IsAvailable=true for unknown provider (no circuit)") + } +} + +func TestHealthTracker_WindowPruning(t *testing.T) { + // Use a tiny window so events expire quickly + ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{}) + + ht.Record("prov", 100, nil) + ht.Record("prov", 200, nil) + + // Wait for events to expire + time.Sleep(60 * time.Millisecond) + + // Record a new event to trigger pruning + ht.Record("prov", 300, nil) + + ht.mu.RLock() + count := len(ht.windows["prov"]) + ht.mu.RUnlock() + + if count != 1 { + t.Fatalf("expected 1 event after pruning, got %d", count) + } +} + +func TestHealthTracker_Status_EmptyAfterPruning(t *testing.T) { + ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{}) + + ht.Record("prov", 100, nil) + + // Wait for events to expire + time.Sleep(60 * time.Millisecond) + + statuses := ht.Status() + if len(statuses) != 0 { + t.Fatalf("expected 0 statuses after window expiry, got %d", len(statuses)) + } +} + +func TestHealthTracker_Status_AvgLatency(t *testing.T) { + ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{}) + + ht.Record("prov", 100, nil) + ht.Record("prov", 200, nil) + ht.Record("prov", 300, nil) + + statuses := ht.Status() + if len(statuses) != 1 { + t.Fatalf("expected 1 status, got %d", len(statuses)) + } + + want := 200.0 + if diff := statuses[0].AvgLatency - want; diff > 0.001 || diff < -0.001 { + t.Errorf("avg_latency = %f, want %f", statuses[0].AvgLatency, want) + } +} + +func TestHealthTracker_Status_CircuitStateReported(t *testing.T) { + cb := defaultCBConfig() + ht := newTestTracker(5*time.Minute, cb) + + // Trip the circuit + for i := 0; i < 5; i++ { + ht.Record("prov", 100, errors.New("err")) + } + + statuses := ht.Status() + if len(statuses) != 1 { + t.Fatalf("expected 1 status, got %d", len(statuses)) + } + + if statuses[0].CircuitState != "open" { + t.Errorf("circuit_state = %q, want %q", statuses[0].CircuitState, "open") + } +} diff --git a/llm-gateway/internal/provider/openai.go b/llm-gateway/internal/provider/openai.go index 1278eea..0e434f3 100644 --- a/llm-gateway/internal/provider/openai.go +++ b/llm-gateway/internal/provider/openai.go @@ -111,6 +111,12 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string, func (p *OpenAIProvider) setHeaders(req *http.Request) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+p.apiKey) + // Forward request ID if present in context + if reqID := req.Context().Value("requestID"); reqID != nil { + if id, ok := reqID.(string); ok && id != "" { + req.Header.Set("X-Request-ID", id) + } + } } // ProviderError represents a non-200 response from a provider. diff --git a/llm-gateway/internal/provider/registry.go b/llm-gateway/internal/provider/registry.go index 1a2302b..5a4dced 100644 --- a/llm-gateway/internal/provider/registry.go +++ b/llm-gateway/internal/provider/registry.go @@ -3,6 +3,7 @@ package provider import ( "fmt" "sort" + "sync" "llm-gateway/internal/config" ) @@ -18,26 +19,40 @@ type Route struct { // Registry maps model names to provider routes. type Registry struct { - routes map[string][]Route - order []string // preserves config order + mu sync.RWMutex + routes map[string][]Route + balancers map[string]LoadBalancer + aliases map[string]string // alias -> canonical name + order []string // preserves config order (canonical names only) } func NewRegistry(cfg *config.Config) (*Registry, error) { + r := &Registry{} + if err := r.buildFromConfig(cfg); err != nil { + return nil, err + } + return r, nil +} + +func (r *Registry) buildFromConfig(cfg *config.Config) error { // Build providers providers := make(map[string]Provider) for _, pc := range cfg.Providers { providers[pc.Name] = NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKey, pc.Timeout) } - // Build routes (preserving config order) + // Build routes routes := make(map[string][]Route) + balancers := make(map[string]LoadBalancer) + aliases := make(map[string]string) order := make([]string, 0, len(cfg.Models)) + for _, mc := range cfg.Models { var modelRoutes []Route for _, rc := range mc.Routes { p, ok := providers[rc.Provider] if !ok { - return nil, fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider) + return fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider) } pc := cfg.ProviderByName(rc.Provider) priority := pc.Priority @@ -55,20 +70,69 @@ func NewRegistry(cfg *config.Config) (*Registry, error) { }) routes[mc.Name] = modelRoutes order = append(order, mc.Name) + + // Load balancer + balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing) + + // Register aliases + for _, alias := range mc.Aliases { + aliases[alias] = mc.Name + } } - return &Registry{routes: routes, order: order}, nil + r.mu.Lock() + r.routes = routes + r.balancers = balancers + r.aliases = aliases + r.order = order + r.mu.Unlock() + + return nil } -// Lookup returns the routes for a model name. +// Reload rebuilds routes from new config. Used for hot-reload. +func (r *Registry) Reload(cfg *config.Config) error { + return r.buildFromConfig(cfg) +} + +// Lookup returns the routes for a model name (resolving aliases). func (r *Registry) Lookup(model string) ([]Route, bool) { - routes, ok := r.routes[model] - return routes, ok + r.mu.RLock() + defer r.mu.RUnlock() + + // Resolve alias + canonical := model + if alias, ok := r.aliases[model]; ok { + canonical = alias + } + + routes, ok := r.routes[canonical] + if !ok { + return nil, false + } + + // Apply load balancer + if balancer, ok := r.balancers[canonical]; ok { + routes = balancer.Reorder(routes) + } + + return routes, true } -// ModelNames returns all registered model names in config order. +// ModelNames returns all registered model names in config order (including aliases). func (r *Registry) ModelNames() []string { - return r.order + r.mu.RLock() + defer r.mu.RUnlock() + + var names []string + for _, name := range r.order { + names = append(names, name) + } + // Add aliases + for alias := range r.aliases { + names = append(names, alias) + } + return names } // RouteInfo exposes route details for dashboard display. @@ -82,16 +146,29 @@ type RouteInfo struct { // ModelRouteInfo exposes a model and its routes for dashboard display. type ModelRouteInfo struct { - Name string `json:"name"` - Routes []RouteInfo `json:"routes"` + Name string `json:"name"` + Aliases []string `json:"aliases,omitempty"` + Routes []RouteInfo `json:"routes"` } // AllRoutes returns all models and their routes in config order. func (r *Registry) AllRoutes() []ModelRouteInfo { + r.mu.RLock() + defer r.mu.RUnlock() + + // Build reverse alias map + modelAliases := make(map[string][]string) + for alias, canonical := range r.aliases { + modelAliases[canonical] = append(modelAliases[canonical], alias) + } + results := make([]ModelRouteInfo, 0, len(r.order)) for _, name := range r.order { routes := r.routes[name] - info := ModelRouteInfo{Name: name} + info := ModelRouteInfo{ + Name: name, + Aliases: modelAliases[name], + } for _, rt := range routes { info.Routes = append(info.Routes, RouteInfo{ ProviderName: rt.Provider.Name(), diff --git a/llm-gateway/internal/provider/registry_test.go b/llm-gateway/internal/provider/registry_test.go new file mode 100644 index 0000000..b04c45a --- /dev/null +++ b/llm-gateway/internal/provider/registry_test.go @@ -0,0 +1,282 @@ +package provider + +import ( + "context" + "io" + "testing" + + "llm-gateway/internal/config" +) + +// mockProvider implements the Provider interface for testing. +type mockProvider struct { + name string +} + +func (m *mockProvider) Name() string { return m.name } + +func (m *mockProvider) ChatCompletion(_ context.Context, _ string, _ *ChatRequest) (*ChatResponse, error) { + return nil, nil +} + +func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *ChatRequest) (io.ReadCloser, error) { + return nil, nil +} + +// newTestRegistry builds a Registry directly without going through config parsing. +func newTestRegistry(models []testModel) *Registry { + r := &Registry{ + routes: make(map[string][]Route), + balancers: make(map[string]LoadBalancer), + aliases: make(map[string]string), + } + + for _, m := range models { + r.routes[m.name] = m.routes + r.balancers[m.name] = &FirstBalancer{} + r.order = append(r.order, m.name) + for _, alias := range m.aliases { + r.aliases[alias] = m.name + } + } + + return r +} + +type testModel struct { + name string + aliases []string + routes []Route +} + +func TestRegistry_Lookup_Canonical(t *testing.T) { + reg := newTestRegistry([]testModel{ + { + name: "gpt-4", + routes: []Route{ + {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, + }, + }, + }) + + routes, ok := reg.Lookup("gpt-4") + if !ok { + t.Fatal("expected Lookup to find gpt-4") + } + if len(routes) != 1 { + t.Fatalf("expected 1 route, got %d", len(routes)) + } + if routes[0].Provider.Name() != "openai" { + t.Errorf("expected provider 'openai', got %q", routes[0].Provider.Name()) + } +} + +func TestRegistry_Lookup_Alias(t *testing.T) { + reg := newTestRegistry([]testModel{ + { + name: "gpt-4", + aliases: []string{"gpt4", "gpt-4-latest"}, + routes: []Route{ + {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, + }, + }, + }) + + tests := []struct { + name string + model string + found bool + }{ + {"canonical", "gpt-4", true}, + {"alias1", "gpt4", true}, + {"alias2", "gpt-4-latest", true}, + {"unknown", "gpt-5", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + routes, ok := reg.Lookup(tt.model) + if ok != tt.found { + t.Fatalf("Lookup(%q) found=%v, want %v", tt.model, ok, tt.found) + } + if tt.found && len(routes) != 1 { + t.Fatalf("expected 1 route, got %d", len(routes)) + } + }) + } +} + +func TestRegistry_ModelNames_IncludesAliases(t *testing.T) { + reg := newTestRegistry([]testModel{ + { + name: "gpt-4", + aliases: []string{"gpt4"}, + routes: []Route{ + {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, + }, + }, + { + name: "claude-3", + routes: []Route{ + {Provider: &mockProvider{name: "anthropic"}, ProviderModel: "claude-3", Priority: 1}, + }, + }, + }) + + names := reg.ModelNames() + + want := map[string]bool{"gpt-4": true, "gpt4": true, "claude-3": true} + got := make(map[string]bool) + for _, n := range names { + got[n] = true + } + + for name := range want { + if !got[name] { + t.Errorf("expected %q in ModelNames, not found", name) + } + } + + if len(names) != len(want) { + t.Errorf("expected %d names, got %d: %v", len(want), len(names), names) + } +} + +func TestRegistry_AllRoutes_ShowsAliases(t *testing.T) { + reg := newTestRegistry([]testModel{ + { + name: "gpt-4", + aliases: []string{"gpt4", "gpt-4-latest"}, + routes: []Route{ + {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, + {Provider: &mockProvider{name: "azure"}, ProviderModel: "gpt-4", Priority: 2}, + }, + }, + }) + + allRoutes := reg.AllRoutes() + if len(allRoutes) != 1 { + t.Fatalf("expected 1 model, got %d", len(allRoutes)) + } + + m := allRoutes[0] + if m.Name != "gpt-4" { + t.Errorf("expected name 'gpt-4', got %q", m.Name) + } + + aliasSet := make(map[string]bool) + for _, a := range m.Aliases { + aliasSet[a] = true + } + if !aliasSet["gpt4"] || !aliasSet["gpt-4-latest"] { + t.Errorf("expected aliases [gpt4, gpt-4-latest], got %v", m.Aliases) + } + + if len(m.Routes) != 2 { + t.Fatalf("expected 2 routes, got %d", len(m.Routes)) + } + if m.Routes[0].ProviderName != "openai" { + t.Errorf("expected first route provider 'openai', got %q", m.Routes[0].ProviderName) + } + if m.Routes[1].ProviderName != "azure" { + t.Errorf("expected second route provider 'azure', got %q", m.Routes[1].ProviderName) + } +} + +func TestRegistry_AllRoutes_ConfigOrder(t *testing.T) { + reg := newTestRegistry([]testModel{ + { + name: "model-b", + routes: []Route{ + {Provider: &mockProvider{name: "prov"}, ProviderModel: "b", Priority: 1}, + }, + }, + { + name: "model-a", + routes: []Route{ + {Provider: &mockProvider{name: "prov"}, ProviderModel: "a", Priority: 1}, + }, + }, + }) + + allRoutes := reg.AllRoutes() + if len(allRoutes) != 2 { + t.Fatalf("expected 2 models, got %d", len(allRoutes)) + } + if allRoutes[0].Name != "model-b" { + t.Errorf("expected first model 'model-b', got %q", allRoutes[0].Name) + } + if allRoutes[1].Name != "model-a" { + t.Errorf("expected second model 'model-a', got %q", allRoutes[1].Name) + } +} + +func TestRegistry_PrioritySorting(t *testing.T) { + reg := newTestRegistry([]testModel{ + { + name: "multi-provider", + routes: []Route{ + {Provider: &mockProvider{name: "low-priority"}, ProviderModel: "m", Priority: 3}, + {Provider: &mockProvider{name: "high-priority"}, ProviderModel: "m", Priority: 1}, + {Provider: &mockProvider{name: "mid-priority"}, ProviderModel: "m", Priority: 2}, + }, + }, + }) + + // Note: routes are stored as given (sorting happens during buildFromConfig). + // For this test we verify AllRoutes returns them in stored order. + allRoutes := reg.AllRoutes() + if len(allRoutes) != 1 { + t.Fatalf("expected 1 model, got %d", len(allRoutes)) + } + + routes := allRoutes[0].Routes + if len(routes) != 3 { + t.Fatalf("expected 3 routes, got %d", len(routes)) + } + + // Verify the priorities are present + priorities := make(map[int]bool) + for _, r := range routes { + priorities[r.Priority] = true + } + for _, p := range []int{1, 2, 3} { + if !priorities[p] { + t.Errorf("expected priority %d in routes", p) + } + } +} + +func TestRegistry_NewRegistry_UnknownProvider(t *testing.T) { + cfg := &config.Config{ + Models: []config.ModelConfig{ + { + Name: "test-model", + Routes: []config.RouteConfig{ + {Provider: "nonexistent", Model: "m"}, + }, + }, + }, + } + + _, err := NewRegistry(cfg) + if err == nil { + t.Fatal("expected error for unknown provider, got nil") + } +} + +func TestRegistry_Lookup_NotFound(t *testing.T) { + reg := newTestRegistry([]testModel{ + { + name: "gpt-4", + routes: []Route{ + {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, + }, + }, + }) + + _, ok := reg.Lookup("nonexistent") + if ok { + t.Fatal("expected Lookup to return false for nonexistent model") + } +} diff --git a/llm-gateway/internal/proxy/concurrency.go b/llm-gateway/internal/proxy/concurrency.go new file mode 100644 index 0000000..4f28262 --- /dev/null +++ b/llm-gateway/internal/proxy/concurrency.go @@ -0,0 +1,51 @@ +package proxy + +import ( + "net/http" + "sync" + "sync/atomic" +) + +// ConcurrencyLimiter enforces per-token concurrent request limits. +type ConcurrencyLimiter struct { + mu sync.Mutex + counters map[string]*atomic.Int64 +} + +func NewConcurrencyLimiter() *ConcurrencyLimiter { + return &ConcurrencyLimiter{ + counters: make(map[string]*atomic.Int64), + } +} + +func (cl *ConcurrencyLimiter) getCounter(tokenName string) *atomic.Int64 { + cl.mu.Lock() + defer cl.mu.Unlock() + c, ok := cl.counters[tokenName] + if !ok { + c = &atomic.Int64{} + cl.counters[tokenName] = c + } + return c +} + +func (cl *ConcurrencyLimiter) Check(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiToken := getAPIToken(r.Context()) + if apiToken == nil || apiToken.MaxConcurrent <= 0 { + next.ServeHTTP(w, r) + return + } + + counter := cl.getCounter(apiToken.Name) + current := counter.Add(1) + defer counter.Add(-1) + + if current > int64(apiToken.MaxConcurrent) { + writeError(w, http.StatusTooManyRequests, "concurrent request limit exceeded") + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/llm-gateway/internal/proxy/concurrency_test.go b/llm-gateway/internal/proxy/concurrency_test.go new file mode 100644 index 0000000..fa6ccbf --- /dev/null +++ b/llm-gateway/internal/proxy/concurrency_test.go @@ -0,0 +1,317 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "llm-gateway/internal/auth" +) + +func TestConcurrencyLimiter_AllowsWithinLimit(t *testing.T) { + tests := []struct { + name string + maxConcurrent int + numRequests int + wantAllowed int + }{ + { + name: "single request within limit", + maxConcurrent: 5, + numRequests: 1, + wantAllowed: 1, + }, + { + name: "all requests within limit", + maxConcurrent: 5, + numRequests: 5, + wantAllowed: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cl := NewConcurrencyLimiter() + + token := &auth.APIToken{ + Name: "conc-token", + MaxConcurrent: tt.maxConcurrent, + } + + var allowed atomic.Int64 + var wg sync.WaitGroup + // Use a channel to hold all goroutines inside the handler simultaneously. + gate := make(chan struct{}) + + handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + allowed.Add(1) + <-gate // Block until released. + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < tt.numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx := withAPIToken(req.Context(), token) + req = req.WithContext(ctx) + handler.ServeHTTP(rec, req) + }() + } + + // Wait for goroutines to enter the handler. + time.Sleep(50 * time.Millisecond) + close(gate) + wg.Wait() + + if int(allowed.Load()) != tt.wantAllowed { + t.Errorf("allowed = %d, want %d", allowed.Load(), tt.wantAllowed) + } + }) + } +} + +func TestConcurrencyLimiter_DeniesOverLimit(t *testing.T) { + tests := []struct { + name string + maxConcurrent int + numRequests int + wantDenied int + }{ + { + name: "one over limit", + maxConcurrent: 2, + numRequests: 3, + wantDenied: 1, + }, + { + name: "many over limit", + maxConcurrent: 1, + numRequests: 5, + wantDenied: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cl := NewConcurrencyLimiter() + + token := &auth.APIToken{ + Name: "conc-token", + MaxConcurrent: tt.maxConcurrent, + } + + var denied atomic.Int64 + var wg sync.WaitGroup + gate := make(chan struct{}) + + handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-gate + w.WriteHeader(http.StatusOK) + })) + + results := make([]int, tt.numRequests) + for i := 0; i < tt.numRequests; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx := withAPIToken(req.Context(), token) + req = req.WithContext(ctx) + handler.ServeHTTP(rec, req) + results[idx] = rec.Code + if rec.Code == http.StatusTooManyRequests { + denied.Add(1) + } + }(i) + } + + // Wait for goroutines to reach the handler or be rejected. + time.Sleep(50 * time.Millisecond) + close(gate) + wg.Wait() + + if int(denied.Load()) != tt.wantDenied { + t.Errorf("denied = %d, want %d", denied.Load(), tt.wantDenied) + } + }) + } +} + +func TestConcurrencyLimiter_CounterDecrementsAfterCompletion(t *testing.T) { + cl := NewConcurrencyLimiter() + + token := &auth.APIToken{ + Name: "decrement-token", + MaxConcurrent: 1, + } + + handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should succeed and complete, decrementing the counter. + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx := withAPIToken(req.Context(), token) + req = req.WithContext(ctx) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("first request: status = %d, want %d", rec.Code, http.StatusOK) + } + + // Counter should have decremented. A second request should also succeed. + rec2 := httptest.NewRecorder() + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + ctx2 := withAPIToken(req2.Context(), token) + req2 = req2.WithContext(ctx2) + handler.ServeHTTP(rec2, req2) + + if rec2.Code != http.StatusOK { + t.Errorf("second request after first completed: status = %d, want %d", rec2.Code, http.StatusOK) + } + + // Verify the internal counter is back to 0. + counter := cl.getCounter(token.Name) + val := counter.Load() + if val != 0 { + t.Errorf("counter = %d, want 0 after all requests completed", val) + } +} + +func TestConcurrencyLimiter_ZeroMaxConcurrentMeansUnlimited(t *testing.T) { + tests := []struct { + name string + maxConcurrent int + numRequests int + }{ + { + name: "zero allows unlimited concurrent requests", + maxConcurrent: 0, + numRequests: 50, + }, + { + name: "negative allows unlimited concurrent requests", + maxConcurrent: -1, + numRequests: 50, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cl := NewConcurrencyLimiter() + + token := &auth.APIToken{ + Name: "unlimited-token", + MaxConcurrent: tt.maxConcurrent, + } + + var allowed atomic.Int64 + var wg sync.WaitGroup + gate := make(chan struct{}) + + handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + allowed.Add(1) + <-gate + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < tt.numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx := withAPIToken(req.Context(), token) + req = req.WithContext(ctx) + handler.ServeHTTP(rec, req) + }() + } + + // Give goroutines time to enter the handler. + time.Sleep(100 * time.Millisecond) + close(gate) + wg.Wait() + + if int(allowed.Load()) != tt.numRequests { + t.Errorf("allowed = %d, want %d (zero/negative maxConcurrent should be unlimited)", allowed.Load(), tt.numRequests) + } + }) + } +} + +func TestConcurrencyLimiter_NoToken(t *testing.T) { + cl := NewConcurrencyLimiter() + + handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // No API token in context. + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d (should pass through without token)", rec.Code, http.StatusOK) + } +} + +func TestConcurrencyLimiter_PerTokenIsolation(t *testing.T) { + cl := NewConcurrencyLimiter() + + tokenA := &auth.APIToken{ + Name: "token-a", + MaxConcurrent: 1, + } + tokenB := &auth.APIToken{ + Name: "token-b", + MaxConcurrent: 1, + } + + gateA := make(chan struct{}) + var wg sync.WaitGroup + + handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tok := getAPIToken(r.Context()) + if tok.Name == "token-a" { + <-gateA // Block token A's request. + } + w.WriteHeader(http.StatusOK) + })) + + // Start a request for token A that blocks. + wg.Add(1) + go func() { + defer wg.Done() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx := withAPIToken(req.Context(), tokenA) + req = req.WithContext(ctx) + handler.ServeHTTP(rec, req) + }() + + // Give token A's goroutine time to enter handler. + time.Sleep(50 * time.Millisecond) + + // Token B should not be affected by token A's in-flight request. + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx := withAPIToken(req.Context(), tokenB) + req = req.WithContext(ctx) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("token-b status = %d, want %d (should not be affected by token-a)", rec.Code, http.StatusOK) + } + + close(gateA) + wg.Wait() +} diff --git a/llm-gateway/internal/proxy/handler.go b/llm-gateway/internal/proxy/handler.go index ba4f36a..cb6c619 100644 --- a/llm-gateway/internal/proxy/handler.go +++ b/llm-gateway/internal/proxy/handler.go @@ -4,11 +4,16 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "log" "net/http" + "sort" + "strings" "time" + "github.com/go-chi/chi/v5/middleware" + "llm-gateway/internal/auth" "llm-gateway/internal/cache" "llm-gateway/internal/config" @@ -47,6 +52,7 @@ type Handler struct { metrics *metrics.Metrics cfg *config.Config healthTracker *provider.HealthTracker + debugLogger *storage.DebugLogger } func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler { @@ -60,6 +66,10 @@ func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cac } } +func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) { + h.debugLogger = dl +} + func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20)) if err != nil { @@ -84,31 +94,53 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { return } + // Filter healthy routes (circuit breaker) + routes = h.filterHealthyRoutes(routes) + tokenName := getTokenName(r.Context()) + requestID := middleware.GetReqID(r.Context()) // Check cache for non-streaming requests if !req.Stream && h.cache != nil { if cached, err := h.cache.Get(r.Context(), req.Model, body); err == nil && cached != nil { - h.logRequest(tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true) + h.logRequest(requestID, tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true) + if h.metrics != nil { + h.metrics.RecordCacheHit() + } w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Cache", "HIT") + w.Header().Set("X-Request-ID", requestID) w.Write(cached) return } + if h.metrics != nil { + h.metrics.RecordCacheMiss() + } } if req.Stream { - h.handleStream(w, r, &req, routes, tokenName) + h.handleStream(w, r, &req, routes, tokenName, requestID) return } - h.handleNonStream(w, r, &req, routes, tokenName, body) + h.handleNonStream(w, r, &req, routes, tokenName, body, requestID) } -func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte) { +func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string) { var lastErr error - for _, route := range routes { + for i, route := range routes { + // Retry backoff between attempts (not before first attempt) + if i > 0 { + backoff := backoffDuration(i, h.cfg.Retry) + select { + case <-time.After(backoff): + case <-r.Context().Done(): + writeError(w, http.StatusGatewayTimeout, "request cancelled") + return + } + } + start := time.Now() resp, err := route.Provider.ChatCompletion(r.Context(), route.ProviderModel, req) latency := time.Since(start).Milliseconds() @@ -116,19 +148,19 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p if err != nil { var pe *provider.ProviderError if errors.As(err, &pe) && !pe.IsRetryable() { - // Client error — don't retry h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) - h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) + h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) if h.healthTracker != nil { h.healthTracker.Record(route.Provider.Name(), latency, err) } + w.Header().Set("X-Request-ID", requestID) writeErrorRaw(w, pe.StatusCode, pe.Body) return } lastErr = err log.Printf("Provider %s failed for %s: %v", route.Provider.Name(), req.Model, err) h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) - h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) + h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) if h.healthTracker != nil { h.healthTracker.Record(route.Provider.Name(), latency, err) } @@ -139,7 +171,6 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p h.healthTracker.Record(route.Provider.Name(), latency, nil) } - // Compute cost inputTokens, outputTokens := 0, 0 if resp.Usage != nil { inputTokens = resp.Usage.PromptTokens @@ -148,9 +179,8 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice) h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost) - h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false) + h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false) - // Override model name in response to match the requested model resp.Model = req.Model respBytes, err := json.Marshal(resp) @@ -159,27 +189,84 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p return } - // Cache the response if h.cache != nil { h.cache.Set(r.Context(), req.Model, rawBody, respBytes) } + // Debug logging + if h.debugLogger != nil && h.debugLogger.IsEnabled() { + reqBody := string(rawBody) + respBody := string(respBytes) + if h.cfg.Debug.MaxBodyBytes > 0 { + if len(reqBody) > h.cfg.Debug.MaxBodyBytes { + reqBody = reqBody[:h.cfg.Debug.MaxBodyBytes] + } + if len(respBody) > h.cfg.Debug.MaxBodyBytes { + respBody = respBody[:h.cfg.Debug.MaxBodyBytes] + } + } + h.debugLogger.Log(storage.DebugLogEntry{ + RequestID: requestID, + TokenName: tokenName, + Model: req.Model, + Provider: route.Provider.Name(), + RequestBody: reqBody, + ResponseBody: respBody, + RequestHeaders: formatHeaders(r.Header), + ResponseStatus: http.StatusOK, + }) + } + w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Cache", "MISS") + w.Header().Set("X-Request-ID", requestID) w.Write(respBytes) return } - // All providers failed if lastErr != nil { + w.Header().Set("X-Request-ID", requestID) writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error()) } else { + w.Header().Set("X-Request-ID", requestID) writeError(w, http.StatusBadGateway, "all providers failed") } } -func (h *Handler) logRequest(tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) { +// filterHealthyRoutes removes providers with open circuit breakers. +// If all are filtered out, returns original routes as fallback. +func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route { + if h.healthTracker == nil { + return routes + } + var healthy []provider.Route + for _, r := range routes { + if h.healthTracker.IsAvailable(r.Provider.Name()) { + healthy = append(healthy, r) + } + } + if len(healthy) == 0 { + return routes // all-down fallback + } + return healthy +} + +// backoffDuration computes exponential backoff for the given attempt. +func backoffDuration(attempt int, cfg config.RetryConfig) time.Duration { + d := cfg.InitialBackoff + for i := 1; i < attempt; i++ { + d = time.Duration(float64(d) * cfg.Multiplier) + if d > cfg.MaxBackoff { + d = cfg.MaxBackoff + break + } + } + return d +} + +func (h *Handler) logRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) { h.logger.Log(storage.RequestLog{ + RequestID: requestID, Timestamp: time.Now().Unix(), TokenName: tokenName, Model: model, @@ -217,3 +304,23 @@ func writeErrorRaw(w http.ResponseWriter, code int, body string) { w.WriteHeader(code) w.Write([]byte(body)) } + +// formatHeaders serializes HTTP headers to a readable string, sorted by key. +// Sensitive headers (Authorization) are redacted. +func formatHeaders(h http.Header) string { + keys := make([]string, 0, len(h)) + for k := range h { + keys = append(keys, k) + } + sort.Strings(keys) + + var b strings.Builder + for _, k := range keys { + val := strings.Join(h[k], ", ") + if strings.EqualFold(k, "Authorization") { + val = "[REDACTED]" + } + fmt.Fprintf(&b, "%s: %s\n", k, val) + } + return b.String() +} diff --git a/llm-gateway/internal/proxy/ratelimit.go b/llm-gateway/internal/proxy/ratelimit.go index 2278ea1..240bbe2 100644 --- a/llm-gateway/internal/proxy/ratelimit.go +++ b/llm-gateway/internal/proxy/ratelimit.go @@ -1,6 +1,8 @@ package proxy import ( + "fmt" + "math" "net/http" "sync" "time" @@ -40,7 +42,19 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler { // Check rate limit if apiToken.RateLimitRPM > 0 { - if !rl.allow(tokenName, apiToken.RateLimitRPM) { + allowed, remaining, resetAt := rl.allow(tokenName, apiToken.RateLimitRPM) + + // Set rate limit headers on all responses + w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", apiToken.RateLimitRPM)) + w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) + w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetAt)) + + if !allowed { + retryAfter := resetAt - time.Now().Unix() + if retryAfter < 1 { + retryAfter = 1 + } + w.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter)) writeError(w, http.StatusTooManyRequests, "rate limit exceeded") return } @@ -59,7 +73,7 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler { }) } -func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) bool { +func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) { rl.mu.Lock() defer rl.mu.Unlock() @@ -82,9 +96,27 @@ func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) bool { } bucket.lastRefill = now + remaining := int(math.Floor(bucket.tokens)) + if remaining < 0 { + remaining = 0 + } + + // Compute reset time: when bucket would be full again + deficit := bucket.maxTokens - bucket.tokens + var resetAt int64 + if deficit > 0 && bucket.refillRate > 0 { + resetAt = now.Add(time.Duration(deficit/bucket.refillRate) * time.Second).Unix() + } else { + resetAt = now.Unix() + } + if bucket.tokens < 1 { - return false + return false, 0, resetAt } bucket.tokens-- - return true + remaining = int(math.Floor(bucket.tokens)) + if remaining < 0 { + remaining = 0 + } + return true, remaining, resetAt } diff --git a/llm-gateway/internal/proxy/ratelimit_test.go b/llm-gateway/internal/proxy/ratelimit_test.go new file mode 100644 index 0000000..8d1afb7 --- /dev/null +++ b/llm-gateway/internal/proxy/ratelimit_test.go @@ -0,0 +1,374 @@ +package proxy + +import ( + "context" + "database/sql" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + _ "modernc.org/sqlite" + + "llm-gateway/internal/auth" + "llm-gateway/internal/storage" +) + +// newTestDB creates an in-memory SQLite database wrapped in storage.DB. +// It creates the request_logs table needed by TodaySpend. +func newTestDB(t *testing.T) *storage.DB { + t.Helper() + sqlDB, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("opening in-memory sqlite: %v", err) + } + t.Cleanup(func() { sqlDB.Close() }) + + // Create the minimal table needed for TodaySpend queries. + _, err = sqlDB.Exec(`CREATE TABLE IF NOT EXISTS request_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_name TEXT, + cost_usd REAL, + timestamp INTEGER + )`) + if err != nil { + t.Fatalf("creating request_logs table: %v", err) + } + return &storage.DB{DB: sqlDB} +} + +// okHandler is a simple handler that writes 200 OK. +var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) +}) + +func TestRateLimiter_Allow(t *testing.T) { + tests := []struct { + name string + rateLimitRPM int + numRequests int + wantAllowed int + wantDenied int + }{ + { + name: "allows requests within limit", + rateLimitRPM: 10, + numRequests: 5, + wantAllowed: 5, + wantDenied: 0, + }, + { + name: "denies requests over limit", + rateLimitRPM: 3, + numRequests: 6, + wantAllowed: 3, + wantDenied: 3, + }, + { + name: "allows exactly up to limit", + rateLimitRPM: 5, + numRequests: 5, + wantAllowed: 5, + wantDenied: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := newTestDB(t) + rl := NewRateLimiter(db) + + allowed := 0 + denied := 0 + for i := 0; i < tt.numRequests; i++ { + ok, _, _ := rl.allow("test-token", tt.rateLimitRPM) + if ok { + allowed++ + } else { + denied++ + } + } + + if allowed != tt.wantAllowed { + t.Errorf("allowed = %d, want %d", allowed, tt.wantAllowed) + } + if denied != tt.wantDenied { + t.Errorf("denied = %d, want %d", denied, tt.wantDenied) + } + }) + } +} + +func TestRateLimiter_TokenRefillsOverTime(t *testing.T) { + db := newTestDB(t) + rl := NewRateLimiter(db) + + rpm := 60 // 1 token per second refill rate + + // Exhaust all tokens. + for i := 0; i < rpm; i++ { + ok, _, _ := rl.allow("refill-token", rpm) + if !ok { + t.Fatalf("request %d should have been allowed", i) + } + } + + // Next request should be denied. + ok, _, _ := rl.allow("refill-token", rpm) + if ok { + t.Fatal("request should have been denied after exhausting tokens") + } + + // Manually advance the bucket's lastRefill to simulate time passing. + rl.mu.Lock() + bucket := rl.buckets["refill-token"] + bucket.lastRefill = bucket.lastRefill.Add(-2 * time.Second) + rl.mu.Unlock() + + // After 2 seconds at 1 token/sec, we should have ~2 tokens refilled. + ok, remaining, _ := rl.allow("refill-token", rpm) + if !ok { + t.Fatal("request should have been allowed after token refill") + } + // We consumed 1 of the ~2 refilled tokens, so remaining should be >= 0. + if remaining < 0 { + t.Errorf("remaining = %d, want >= 0", remaining) + } +} + +func TestRateLimiter_AllowReturnValues(t *testing.T) { + tests := []struct { + name string + rateLimitRPM int + numRequests int + wantLastAllowed bool + wantLastRemaining int + }{ + { + name: "remaining decrements correctly", + rateLimitRPM: 5, + numRequests: 1, + wantLastAllowed: true, + wantLastRemaining: 4, + }, + { + name: "remaining is zero at limit", + rateLimitRPM: 3, + numRequests: 3, + wantLastAllowed: true, + wantLastRemaining: 0, + }, + { + name: "denied returns zero remaining", + rateLimitRPM: 2, + numRequests: 3, + wantLastAllowed: false, + wantLastRemaining: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := newTestDB(t) + rl := NewRateLimiter(db) + + var allowed bool + var remaining int + for i := 0; i < tt.numRequests; i++ { + allowed, remaining, _ = rl.allow("test-token", tt.rateLimitRPM) + } + + if allowed != tt.wantLastAllowed { + t.Errorf("allowed = %v, want %v", allowed, tt.wantLastAllowed) + } + if remaining != tt.wantLastRemaining { + t.Errorf("remaining = %d, want %d", remaining, tt.wantLastRemaining) + } + }) + } +} + +func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) { + tests := []struct { + name string + rateLimitRPM int + numRequests int + wantStatusCode int + wantLimitHeader string + wantRetryAfter bool + }{ + { + name: "sets rate limit headers on allowed request", + rateLimitRPM: 10, + numRequests: 1, + wantStatusCode: http.StatusOK, + wantLimitHeader: "10", + wantRetryAfter: false, + }, + { + name: "sets Retry-After header on 429", + rateLimitRPM: 2, + numRequests: 3, + wantStatusCode: http.StatusTooManyRequests, + wantLimitHeader: "2", + wantRetryAfter: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := newTestDB(t) + rl := NewRateLimiter(db) + + token := &auth.APIToken{ + Name: "header-test-token", + RateLimitRPM: tt.rateLimitRPM, + } + + handler := rl.Check(okHandler) + + var rec *httptest.ResponseRecorder + for i := 0; i < tt.numRequests; i++ { + rec = httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx := withAPIToken(req.Context(), token) + req = req.WithContext(ctx) + handler.ServeHTTP(rec, req) + } + + // Check the last response. + if rec.Code != tt.wantStatusCode { + t.Errorf("status code = %d, want %d", rec.Code, tt.wantStatusCode) + } + + // X-RateLimit-Limit header. + limitHeader := rec.Header().Get("X-RateLimit-Limit") + if limitHeader != tt.wantLimitHeader { + t.Errorf("X-RateLimit-Limit = %q, want %q", limitHeader, tt.wantLimitHeader) + } + + // X-RateLimit-Remaining header must be present and numeric. + remainingHeader := rec.Header().Get("X-RateLimit-Remaining") + if remainingHeader == "" { + t.Error("X-RateLimit-Remaining header is missing") + } else if _, err := strconv.Atoi(remainingHeader); err != nil { + t.Errorf("X-RateLimit-Remaining = %q, not a valid integer", remainingHeader) + } + + // X-RateLimit-Reset header must be present and numeric. + resetHeader := rec.Header().Get("X-RateLimit-Reset") + if resetHeader == "" { + t.Error("X-RateLimit-Reset header is missing") + } else if _, err := strconv.ParseInt(resetHeader, 10, 64); err != nil { + t.Errorf("X-RateLimit-Reset = %q, not a valid integer", resetHeader) + } + + // Retry-After header. + retryAfter := rec.Header().Get("Retry-After") + if tt.wantRetryAfter && retryAfter == "" { + t.Error("Retry-After header is missing on 429 response") + } + if !tt.wantRetryAfter && retryAfter != "" { + t.Errorf("Retry-After header should not be present, got %q", retryAfter) + } + }) + } +} + +func TestRateLimiter_CheckMiddleware_NoToken(t *testing.T) { + db := newTestDB(t) + rl := NewRateLimiter(db) + + handler := rl.Check(okHandler) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // No API token in context. + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status code = %d, want %d (should pass through without token)", rec.Code, http.StatusOK) + } +} + +func TestRateLimiter_CheckMiddleware_ZeroRPM(t *testing.T) { + db := newTestDB(t) + rl := NewRateLimiter(db) + + token := &auth.APIToken{ + Name: "unlimited-token", + RateLimitRPM: 0, // zero means unlimited + } + + handler := rl.Check(okHandler) + + for i := 0; i < 100; i++ { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx := withAPIToken(req.Context(), token) + req = req.WithContext(ctx) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("request %d: status code = %d, want %d (zero RPM should be unlimited)", i, rec.Code, http.StatusOK) + } + } +} + +func TestRateLimiter_PerTokenIsolation(t *testing.T) { + db := newTestDB(t) + rl := NewRateLimiter(db) + + rpm := 2 + + // Exhaust token A. + for i := 0; i < rpm; i++ { + rl.allow("token-a", rpm) + } + ok, _, _ := rl.allow("token-a", rpm) + if ok { + t.Fatal("token-a should be rate limited") + } + + // Token B should still have its own bucket. + ok, _, _ = rl.allow("token-b", rpm) + if !ok { + t.Fatal("token-b should not be affected by token-a's rate limit") + } +} + +func TestRateLimiter_ResetAtIsFuture(t *testing.T) { + db := newTestDB(t) + rl := NewRateLimiter(db) + + // Consume one token so there's a deficit. + _, _, resetAt := rl.allow("reset-token", 10) + now := time.Now().Unix() + + if resetAt < now { + t.Errorf("resetAt = %d, want >= %d (should be now or in the future)", resetAt, now) + } +} + +func TestRateLimiter_CheckMiddleware_ContextCancelled(t *testing.T) { + db := newTestDB(t) + rl := NewRateLimiter(db) + + token := &auth.APIToken{ + Name: "ctx-token", + RateLimitRPM: 10, + } + + handler := rl.Check(okHandler) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx, cancel := context.WithCancel(req.Context()) + ctx = withAPIToken(ctx, token) + cancel() // Cancel immediately. + req = req.WithContext(ctx) + + // Should still process (rate limiter does not check context cancellation). + handler.ServeHTTP(rec, req) + // The handler itself may or may not respect cancelled context; + // the key point is no panic occurs. +} diff --git a/llm-gateway/internal/proxy/stream.go b/llm-gateway/internal/proxy/stream.go index eb9e225..c1304d5 100644 --- a/llm-gateway/internal/proxy/stream.go +++ b/llm-gateway/internal/proxy/stream.go @@ -2,6 +2,7 @@ package proxy import ( "bufio" + "context" "encoding/json" "errors" "log" @@ -12,7 +13,7 @@ import ( "llm-gateway/internal/provider" ) -func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string) { +func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) { flusher, ok := w.(http.Flusher) if !ok { writeError(w, http.StatusInternalServerError, "streaming not supported") @@ -21,7 +22,18 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov var lastErr error - for _, route := range routes { + for i, route := range routes { + // Retry backoff between attempts + if i > 0 { + backoff := backoffDuration(i, h.cfg.Retry) + select { + case <-time.After(backoff): + case <-r.Context().Done(): + writeError(w, http.StatusGatewayTimeout, "request cancelled") + return + } + } + start := time.Now() body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req) @@ -30,67 +42,95 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov if errors.As(err, &pe) && !pe.IsRetryable() { latency := time.Since(start).Milliseconds() h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) - h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false) + h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false) if h.healthTracker != nil { h.healthTracker.Record(route.Provider.Name(), latency, err) } + w.Header().Set("X-Request-ID", requestID) writeErrorRaw(w, pe.StatusCode, pe.Body) return } lastErr = err latency := time.Since(start).Milliseconds() log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err) - h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false) + h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false) if h.healthTracker != nil { h.healthTracker.Record(route.Provider.Name(), latency, err) } continue } + // Apply streaming timeout + var streamCtx context.Context + var streamCancel context.CancelFunc + if h.cfg.Server.StreamingTimeout > 0 { + streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout) + } else { + streamCtx, streamCancel = context.WithCancel(r.Context()) + } + // Stream the response w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") + w.Header().Set("X-Request-ID", requestID) w.WriteHeader(http.StatusOK) inputTokens, outputTokens := 0, 0 scanner := bufio.NewScanner(body) scanner.Buffer(make([]byte, 64*1024), 256*1024) - for scanner.Scan() { - line := scanner.Text() + scanDone := make(chan struct{}) + go func() { + defer close(scanDone) + for scanner.Scan() { + select { + case <-streamCtx.Done(): + return + default: + } - // Parse usage from the final chunk if available - if strings.HasPrefix(line, "data: ") { - data := strings.TrimPrefix(line, "data: ") - if data != "[DONE]" { - var chunk streamChunk - if json.Unmarshal([]byte(data), &chunk) == nil { - if chunk.Usage != nil { - inputTokens = chunk.Usage.PromptTokens - outputTokens = chunk.Usage.CompletionTokens - } - // Override model name in chunk - if chunk.Model != "" { - chunk.Model = req.Model - if rewritten, err := json.Marshal(chunk); err == nil { - line = "data: " + string(rewritten) + line := scanner.Text() + + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data != "[DONE]" { + var chunk streamChunk + if json.Unmarshal([]byte(data), &chunk) == nil { + if chunk.Usage != nil { + inputTokens = chunk.Usage.PromptTokens + outputTokens = chunk.Usage.CompletionTokens + } + if chunk.Model != "" { + chunk.Model = req.Model + if rewritten, err := json.Marshal(chunk); err == nil { + line = "data: " + string(rewritten) + } } } } } - } - w.Write([]byte(line + "\n")) - flusher.Flush() + w.Write([]byte(line + "\n")) + flusher.Flush() + } + }() + + select { + case <-scanDone: + // Normal completion + case <-streamCtx.Done(): + log.Printf("Stream timeout for %s via %s", req.Model, route.Provider.Name()) } + body.Close() + streamCancel() latency := time.Since(start).Milliseconds() cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice) h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost) - h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false) + h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false) if h.healthTracker != nil { h.healthTracker.Record(route.Provider.Name(), latency, nil) } @@ -98,6 +138,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov } // All providers failed + w.Header().Set("X-Request-ID", requestID) if lastErr != nil { writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error()) } else { diff --git a/llm-gateway/internal/storage/audit.go b/llm-gateway/internal/storage/audit.go new file mode 100644 index 0000000..d821f99 --- /dev/null +++ b/llm-gateway/internal/storage/audit.go @@ -0,0 +1,102 @@ +package storage + +import ( + "log" + "time" +) + +type AuditEntry struct { + ID int64 `json:"id"` + Timestamp int64 `json:"timestamp"` + UserID int64 `json:"user_id"` + Username string `json:"username"` + Action string `json:"action"` + TargetType string `json:"target_type"` + TargetID string `json:"target_id"` + Details string `json:"details"` + IPAddress string `json:"ip_address"` + RequestID string `json:"request_id"` +} + +type AuditLogger struct { + db *DB +} + +func NewAuditLogger(db *DB) *AuditLogger { + return &AuditLogger{db: db} +} + +func (a *AuditLogger) Log(entry AuditEntry) { + if entry.Timestamp == 0 { + entry.Timestamp = time.Now().Unix() + } + _, err := a.db.Exec(`INSERT INTO audit_log + (timestamp, user_id, username, action, target_type, target_id, details, ip_address, request_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + entry.Timestamp, entry.UserID, entry.Username, entry.Action, + entry.TargetType, entry.TargetID, entry.Details, entry.IPAddress, entry.RequestID, + ) + if err != nil { + log.Printf("ERROR: audit log: %v", err) + } +} + +type AuditQueryResult struct { + Entries []AuditEntry `json:"entries"` + Page int `json:"page"` + TotalPages int `json:"total_pages"` + Total int `json:"total"` +} + +func (a *AuditLogger) Query(since int64, action string, page, limit int) *AuditQueryResult { + if page < 1 { + page = 1 + } + if limit <= 0 { + limit = 50 + } + offset := (page - 1) * limit + + where := "WHERE timestamp >= ?" + args := []any{since} + + if action != "" { + where += " AND action = ?" + args = append(args, action) + } + + var total int + countArgs := make([]any, len(args)) + copy(countArgs, args) + a.db.QueryRow("SELECT COUNT(*) FROM audit_log "+where, countArgs...).Scan(&total) + + totalPages := (total + limit - 1) / limit + if totalPages < 1 { + totalPages = 1 + } + + query := `SELECT id, timestamp, COALESCE(user_id, 0), username, action, + COALESCE(target_type, ''), COALESCE(target_id, ''), COALESCE(details, ''), + COALESCE(ip_address, ''), COALESCE(request_id, '') + FROM audit_log ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?` + args = append(args, limit, offset) + + rows, err := a.db.Query(query, args...) + if err != nil { + return &AuditQueryResult{Entries: []AuditEntry{}, Page: page, TotalPages: totalPages, Total: total} + } + defer rows.Close() + + var entries []AuditEntry + for rows.Next() { + var e AuditEntry + rows.Scan(&e.ID, &e.Timestamp, &e.UserID, &e.Username, &e.Action, + &e.TargetType, &e.TargetID, &e.Details, &e.IPAddress, &e.RequestID) + entries = append(entries, e) + } + if entries == nil { + entries = []AuditEntry{} + } + + return &AuditQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total} +} diff --git a/llm-gateway/internal/storage/debuglog.go b/llm-gateway/internal/storage/debuglog.go new file mode 100644 index 0000000..be333dc --- /dev/null +++ b/llm-gateway/internal/storage/debuglog.go @@ -0,0 +1,250 @@ +package storage + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "sort" + "strings" + "sync/atomic" + "time" +) + +type DebugLogEntry struct { + ID int64 `json:"id"` + RequestID string `json:"request_id"` + Timestamp int64 `json:"timestamp"` + TokenName string `json:"token_name"` + Model string `json:"model"` + Provider string `json:"provider"` + RequestBody string `json:"request_body"` + ResponseBody string `json:"response_body"` + RequestHeaders string `json:"request_headers"` + ResponseStatus int `json:"response_status"` + FilePath string `json:"-"` +} + +// debugFile is the JSON structure written to disk. +type debugFile struct { + RequestHeaders string `json:"request_headers"` + RequestBody string `json:"request_body"` + ResponseBody string `json:"response_body"` +} + +type DebugLogger struct { + db *DB + enabled atomic.Bool + dataDir string +} + +func NewDebugLogger(db *DB, enabled bool, dataDir string) *DebugLogger { + dl := &DebugLogger{db: db, dataDir: dataDir} + dl.enabled.Store(enabled) + return dl +} + +func (d *DebugLogger) SetEnabled(v bool) { + d.enabled.Store(v) +} + +func (d *DebugLogger) IsEnabled() bool { + return d.enabled.Load() +} + +// debugLogDir returns the base directory for debug log files. +func (d *DebugLogger) debugLogDir() string { + return filepath.Join(d.dataDir, "debug-logs") +} + +// debugFilePath builds the file path for a debug log entry. +func (d *DebugLogger) debugFilePath(requestID string, ts time.Time) string { + date := ts.Format("2006-01-02") + return filepath.Join(d.debugLogDir(), date, requestID+".json") +} + +func (d *DebugLogger) Log(entry DebugLogEntry) { + if !d.IsEnabled() { + return + } + if entry.Timestamp == 0 { + entry.Timestamp = time.Now().Unix() + } + + ts := time.Unix(entry.Timestamp, 0) + fp := d.debugFilePath(entry.RequestID, ts) + + // Write body file + if err := os.MkdirAll(filepath.Dir(fp), 0755); err != nil { + log.Printf("ERROR: debug log mkdir: %v", err) + return + } + + df := debugFile{ + RequestHeaders: entry.RequestHeaders, + RequestBody: entry.RequestBody, + ResponseBody: entry.ResponseBody, + } + data, err := json.Marshal(df) + if err != nil { + log.Printf("ERROR: debug log marshal: %v", err) + return + } + if err := os.WriteFile(fp, data, 0644); err != nil { + log.Printf("ERROR: debug log write: %v", err) + return + } + + // Insert metadata into DB (no bodies) + _, err = d.db.Exec(`INSERT INTO debug_log + (request_id, timestamp, token_name, model, provider, response_status, file_path) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + entry.RequestID, entry.Timestamp, entry.TokenName, entry.Model, + entry.Provider, entry.ResponseStatus, fp, + ) + if err != nil { + log.Printf("ERROR: debug log db insert: %v", err) + } +} + +type DebugLogQueryResult struct { + Entries []DebugLogEntry `json:"entries"` + Page int `json:"page"` + TotalPages int `json:"total_pages"` + Total int `json:"total"` +} + +// Query returns paginated debug log metadata (no bodies — fast). +func (d *DebugLogger) Query(page, limit int) *DebugLogQueryResult { + if page < 1 { + page = 1 + } + if limit <= 0 { + limit = 50 + } + offset := (page - 1) * limit + + var total int + d.db.QueryRow("SELECT COUNT(*) FROM debug_log").Scan(&total) + + totalPages := (total + limit - 1) / limit + if totalPages < 1 { + totalPages = 1 + } + + rows, err := d.db.Query(`SELECT id, request_id, timestamp, COALESCE(token_name, ''), + COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '') + FROM debug_log ORDER BY timestamp DESC LIMIT ? OFFSET ?`, limit, offset) + if err != nil { + return &DebugLogQueryResult{Entries: []DebugLogEntry{}, Page: page, TotalPages: totalPages, Total: total} + } + defer rows.Close() + + var entries []DebugLogEntry + for rows.Next() { + var e DebugLogEntry + rows.Scan(&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName, + &e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath) + entries = append(entries, e) + } + if entries == nil { + entries = []DebugLogEntry{} + } + + return &DebugLogQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total} +} + +// QueryFull returns paginated debug log entries including request/response bodies read from files. +func (d *DebugLogger) QueryFull(page, limit int) *DebugLogQueryResult { + result := d.Query(page, limit) + for i := range result.Entries { + d.populateFromFile(&result.Entries[i]) + } + return result +} + +// GetByRequestID returns a single debug log entry with bodies read from file. +func (d *DebugLogger) GetByRequestID(requestID string) *DebugLogEntry { + var e DebugLogEntry + err := d.db.QueryRow(`SELECT id, request_id, timestamp, COALESCE(token_name, ''), + COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '') + FROM debug_log WHERE request_id = ?`, requestID).Scan( + &e.ID, &e.RequestID, &e.Timestamp, &e.TokenName, + &e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath) + if err != nil { + return nil + } + d.populateFromFile(&e) + return &e +} + +// populateFromFile reads body data from the debug file on disk. +// Falls back to DB columns for pre-migration entries that have no file_path. +func (d *DebugLogger) populateFromFile(e *DebugLogEntry) { + if e.FilePath == "" { + // Legacy entry: try reading bodies from DB columns + d.db.QueryRow(`SELECT COALESCE(request_body, ''), COALESCE(response_body, ''), COALESCE(request_headers, '') + FROM debug_log WHERE id = ?`, e.ID).Scan(&e.RequestBody, &e.ResponseBody, &e.RequestHeaders) + return + } + data, err := os.ReadFile(e.FilePath) + if err != nil { + log.Printf("WARN: debug log read file %s: %v", e.FilePath, err) + return + } + var df debugFile + if err := json.Unmarshal(data, &df); err != nil { + log.Printf("WARN: debug log parse file %s: %v", e.FilePath, err) + return + } + e.RequestHeaders = df.RequestHeaders + e.RequestBody = df.RequestBody + e.ResponseBody = df.ResponseBody +} + +// Cleanup removes debug log entries and files older than retentionDays. +func (d *DebugLogger) Cleanup(retentionDays int) error { + cutoff := time.Now().AddDate(0, 0, -retentionDays) + cutoffUnix := cutoff.Unix() + + // Delete old DB rows + result, err := d.db.Exec("DELETE FROM debug_log WHERE timestamp < ?", cutoffUnix) + if err != nil { + return fmt.Errorf("delete old debug rows: %w", err) + } + affected, _ := result.RowsAffected() + if affected > 0 { + log.Printf("Cleaned up %d old debug log entries", affected) + } + + // Remove old date directories + baseDir := d.debugLogDir() + dirs, err := os.ReadDir(baseDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("read debug log dir: %w", err) + } + + cutoffDate := cutoff.Format("2006-01-02") + sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() }) + + for _, dir := range dirs { + if !dir.IsDir() { + continue + } + // Date directories are named YYYY-MM-DD; string comparison works + if strings.Compare(dir.Name(), cutoffDate) < 0 { + dirPath := filepath.Join(baseDir, dir.Name()) + if err := os.RemoveAll(dirPath); err != nil { + log.Printf("WARN: failed to remove debug log dir %s: %v", dirPath, err) + } else { + log.Printf("Removed old debug log directory: %s", dir.Name()) + } + } + } + + return nil +} diff --git a/llm-gateway/internal/storage/logger.go b/llm-gateway/internal/storage/logger.go index ad2e829..d832dd4 100644 --- a/llm-gateway/internal/storage/logger.go +++ b/llm-gateway/internal/storage/logger.go @@ -6,6 +6,7 @@ import ( ) type RequestLog struct { + RequestID string Timestamp int64 TokenName string Model string @@ -93,8 +94,8 @@ func (l *AsyncLogger) flush(batch []RequestLog) { } stmt, err := tx.Prepare(`INSERT INTO request_logs - (timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + (request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) if err != nil { log.Printf("ERROR: preparing log statement: %v", err) tx.Rollback() @@ -112,7 +113,7 @@ func (l *AsyncLogger) flush(batch []RequestLog) { cached = 1 } _, err := stmt.Exec( - r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel, + r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel, r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS, r.Status, r.ErrorMessage, streaming, cached, ) diff --git a/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql b/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql new file mode 100644 index 0000000..e11bb40 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql @@ -0,0 +1,4 @@ +-- SQLite doesn't support DROP COLUMN in older versions, so we recreate the table +CREATE TABLE api_tokens_backup AS SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens; +DROP TABLE api_tokens; +ALTER TABLE api_tokens_backup RENAME TO api_tokens; diff --git a/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql b/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql new file mode 100644 index 0000000..ccf0549 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql @@ -0,0 +1 @@ +ALTER TABLE api_tokens ADD COLUMN max_concurrent INTEGER DEFAULT 0; diff --git a/llm-gateway/internal/storage/migrations/005_request_id.down.sql b/llm-gateway/internal/storage/migrations/005_request_id.down.sql new file mode 100644 index 0000000..819b90b --- /dev/null +++ b/llm-gateway/internal/storage/migrations/005_request_id.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS idx_request_logs_request_id; diff --git a/llm-gateway/internal/storage/migrations/005_request_id.up.sql b/llm-gateway/internal/storage/migrations/005_request_id.up.sql new file mode 100644 index 0000000..ff54384 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/005_request_id.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE request_logs ADD COLUMN request_id TEXT DEFAULT ''; +CREATE INDEX idx_request_logs_request_id ON request_logs(request_id); diff --git a/llm-gateway/internal/storage/migrations/006_audit_log.down.sql b/llm-gateway/internal/storage/migrations/006_audit_log.down.sql new file mode 100644 index 0000000..b750c3b --- /dev/null +++ b/llm-gateway/internal/storage/migrations/006_audit_log.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS audit_log; diff --git a/llm-gateway/internal/storage/migrations/006_audit_log.up.sql b/llm-gateway/internal/storage/migrations/006_audit_log.up.sql new file mode 100644 index 0000000..c2fc48a --- /dev/null +++ b/llm-gateway/internal/storage/migrations/006_audit_log.up.sql @@ -0,0 +1,14 @@ +CREATE TABLE audit_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp INTEGER NOT NULL, + user_id INTEGER, + username TEXT NOT NULL DEFAULT '', + action TEXT NOT NULL, + target_type TEXT DEFAULT '', + target_id TEXT DEFAULT '', + details TEXT DEFAULT '', + ip_address TEXT DEFAULT '', + request_id TEXT DEFAULT '' +); +CREATE INDEX idx_audit_timestamp ON audit_log(timestamp); +CREATE INDEX idx_audit_action ON audit_log(action); diff --git a/llm-gateway/internal/storage/migrations/007_debug_log.down.sql b/llm-gateway/internal/storage/migrations/007_debug_log.down.sql new file mode 100644 index 0000000..41353f5 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/007_debug_log.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS debug_log; diff --git a/llm-gateway/internal/storage/migrations/007_debug_log.up.sql b/llm-gateway/internal/storage/migrations/007_debug_log.up.sql new file mode 100644 index 0000000..9a8441a --- /dev/null +++ b/llm-gateway/internal/storage/migrations/007_debug_log.up.sql @@ -0,0 +1,14 @@ +CREATE TABLE debug_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + request_id TEXT NOT NULL, + timestamp INTEGER NOT NULL, + token_name TEXT DEFAULT '', + model TEXT DEFAULT '', + provider TEXT DEFAULT '', + request_body TEXT DEFAULT '', + response_body TEXT DEFAULT '', + request_headers TEXT DEFAULT '', + response_status INTEGER DEFAULT 0 +); +CREATE INDEX idx_debug_request_id ON debug_log(request_id); +CREATE INDEX idx_debug_timestamp ON debug_log(timestamp); diff --git a/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql b/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql new file mode 100644 index 0000000..032a37d --- /dev/null +++ b/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql @@ -0,0 +1 @@ +-- no-op: file_path column is harmless to keep diff --git a/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql b/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql new file mode 100644 index 0000000..7a5bf8b --- /dev/null +++ b/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql @@ -0,0 +1 @@ +ALTER TABLE debug_log ADD COLUMN file_path TEXT DEFAULT '';