Compare commits

..

No commits in common. "28a694744d291f70fb403a0b64e65c0d3eb8aa83" and "f23a7c14c077fc6676e02d25dd5e28e8ad2abd22" have entirely different histories.

53 changed files with 709 additions and 4824 deletions

View file

@ -12,7 +12,7 @@ ADMIN_USERNAME=admin
ADMIN_PASSWORD=change-me-min-8-chars
# Static API tokens (seeded on startup, leave empty to skip)
OPENWEBUI_API_KEY=sk-...
OPENCODE_API_KEY=sk-...
PERSONAL_API_KEY=sk-...
# Provider API keys
OPENROUTER_API_KEY=sk-or-...
SILICONFLOW_API_KEY=sk-...

3
.gitignore vendored
View file

@ -3,8 +3,5 @@
# Environment secrets
.env
# Host-mounted data directories
data/
# SearXNG runtime state
searxng/uwsgi.ini

View file

@ -44,7 +44,7 @@ services:
ports:
- "0.0.0.0:4000:3000"
volumes:
- ./data/llm-gateway:/data
- llm-gateway-data:/data
- ./llm-gateway.yaml:/etc/llm-gateway/config.yaml:ro
environment:
- SESSION_SECRET=${SESSION_SECRET}
@ -161,6 +161,7 @@ services:
volumes:
valkey-data:
chromadb-data:
llm-gateway-data:
open-webui-data:
tailscale-state:
victoriametrics-data:

233
litellm/config.yaml Normal file
View file

@ -0,0 +1,233 @@
model_list:
# ═══════════════════════════════════════════════
# TIER 1: Free providers (try first)
# ═══════════════════════════════════════════════
# --- Groq (free tier, very fast) ---
- model_name: llama-3.3-70b
litellm_params:
model: groq/llama-3.3-70b-versatile
api_key: os.environ/GROQ_API_KEY
# --- Cerebras (free tier, very fast) ---
- model_name: llama-3.3-70b-cerebras
litellm_params:
model: cerebras/llama-3.3-70b
api_key: os.environ/CEREBRAS_API_KEY
# --- OpenRouter free models ---
- model_name: deepseek-v3-free
litellm_params:
model: openrouter/deepseek/deepseek-chat-v3-0324:free
api_key: os.environ/OPENROUTER_API_KEY
# ═══════════════════════════════════════════════
# TIER 2: DeepSeek V3.2 (cheapest first)
# ═══════════════════════════════════════════════
# DeepSeek V3.2 via DeepInfra ($0.26 in / $0.38 out per M)
- model_name: deepseek-v3.2
litellm_params:
model: deepinfra/deepseek-ai/DeepSeek-V3.2
api_key: os.environ/DEEPINFRA_API_KEY
# DeepSeek V3.2 fallback via SiliconFlow ($0.27 in / $0.42 out per M)
- model_name: deepseek-v3.2
litellm_params:
model: openai/deepseek-ai/DeepSeek-V3.2
api_base: https://api.siliconflow.com/v1
api_key: os.environ/SILICONFLOW_API_KEY
# ═══════════════════════════════════════════════
# TIER 3: Ultra-cheap DeepInfra models
# ═══════════════════════════════════════════════
# GPT-OSS-120B — OpenAI open-weight MoE ($0.05 in / $0.24 out per M)
- model_name: gpt-oss
litellm_params:
model: deepinfra/openai/gpt-oss-120b
api_key: os.environ/DEEPINFRA_API_KEY
# GPT-OSS-20B — lower latency variant ($0.04 in / $0.16 out per M)
- model_name: gpt-oss-20b
litellm_params:
model: deepinfra/openai/gpt-oss-20b
api_key: os.environ/DEEPINFRA_API_KEY
# Nemotron Super 49B — near-flagship quality ($0.10 in / $0.40 out per M)
- model_name: nemotron-super
litellm_params:
model: deepinfra/nvidia/Llama-3.3-Nemotron-Super-49B-v1.5
api_key: os.environ/DEEPINFRA_API_KEY
# Nemotron Nano 9B — dirt cheap for simple tasks ($0.04 in / $0.16 out per M)
- model_name: nemotron-nano
litellm_params:
model: deepinfra/nvidia/NVIDIA-Nemotron-Nano-9B-v2
api_key: os.environ/DEEPINFRA_API_KEY
# ═══════════════════════════════════════════════
# TIER 4: Other DeepInfra models
# ═══════════════════════════════════════════════
- model_name: deepseek-r1
litellm_params:
model: deepinfra/deepseek-ai/DeepSeek-R1
api_key: os.environ/DEEPINFRA_API_KEY
- model_name: devstral
litellm_params:
model: deepinfra/mistralai/Devstral-Small-2505
api_key: os.environ/DEEPINFRA_API_KEY
# ═══════════════════════════════════════════════
# TIER 5: GLM models (cheapest first)
# ═══════════════════════════════════════════════
# GLM-4.6 via DeepInfra ($0.60 in / $1.90 out per M)
- model_name: glm-4.6
litellm_params:
model: deepinfra/zai-org/GLM-4.6
api_key: os.environ/DEEPINFRA_API_KEY
# GLM-4.7 via DeepInfra ($0.40 in / $1.75 out per M)
- model_name: glm-4.7
litellm_params:
model: deepinfra/zai-org/GLM-4.7
api_key: os.environ/DEEPINFRA_API_KEY
# GLM-4.7 fallback via SiliconFlow
- model_name: glm-4.7
litellm_params:
model: openai/THUDM/GLM-4-32B-0414
api_base: https://api.siliconflow.com/v1
api_key: os.environ/SILICONFLOW_API_KEY
# GLM-5 via DeepInfra ($0.80 in / $2.56 out per M)
- model_name: glm-5
litellm_params:
model: deepinfra/zai-org/GLM-5
api_key: os.environ/DEEPINFRA_API_KEY
# ═══════════════════════════════════════════════
# TIER 6: Kimi K2 (cheapest first)
# ═══════════════════════════════════════════════
# Kimi K2 via DeepInfra ($0.50 in / $2.00 out per M)
- model_name: kimi-k2
litellm_params:
model: deepinfra/moonshotai/Kimi-K2-Instruct-0905
api_key: os.environ/DEEPINFRA_API_KEY
# Kimi K2 fallback via SiliconFlow ($0.58 in / $2.29 out per M)
- model_name: kimi-k2
litellm_params:
model: openai/moonshotai/Kimi-K2-Instruct-0905
api_base: https://api.siliconflow.com/v1
api_key: os.environ/SILICONFLOW_API_KEY
# ═══════════════════════════════════════════════
# TIER 7: SiliconFlow (Qwen)
# ═══════════════════════════════════════════════
# Qwen3 Coder 480B MoE via SiliconFlow ($1.14 in / $2.28 out per M)
- model_name: qwen3-coder
litellm_params:
model: openai/Qwen/Qwen3-Coder-480B-A35B-Instruct
api_base: https://api.siliconflow.com/v1
api_key: os.environ/SILICONFLOW_API_KEY
# Qwen3 Coder 30B — cheaper alternative for simpler tasks
- model_name: qwen3-coder-30b
litellm_params:
model: openai/Qwen/Qwen3-Coder-30B-A3B-Instruct
api_base: https://api.siliconflow.com/v1
api_key: os.environ/SILICONFLOW_API_KEY
# ═══════════════════════════════════════════════
# TIER 8: OpenRouter (most expensive, widest selection)
# ═══════════════════════════════════════════════
# Kimi K2.5 — DeepInfra is cheapest ($0.45 in / $2.25 out per M)
- model_name: kimi-k2.5
litellm_params:
model: deepinfra/moonshotai/Kimi-K2.5
api_key: os.environ/DEEPINFRA_API_KEY
# Kimi K2.5 fallback via OpenRouter
- model_name: kimi-k2.5
litellm_params:
model: openrouter/moonshotai/kimi-k2.5
api_key: os.environ/OPENROUTER_API_KEY
- model_name: minimax-m2.5
litellm_params:
model: openrouter/minimax/minimax-m2.5
api_key: os.environ/OPENROUTER_API_KEY
- model_name: gpt-4.1-mini
litellm_params:
model: openrouter/openai/gpt-4.1-mini
api_key: os.environ/OPENROUTER_API_KEY
- model_name: gemini-3-flash-preview
litellm_params:
model: openrouter/google/gemini-3-flash-preview
api_key: os.environ/OPENROUTER_API_KEY
- model_name: trinity-large-preview
litellm_params:
model: openrouter/arcee-ai/trinity-large-preview
api_key: os.environ/OPENROUTER_API_KEY
# --- OpenRouter premium models ---
- model_name: gemini-2.5-pro
litellm_params:
model: openrouter/google/gemini-2.5-pro-preview
api_key: os.environ/OPENROUTER_API_KEY
- model_name: claude-sonnet
litellm_params:
model: openrouter/anthropic/claude-sonnet-4
api_key: os.environ/OPENROUTER_API_KEY
- model_name: gpt-4.1
litellm_params:
model: openrouter/openai/gpt-4.1
api_key: os.environ/OPENROUTER_API_KEY
# DeepSeek V3.2 last-resort fallback via OpenRouter
- model_name: deepseek-v3.2
litellm_params:
model: openrouter/deepseek/deepseek-chat-v3-0324
api_key: os.environ/OPENROUTER_API_KEY
general_settings:
master_key: os.environ/LITELLM_MASTER_KEY
# ── Model group fallbacks (when model misbehaves mid-stream) ──
fallbacks:
- deepseek-v3.2: [gpt-oss, kimi-k2]
- gpt-oss: [deepseek-v3.2, nemotron-super]
- kimi-k2: [deepseek-v3.2, gpt-oss]
- kimi-k2.5: [deepseek-v3.2, gpt-oss]
- glm-4.7: [deepseek-v3.2, gpt-oss]
- glm-5: [deepseek-v3.2, gpt-oss]
litellm_settings:
drop_params: true
set_verbose: false
num_retries: 2
request_timeout: 600
# ── Response caching via Valkey (reuses SearXNG's instance) ──
cache: true
cache_params:
type: redis
host: valkey
port: 6379
ttl: 3600
# ── Budget limit: $3/day to prevent surprise bills ──
# max_budget: 3.0
# budget_duration: "1d"

View file

@ -25,12 +25,6 @@ database:
path: "/data/gateway.db"
retention_days: 90
debug:
enabled: true
retention_days: 90
# data_dir: "/data" # defaults to directory of database.path
# max_body_bytes: 0 # 0 = unlimited (save full bodies)
cache:
enabled: true
address: "valkey:6379"

View file

@ -8,9 +8,6 @@ llm-gateway
*.db-wal
*.db-shm
# Debug log files
debug-logs/
# Local config
configs/config.local.yaml

View file

@ -7,13 +7,11 @@ import (
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
gocors "github.com/go-chi/cors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"llm-gateway/internal/auth"
@ -93,7 +91,7 @@ func main() {
log.Printf("Registered %d models", len(cfg.Models))
// Provider health tracker
healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker)
healthTracker := provider.NewHealthTracker(5 * time.Minute)
// Auth store (static tokens checked in-memory, not seeded to DB)
var staticTokens []auth.StaticToken
@ -104,7 +102,6 @@ func main() {
Key: t.Key,
RateLimitRPM: t.RateLimitRPM,
DailyBudgetUSD: t.DailyBudgetUSD,
MaxConcurrent: t.MaxConcurrent,
})
log.Printf("Loaded static token: %s", t.Name)
}
@ -113,17 +110,6 @@ func main() {
authMiddleware := auth.NewMiddleware(authStore)
authHandlers := auth.NewHandlers(authStore, cfg.Server.SessionSecret)
// Audit logger
auditLogger := storage.NewAuditLogger(db)
authHandlers.SetAuditLogger(auditLogger)
// Debug logger
debugDataDir := cfg.Debug.DataDir
if debugDataDir == "" {
debugDataDir = filepath.Dir(cfg.Database.Path)
}
debugLogger := storage.NewDebugLogger(db, cfg.Debug.Enabled, debugDataDir)
// Seed default admin
seedDefaultAdmin(cfg, authStore)
@ -132,43 +118,22 @@ func main() {
// Handlers
proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker)
proxyHandler.SetDebugLogger(debugLogger)
modelsHandler := proxy.NewModelsHandler(registry)
proxyAuth := proxy.NewAuthMiddleware(authStore)
rateLimiter := proxy.NewRateLimiter(db)
concurrencyLimiter := proxy.NewConcurrencyLimiter()
statsAPI := dashboard.NewStatsAPI(db, authStore)
statsAPI.SetHealthTracker(healthTracker)
statsAPI.SetAuditLogger(auditLogger)
statsAPI.SetDebugLogger(debugLogger)
if c != nil {
statsAPI.SetCache(c)
}
dash := dashboard.NewDashboard(authStore, statsAPI)
dash.SetRegistry(registry)
dash.SetAuditLogger(auditLogger)
dash.SetDebugLogger(debugLogger)
if c != nil {
dash.SetCache(c)
}
// Export handler
exportHandler := dashboard.NewExportHandler(db, authStore)
// Router
r := chi.NewRouter()
// CORS (before other middleware)
if cfg.CORS.Enabled {
r.Use(gocors.Handler(gocors.Options{
AllowedOrigins: cfg.CORS.AllowedOrigins,
AllowedMethods: cfg.CORS.AllowedMethods,
AllowedHeaders: cfg.CORS.AllowedHeaders,
MaxAge: cfg.CORS.MaxAge,
AllowCredentials: true,
}))
}
r.Use(middleware.RealIP)
r.Use(middleware.Recoverer)
r.Use(middleware.RequestID)
@ -194,7 +159,6 @@ func main() {
r.Group(func(r chi.Router) {
r.Use(proxyAuth.Authenticate)
r.Use(rateLimiter.Check)
r.Use(concurrencyLimiter.Check)
r.Post("/v1/chat/completions", proxyHandler.ChatCompletions)
r.Get("/v1/models", modelsHandler.ListModels)
})
@ -228,8 +192,6 @@ func main() {
r.Group(func(r chi.Router) {
r.Use(authMiddleware.RequireAdmin)
r.Get("/users", dash.UsersPage)
r.Get("/audit", dash.AuditPage)
r.Get("/debug", dash.DebugPage)
})
// Auth API
@ -262,29 +224,16 @@ func main() {
r.Get("/api/stats/provider-health", statsAPI.ProviderHealthHandler)
r.Get("/api/stats/cache", statsAPI.CacheStats)
// Data export
r.Get("/api/export/logs", exportHandler.ExportLogs)
r.Get("/api/export/stats", exportHandler.ExportStats)
// Admin-only: user management, audit, debug
// Admin-only: user management
r.Group(func(r chi.Router) {
r.Use(authMiddleware.RequireAdmin)
r.Get("/api/auth/users", authHandlers.ListUsers)
r.Post("/api/auth/users", authHandlers.CreateUser)
r.Delete("/api/auth/users/{id}", authHandlers.DeleteUser)
// Audit log
r.Get("/api/stats/audit", statsAPI.AuditLogs)
// Debug logging
r.Post("/api/debug/toggle", statsAPI.DebugToggle)
r.Get("/api/debug/status", statsAPI.DebugStatus)
r.Get("/api/debug/logs", statsAPI.DebugLogs)
r.Get("/api/debug/logs/{requestID}", statsAPI.DebugLogByRequestID)
})
})
// Periodic session cleanup and debug log cleanup
// Periodic session cleanup
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
@ -292,9 +241,6 @@ func main() {
if err := authStore.CleanExpiredSessions(); err != nil {
log.Printf("WARNING: session cleanup failed: %v", err)
}
if err := debugLogger.Cleanup(cfg.Debug.RetentionDays); err != nil {
log.Printf("WARNING: debug log cleanup failed: %v", err)
}
}
}()
@ -307,45 +253,6 @@ func main() {
IdleTimeout: 120 * time.Second,
}
// Config hot-reload via SIGHUP
config.WatchReload(*configPath, func(newCfg *config.Config) {
// Reload registry (models, providers, routes)
if err := registry.Reload(newCfg); err != nil {
log.Printf("ERROR: registry reload failed: %v", err)
return
}
log.Printf("Reloaded %d models", len(newCfg.Models))
// Reload pricing
for i, m := range newCfg.Models {
for j, rt := range m.Routes {
if rt.Pricing.Input == 0 && rt.Pricing.Output == 0 {
pricingLookup.FillMissing(rt.Provider, rt.Model,
&newCfg.Models[i].Routes[j].Pricing.Input,
&newCfg.Models[i].Routes[j].Pricing.Output)
}
}
}
// Reload static tokens
var newStaticTokens []auth.StaticToken
for _, t := range newCfg.Tokens {
if t.Key != "" {
newStaticTokens = append(newStaticTokens, auth.StaticToken{
Name: t.Name,
Key: t.Key,
RateLimitRPM: t.RateLimitRPM,
DailyBudgetUSD: t.DailyBudgetUSD,
MaxConcurrent: t.MaxConcurrent,
})
}
}
authStore.SetStaticTokens(newStaticTokens)
// Update config pointer for retry/debug/etc
cfg = newCfg
})
// Graceful shutdown
done := make(chan os.Signal, 1)
signal.Notify(done, os.Interrupt, syscall.SIGTERM)

View file

@ -19,7 +19,6 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/go-chi/cors v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect

View file

@ -18,8 +18,6 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=

View file

@ -13,15 +13,12 @@ import (
"time"
"github.com/go-chi/chi/v5"
"llm-gateway/internal/storage"
)
type Handlers struct {
store *Store
sessionSecret string
loginLimiter *loginRateLimiter
auditLogger *storage.AuditLogger
}
func NewHandlers(store *Store, sessionSecret string) *Handlers {
@ -32,36 +29,6 @@ func NewHandlers(store *Store, sessionSecret string) *Handlers {
}
}
func (h *Handlers) SetAuditLogger(al *storage.AuditLogger) {
h.auditLogger = al
}
func (h *Handlers) audit(r *http.Request, action, targetType, targetID, details string) {
if h.auditLogger == nil {
return
}
user := UserFromContext(r.Context())
var userID int64
var username string
if user != nil {
userID = user.ID
username = user.Username
}
ip := r.RemoteAddr
if fwd := r.Header.Get("X-Real-IP"); fwd != "" {
ip = fwd
}
h.auditLogger.Log(storage.AuditEntry{
UserID: userID,
Username: username,
Action: action,
TargetType: targetType,
TargetID: targetID,
Details: details,
IPAddress: ip,
})
}
// Login brute-force protection
type loginRateLimiter struct {
mu sync.Mutex
@ -159,8 +126,6 @@ func (h *Handlers) Setup(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "auth.setup", "user", fmt.Sprintf("%d", user.ID), "initial setup")
h.setSessionCookie(w, sessionID)
writeJSON(w, map[string]any{
"user": map[string]any{
@ -222,8 +187,6 @@ func (h *Handlers) Login(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "auth.login", "user", fmt.Sprintf("%d", user.ID), user.Username)
h.setSessionCookie(w, sessionID)
writeJSON(w, map[string]any{
"require_totp": false,
@ -293,8 +256,6 @@ func (h *Handlers) LoginTOTP(w http.ResponseWriter, r *http.Request) {
}
func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) {
h.audit(r, "auth.logout", "", "", "")
cookie, err := r.Cookie(sessionCookieName)
if err == nil {
h.store.DeleteSession(cookie.Value)
@ -386,7 +347,6 @@ func (h *Handlers) TOTPVerify(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "totp.enable", "user", fmt.Sprintf("%d", user.ID), "")
writeJSON(w, map[string]string{"status": "totp_enabled"})
}
@ -402,7 +362,6 @@ func (h *Handlers) TOTPDisable(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "totp.disable", "user", fmt.Sprintf("%d", user.ID), "")
writeJSON(w, map[string]string{"status": "totp_disabled"})
}
@ -461,7 +420,6 @@ func (h *Handlers) CreateUser(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "user.create", "user", fmt.Sprintf("%d", user.ID), user.Username)
writeJSON(w, map[string]any{
"id": user.ID,
"username": user.Username,
@ -489,7 +447,6 @@ func (h *Handlers) DeleteUser(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "user.delete", "user", idStr, "")
writeJSON(w, map[string]string{"status": "deleted"})
}
@ -550,7 +507,6 @@ func (h *Handlers) CreateToken(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "token.create", "token", fmt.Sprintf("%d", token.ID), req.Name)
writeJSON(w, map[string]any{
"key": plainKey,
"token": token,
@ -589,7 +545,6 @@ func (h *Handlers) DeleteToken(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "token.delete", "token", idStr, "")
writeJSON(w, map[string]string{"status": "deleted"})
}
@ -632,7 +587,6 @@ func (h *Handlers) ChangePassword(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "password.change", "user", fmt.Sprintf("%d", user.ID), "")
writeJSON(w, map[string]string{"status": "password_updated"})
}
@ -667,7 +621,6 @@ func (h *Handlers) ChangeUsername(w http.ResponseWriter, r *http.Request) {
return
}
h.audit(r, "username.change", "user", fmt.Sprintf("%d", user.ID), req.NewUsername)
writeJSON(w, map[string]string{"status": "username_updated"})
}

View file

@ -38,7 +38,6 @@ type APIToken struct {
UserID int64 `json:"user_id"`
RateLimitRPM int `json:"rate_limit_rpm"`
DailyBudgetUSD float64 `json:"daily_budget_usd"`
MaxConcurrent int `json:"max_concurrent"`
CreatedAt int64 `json:"created_at"`
LastUsedAt int64 `json:"last_used_at"`
}
@ -49,7 +48,6 @@ type StaticToken struct {
Key string
RateLimitRPM int
DailyBudgetUSD float64
MaxConcurrent int
}
type Store struct {
@ -61,11 +59,6 @@ func NewStore(db *sql.DB, staticTokens []StaticToken) *Store {
return &Store{db: db, staticTokens: staticTokens}
}
// SetStaticTokens updates the static tokens list (used for config hot-reload).
func (s *Store) SetStaticTokens(tokens []StaticToken) {
s.staticTokens = tokens
}
func (s *Store) HasAnyUser() bool {
var count int
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
@ -294,7 +287,6 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
KeyPrefix: prefix,
RateLimitRPM: st.RateLimitRPM,
DailyBudgetUSD: st.DailyBudgetUSD,
MaxConcurrent: st.MaxConcurrent,
}, nil
}
}
@ -305,9 +297,9 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
var t APIToken
err := s.db.QueryRow(
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
keyHash,
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt)
if err != nil {
return nil, err
}
@ -328,7 +320,6 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
KeyPrefix: prefix,
RateLimitRPM: st.RateLimitRPM,
DailyBudgetUSD: st.DailyBudgetUSD,
MaxConcurrent: st.MaxConcurrent,
})
}
@ -336,18 +327,19 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
var rows *sql.Rows
var err error
if userID == 0 {
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id")
// Admin: list all
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens ORDER BY id")
} else {
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
}
if err != nil {
return tokens, nil
return tokens, nil // return static tokens even if DB query fails
}
defer rows.Close()
for rows.Next() {
var t APIToken
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil {
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt); err != nil {
return tokens, nil
}
tokens = append(tokens, t)
@ -363,9 +355,9 @@ func (s *Store) DeleteAPIToken(id int64) error {
func (s *Store) GetAPIToken(id int64) (*APIToken, error) {
var t APIToken
err := s.db.QueryRow(
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?",
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE id = ?",
id,
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt)
if err != nil {
return nil, err
}

View file

@ -1,300 +0,0 @@
package auth
import (
"database/sql"
"testing"
"time"
_ "modernc.org/sqlite"
)
func setupTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("opening test db: %v", err)
}
// Create tables
_, err = db.Exec(`
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
email TEXT DEFAULT '',
password_hash TEXT NOT NULL,
is_admin INTEGER DEFAULT 0,
totp_secret TEXT DEFAULT '',
totp_enabled INTEGER DEFAULT 0,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
);
CREATE TABLE sessions (
id TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
);
CREATE TABLE api_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
key_hash TEXT NOT NULL,
key_prefix TEXT NOT NULL,
user_id INTEGER NOT NULL,
rate_limit_rpm INTEGER DEFAULT 0,
daily_budget_usd REAL DEFAULT 0,
max_concurrent INTEGER DEFAULT 0,
created_at INTEGER NOT NULL,
last_used_at INTEGER DEFAULT 0
);
`)
if err != nil {
t.Fatalf("creating tables: %v", err)
}
return db
}
func TestCreateUser(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
store := NewStore(db, nil)
user, err := store.CreateUser("alice", "password123", true)
if err != nil {
t.Fatalf("CreateUser: %v", err)
}
if user.Username != "alice" {
t.Errorf("expected username 'alice', got '%s'", user.Username)
}
if !user.IsAdmin {
t.Error("expected admin user")
}
if user.ID == 0 {
t.Error("expected non-zero ID")
}
}
func TestGetUserByUsername(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
store := NewStore(db, nil)
store.CreateUser("bob", "password123", false)
user, err := store.GetUserByUsername("bob")
if err != nil {
t.Fatalf("GetUserByUsername: %v", err)
}
if user.Username != "bob" {
t.Errorf("expected 'bob', got '%s'", user.Username)
}
_, err = store.GetUserByUsername("nonexistent")
if err == nil {
t.Error("expected error for nonexistent user")
}
}
func TestCheckPassword(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
store := NewStore(db, nil)
store.CreateUser("charlie", "correctpassword", false)
user, _ := store.GetUserByUsername("charlie")
if !store.CheckPassword(user, "correctpassword") {
t.Error("correct password should match")
}
if store.CheckPassword(user, "wrongpassword") {
t.Error("wrong password should not match")
}
}
func TestUpdatePassword(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
store := NewStore(db, nil)
user, _ := store.CreateUser("dave", "oldpass12", false)
if err := store.UpdatePassword(user.ID, "newpass12"); err != nil {
t.Fatalf("UpdatePassword: %v", err)
}
user, _ = store.GetUserByUsername("dave")
if store.CheckPassword(user, "oldpass12") {
t.Error("old password should not work")
}
if !store.CheckPassword(user, "newpass12") {
t.Error("new password should work")
}
}
func TestDeleteUser(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
store := NewStore(db, nil)
user1, _ := store.CreateUser("admin1", "password1234", true)
user2, _ := store.CreateUser("user2", "password1234", false)
// Can delete non-admin
if err := store.DeleteUser(user2.ID); err != nil {
t.Fatalf("DeleteUser: %v", err)
}
// Cannot delete last admin
if err := store.DeleteUser(user1.ID); err == nil {
t.Error("should not be able to delete last admin")
}
}
func TestHasAnyUser(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
store := NewStore(db, nil)
if store.HasAnyUser() {
t.Error("should have no users initially")
}
store.CreateUser("first", "password1234", true)
if !store.HasAnyUser() {
t.Error("should have users after creation")
}
}
func TestSessionCRUD(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
store := NewStore(db, nil)
user, _ := store.CreateUser("sessuser", "password1234", false)
sessionID, err := store.CreateSession(user.ID, 1*time.Hour)
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
sess, err := store.GetSession(sessionID)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if sess.UserID != user.ID {
t.Errorf("expected user ID %d, got %d", user.ID, sess.UserID)
}
if err := store.DeleteSession(sessionID); err != nil {
t.Fatalf("DeleteSession: %v", err)
}
_, err = store.GetSession(sessionID)
if err == nil {
t.Error("session should be deleted")
}
}
func TestStaticTokenLookup(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
staticTokens := []StaticToken{
{Name: "test-token", Key: "sk-test-key-12345678", RateLimitRPM: 60, DailyBudgetUSD: 10.0, MaxConcurrent: 5},
}
store := NewStore(db, staticTokens)
token, err := store.LookupAPIToken("sk-test-key-12345678")
if err != nil {
t.Fatalf("LookupAPIToken: %v", err)
}
if token.Name != "test-token" {
t.Errorf("expected 'test-token', got '%s'", token.Name)
}
if token.ID != -1 {
t.Errorf("static token should have ID -1, got %d", token.ID)
}
if token.RateLimitRPM != 60 {
t.Errorf("expected RPM 60, got %d", token.RateLimitRPM)
}
if token.MaxConcurrent != 5 {
t.Errorf("expected max_concurrent 5, got %d", token.MaxConcurrent)
}
// Non-existent token
_, err = store.LookupAPIToken("nonexistent")
if err == nil {
t.Error("should error on nonexistent token")
}
}
func TestDBTokenCRUD(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
store := NewStore(db, nil)
user, _ := store.CreateUser("tokenuser", "password1234", false)
plainKey, token, err := store.CreateAPIToken(user.ID, "my-token", 100, 5.0)
if err != nil {
t.Fatalf("CreateAPIToken: %v", err)
}
if plainKey == "" {
t.Error("plain key should not be empty")
}
if token.Name != "my-token" {
t.Errorf("expected 'my-token', got '%s'", token.Name)
}
// Lookup by key
found, err := store.LookupAPIToken(plainKey)
if err != nil {
t.Fatalf("LookupAPIToken: %v", err)
}
if found.Name != "my-token" {
t.Errorf("expected 'my-token', got '%s'", found.Name)
}
// List tokens
tokens, err := store.ListAPITokens(user.ID)
if err != nil {
t.Fatalf("ListAPITokens: %v", err)
}
if len(tokens) != 1 {
t.Errorf("expected 1 token, got %d", len(tokens))
}
// Delete
if err := store.DeleteAPIToken(token.ID); err != nil {
t.Fatalf("DeleteAPIToken: %v", err)
}
_, err = store.LookupAPIToken(plainKey)
if err == nil {
t.Error("token should be deleted")
}
}
func TestSetStaticTokens(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
store := NewStore(db, nil)
_, err := store.LookupAPIToken("key1")
if err == nil {
t.Error("should not find token before setting")
}
store.SetStaticTokens([]StaticToken{
{Name: "new-token", Key: "key1"},
})
token, err := store.LookupAPIToken("key1")
if err != nil {
t.Fatalf("after SetStaticTokens: %v", err)
}
if token.Name != "new-token" {
t.Errorf("expected 'new-token', got '%s'", token.Name)
}
}

View file

@ -1,112 +0,0 @@
package cache
import (
"testing"
)
func TestCacheKey_Deterministic(t *testing.T) {
c := &Cache{}
model := "gpt-4"
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
key1 := c.cacheKey(model, body)
key2 := c.cacheKey(model, body)
if key1 != key2 {
t.Errorf("cache key not deterministic: %s != %s", key1, key2)
}
if key1 == "" {
t.Error("cache key is empty")
}
}
func TestCacheKey_DifferentInputs(t *testing.T) {
c := &Cache{}
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
key1 := c.cacheKey("gpt-4", body)
key2 := c.cacheKey("gpt-3.5", body)
if key1 == key2 {
t.Error("different models should produce different cache keys")
}
key3 := c.cacheKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"world"}]}`))
if key1 == key3 {
t.Error("different bodies should produce different cache keys")
}
}
func TestCacheKey_HasPrefix(t *testing.T) {
c := &Cache{}
key := c.cacheKey("gpt-4", []byte("test"))
if len(key) < 7 || key[:7] != "llm-gw:" {
t.Errorf("cache key should start with 'llm-gw:', got: %s", key)
}
}
func TestParseInfoInt(t *testing.T) {
info := "keyspace_hits:42\nkeyspace_misses:10\n"
hits := parseInfoInt(info, "keyspace_hits")
if hits != 42 {
t.Errorf("expected 42, got %d", hits)
}
misses := parseInfoInt(info, "keyspace_misses")
if misses != 10 {
t.Errorf("expected 10, got %d", misses)
}
unknown := parseInfoInt(info, "nonexistent")
if unknown != 0 {
t.Errorf("expected 0 for unknown key, got %d", unknown)
}
}
func TestParseInfoString(t *testing.T) {
info := "used_memory_human:1.5M\r\nother:value\r\n"
mem := parseInfoString(info, "used_memory_human")
if mem != "1.5M" {
t.Errorf("expected '1.5M', got '%s'", mem)
}
unknown := parseInfoString(info, "nonexistent")
if unknown != "" {
t.Errorf("expected empty for unknown key, got '%s'", unknown)
}
}
func TestParseKeyspaceKeys(t *testing.T) {
info := "# Keyspace\ndb0:keys=123,expires=45,avg_ttl=6789\n"
keys := parseKeyspaceKeys(info)
if keys != 123 {
t.Errorf("expected 123, got %d", keys)
}
empty := parseKeyspaceKeys("# Keyspace\n")
if empty != 0 {
t.Errorf("expected 0 for empty keyspace, got %d", empty)
}
}
func TestSplitLines(t *testing.T) {
lines := splitLines("a\nb\nc")
if len(lines) != 3 {
t.Errorf("expected 3 lines, got %d", len(lines))
}
if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
t.Errorf("unexpected lines: %v", lines)
}
single := splitLines("hello")
if len(single) != 1 || single[0] != "hello" {
t.Errorf("single line: %v", single)
}
}

View file

@ -12,17 +12,13 @@ import (
)
type Config struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Cache CacheConfig `yaml:"cache"`
Pricing PricingLookupConfig `yaml:"pricing_lookup"`
CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"`
Retry RetryConfig `yaml:"retry"`
Debug DebugConfig `yaml:"debug"`
CORS CORSConfig `yaml:"cors"`
Providers []ProviderConfig `yaml:"providers"`
Models []ModelConfig `yaml:"models"`
Tokens []TokenConfig `yaml:"tokens"`
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Cache CacheConfig `yaml:"cache"`
Pricing PricingLookupConfig `yaml:"pricing_lookup"`
Providers []ProviderConfig `yaml:"providers"`
Models []ModelConfig `yaml:"models"`
Tokens []TokenConfig `yaml:"tokens"`
}
type PricingLookupConfig struct {
@ -40,44 +36,14 @@ type TokenConfig struct {
Key string `yaml:"key"`
RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited
DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited
MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited
}
type ServerConfig struct {
Listen string `yaml:"listen"`
RequestTimeout time.Duration `yaml:"request_timeout"`
StreamingTimeout time.Duration `yaml:"streaming_timeout"`
MaxRequestBodyMB int `yaml:"max_request_body_mb"`
SessionSecret string `yaml:"session_secret"`
DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
}
type CircuitBreakerConfig struct {
Enabled bool `yaml:"enabled"`
ErrorThreshold float64 `yaml:"error_threshold"`
MinRequests int `yaml:"min_requests"`
CooldownDuration time.Duration `yaml:"cooldown_duration"`
}
type RetryConfig struct {
InitialBackoff time.Duration `yaml:"initial_backoff"`
MaxBackoff time.Duration `yaml:"max_backoff"`
Multiplier float64 `yaml:"multiplier"`
}
type DebugConfig struct {
Enabled bool `yaml:"enabled"`
MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies)
RetentionDays int `yaml:"retention_days"`
DataDir string `yaml:"data_dir"`
}
type CORSConfig struct {
Enabled bool `yaml:"enabled"`
AllowedOrigins []string `yaml:"allowed_origins"`
AllowedMethods []string `yaml:"allowed_methods"`
AllowedHeaders []string `yaml:"allowed_headers"`
MaxAge int `yaml:"max_age"`
Listen string `yaml:"listen"`
RequestTimeout time.Duration `yaml:"request_timeout"`
MaxRequestBodyMB int `yaml:"max_request_body_mb"`
SessionSecret string `yaml:"session_secret"`
DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
}
type DatabaseConfig struct {
@ -100,10 +66,8 @@ type ProviderConfig struct {
}
type ModelConfig struct {
Name string `yaml:"name"`
Aliases []string `yaml:"aliases"`
Routes []RouteConfig `yaml:"routes"`
LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost
Name string `yaml:"name"`
Routes []RouteConfig `yaml:"routes"`
}
type RouteConfig struct {
@ -164,43 +128,6 @@ func (c *Config) validate() error {
c.Pricing.RefreshInterval = 6 * time.Hour
}
// Server defaults
if c.Server.StreamingTimeout == 0 {
c.Server.StreamingTimeout = 5 * time.Minute
}
// Circuit breaker defaults
if c.CircuitBreaker.ErrorThreshold == 0 {
c.CircuitBreaker.ErrorThreshold = 0.5
}
if c.CircuitBreaker.MinRequests == 0 {
c.CircuitBreaker.MinRequests = 5
}
if c.CircuitBreaker.CooldownDuration == 0 {
c.CircuitBreaker.CooldownDuration = 30 * time.Second
}
// Retry defaults
if c.Retry.InitialBackoff == 0 {
c.Retry.InitialBackoff = 100 * time.Millisecond
}
if c.Retry.MaxBackoff == 0 {
c.Retry.MaxBackoff = 5 * time.Second
}
if c.Retry.Multiplier == 0 {
c.Retry.Multiplier = 2.0
}
// Debug defaults
if c.Debug.RetentionDays == 0 {
c.Debug.RetentionDays = 90
}
// CORS defaults
if c.CORS.MaxAge == 0 {
c.CORS.MaxAge = 300
}
if len(c.Providers) == 0 {
return fmt.Errorf("at least one provider is required")
}
@ -233,12 +160,6 @@ func (c *Config) validate() error {
return fmt.Errorf("duplicate model name: %s", m.Name)
}
modelNames[m.Name] = true
for _, alias := range m.Aliases {
if modelNames[alias] {
return fmt.Errorf("model alias %s conflicts with existing model or alias", alias)
}
modelNames[alias] = true
}
if len(m.Routes) == 0 {
return fmt.Errorf("model %s: at least one route is required", m.Name)
}

View file

@ -1,738 +0,0 @@
package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
// writeConfigFile creates a temporary YAML config file and returns its path.
func writeConfigFile(t *testing.T, content string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "config-*.yaml")
if err != nil {
t.Fatalf("creating temp file: %v", err)
}
if _, err := f.WriteString(content); err != nil {
f.Close()
t.Fatalf("writing temp file: %v", err)
}
f.Close()
return f.Name()
}
// minimalValidConfig returns a minimal valid YAML config string.
func minimalValidConfig() string {
return `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-test-key
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`
}
func TestLoad_ValidConfig(t *testing.T) {
path := writeConfigFile(t, `
server:
listen: "127.0.0.1:8080"
request_timeout: 60s
streaming_timeout: 120s
max_request_body_mb: 5
session_secret: "test-secret-1234567890abcdef1234567890abcdef"
database:
path: "/tmp/test.db"
retention_days: 30
pricing_lookup:
url: "https://pricing.example.com"
refresh_interval: 1h
circuit_breaker:
enabled: true
error_threshold: 0.3
min_requests: 10
cooldown_duration: 60s
retry:
initial_backoff: 200ms
max_backoff: 10s
multiplier: 3.0
debug:
enabled: true
max_body_bytes: 65536
retention_days: 60
cors:
enabled: true
allowed_origins:
- "https://example.com"
allowed_methods:
- GET
- POST
allowed_headers:
- Authorization
max_age: 600
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-test-key
priority: 2
timeout: 60s
- name: anthropic
base_url: https://api.anthropic.com/v1
api_key: sk-ant-test
priority: 1
timeout: 30s
models:
- name: gpt-4
aliases:
- gpt4
routes:
- provider: openai
model: gpt-4
pricing:
input: 30.0
output: 60.0
load_balancing: first
- name: claude-3
routes:
- provider: anthropic
model: claude-3-opus-20240229
tokens:
- name: test-token
key: tok-abc123
rate_limit_rpm: 100
daily_budget_usd: 10.0
max_concurrent: 5
`)
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() returned error: %v", err)
}
// Server
if cfg.Server.Listen != "127.0.0.1:8080" {
t.Errorf("Listen = %q, want %q", cfg.Server.Listen, "127.0.0.1:8080")
}
if cfg.Server.RequestTimeout != 60*time.Second {
t.Errorf("RequestTimeout = %v, want %v", cfg.Server.RequestTimeout, 60*time.Second)
}
if cfg.Server.StreamingTimeout != 120*time.Second {
t.Errorf("StreamingTimeout = %v, want %v", cfg.Server.StreamingTimeout, 120*time.Second)
}
if cfg.Server.MaxRequestBodyMB != 5 {
t.Errorf("MaxRequestBodyMB = %d, want %d", cfg.Server.MaxRequestBodyMB, 5)
}
if cfg.Server.SessionSecret != "test-secret-1234567890abcdef1234567890abcdef" {
t.Errorf("SessionSecret = %q, want %q", cfg.Server.SessionSecret, "test-secret-1234567890abcdef1234567890abcdef")
}
// Database
if cfg.Database.Path != "/tmp/test.db" {
t.Errorf("Database.Path = %q, want %q", cfg.Database.Path, "/tmp/test.db")
}
if cfg.Database.RetentionDays != 30 {
t.Errorf("Database.RetentionDays = %d, want %d", cfg.Database.RetentionDays, 30)
}
// Pricing
if cfg.Pricing.URL != "https://pricing.example.com" {
t.Errorf("Pricing.URL = %q, want %q", cfg.Pricing.URL, "https://pricing.example.com")
}
if cfg.Pricing.RefreshInterval != 1*time.Hour {
t.Errorf("Pricing.RefreshInterval = %v, want %v", cfg.Pricing.RefreshInterval, 1*time.Hour)
}
// Circuit breaker
if !cfg.CircuitBreaker.Enabled {
t.Error("CircuitBreaker.Enabled = false, want true")
}
if cfg.CircuitBreaker.ErrorThreshold != 0.3 {
t.Errorf("CircuitBreaker.ErrorThreshold = %v, want %v", cfg.CircuitBreaker.ErrorThreshold, 0.3)
}
if cfg.CircuitBreaker.MinRequests != 10 {
t.Errorf("CircuitBreaker.MinRequests = %d, want %d", cfg.CircuitBreaker.MinRequests, 10)
}
if cfg.CircuitBreaker.CooldownDuration != 60*time.Second {
t.Errorf("CircuitBreaker.CooldownDuration = %v, want %v", cfg.CircuitBreaker.CooldownDuration, 60*time.Second)
}
// Retry
if cfg.Retry.InitialBackoff != 200*time.Millisecond {
t.Errorf("Retry.InitialBackoff = %v, want %v", cfg.Retry.InitialBackoff, 200*time.Millisecond)
}
if cfg.Retry.MaxBackoff != 10*time.Second {
t.Errorf("Retry.MaxBackoff = %v, want %v", cfg.Retry.MaxBackoff, 10*time.Second)
}
if cfg.Retry.Multiplier != 3.0 {
t.Errorf("Retry.Multiplier = %v, want %v", cfg.Retry.Multiplier, 3.0)
}
// Debug
if !cfg.Debug.Enabled {
t.Error("Debug.Enabled = false, want true")
}
if cfg.Debug.MaxBodyBytes != 65536 {
t.Errorf("Debug.MaxBodyBytes = %d, want %d", cfg.Debug.MaxBodyBytes, 65536)
}
if cfg.Debug.RetentionDays != 60 {
t.Errorf("Debug.RetentionDays = %d, want %d", cfg.Debug.RetentionDays, 60)
}
// CORS
if !cfg.CORS.Enabled {
t.Error("CORS.Enabled = false, want true")
}
if cfg.CORS.MaxAge != 600 {
t.Errorf("CORS.MaxAge = %d, want %d", cfg.CORS.MaxAge, 600)
}
// Providers
if len(cfg.Providers) != 2 {
t.Fatalf("len(Providers) = %d, want 2", len(cfg.Providers))
}
if cfg.Providers[0].Name != "openai" {
t.Errorf("Providers[0].Name = %q, want %q", cfg.Providers[0].Name, "openai")
}
if cfg.Providers[0].Timeout != 60*time.Second {
t.Errorf("Providers[0].Timeout = %v, want %v", cfg.Providers[0].Timeout, 60*time.Second)
}
// Models
if len(cfg.Models) != 2 {
t.Fatalf("len(Models) = %d, want 2", len(cfg.Models))
}
if cfg.Models[0].LoadBalancing != "first" {
t.Errorf("Models[0].LoadBalancing = %q, want %q", cfg.Models[0].LoadBalancing, "first")
}
if len(cfg.Models[0].Aliases) != 1 || cfg.Models[0].Aliases[0] != "gpt4" {
t.Errorf("Models[0].Aliases = %v, want [gpt4]", cfg.Models[0].Aliases)
}
if cfg.Models[0].Routes[0].Pricing.Input != 30.0 {
t.Errorf("Models[0].Routes[0].Pricing.Input = %v, want 30.0", cfg.Models[0].Routes[0].Pricing.Input)
}
// Tokens
if len(cfg.Tokens) != 1 {
t.Fatalf("len(Tokens) = %d, want 1", len(cfg.Tokens))
}
if cfg.Tokens[0].Name != "test-token" {
t.Errorf("Tokens[0].Name = %q, want %q", cfg.Tokens[0].Name, "test-token")
}
if cfg.Tokens[0].RateLimitRPM != 100 {
t.Errorf("Tokens[0].RateLimitRPM = %d, want 100", cfg.Tokens[0].RateLimitRPM)
}
}
func TestValidate_Defaults(t *testing.T) {
path := writeConfigFile(t, minimalValidConfig())
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() returned error: %v", err)
}
tests := []struct {
name string
got any
want any
}{
// Server defaults
{"Server.Listen", cfg.Server.Listen, "0.0.0.0:3000"},
{"Server.RequestTimeout", cfg.Server.RequestTimeout, 300 * time.Second},
{"Server.StreamingTimeout", cfg.Server.StreamingTimeout, 5 * time.Minute},
{"Server.MaxRequestBodyMB", cfg.Server.MaxRequestBodyMB, 10},
// Database defaults
{"Database.Path", cfg.Database.Path, "gateway.db"},
{"Database.RetentionDays", cfg.Database.RetentionDays, 90},
// Pricing defaults
{"Pricing.RefreshInterval", cfg.Pricing.RefreshInterval, 6 * time.Hour},
// Circuit breaker defaults
{"CircuitBreaker.ErrorThreshold", cfg.CircuitBreaker.ErrorThreshold, 0.5},
{"CircuitBreaker.MinRequests", cfg.CircuitBreaker.MinRequests, 5},
{"CircuitBreaker.CooldownDuration", cfg.CircuitBreaker.CooldownDuration, 30 * time.Second},
// Retry defaults
{"Retry.InitialBackoff", cfg.Retry.InitialBackoff, 100 * time.Millisecond},
{"Retry.MaxBackoff", cfg.Retry.MaxBackoff, 5 * time.Second},
{"Retry.Multiplier", cfg.Retry.Multiplier, 2.0},
// Debug defaults
{"Debug.MaxBodyBytes", cfg.Debug.MaxBodyBytes, 0},
{"Debug.RetentionDays", cfg.Debug.RetentionDays, 90},
// CORS defaults
{"CORS.MaxAge", cfg.CORS.MaxAge, 300},
// Provider defaults
{"Providers[0].Timeout", cfg.Providers[0].Timeout, 120 * time.Second},
{"Providers[0].Priority", cfg.Providers[0].Priority, 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Compare using formatted strings to handle different numeric types
gotStr := formatValue(tt.got)
wantStr := formatValue(tt.want)
if gotStr != wantStr {
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
}
})
}
// SessionSecret should be auto-generated (non-empty, 64 hex chars)
if cfg.Server.SessionSecret == "" {
t.Error("SessionSecret should be auto-generated when empty")
}
if len(cfg.Server.SessionSecret) != 64 {
t.Errorf("SessionSecret length = %d, want 64 hex chars", len(cfg.Server.SessionSecret))
}
}
func formatValue(v any) string {
switch val := v.(type) {
case time.Duration:
return val.String()
case float64:
return fmt.Sprintf("%g", val)
case int:
return fmt.Sprintf("%d", val)
case string:
return val
default:
return fmt.Sprintf("%v", val)
}
}
func TestLoad_FileNotFound(t *testing.T) {
_, err := Load(filepath.Join(t.TempDir(), "nonexistent.yaml"))
if err == nil {
t.Fatal("Load() should return error for nonexistent file")
}
}
func TestLoad_InvalidYAML(t *testing.T) {
path := writeConfigFile(t, `{{{invalid yaml`)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error for invalid YAML")
}
}
func TestValidate_DuplicateProviderNames(t *testing.T) {
path := writeConfigFile(t, `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key1
- name: openai
base_url: https://api.openai.com/v2
api_key: sk-key2
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error for duplicate provider names")
}
wantSubstr := "duplicate provider name: openai"
if !strings.Contains(err.Error(), wantSubstr) {
t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
}
}
func TestValidate_DuplicateModelNames(t *testing.T) {
path := writeConfigFile(t, `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key1
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
- name: gpt-4
routes:
- provider: openai
model: gpt-4-turbo
`)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error for duplicate model names")
}
wantSubstr := "duplicate model name: gpt-4"
if !strings.Contains(err.Error(), wantSubstr) {
t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
}
}
func TestValidate_AliasConflicts(t *testing.T) {
tests := []struct {
name string
config string
wantErr string
}{
{
name: "alias conflicts with model name",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key1
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
- name: claude-3
aliases:
- gpt-4
routes:
- provider: openai
model: claude-3
`,
wantErr: "model alias gpt-4 conflicts with existing model or alias",
},
{
name: "alias conflicts with another alias",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key1
models:
- name: gpt-4
aliases:
- fast-model
routes:
- provider: openai
model: gpt-4
- name: claude-3
aliases:
- fast-model
routes:
- provider: openai
model: claude-3
`,
wantErr: "model alias fast-model conflicts with existing model or alias",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path := writeConfigFile(t, tt.config)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error for alias conflicts")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
}
})
}
}
func TestValidate_MissingRequiredFields(t *testing.T) {
tests := []struct {
name string
config string
wantErr string
}{
{
name: "no providers",
config: `models: [{name: test, routes: [{provider: x, model: y}]}]`,
wantErr: "at least one provider is required",
},
{
name: "no models",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
`,
wantErr: "at least one model is required",
},
{
name: "provider missing name",
config: `
providers:
- base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`,
wantErr: "provider 0: name, base_url, and api_key are required",
},
{
name: "provider missing base_url",
config: `
providers:
- name: openai
api_key: sk-key
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`,
wantErr: "provider 0: name, base_url, and api_key are required",
},
{
name: "provider missing api_key",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`,
wantErr: "provider 0: name, base_url, and api_key are required",
},
{
name: "model missing name",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- routes:
- provider: openai
model: gpt-4
`,
wantErr: "model 0: name is required",
},
{
name: "model missing routes",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
`,
wantErr: "model gpt-4: at least one route is required",
},
{
name: "route missing provider",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
routes:
- model: gpt-4
`,
wantErr: "model gpt-4 route 0: provider and model are required",
},
{
name: "route missing model",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
routes:
- provider: openai
`,
wantErr: "model gpt-4 route 0: provider and model are required",
},
{
name: "route references unknown provider",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
routes:
- provider: anthropic
model: gpt-4
`,
wantErr: "model gpt-4 route 0: unknown provider anthropic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path := writeConfigFile(t, tt.config)
_, err := Load(path)
if err == nil {
t.Fatalf("Load() should return error, want %q", tt.wantErr)
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
}
})
}
}
func TestProviderByName(t *testing.T) {
path := writeConfigFile(t, `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-openai
- name: anthropic
base_url: https://api.anthropic.com/v1
api_key: sk-anthropic
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`)
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() returned error: %v", err)
}
tests := []struct {
name string
lookup string
wantNil bool
wantName string
}{
{"existing provider openai", "openai", false, "openai"},
{"existing provider anthropic", "anthropic", false, "anthropic"},
{"nonexistent provider", "google", true, ""},
{"empty name", "", true, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := cfg.ProviderByName(tt.lookup)
if tt.wantNil {
if p != nil {
t.Errorf("ProviderByName(%q) = %v, want nil", tt.lookup, p)
}
} else {
if p == nil {
t.Fatalf("ProviderByName(%q) = nil, want provider", tt.lookup)
}
if p.Name != tt.wantName {
t.Errorf("ProviderByName(%q).Name = %q, want %q", tt.lookup, p.Name, tt.wantName)
}
}
})
}
// Verify returned pointer refers to the actual config entry
p := cfg.ProviderByName("openai")
if p.APIKey != "sk-openai" {
t.Errorf("ProviderByName(openai).APIKey = %q, want %q", p.APIKey, "sk-openai")
}
}
func TestLoad_EnvironmentVariableExpansion(t *testing.T) {
t.Setenv("TEST_API_KEY", "sk-from-env")
t.Setenv("TEST_BASE_URL", "https://env.example.com/v1")
t.Setenv("TEST_PROVIDER_NAME", "env-provider")
path := writeConfigFile(t, `
providers:
- name: $TEST_PROVIDER_NAME
base_url: ${TEST_BASE_URL}
api_key: ${TEST_API_KEY}
models:
- name: test-model
routes:
- provider: env-provider
model: gpt-4
`)
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() returned error: %v", err)
}
if cfg.Providers[0].Name != "env-provider" {
t.Errorf("Provider.Name = %q, want %q", cfg.Providers[0].Name, "env-provider")
}
if cfg.Providers[0].BaseURL != "https://env.example.com/v1" {
t.Errorf("Provider.BaseURL = %q, want %q", cfg.Providers[0].BaseURL, "https://env.example.com/v1")
}
if cfg.Providers[0].APIKey != "sk-from-env" {
t.Errorf("Provider.APIKey = %q, want %q", cfg.Providers[0].APIKey, "sk-from-env")
}
}
func TestLoad_UnsetEnvVarExpandsToEmpty(t *testing.T) {
// Ensure the variable is not set
t.Setenv("TEST_UNSET_VAR", "")
os.Unsetenv("TEST_UNSET_VAR")
path := writeConfigFile(t, `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: ${TEST_UNSET_VAR}
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error when env var expands to empty required field")
}
// api_key will be empty, so validation should catch it
if !strings.Contains(err.Error(), "api_key are required") {
t.Errorf("error = %q, want to contain api_key validation message", err.Error())
}
}

View file

@ -1,27 +0,0 @@
package config
import (
"log"
"os"
"os/signal"
"syscall"
)
// WatchReload listens for SIGHUP and calls the callback with the new config.
func WatchReload(configPath string, callback func(*Config)) {
sighup := make(chan os.Signal, 1)
signal.Notify(sighup, syscall.SIGHUP)
go func() {
for range sighup {
log.Println("SIGHUP received, reloading config...")
newCfg, err := Load(configPath)
if err != nil {
log.Printf("ERROR: config reload failed: %v", err)
continue
}
callback(newCfg)
log.Println("Config reloaded successfully")
}
}()
}

View file

@ -7,8 +7,6 @@ import (
"strconv"
"time"
"github.com/go-chi/chi/v5"
"llm-gateway/internal/auth"
"llm-gateway/internal/cache"
"llm-gateway/internal/provider"
@ -60,7 +58,6 @@ type TokenUsageStats struct {
// RequestLogEntry represents a single request log row.
type RequestLogEntry struct {
RequestID string `json:"request_id"`
Timestamp int64 `json:"timestamp"`
TokenName string `json:"token_name"`
Model string `json:"model"`
@ -107,8 +104,6 @@ type StatsAPI struct {
authStore *auth.Store
healthTracker *provider.HealthTracker
cache *cache.Cache
auditLogger *storage.AuditLogger
debugLogger *storage.DebugLogger
}
func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI {
@ -125,16 +120,6 @@ func (s *StatsAPI) SetCache(c *cache.Cache) {
s.cache = c
}
// SetAuditLogger sets the audit logger.
func (s *StatsAPI) SetAuditLogger(al *storage.AuditLogger) {
s.auditLogger = al
}
// SetDebugLogger sets the debug logger.
func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) {
s.debugLogger = dl
}
// TokenNamesForUser returns the token names that belong to the user.
// Admins get nil (no filter), non-admins get their token names.
func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string {
@ -340,7 +325,7 @@ func (s *StatsAPI) GetLogs(tokenNames []string, page int, model, token, status s
}
// Get page
query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model,
query := `SELECT timestamp, token_name, model, provider, provider_model,
input_tokens, output_tokens, cost_usd, latency_ms, status,
COALESCE(error_message, ''), streaming, cached
FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?`
@ -356,7 +341,7 @@ func (s *StatsAPI) GetLogs(tokenNames []string, page int, model, token, status s
for rows.Next() {
var l RequestLogEntry
var streaming, cached int
rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
rows.Scan(&l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
&l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status,
&l.ErrorMessage, &streaming, &cached)
l.Streaming = streaming == 1
@ -639,79 +624,6 @@ func (s *StatsAPI) CacheStats(w http.ResponseWriter, r *http.Request) {
writeJSON(w, stats)
}
// AuditLogs serves the audit log API (admin-only).
func (s *StatsAPI) AuditLogs(w http.ResponseWriter, r *http.Request) {
if s.auditLogger == nil {
writeJSON(w, map[string]any{"entries": []any{}, "total": 0})
return
}
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
action := r.URL.Query().Get("action")
since := time.Now().AddDate(0, 0, -30).Unix()
if sinceStr := r.URL.Query().Get("since"); sinceStr != "" {
if s, err := strconv.ParseInt(sinceStr, 10, 64); err == nil {
since = s
}
}
result := s.auditLogger.Query(since, action, page, 50)
writeJSON(w, result)
}
// DebugToggle enables/disables debug logging at runtime.
func (s *StatsAPI) DebugToggle(w http.ResponseWriter, r *http.Request) {
if s.debugLogger == nil {
writeJSON(w, map[string]any{"error": "debug logger not configured"})
return
}
var req struct {
Enabled bool `json:"enabled"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJSON(w, map[string]string{"error": "invalid JSON"})
return
}
s.debugLogger.SetEnabled(req.Enabled)
writeJSON(w, map[string]any{"enabled": s.debugLogger.IsEnabled()})
}
// DebugStatus returns whether debug logging is enabled.
func (s *StatsAPI) DebugStatus(w http.ResponseWriter, r *http.Request) {
enabled := false
if s.debugLogger != nil {
enabled = s.debugLogger.IsEnabled()
}
writeJSON(w, map[string]any{"enabled": enabled})
}
// DebugLogs serves paginated debug log entries.
func (s *StatsAPI) DebugLogs(w http.ResponseWriter, r *http.Request) {
if s.debugLogger == nil {
writeJSON(w, map[string]any{"entries": []any{}, "total": 0})
return
}
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
result := s.debugLogger.Query(page, 50)
writeJSON(w, result)
}
// DebugLogByRequestID serves a single debug log entry by request ID.
func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) {
if s.debugLogger == nil {
w.WriteHeader(http.StatusNotFound)
writeJSON(w, map[string]string{"error": "debug logger not configured"})
return
}
requestID := chi.URLParam(r, "requestID")
entry := s.debugLogger.GetByRequestID(requestID)
if entry == nil {
w.WriteHeader(http.StatusNotFound)
writeJSON(w, map[string]string{"error": "not found"})
return
}
writeJSON(w, entry)
}
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(v)

View file

@ -1,297 +0,0 @@
package dashboard
import (
"encoding/csv"
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"llm-gateway/internal/auth"
"llm-gateway/internal/storage"
)
type ExportHandler struct {
db *storage.DB
authStore *auth.Store
}
func NewExportHandler(db *storage.DB, authStore *auth.Store) *ExportHandler {
return &ExportHandler{db: db, authStore: authStore}
}
// ExportLogs exports request logs as CSV or JSON.
func (e *ExportHandler) ExportLogs(w http.ResponseWriter, r *http.Request) {
format := r.URL.Query().Get("format")
if format == "" {
format = "json"
}
// Build query
where := "WHERE 1=1"
var args []any
if from := r.URL.Query().Get("from"); from != "" {
if ts, err := strconv.ParseInt(from, 10, 64); err == nil {
where += " AND timestamp >= ?"
args = append(args, ts)
}
}
if to := r.URL.Query().Get("to"); to != "" {
if ts, err := strconv.ParseInt(to, 10, 64); err == nil {
where += " AND timestamp <= ?"
args = append(args, ts)
}
}
if model := r.URL.Query().Get("model"); model != "" {
where += " AND model = ?"
args = append(args, model)
}
if token := r.URL.Query().Get("token"); token != "" {
where += " AND token_name = ?"
args = append(args, token)
}
if status := r.URL.Query().Get("status"); status != "" {
where += " AND status = ?"
args = append(args, status)
}
// Token filtering for non-admins
user := auth.UserFromContext(r.Context())
if user != nil && !user.IsAdmin {
tokens, err := e.authStore.ListAPITokens(user.ID)
if err != nil || len(tokens) == 0 {
where += " AND 1=0"
} else {
where += " AND token_name IN ("
for i, t := range tokens {
if i > 0 {
where += ","
}
where += "?"
args = append(args, t.Name)
}
where += ")"
}
}
query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model,
input_tokens, output_tokens, cost_usd, latency_ms, status,
COALESCE(error_message, ''), streaming, cached
FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT 100000`
rows, err := e.db.Query(query, args...)
if err != nil {
http.Error(w, "query failed", http.StatusInternalServerError)
return
}
defer rows.Close()
type logRow struct {
RequestID string `json:"request_id"`
Timestamp int64 `json:"timestamp"`
TokenName string `json:"token_name"`
Model string `json:"model"`
Provider string `json:"provider"`
ProviderModel string `json:"provider_model"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CostUSD float64 `json:"cost_usd"`
LatencyMS int64 `json:"latency_ms"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
Streaming bool `json:"streaming"`
Cached bool `json:"cached"`
}
var results []logRow
for rows.Next() {
var l logRow
var streaming, cached int
rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
&l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status,
&l.ErrorMessage, &streaming, &cached)
l.Streaming = streaming == 1
l.Cached = cached == 1
results = append(results, l)
}
now := time.Now().Format("20060102-150405")
switch format {
case "csv":
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.csv", now))
writer := csv.NewWriter(w)
writer.Write([]string{"request_id", "timestamp", "token_name", "model", "provider", "provider_model",
"input_tokens", "output_tokens", "cost_usd", "latency_ms", "status", "error_message", "streaming", "cached"})
for _, l := range results {
writer.Write([]string{
l.RequestID,
strconv.FormatInt(l.Timestamp, 10),
l.TokenName, l.Model, l.Provider, l.ProviderModel,
strconv.Itoa(l.InputTokens), strconv.Itoa(l.OutputTokens),
fmt.Sprintf("%.8f", l.CostUSD),
strconv.FormatInt(l.LatencyMS, 10),
l.Status, l.ErrorMessage,
strconv.FormatBool(l.Streaming), strconv.FormatBool(l.Cached),
})
}
writer.Flush()
default:
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.json", now))
json.NewEncoder(w).Encode(results)
}
}
// ExportStats exports aggregated stats as CSV or JSON.
func (e *ExportHandler) ExportStats(w http.ResponseWriter, r *http.Request) {
format := r.URL.Query().Get("format")
if format == "" {
format = "json"
}
statsType := r.URL.Query().Get("type")
if statsType == "" {
statsType = "summary"
}
now := time.Now().Format("20060102-150405")
since := time.Now().AddDate(0, -1, 0).Unix()
switch statsType {
case "models":
rows, err := e.db.Query(`SELECT model, COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
COALESCE(SUM(cost_usd), 0), COALESCE(AVG(latency_ms), 0)
FROM request_logs WHERE timestamp >= ? GROUP BY model ORDER BY requests DESC`, since)
if err != nil {
http.Error(w, "query failed", http.StatusInternalServerError)
return
}
defer rows.Close()
type modelRow struct {
Model string `json:"model"`
Requests int `json:"requests"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CostUSD float64 `json:"cost_usd"`
AvgLatencyMS float64 `json:"avg_latency_ms"`
}
var results []modelRow
for rows.Next() {
var m modelRow
rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS)
results = append(results, m)
}
if format == "csv" {
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.csv", now))
writer := csv.NewWriter(w)
writer.Write([]string{"model", "requests", "input_tokens", "output_tokens", "cost_usd", "avg_latency_ms"})
for _, m := range results {
writer.Write([]string{m.Model, strconv.Itoa(m.Requests), strconv.Itoa(m.InputTokens),
strconv.Itoa(m.OutputTokens), fmt.Sprintf("%.8f", m.CostUSD), fmt.Sprintf("%.2f", m.AvgLatencyMS)})
}
writer.Flush()
} else {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.json", now))
json.NewEncoder(w).Encode(results)
}
case "providers":
rows, err := e.db.Query(`SELECT provider, COUNT(*) as requests,
COALESCE(SUM(CASE WHEN status='success' THEN 1 ELSE 0 END), 0),
COALESCE(SUM(CASE WHEN status='error' THEN 1 ELSE 0 END), 0),
COALESCE(AVG(latency_ms), 0), COALESCE(SUM(cost_usd), 0)
FROM request_logs WHERE timestamp >= ? GROUP BY provider ORDER BY requests DESC`, since)
if err != nil {
http.Error(w, "query failed", http.StatusInternalServerError)
return
}
defer rows.Close()
type providerRow struct {
Provider string `json:"provider"`
Requests int `json:"requests"`
Successes int `json:"successes"`
Errors int `json:"errors"`
AvgLatencyMS float64 `json:"avg_latency_ms"`
CostUSD float64 `json:"cost_usd"`
}
var results []providerRow
for rows.Next() {
var p providerRow
rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD)
results = append(results, p)
}
if format == "csv" {
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.csv", now))
writer := csv.NewWriter(w)
writer.Write([]string{"provider", "requests", "successes", "errors", "avg_latency_ms", "cost_usd"})
for _, p := range results {
writer.Write([]string{p.Provider, strconv.Itoa(p.Requests), strconv.Itoa(p.Successes),
strconv.Itoa(p.Errors), fmt.Sprintf("%.2f", p.AvgLatencyMS), fmt.Sprintf("%.8f", p.CostUSD)})
}
writer.Flush()
} else {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.json", now))
json.NewEncoder(w).Encode(results)
}
case "tokens":
rows, err := e.db.Query(`SELECT token_name, COUNT(*) as requests,
COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
COALESCE(SUM(cost_usd), 0)
FROM request_logs WHERE timestamp >= ? GROUP BY token_name ORDER BY requests DESC`, since)
if err != nil {
http.Error(w, "query failed", http.StatusInternalServerError)
return
}
defer rows.Close()
type tokenRow struct {
TokenName string `json:"token_name"`
Requests int `json:"requests"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CostUSD float64 `json:"cost_usd"`
}
var results []tokenRow
for rows.Next() {
var t tokenRow
rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD)
results = append(results, t)
}
if format == "csv" {
w.Header().Set("Content-Type", "text/csv")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.csv", now))
writer := csv.NewWriter(w)
writer.Write([]string{"token_name", "requests", "input_tokens", "output_tokens", "cost_usd"})
for _, t := range results {
writer.Write([]string{t.TokenName, strconv.Itoa(t.Requests), strconv.Itoa(t.InputTokens),
strconv.Itoa(t.OutputTokens), fmt.Sprintf("%.8f", t.CostUSD)})
}
writer.Flush()
} else {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.json", now))
json.NewEncoder(w).Encode(results)
}
default: // summary
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-summary-%s.json", now))
statsAPI := NewStatsAPI(e.db, e.authStore)
result := statsAPI.GetSummary(nil)
json.NewEncoder(w).Encode(result)
}
}

View file

@ -11,7 +11,6 @@ import (
"llm-gateway/internal/auth"
"llm-gateway/internal/cache"
"llm-gateway/internal/provider"
"llm-gateway/internal/storage"
)
//go:embed templates/*.html templates/partials/*.html
@ -126,24 +125,15 @@ type PageData struct {
FilterStatus string
// Models routing page data
ModelRoutes []provider.ModelRouteInfo
// Audit page data
AuditResult *storage.AuditQueryResult
AuditFilterActions []string
FilterAction string
// Debug page data
DebugResult *storage.DebugLogQueryResult
DebugEnabled bool
}
// Dashboard serves the HTMX-based dashboard pages.
type Dashboard struct {
templates *template.Template
authStore *auth.Store
statsAPI *StatsAPI
registry *provider.Registry
cache *cache.Cache
auditLogger *storage.AuditLogger
debugLogger *storage.DebugLogger
templates *template.Template
authStore *auth.Store
statsAPI *StatsAPI
registry *provider.Registry
cache *cache.Cache
}
// NewDashboard creates a new Dashboard handler.
@ -172,16 +162,6 @@ func (d *Dashboard) SetCache(c *cache.Cache) {
d.cache = c
}
// SetAuditLogger sets the audit logger for the audit page.
func (d *Dashboard) SetAuditLogger(al *storage.AuditLogger) {
d.auditLogger = al
}
// SetDebugLogger sets the debug logger for the debug page.
func (d *Dashboard) SetDebugLogger(dl *storage.DebugLogger) {
d.debugLogger = dl
}
// LoginPage serves the login page.
func (d *Dashboard) LoginPage(w http.ResponseWriter, r *http.Request) {
if !d.authStore.HasAnyUser() {
@ -318,62 +298,6 @@ func (d *Dashboard) UsersPage(w http.ResponseWriter, r *http.Request) {
})
}
// AuditPage serves the audit log view (admin only).
func (d *Dashboard) AuditPage(w http.ResponseWriter, r *http.Request) {
user := auth.UserFromContext(r.Context())
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
if page < 1 {
page = 1
}
action := r.URL.Query().Get("action")
since := time.Now().AddDate(0, 0, -30).Unix()
var auditResult *storage.AuditQueryResult
if d.auditLogger != nil {
auditResult = d.auditLogger.Query(since, action, page, 50)
} else {
auditResult = &storage.AuditQueryResult{Entries: []storage.AuditEntry{}, Page: 1, TotalPages: 1}
}
// Common audit action types for the filter dropdown
actions := []string{"login", "logout", "create_user", "delete_user", "create_token", "delete_token", "change_password", "setup_totp", "disable_totp"}
d.renderDashboardPage(w, r, "partials/audit.html", PageData{
ActivePage: "audit",
User: user,
AuditResult: auditResult,
AuditFilterActions: actions,
FilterAction: action,
})
}
// DebugPage serves the debug logging view (admin only).
func (d *Dashboard) DebugPage(w http.ResponseWriter, r *http.Request) {
user := auth.UserFromContext(r.Context())
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
if page < 1 {
page = 1
}
var debugResult *storage.DebugLogQueryResult
debugEnabled := false
if d.debugLogger != nil {
debugResult = d.debugLogger.QueryFull(page, 50)
debugEnabled = d.debugLogger.IsEnabled()
} else {
debugResult = &storage.DebugLogQueryResult{Entries: []storage.DebugLogEntry{}, Page: 1, TotalPages: 1}
}
d.renderDashboardPage(w, r, "partials/debug.html", PageData{
ActivePage: "debug",
User: user,
DebugResult: debugResult,
DebugEnabled: debugEnabled,
})
}
// SettingsPage serves the settings view.
func (d *Dashboard) SettingsPage(w http.ResponseWriter, r *http.Request) {
user := auth.UserFromContext(r.Context())

View file

@ -160,24 +160,6 @@
.badge-error { background: var(--accent-red-bg); color: var(--accent-red); }
.badge-cached { background: var(--accent-blue-bg); color: var(--accent-blue); }
.badge-priority { background: var(--bg-tertiary); color: var(--text-secondary); }
.badge-open { background: var(--accent-red-bg); color: var(--accent-red); }
.badge-half-open { background: var(--accent-yellow-bg); color: var(--accent-yellow); }
/* Toggle switch */
.toggle-switch { position: relative; display: inline-block; width: 44px; height: 24px; }
.toggle-switch input { opacity: 0; width: 0; height: 0; }
.toggle-slider { position: absolute; cursor: pointer; top: 0; left: 0; right: 0; bottom: 0; background: var(--bg-tertiary); border-radius: 24px; transition: 0.2s; }
.toggle-slider:before { content: ""; position: absolute; height: 18px; width: 18px; left: 3px; bottom: 3px; background: var(--text-secondary); border-radius: 50%; transition: 0.2s; }
.toggle-switch input:checked + .toggle-slider { background: var(--accent-blue); }
.toggle-switch input:checked + .toggle-slider:before { transform: translateX(20px); background: #fff; }
/* Code block for debug bodies */
.code-block { background: var(--bg-primary); border: 1px solid var(--border-color); border-radius: 6px; padding: 12px; font-family: monospace; font-size: 0.8rem; white-space: pre-wrap; word-break: break-all; max-height: 300px; overflow-y: auto; }
/* Export button inline */
.export-links { display: inline-flex; gap: 6px; margin-left: 12px; }
.export-links a { font-size: 0.7rem; color: var(--text-muted); text-decoration: none; padding: 2px 6px; border: 1px solid var(--border-color); border-radius: 4px; }
.export-links a:hover { color: var(--text-primary); border-color: var(--text-muted); }
.page-header { display: flex; align-items: center; gap: 12px; margin-bottom: 20px; }
.page-header h1 { font-size: 1.3rem; color: var(--text-heading); }
@ -286,8 +268,6 @@ window.matchMedia('(prefers-color-scheme: light)').addEventListener('change', fu
<a href="/tokens" hx-get="/tokens" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "tokens"}}class="active"{{end}}>API Tokens</a>
{{if .User.IsAdmin}}
<a href="/users" hx-get="/users" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "users"}}class="active"{{end}}>Users</a>
<a href="/audit" hx-get="/audit" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "audit"}}class="active"{{end}}>Audit Log</a>
<a href="/debug" hx-get="/debug" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "debug"}}class="active"{{end}}>Debug</a>
{{end}}
<a href="/settings" hx-get="/settings" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "settings"}}class="active"{{end}}>Settings</a>
</nav>

View file

@ -1,83 +0,0 @@
{{define "content"}}
<div class="page-header">
<h1>Audit Log</h1>
<span style="font-size:0.85rem;color:var(--text-muted)">{{.AuditResult.Total}} total</span>
</div>
<div class="filter-bar">
<select id="filter-action" onchange="applyAuditFilter()">
<option value="">All Actions</option>
{{range .AuditFilterActions}}<option value="{{.}}" {{if eq . $.FilterAction}}selected{{end}}>{{.}}</option>{{end}}
</select>
<button class="btn btn-sm btn-outline" onclick="clearAuditFilter()">Clear</button>
</div>
<div class="section">
<table>
<thead>
<tr>
<th>Time</th>
<th>User</th>
<th>Action</th>
<th>Target</th>
<th>Details</th>
<th>IP</th>
</tr>
</thead>
<tbody>
{{range .AuditResult.Entries}}
<tr>
<td>{{formatTimeDetail .Timestamp}}</td>
<td>{{.Username}}</td>
<td><span class="badge badge-priority">{{.Action}}</span></td>
<td>{{if .TargetType}}{{.TargetType}}{{if .TargetID}}/{{.TargetID}}{{end}}{{else}}-{{end}}</td>
<td style="max-width:300px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="{{.Details}}">{{if .Details}}{{.Details}}{{else}}-{{end}}</td>
<td>{{if .IPAddress}}{{.IPAddress}}{{else}}-{{end}}</td>
</tr>
{{end}}
{{if not .AuditResult.Entries}}
<tr><td colspan="6" style="text-align:center;color:var(--text-muted);padding:24px;">No audit log entries</td></tr>
{{end}}
</tbody>
</table>
{{if gt .AuditResult.TotalPages 1}}
<div class="pagination">
<button {{if le .AuditResult.Page 1}}disabled{{end}} onclick="goToAuditPage(1)">First</button>
<button {{if le .AuditResult.Page 1}}disabled{{end}} onclick="goToAuditPage({{subInt .AuditResult.Page 1}})">Prev</button>
{{$page := .AuditResult.Page}}
{{$total := .AuditResult.TotalPages}}
{{range seq (paginationStart $page $total) (paginationEnd $page $total)}}
<button class="{{if eq . $page}}active{{end}}" onclick="goToAuditPage({{.}})">{{.}}</button>
{{end}}
<button {{if ge .AuditResult.Page .AuditResult.TotalPages}}disabled{{end}} onclick="goToAuditPage({{addInt .AuditResult.Page 1}})">Next</button>
<button {{if ge .AuditResult.Page .AuditResult.TotalPages}}disabled{{end}} onclick="goToAuditPage({{.AuditResult.TotalPages}})">Last</button>
<span class="page-info">Page {{.AuditResult.Page}} of {{.AuditResult.TotalPages}}</span>
</div>
{{end}}
</div>
<script>
function buildAuditURL(page) {
var params = [];
var action = document.getElementById('filter-action').value;
if (action) params.push('action=' + encodeURIComponent(action));
if (page > 1) params.push('page=' + page);
return '/audit' + (params.length ? '?' + params.join('&') : '');
}
function applyAuditFilter() {
var url = buildAuditURL(1);
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
history.pushState({}, '', url);
}
function goToAuditPage(page) {
var url = buildAuditURL(page);
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
history.pushState({}, '', url);
}
function clearAuditFilter() {
document.getElementById('filter-action').value = '';
applyAuditFilter();
}
</script>
{{end}}

View file

@ -25,8 +25,6 @@
<div class="health-item">
<span class="provider-name">{{.Provider}}</span>
<span class="badge badge-{{.Status}}">{{.Status}}</span>
{{if eq .CircuitState "open"}}<span class="badge badge-open">circuit open</span>{{end}}
{{if eq .CircuitState "half-open"}}<span class="badge badge-half-open">half-open</span>{{end}}
<span style="font-size:0.75rem;color:var(--text-muted)">{{printf "%.0f" .AvgLatency}}ms avg | {{formatPct .ErrorRate}} errors</span>
</div>
{{end}}
@ -73,7 +71,7 @@
{{if .Models}}
<div class="section">
<h2>Models<span class="export-links"><a href="/api/export/stats?format=csv&type=models" target="_blank">CSV</a><a href="/api/export/stats?format=json&type=models" target="_blank">JSON</a></span></h2>
<h2>Models</h2>
<table>
<thead><tr><th>Model</th><th>Requests</th><th>Tokens (in/out)</th><th>Cost</th><th>Avg Latency</th></tr></thead>
<tbody>
@ -93,7 +91,7 @@
{{if .Providers}}
<div class="section">
<h2>Providers<span class="export-links"><a href="/api/export/stats?format=csv&type=providers" target="_blank">CSV</a><a href="/api/export/stats?format=json&type=providers" target="_blank">JSON</a></span></h2>
<h2>Providers</h2>
<table>
<thead><tr><th>Provider</th><th>Requests</th><th>Success</th><th>Errors</th><th>Avg Latency</th><th>Cost</th></tr></thead>
<tbody>
@ -114,7 +112,7 @@
{{if .TokenStats}}
<div class="section">
<h2>API Token Usage<span class="export-links"><a href="/api/export/stats?format=csv&type=tokens" target="_blank">CSV</a><a href="/api/export/stats?format=json&type=tokens" target="_blank">JSON</a></span></h2>
<h2>API Token Usage</h2>
<table>
<thead><tr><th>Token</th><th>Requests</th><th>Tokens (in/out)</th><th>Cost</th></tr></thead>
<tbody>

View file

@ -1,100 +0,0 @@
{{define "content"}}
<div class="page-header">
<h1>Debug Logging</h1>
<span style="font-size:0.85rem;color:var(--text-muted)">{{.DebugResult.Total}} entries</span>
</div>
<div class="section" style="display:flex;align-items:center;gap:16px;padding:12px 16px;">
<span style="font-size:0.9rem;font-weight:600;">Debug Mode</span>
<label class="toggle-switch">
<input type="checkbox" id="debug-toggle" {{if .DebugEnabled}}checked{{end}} onchange="toggleDebug(this.checked)">
<span class="toggle-slider"></span>
</label>
<span id="debug-status" style="font-size:0.8rem;color:var(--text-muted)">{{if .DebugEnabled}}Enabled — requests are being logged{{else}}Disabled{{end}}</span>
</div>
<div class="section">
<table>
<thead>
<tr>
<th></th>
<th>Time</th>
<th>Request ID</th>
<th>Token</th>
<th>Model</th>
<th>Provider</th>
<th>Status</th>
</tr>
</thead>
<tbody>
{{range $i, $entry := .DebugResult.Entries}}
<tr class="expandable" onclick="toggleDebugExpand('debug-expand-{{$i}}')">
<td style="width:20px;text-align:center;color:var(--text-muted)">&#9654;</td>
<td>{{formatTimeDetail $entry.Timestamp}}</td>
<td><code style="font-size:0.75rem">{{$entry.RequestID}}</code></td>
<td>{{$entry.TokenName}}</td>
<td>{{$entry.Model}}</td>
<td>{{$entry.Provider}}</td>
<td>
{{if and (ge $entry.ResponseStatus 200) (lt $entry.ResponseStatus 300)}}<span class="badge badge-success">{{$entry.ResponseStatus}}</span>
{{else if ge $entry.ResponseStatus 400}}<span class="badge badge-error">{{$entry.ResponseStatus}}</span>
{{else}}<span class="badge">{{$entry.ResponseStatus}}</span>{{end}}
</td>
</tr>
<tr>
<td colspan="7" style="padding:0;">
<div id="debug-expand-{{$i}}" class="expand-content">
<div style="margin-bottom:8px"><strong>Request Headers:</strong></div>
<div class="code-block">{{if $entry.RequestHeaders}}{{$entry.RequestHeaders}}{{else}}(none){{end}}</div>
<div style="margin:8px 0"><strong>Request Body:</strong></div>
<div class="code-block">{{if $entry.RequestBody}}{{$entry.RequestBody}}{{else}}(none){{end}}</div>
<div style="margin:8px 0"><strong>Response Body:</strong></div>
<div class="code-block">{{if $entry.ResponseBody}}{{$entry.ResponseBody}}{{else}}(none){{end}}</div>
</div>
</td>
</tr>
{{end}}
{{if not .DebugResult.Entries}}
<tr><td colspan="7" style="text-align:center;color:var(--text-muted);padding:24px;">No debug log entries</td></tr>
{{end}}
</tbody>
</table>
{{if gt .DebugResult.TotalPages 1}}
<div class="pagination">
<button {{if le .DebugResult.Page 1}}disabled{{end}} onclick="goToDebugPage(1)">First</button>
<button {{if le .DebugResult.Page 1}}disabled{{end}} onclick="goToDebugPage({{subInt .DebugResult.Page 1}})">Prev</button>
{{$page := .DebugResult.Page}}
{{$total := .DebugResult.TotalPages}}
{{range seq (paginationStart $page $total) (paginationEnd $page $total)}}
<button class="{{if eq . $page}}active{{end}}" onclick="goToDebugPage({{.}})">{{.}}</button>
{{end}}
<button {{if ge .DebugResult.Page .DebugResult.TotalPages}}disabled{{end}} onclick="goToDebugPage({{addInt .DebugResult.Page 1}})">Next</button>
<button {{if ge .DebugResult.Page .DebugResult.TotalPages}}disabled{{end}} onclick="goToDebugPage({{.DebugResult.TotalPages}})">Last</button>
<span class="page-info">Page {{.DebugResult.Page}} of {{.DebugResult.TotalPages}}</span>
</div>
{{end}}
</div>
<script>
function toggleDebug(enabled) {
fetch('/api/debug/toggle', {
method: 'POST',
credentials: 'same-origin',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({enabled: enabled})
}).then(function() {
htmx.ajax('GET', '/debug', {target: '#content', swap: 'innerHTML'});
});
}
function toggleDebugExpand(id) {
var el = document.getElementById(id);
if (el) el.classList.toggle('show');
}
function goToDebugPage(page) {
var url = '/debug' + (page > 1 ? '?page=' + page : '');
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
history.pushState({}, '', url);
}
</script>
{{end}}

View file

@ -20,9 +20,6 @@
<option value="cached" {{if eq .FilterStatus "cached"}}selected{{end}}>Cached</option>
</select>
<button class="btn btn-sm btn-outline" onclick="clearLogsFilter()">Clear</button>
<span style="margin-left:auto"></span>
<button class="btn btn-sm btn-outline" onclick="exportLogs('csv')">Export CSV</button>
<button class="btn btn-sm btn-outline" onclick="exportLogs('json')">Export JSON</button>
</div>
<div class="section">
@ -119,15 +116,5 @@ function toggleExpand(id) {
var el = document.getElementById(id);
if (el) el.classList.toggle('show');
}
function exportLogs(format) {
var params = ['format=' + format];
var model = document.getElementById('filter-model').value;
var token = document.getElementById('filter-token').value;
var status = document.getElementById('filter-status').value;
if (model) params.push('model=' + encodeURIComponent(model));
if (token) params.push('token=' + encodeURIComponent(token));
if (status) params.push('status=' + encodeURIComponent(status));
window.open('/api/export/logs?' + params.join('&'), '_blank');
}
</script>
{{end}}

View file

@ -10,8 +10,6 @@ type Metrics struct {
requestDuration *prometheus.HistogramVec
tokensTotal *prometheus.CounterVec
costTotal *prometheus.CounterVec
cacheHits prometheus.Counter
cacheMisses prometheus.Counter
}
func New() *Metrics {
@ -36,16 +34,6 @@ func New() *Metrics {
Name: "llm_gateway_cost_usd_total",
Help: "Total cost in USD",
}, []string{"model", "provider", "token_name"}),
cacheHits: promauto.NewCounter(prometheus.CounterOpts{
Name: "llm_gateway_cache_hits_total",
Help: "Total number of cache hits",
}),
cacheMisses: promauto.NewCounter(prometheus.CounterOpts{
Name: "llm_gateway_cache_misses_total",
Help: "Total number of cache misses",
}),
}
}
@ -63,11 +51,3 @@ func (m *Metrics) RecordRequest(model, providerName, tokenName, status string, l
m.costTotal.WithLabelValues(model, providerName, tokenName).Add(cost)
}
}
func (m *Metrics) RecordCacheHit() {
m.cacheHits.Inc()
}
func (m *Metrics) RecordCacheMiss() {
m.cacheMisses.Inc()
}

View file

@ -1,144 +0,0 @@
package provider
import (
"math/rand"
"sort"
"sync/atomic"
)
// LoadBalancer reorders routes for load distribution.
type LoadBalancer interface {
Reorder(routes []Route) []Route
}
// NewLoadBalancer creates a load balancer by strategy name.
func NewLoadBalancer(strategy string) LoadBalancer {
switch strategy {
case "round-robin":
return &RoundRobinBalancer{}
case "random":
return &RandomBalancer{}
case "least-cost":
return &LeastCostBalancer{}
default:
return &FirstBalancer{}
}
}
// FirstBalancer is a no-op that preserves original order.
type FirstBalancer struct{}
func (b *FirstBalancer) Reorder(routes []Route) []Route {
return routes
}
// RoundRobinBalancer rotates routes within same-priority groups.
type RoundRobinBalancer struct {
counter atomic.Uint64
}
func (b *RoundRobinBalancer) Reorder(routes []Route) []Route {
if len(routes) <= 1 {
return routes
}
result := make([]Route, len(routes))
copy(result, routes)
// Group by priority and rotate within each group
groups := groupByPriority(result)
idx := 0
count := b.counter.Add(1)
for _, group := range groups {
if len(group) > 1 {
offset := int(count) % len(group)
for j := 0; j < len(group); j++ {
result[idx] = group[(j+offset)%len(group)]
idx++
}
} else {
result[idx] = group[0]
idx++
}
}
return result
}
// RandomBalancer shuffles routes within same-priority groups.
type RandomBalancer struct{}
func (b *RandomBalancer) Reorder(routes []Route) []Route {
if len(routes) <= 1 {
return routes
}
result := make([]Route, len(routes))
copy(result, routes)
groups := groupByPriority(result)
idx := 0
for _, group := range groups {
rand.Shuffle(len(group), func(i, j int) {
group[i], group[j] = group[j], group[i]
})
for _, r := range group {
result[idx] = r
idx++
}
}
return result
}
// LeastCostBalancer sorts by price within same-priority groups.
type LeastCostBalancer struct{}
func (b *LeastCostBalancer) Reorder(routes []Route) []Route {
if len(routes) <= 1 {
return routes
}
result := make([]Route, len(routes))
copy(result, routes)
groups := groupByPriority(result)
idx := 0
for _, group := range groups {
sort.Slice(group, func(i, j int) bool {
costI := group[i].InputPrice + group[i].OutputPrice
costJ := group[j].InputPrice + group[j].OutputPrice
return costI < costJ
})
for _, r := range group {
result[idx] = r
idx++
}
}
return result
}
// groupByPriority splits routes into groups of same priority, preserving order.
func groupByPriority(routes []Route) [][]Route {
if len(routes) == 0 {
return nil
}
var groups [][]Route
currentPriority := routes[0].Priority
currentGroup := []Route{routes[0]}
for i := 1; i < len(routes); i++ {
if routes[i].Priority == currentPriority {
currentGroup = append(currentGroup, routes[i])
} else {
groups = append(groups, currentGroup)
currentPriority = routes[i].Priority
currentGroup = []Route{routes[i]}
}
}
groups = append(groups, currentGroup)
return groups
}

View file

@ -1,294 +0,0 @@
package provider
import (
"fmt"
"testing"
)
type routeSpec struct {
name string
priority int
input float64
output float64
}
func makeRoutes(specs ...routeSpec) []Route {
routes := make([]Route, len(specs))
for i, s := range specs {
routes[i] = Route{
Provider: &mockProvider{name: s.name},
ProviderModel: s.name + "-model",
Priority: s.priority,
InputPrice: s.input,
OutputPrice: s.output,
}
}
return routes
}
func routeNames(routes []Route) []string {
names := make([]string, len(routes))
for i, r := range routes {
names[i] = r.Provider.Name()
}
return names
}
func TestFirstBalancer_PreservesOrder(t *testing.T) {
routes := makeRoutes(
routeSpec{"a", 1, 1.0, 1.0},
routeSpec{"b", 1, 2.0, 2.0},
routeSpec{"c", 1, 3.0, 3.0},
)
b := &FirstBalancer{}
result := b.Reorder(routes)
names := routeNames(result)
if names[0] != "a" || names[1] != "b" || names[2] != "c" {
t.Fatalf("expected [a b c], got %v", names)
}
}
func TestRoundRobinBalancer_RotatesWithinPriorityGroup(t *testing.T) {
routes := makeRoutes(
routeSpec{"a", 1, 1.0, 1.0},
routeSpec{"b", 1, 1.0, 1.0},
routeSpec{"c", 1, 1.0, 1.0},
)
b := &RoundRobinBalancer{}
// Collect the first element from multiple calls
seen := make(map[string]bool)
for i := 0; i < 6; i++ {
result := b.Reorder(routes)
seen[result[0].Provider.Name()] = true
}
// All routes should have appeared as first at some point
for _, name := range []string{"a", "b", "c"} {
if !seen[name] {
t.Errorf("expected %q to appear as first element in rotation", name)
}
}
}
func TestRoundRobinBalancer_PreservesPriorityOrder(t *testing.T) {
routes := makeRoutes(
routeSpec{"a", 1, 1.0, 1.0},
routeSpec{"b", 1, 1.0, 1.0},
routeSpec{"c", 2, 1.0, 1.0},
)
b := &RoundRobinBalancer{}
// Priority 2 route should always be last
for i := 0; i < 5; i++ {
result := b.Reorder(routes)
if result[2].Provider.Name() != "c" {
t.Fatalf("expected priority-2 route 'c' at the end, got %q", result[2].Provider.Name())
}
}
}
func TestRandomBalancer_AllRoutesPresent(t *testing.T) {
routes := makeRoutes(
routeSpec{"a", 1, 1.0, 1.0},
routeSpec{"b", 1, 1.0, 1.0},
routeSpec{"c", 1, 1.0, 1.0},
)
b := &RandomBalancer{}
for i := 0; i < 10; i++ {
result := b.Reorder(routes)
if len(result) != 3 {
t.Fatalf("expected 3 routes, got %d", len(result))
}
names := make(map[string]bool)
for _, r := range result {
names[r.Provider.Name()] = true
}
for _, want := range []string{"a", "b", "c"} {
if !names[want] {
t.Errorf("missing route %q in result", want)
}
}
}
}
func TestRandomBalancer_PreservesPriorityOrder(t *testing.T) {
routes := makeRoutes(
routeSpec{"a", 1, 1.0, 1.0},
routeSpec{"b", 1, 1.0, 1.0},
routeSpec{"c", 2, 1.0, 1.0},
)
b := &RandomBalancer{}
for i := 0; i < 10; i++ {
result := b.Reorder(routes)
if result[2].Provider.Name() != "c" {
t.Fatalf("expected priority-2 route 'c' last, got %q", result[2].Provider.Name())
}
}
}
func TestLeastCostBalancer_SortsByCost(t *testing.T) {
routes := makeRoutes(
routeSpec{"expensive", 1, 10.0, 10.0},
routeSpec{"cheap", 1, 1.0, 1.0},
routeSpec{"medium", 1, 5.0, 5.0},
)
b := &LeastCostBalancer{}
result := b.Reorder(routes)
names := routeNames(result)
expected := []string{"cheap", "medium", "expensive"}
for i, want := range expected {
if names[i] != want {
t.Errorf("position %d: got %q, want %q", i, names[i], want)
}
}
}
func TestLeastCostBalancer_PreservesPriorityOrder(t *testing.T) {
routes := makeRoutes(
routeSpec{"expensive-p1", 1, 10.0, 10.0},
routeSpec{"cheap-p1", 1, 1.0, 1.0},
routeSpec{"cheap-p2", 2, 0.5, 0.5},
)
b := &LeastCostBalancer{}
result := b.Reorder(routes)
names := routeNames(result)
// Within priority 1, cheap should come first; priority 2 always last
if names[0] != "cheap-p1" {
t.Errorf("expected cheap-p1 first, got %q", names[0])
}
if names[1] != "expensive-p1" {
t.Errorf("expected expensive-p1 second, got %q", names[1])
}
if names[2] != "cheap-p2" {
t.Errorf("expected cheap-p2 last, got %q", names[2])
}
}
func TestGroupByPriority(t *testing.T) {
tests := []struct {
name string
priorities []int
wantGroups [][]int
}{
{
name: "empty",
priorities: nil,
wantGroups: nil,
},
{
name: "single",
priorities: []int{1},
wantGroups: [][]int{{1}},
},
{
name: "all same",
priorities: []int{1, 1, 1},
wantGroups: [][]int{{1, 1, 1}},
},
{
name: "two groups",
priorities: []int{1, 1, 2, 2},
wantGroups: [][]int{{1, 1}, {2, 2}},
},
{
name: "three groups",
priorities: []int{1, 2, 2, 3},
wantGroups: [][]int{{1}, {2, 2}, {3}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var routes []Route
for _, p := range tt.priorities {
routes = append(routes, Route{Priority: p})
}
groups := groupByPriority(routes)
if tt.wantGroups == nil {
if groups != nil {
t.Fatalf("expected nil groups, got %v", groups)
}
return
}
if len(groups) != len(tt.wantGroups) {
t.Fatalf("expected %d groups, got %d", len(tt.wantGroups), len(groups))
}
for i, wg := range tt.wantGroups {
if len(groups[i]) != len(wg) {
t.Errorf("group %d: expected %d routes, got %d", i, len(wg), len(groups[i]))
continue
}
for j, wp := range wg {
if groups[i][j].Priority != wp {
t.Errorf("group %d, route %d: expected priority %d, got %d", i, j, wp, groups[i][j].Priority)
}
}
}
})
}
}
func TestBalancer_SingleRoute(t *testing.T) {
routes := makeRoutes(routeSpec{"only", 1, 1.0, 1.0})
balancers := []struct {
name string
balancer LoadBalancer
}{
{"first", &FirstBalancer{}},
{"round-robin", &RoundRobinBalancer{}},
{"random", &RandomBalancer{}},
{"least-cost", &LeastCostBalancer{}},
}
for _, bb := range balancers {
t.Run(bb.name, func(t *testing.T) {
result := bb.balancer.Reorder(routes)
if len(result) != 1 || result[0].Provider.Name() != "only" {
t.Fatalf("expected single route 'only', got %v", routeNames(result))
}
})
}
}
func TestNewLoadBalancer(t *testing.T) {
tests := []struct {
strategy string
wantType string
}{
{"round-robin", "*provider.RoundRobinBalancer"},
{"random", "*provider.RandomBalancer"},
{"least-cost", "*provider.LeastCostBalancer"},
{"first", "*provider.FirstBalancer"},
{"unknown", "*provider.FirstBalancer"},
{"", "*provider.FirstBalancer"},
}
for _, tt := range tests {
t.Run(tt.strategy, func(t *testing.T) {
b := NewLoadBalancer(tt.strategy)
got := fmt.Sprintf("%T", b)
if got != tt.wantType {
t.Errorf("NewLoadBalancer(%q) = %s, want %s", tt.strategy, got, tt.wantType)
}
})
}
}

View file

@ -3,39 +3,8 @@ package provider
import (
"sync"
"time"
"llm-gateway/internal/config"
)
// CircuitState represents the state of a circuit breaker.
type CircuitState int
const (
CircuitClosed CircuitState = iota // normal operation
CircuitOpen // blocking requests
CircuitHalfOpen // testing with probe request
)
func (s CircuitState) String() string {
switch s {
case CircuitClosed:
return "closed"
case CircuitOpen:
return "open"
case CircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
}
// ProviderCircuit tracks circuit breaker state for a single provider.
type ProviderCircuit struct {
State CircuitState
OpenedAt time.Time
LastProbe time.Time
}
// HealthEvent represents a single request outcome for a provider.
type HealthEvent struct {
Timestamp time.Time
@ -46,13 +15,12 @@ type HealthEvent struct {
// ProviderHealth is the computed health status for a provider.
type ProviderHealth struct {
Provider string `json:"provider"`
Status string `json:"status"` // healthy, degraded, down
ErrorRate float64 `json:"error_rate"`
AvgLatency float64 `json:"avg_latency_ms"`
Total int `json:"total"`
Errors int `json:"errors"`
CircuitState string `json:"circuit_state"`
Provider string `json:"provider"`
Status string `json:"status"` // healthy, degraded, down
ErrorRate float64 `json:"error_rate"`
AvgLatency float64 `json:"avg_latency_ms"`
Total int `json:"total"`
Errors int `json:"errors"`
}
// HealthTracker tracks per-provider health using a sliding window.
@ -60,52 +28,20 @@ type HealthTracker struct {
mu sync.RWMutex
windows map[string][]HealthEvent
windowDu time.Duration
circuits map[string]*ProviderCircuit
cbConfig config.CircuitBreakerConfig
}
// NewHealthTracker creates a health tracker with the given window duration.
func NewHealthTracker(window time.Duration, cbCfg config.CircuitBreakerConfig) *HealthTracker {
func NewHealthTracker(window time.Duration) *HealthTracker {
if window == 0 {
window = 5 * time.Minute
}
return &HealthTracker{
windows: make(map[string][]HealthEvent),
circuits: make(map[string]*ProviderCircuit),
windowDu: window,
cbConfig: cbCfg,
}
}
// IsAvailable returns true if the provider's circuit breaker allows requests.
func (h *HealthTracker) IsAvailable(provider string) bool {
if !h.cbConfig.Enabled {
return true
}
h.mu.RLock()
defer h.mu.RUnlock()
circuit, ok := h.circuits[provider]
if !ok {
return true // no circuit = closed = available
}
switch circuit.State {
case CircuitOpen:
// Check if cooldown has elapsed -> transition to half-open
if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
return true // will transition to half-open on next record
}
return false
case CircuitHalfOpen:
return true // allow probe
default:
return true
}
}
// Record adds a health event for a provider and evaluates circuit transitions.
// Record adds a health event for a provider.
func (h *HealthTracker) Record(provider string, latencyMS int64, err error) {
event := HealthEvent{
Timestamp: time.Now(),
@ -121,69 +57,6 @@ func (h *HealthTracker) Record(provider string, latencyMS int64, err error) {
h.windows[provider] = append(h.windows[provider], event)
h.prune(provider)
if h.cbConfig.Enabled {
h.evaluateCircuit(provider, err)
}
}
// evaluateCircuit transitions circuit breaker state. Must be called with lock held.
func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
circuit, ok := h.circuits[providerName]
if !ok {
circuit = &ProviderCircuit{State: CircuitClosed}
h.circuits[providerName] = circuit
}
switch circuit.State {
case CircuitClosed:
// Check if error threshold exceeded
errorRate, total := h.errorRateUnlocked(providerName)
if total >= h.cbConfig.MinRequests && errorRate >= h.cbConfig.ErrorThreshold {
circuit.State = CircuitOpen
circuit.OpenedAt = time.Now()
}
case CircuitOpen:
// Check if cooldown elapsed -> half-open
if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
circuit.State = CircuitHalfOpen
circuit.LastProbe = time.Now()
// Evaluate the probe result immediately
if lastErr == nil {
circuit.State = CircuitClosed
} else {
circuit.State = CircuitOpen
circuit.OpenedAt = time.Now()
}
}
case CircuitHalfOpen:
if lastErr == nil {
circuit.State = CircuitClosed
} else {
circuit.State = CircuitOpen
circuit.OpenedAt = time.Now()
}
}
}
// errorRateUnlocked computes error rate within window. Must be called with lock held.
func (h *HealthTracker) errorRateUnlocked(provider string) (float64, int) {
cutoff := time.Now().Add(-h.windowDu)
events := h.windows[provider]
var total, errors int
for _, e := range events {
if e.Timestamp.Before(cutoff) {
continue
}
total++
if e.IsError {
errors++
}
}
if total == 0 {
return 0, 0
}
return float64(errors) / float64(total), total
}
// Status returns computed health for all tracked providers.
@ -221,19 +94,13 @@ func (h *HealthTracker) Status() []ProviderHealth {
status = "degraded"
}
circuitState := "closed"
if circuit, ok := h.circuits[provider]; ok {
circuitState = circuit.State.String()
}
results = append(results, ProviderHealth{
Provider: provider,
Status: status,
ErrorRate: errorRate,
AvgLatency: float64(totalLatency) / float64(total),
Total: total,
Errors: errors,
CircuitState: circuitState,
Provider: provider,
Status: status,
ErrorRate: errorRate,
AvgLatency: float64(totalLatency) / float64(total),
Total: total,
Errors: errors,
})
}

View file

@ -1,345 +0,0 @@
package provider
import (
"errors"
"testing"
"time"
"llm-gateway/internal/config"
)
func newTestTracker(window time.Duration, cb config.CircuitBreakerConfig) *HealthTracker {
return NewHealthTracker(window, cb)
}
func defaultCBConfig() config.CircuitBreakerConfig {
return config.CircuitBreakerConfig{
Enabled: true,
ErrorThreshold: 0.5,
MinRequests: 3,
CooldownDuration: 100 * time.Millisecond,
}
}
func TestHealthTracker_Record(t *testing.T) {
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
ht.Record("provA", 100, nil)
ht.Record("provA", 200, errors.New("fail"))
ht.Record("provB", 50, nil)
ht.mu.RLock()
defer ht.mu.RUnlock()
if len(ht.windows["provA"]) != 2 {
t.Fatalf("expected 2 events for provA, got %d", len(ht.windows["provA"]))
}
if len(ht.windows["provB"]) != 1 {
t.Fatalf("expected 1 event for provB, got %d", len(ht.windows["provB"]))
}
// Verify event fields
ev := ht.windows["provA"][1]
if !ev.IsError || ev.ErrorMsg != "fail" || ev.LatencyMS != 200 {
t.Fatalf("unexpected event fields: %+v", ev)
}
}
func TestHealthTracker_Status(t *testing.T) {
tests := []struct {
name string
successCount int
errorCount int
wantStatus string
wantErrorRate float64
wantTotal int
wantErrors int
}{
{
name: "healthy - no errors",
successCount: 10,
errorCount: 0,
wantStatus: "healthy",
wantErrorRate: 0.0,
wantTotal: 10,
wantErrors: 0,
},
{
name: "healthy - below 10% errors",
successCount: 19,
errorCount: 1,
wantStatus: "healthy",
wantErrorRate: 0.05,
wantTotal: 20,
wantErrors: 1,
},
{
name: "degraded - 20% errors",
successCount: 8,
errorCount: 2,
wantStatus: "degraded",
wantErrorRate: 0.2,
wantTotal: 10,
wantErrors: 2,
},
{
name: "degraded - exactly 10% errors",
successCount: 9,
errorCount: 1,
wantStatus: "degraded",
wantErrorRate: 0.1,
wantTotal: 10,
wantErrors: 1,
},
{
name: "down - 50% errors",
successCount: 5,
errorCount: 5,
wantStatus: "down",
wantErrorRate: 0.5,
wantTotal: 10,
wantErrors: 5,
},
{
name: "down - all errors",
successCount: 0,
errorCount: 5,
wantStatus: "down",
wantErrorRate: 1.0,
wantTotal: 5,
wantErrors: 5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
for i := 0; i < tt.successCount; i++ {
ht.Record("prov", 100, nil)
}
for i := 0; i < tt.errorCount; i++ {
ht.Record("prov", 100, errors.New("err"))
}
statuses := ht.Status()
if len(statuses) != 1 {
t.Fatalf("expected 1 status, got %d", len(statuses))
}
s := statuses[0]
if s.Status != tt.wantStatus {
t.Errorf("status = %q, want %q", s.Status, tt.wantStatus)
}
if s.Total != tt.wantTotal {
t.Errorf("total = %d, want %d", s.Total, tt.wantTotal)
}
if s.Errors != tt.wantErrors {
t.Errorf("errors = %d, want %d", s.Errors, tt.wantErrors)
}
// Allow small float tolerance
if diff := s.ErrorRate - tt.wantErrorRate; diff > 0.001 || diff < -0.001 {
t.Errorf("error_rate = %f, want %f", s.ErrorRate, tt.wantErrorRate)
}
})
}
}
func TestHealthTracker_CircuitBreaker_ClosedToOpen(t *testing.T) {
cb := defaultCBConfig()
cb.MinRequests = 3
cb.ErrorThreshold = 0.5
ht := newTestTracker(5*time.Minute, cb)
// Record errors to exceed threshold (3 errors out of 3 = 100% > 50%)
ht.Record("prov", 100, errors.New("err"))
ht.Record("prov", 100, errors.New("err"))
ht.Record("prov", 100, errors.New("err"))
ht.mu.RLock()
state := ht.circuits["prov"].State
ht.mu.RUnlock()
if state != CircuitOpen {
t.Fatalf("expected CircuitOpen, got %s", state)
}
if ht.IsAvailable("prov") {
t.Fatal("expected IsAvailable=false when circuit is open")
}
}
func TestHealthTracker_CircuitBreaker_OpenToHalfOpenOnCooldown(t *testing.T) {
cb := defaultCBConfig()
cb.CooldownDuration = 50 * time.Millisecond
ht := newTestTracker(5*time.Minute, cb)
// Trip the circuit
for i := 0; i < 5; i++ {
ht.Record("prov", 100, errors.New("err"))
}
if ht.IsAvailable("prov") {
t.Fatal("expected circuit open, IsAvailable should be false")
}
// Wait for cooldown
time.Sleep(60 * time.Millisecond)
// After cooldown, IsAvailable should return true (will transition to half-open)
if !ht.IsAvailable("prov") {
t.Fatal("expected IsAvailable=true after cooldown")
}
}
func TestHealthTracker_CircuitBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
cb := defaultCBConfig()
cb.CooldownDuration = 10 * time.Millisecond
ht := newTestTracker(5*time.Minute, cb)
// Trip the circuit
for i := 0; i < 5; i++ {
ht.Record("prov", 100, errors.New("err"))
}
// Wait for cooldown so next Record transitions through Open->HalfOpen
time.Sleep(20 * time.Millisecond)
// A successful record should transition: Open -> HalfOpen -> Closed
ht.Record("prov", 100, nil)
ht.mu.RLock()
state := ht.circuits["prov"].State
ht.mu.RUnlock()
if state != CircuitClosed {
t.Fatalf("expected CircuitClosed after success in half-open, got %s", state)
}
if !ht.IsAvailable("prov") {
t.Fatal("expected IsAvailable=true after circuit closed")
}
}
func TestHealthTracker_CircuitBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
cb := defaultCBConfig()
cb.CooldownDuration = 10 * time.Millisecond
ht := newTestTracker(5*time.Minute, cb)
// Trip the circuit
for i := 0; i < 5; i++ {
ht.Record("prov", 100, errors.New("err"))
}
// Wait for cooldown
time.Sleep(20 * time.Millisecond)
// A failed record should transition: Open -> HalfOpen -> Open
ht.Record("prov", 100, errors.New("still failing"))
ht.mu.RLock()
state := ht.circuits["prov"].State
ht.mu.RUnlock()
if state != CircuitOpen {
t.Fatalf("expected CircuitOpen after failure in half-open, got %s", state)
}
}
func TestHealthTracker_IsAvailable_NoCircuitBreaker(t *testing.T) {
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{Enabled: false})
// Even with errors, IsAvailable should return true when CB is disabled
for i := 0; i < 10; i++ {
ht.Record("prov", 100, errors.New("err"))
}
if !ht.IsAvailable("prov") {
t.Fatal("expected IsAvailable=true when circuit breaker disabled")
}
}
func TestHealthTracker_IsAvailable_UnknownProvider(t *testing.T) {
ht := newTestTracker(5*time.Minute, defaultCBConfig())
if !ht.IsAvailable("unknown") {
t.Fatal("expected IsAvailable=true for unknown provider (no circuit)")
}
}
func TestHealthTracker_WindowPruning(t *testing.T) {
// Use a tiny window so events expire quickly
ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
ht.Record("prov", 100, nil)
ht.Record("prov", 200, nil)
// Wait for events to expire
time.Sleep(60 * time.Millisecond)
// Record a new event to trigger pruning
ht.Record("prov", 300, nil)
ht.mu.RLock()
count := len(ht.windows["prov"])
ht.mu.RUnlock()
if count != 1 {
t.Fatalf("expected 1 event after pruning, got %d", count)
}
}
func TestHealthTracker_Status_EmptyAfterPruning(t *testing.T) {
ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
ht.Record("prov", 100, nil)
// Wait for events to expire
time.Sleep(60 * time.Millisecond)
statuses := ht.Status()
if len(statuses) != 0 {
t.Fatalf("expected 0 statuses after window expiry, got %d", len(statuses))
}
}
func TestHealthTracker_Status_AvgLatency(t *testing.T) {
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
ht.Record("prov", 100, nil)
ht.Record("prov", 200, nil)
ht.Record("prov", 300, nil)
statuses := ht.Status()
if len(statuses) != 1 {
t.Fatalf("expected 1 status, got %d", len(statuses))
}
want := 200.0
if diff := statuses[0].AvgLatency - want; diff > 0.001 || diff < -0.001 {
t.Errorf("avg_latency = %f, want %f", statuses[0].AvgLatency, want)
}
}
func TestHealthTracker_Status_CircuitStateReported(t *testing.T) {
cb := defaultCBConfig()
ht := newTestTracker(5*time.Minute, cb)
// Trip the circuit
for i := 0; i < 5; i++ {
ht.Record("prov", 100, errors.New("err"))
}
statuses := ht.Status()
if len(statuses) != 1 {
t.Fatalf("expected 1 status, got %d", len(statuses))
}
if statuses[0].CircuitState != "open" {
t.Errorf("circuit_state = %q, want %q", statuses[0].CircuitState, "open")
}
}

View file

@ -111,12 +111,6 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string,
func (p *OpenAIProvider) setHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
// Forward request ID if present in context
if reqID := req.Context().Value("requestID"); reqID != nil {
if id, ok := reqID.(string); ok && id != "" {
req.Header.Set("X-Request-ID", id)
}
}
}
// ProviderError represents a non-200 response from a provider.

View file

@ -3,7 +3,6 @@ package provider
import (
"fmt"
"sort"
"sync"
"llm-gateway/internal/config"
)
@ -19,40 +18,26 @@ type Route struct {
// Registry maps model names to provider routes.
type Registry struct {
mu sync.RWMutex
routes map[string][]Route
balancers map[string]LoadBalancer
aliases map[string]string // alias -> canonical name
order []string // preserves config order (canonical names only)
routes map[string][]Route
order []string // preserves config order
}
func NewRegistry(cfg *config.Config) (*Registry, error) {
r := &Registry{}
if err := r.buildFromConfig(cfg); err != nil {
return nil, err
}
return r, nil
}
func (r *Registry) buildFromConfig(cfg *config.Config) error {
// Build providers
providers := make(map[string]Provider)
for _, pc := range cfg.Providers {
providers[pc.Name] = NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKey, pc.Timeout)
}
// Build routes
// Build routes (preserving config order)
routes := make(map[string][]Route)
balancers := make(map[string]LoadBalancer)
aliases := make(map[string]string)
order := make([]string, 0, len(cfg.Models))
for _, mc := range cfg.Models {
var modelRoutes []Route
for _, rc := range mc.Routes {
p, ok := providers[rc.Provider]
if !ok {
return fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider)
return nil, fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider)
}
pc := cfg.ProviderByName(rc.Provider)
priority := pc.Priority
@ -70,69 +55,20 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
})
routes[mc.Name] = modelRoutes
order = append(order, mc.Name)
// Load balancer
balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing)
// Register aliases
for _, alias := range mc.Aliases {
aliases[alias] = mc.Name
}
}
r.mu.Lock()
r.routes = routes
r.balancers = balancers
r.aliases = aliases
r.order = order
r.mu.Unlock()
return nil
return &Registry{routes: routes, order: order}, nil
}
// Reload rebuilds routes from new config. Used for hot-reload.
func (r *Registry) Reload(cfg *config.Config) error {
return r.buildFromConfig(cfg)
}
// Lookup returns the routes for a model name (resolving aliases).
// Lookup returns the routes for a model name.
func (r *Registry) Lookup(model string) ([]Route, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
// Resolve alias
canonical := model
if alias, ok := r.aliases[model]; ok {
canonical = alias
}
routes, ok := r.routes[canonical]
if !ok {
return nil, false
}
// Apply load balancer
if balancer, ok := r.balancers[canonical]; ok {
routes = balancer.Reorder(routes)
}
return routes, true
routes, ok := r.routes[model]
return routes, ok
}
// ModelNames returns all registered model names in config order (including aliases).
// ModelNames returns all registered model names in config order.
func (r *Registry) ModelNames() []string {
r.mu.RLock()
defer r.mu.RUnlock()
var names []string
for _, name := range r.order {
names = append(names, name)
}
// Add aliases
for alias := range r.aliases {
names = append(names, alias)
}
return names
return r.order
}
// RouteInfo exposes route details for dashboard display.
@ -146,29 +82,16 @@ type RouteInfo struct {
// ModelRouteInfo exposes a model and its routes for dashboard display.
type ModelRouteInfo struct {
Name string `json:"name"`
Aliases []string `json:"aliases,omitempty"`
Routes []RouteInfo `json:"routes"`
Name string `json:"name"`
Routes []RouteInfo `json:"routes"`
}
// AllRoutes returns all models and their routes in config order.
func (r *Registry) AllRoutes() []ModelRouteInfo {
r.mu.RLock()
defer r.mu.RUnlock()
// Build reverse alias map
modelAliases := make(map[string][]string)
for alias, canonical := range r.aliases {
modelAliases[canonical] = append(modelAliases[canonical], alias)
}
results := make([]ModelRouteInfo, 0, len(r.order))
for _, name := range r.order {
routes := r.routes[name]
info := ModelRouteInfo{
Name: name,
Aliases: modelAliases[name],
}
info := ModelRouteInfo{Name: name}
for _, rt := range routes {
info.Routes = append(info.Routes, RouteInfo{
ProviderName: rt.Provider.Name(),

View file

@ -1,282 +0,0 @@
package provider
import (
"context"
"io"
"testing"
"llm-gateway/internal/config"
)
// mockProvider implements the Provider interface for testing.
type mockProvider struct {
name string
}
func (m *mockProvider) Name() string { return m.name }
func (m *mockProvider) ChatCompletion(_ context.Context, _ string, _ *ChatRequest) (*ChatResponse, error) {
return nil, nil
}
func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *ChatRequest) (io.ReadCloser, error) {
return nil, nil
}
// newTestRegistry builds a Registry directly without going through config parsing.
func newTestRegistry(models []testModel) *Registry {
r := &Registry{
routes: make(map[string][]Route),
balancers: make(map[string]LoadBalancer),
aliases: make(map[string]string),
}
for _, m := range models {
r.routes[m.name] = m.routes
r.balancers[m.name] = &FirstBalancer{}
r.order = append(r.order, m.name)
for _, alias := range m.aliases {
r.aliases[alias] = m.name
}
}
return r
}
type testModel struct {
name string
aliases []string
routes []Route
}
func TestRegistry_Lookup_Canonical(t *testing.T) {
reg := newTestRegistry([]testModel{
{
name: "gpt-4",
routes: []Route{
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
},
},
})
routes, ok := reg.Lookup("gpt-4")
if !ok {
t.Fatal("expected Lookup to find gpt-4")
}
if len(routes) != 1 {
t.Fatalf("expected 1 route, got %d", len(routes))
}
if routes[0].Provider.Name() != "openai" {
t.Errorf("expected provider 'openai', got %q", routes[0].Provider.Name())
}
}
func TestRegistry_Lookup_Alias(t *testing.T) {
reg := newTestRegistry([]testModel{
{
name: "gpt-4",
aliases: []string{"gpt4", "gpt-4-latest"},
routes: []Route{
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
},
},
})
tests := []struct {
name string
model string
found bool
}{
{"canonical", "gpt-4", true},
{"alias1", "gpt4", true},
{"alias2", "gpt-4-latest", true},
{"unknown", "gpt-5", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
routes, ok := reg.Lookup(tt.model)
if ok != tt.found {
t.Fatalf("Lookup(%q) found=%v, want %v", tt.model, ok, tt.found)
}
if tt.found && len(routes) != 1 {
t.Fatalf("expected 1 route, got %d", len(routes))
}
})
}
}
func TestRegistry_ModelNames_IncludesAliases(t *testing.T) {
reg := newTestRegistry([]testModel{
{
name: "gpt-4",
aliases: []string{"gpt4"},
routes: []Route{
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
},
},
{
name: "claude-3",
routes: []Route{
{Provider: &mockProvider{name: "anthropic"}, ProviderModel: "claude-3", Priority: 1},
},
},
})
names := reg.ModelNames()
want := map[string]bool{"gpt-4": true, "gpt4": true, "claude-3": true}
got := make(map[string]bool)
for _, n := range names {
got[n] = true
}
for name := range want {
if !got[name] {
t.Errorf("expected %q in ModelNames, not found", name)
}
}
if len(names) != len(want) {
t.Errorf("expected %d names, got %d: %v", len(want), len(names), names)
}
}
func TestRegistry_AllRoutes_ShowsAliases(t *testing.T) {
reg := newTestRegistry([]testModel{
{
name: "gpt-4",
aliases: []string{"gpt4", "gpt-4-latest"},
routes: []Route{
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
{Provider: &mockProvider{name: "azure"}, ProviderModel: "gpt-4", Priority: 2},
},
},
})
allRoutes := reg.AllRoutes()
if len(allRoutes) != 1 {
t.Fatalf("expected 1 model, got %d", len(allRoutes))
}
m := allRoutes[0]
if m.Name != "gpt-4" {
t.Errorf("expected name 'gpt-4', got %q", m.Name)
}
aliasSet := make(map[string]bool)
for _, a := range m.Aliases {
aliasSet[a] = true
}
if !aliasSet["gpt4"] || !aliasSet["gpt-4-latest"] {
t.Errorf("expected aliases [gpt4, gpt-4-latest], got %v", m.Aliases)
}
if len(m.Routes) != 2 {
t.Fatalf("expected 2 routes, got %d", len(m.Routes))
}
if m.Routes[0].ProviderName != "openai" {
t.Errorf("expected first route provider 'openai', got %q", m.Routes[0].ProviderName)
}
if m.Routes[1].ProviderName != "azure" {
t.Errorf("expected second route provider 'azure', got %q", m.Routes[1].ProviderName)
}
}
func TestRegistry_AllRoutes_ConfigOrder(t *testing.T) {
reg := newTestRegistry([]testModel{
{
name: "model-b",
routes: []Route{
{Provider: &mockProvider{name: "prov"}, ProviderModel: "b", Priority: 1},
},
},
{
name: "model-a",
routes: []Route{
{Provider: &mockProvider{name: "prov"}, ProviderModel: "a", Priority: 1},
},
},
})
allRoutes := reg.AllRoutes()
if len(allRoutes) != 2 {
t.Fatalf("expected 2 models, got %d", len(allRoutes))
}
if allRoutes[0].Name != "model-b" {
t.Errorf("expected first model 'model-b', got %q", allRoutes[0].Name)
}
if allRoutes[1].Name != "model-a" {
t.Errorf("expected second model 'model-a', got %q", allRoutes[1].Name)
}
}
func TestRegistry_PrioritySorting(t *testing.T) {
reg := newTestRegistry([]testModel{
{
name: "multi-provider",
routes: []Route{
{Provider: &mockProvider{name: "low-priority"}, ProviderModel: "m", Priority: 3},
{Provider: &mockProvider{name: "high-priority"}, ProviderModel: "m", Priority: 1},
{Provider: &mockProvider{name: "mid-priority"}, ProviderModel: "m", Priority: 2},
},
},
})
// Note: routes are stored as given (sorting happens during buildFromConfig).
// For this test we verify AllRoutes returns them in stored order.
allRoutes := reg.AllRoutes()
if len(allRoutes) != 1 {
t.Fatalf("expected 1 model, got %d", len(allRoutes))
}
routes := allRoutes[0].Routes
if len(routes) != 3 {
t.Fatalf("expected 3 routes, got %d", len(routes))
}
// Verify the priorities are present
priorities := make(map[int]bool)
for _, r := range routes {
priorities[r.Priority] = true
}
for _, p := range []int{1, 2, 3} {
if !priorities[p] {
t.Errorf("expected priority %d in routes", p)
}
}
}
func TestRegistry_NewRegistry_UnknownProvider(t *testing.T) {
cfg := &config.Config{
Models: []config.ModelConfig{
{
Name: "test-model",
Routes: []config.RouteConfig{
{Provider: "nonexistent", Model: "m"},
},
},
},
}
_, err := NewRegistry(cfg)
if err == nil {
t.Fatal("expected error for unknown provider, got nil")
}
}
func TestRegistry_Lookup_NotFound(t *testing.T) {
reg := newTestRegistry([]testModel{
{
name: "gpt-4",
routes: []Route{
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
},
},
})
_, ok := reg.Lookup("nonexistent")
if ok {
t.Fatal("expected Lookup to return false for nonexistent model")
}
}

View file

@ -1,51 +0,0 @@
package proxy
import (
"net/http"
"sync"
"sync/atomic"
)
// ConcurrencyLimiter enforces per-token concurrent request limits.
type ConcurrencyLimiter struct {
mu sync.Mutex
counters map[string]*atomic.Int64
}
func NewConcurrencyLimiter() *ConcurrencyLimiter {
return &ConcurrencyLimiter{
counters: make(map[string]*atomic.Int64),
}
}
func (cl *ConcurrencyLimiter) getCounter(tokenName string) *atomic.Int64 {
cl.mu.Lock()
defer cl.mu.Unlock()
c, ok := cl.counters[tokenName]
if !ok {
c = &atomic.Int64{}
cl.counters[tokenName] = c
}
return c
}
func (cl *ConcurrencyLimiter) Check(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiToken := getAPIToken(r.Context())
if apiToken == nil || apiToken.MaxConcurrent <= 0 {
next.ServeHTTP(w, r)
return
}
counter := cl.getCounter(apiToken.Name)
current := counter.Add(1)
defer counter.Add(-1)
if current > int64(apiToken.MaxConcurrent) {
writeError(w, http.StatusTooManyRequests, "concurrent request limit exceeded")
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -1,317 +0,0 @@
package proxy
import (
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"llm-gateway/internal/auth"
)
func TestConcurrencyLimiter_AllowsWithinLimit(t *testing.T) {
tests := []struct {
name string
maxConcurrent int
numRequests int
wantAllowed int
}{
{
name: "single request within limit",
maxConcurrent: 5,
numRequests: 1,
wantAllowed: 1,
},
{
name: "all requests within limit",
maxConcurrent: 5,
numRequests: 5,
wantAllowed: 5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cl := NewConcurrencyLimiter()
token := &auth.APIToken{
Name: "conc-token",
MaxConcurrent: tt.maxConcurrent,
}
var allowed atomic.Int64
var wg sync.WaitGroup
// Use a channel to hold all goroutines inside the handler simultaneously.
gate := make(chan struct{})
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
allowed.Add(1)
<-gate // Block until released.
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < tt.numRequests; i++ {
wg.Add(1)
go func() {
defer wg.Done()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
}()
}
// Wait for goroutines to enter the handler.
time.Sleep(50 * time.Millisecond)
close(gate)
wg.Wait()
if int(allowed.Load()) != tt.wantAllowed {
t.Errorf("allowed = %d, want %d", allowed.Load(), tt.wantAllowed)
}
})
}
}
func TestConcurrencyLimiter_DeniesOverLimit(t *testing.T) {
tests := []struct {
name string
maxConcurrent int
numRequests int
wantDenied int
}{
{
name: "one over limit",
maxConcurrent: 2,
numRequests: 3,
wantDenied: 1,
},
{
name: "many over limit",
maxConcurrent: 1,
numRequests: 5,
wantDenied: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cl := NewConcurrencyLimiter()
token := &auth.APIToken{
Name: "conc-token",
MaxConcurrent: tt.maxConcurrent,
}
var denied atomic.Int64
var wg sync.WaitGroup
gate := make(chan struct{})
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-gate
w.WriteHeader(http.StatusOK)
}))
results := make([]int, tt.numRequests)
for i := 0; i < tt.numRequests; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
results[idx] = rec.Code
if rec.Code == http.StatusTooManyRequests {
denied.Add(1)
}
}(i)
}
// Wait for goroutines to reach the handler or be rejected.
time.Sleep(50 * time.Millisecond)
close(gate)
wg.Wait()
if int(denied.Load()) != tt.wantDenied {
t.Errorf("denied = %d, want %d", denied.Load(), tt.wantDenied)
}
})
}
}
func TestConcurrencyLimiter_CounterDecrementsAfterCompletion(t *testing.T) {
cl := NewConcurrencyLimiter()
token := &auth.APIToken{
Name: "decrement-token",
MaxConcurrent: 1,
}
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request should succeed and complete, decrementing the counter.
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("first request: status = %d, want %d", rec.Code, http.StatusOK)
}
// Counter should have decremented. A second request should also succeed.
rec2 := httptest.NewRecorder()
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
ctx2 := withAPIToken(req2.Context(), token)
req2 = req2.WithContext(ctx2)
handler.ServeHTTP(rec2, req2)
if rec2.Code != http.StatusOK {
t.Errorf("second request after first completed: status = %d, want %d", rec2.Code, http.StatusOK)
}
// Verify the internal counter is back to 0.
counter := cl.getCounter(token.Name)
val := counter.Load()
if val != 0 {
t.Errorf("counter = %d, want 0 after all requests completed", val)
}
}
func TestConcurrencyLimiter_ZeroMaxConcurrentMeansUnlimited(t *testing.T) {
tests := []struct {
name string
maxConcurrent int
numRequests int
}{
{
name: "zero allows unlimited concurrent requests",
maxConcurrent: 0,
numRequests: 50,
},
{
name: "negative allows unlimited concurrent requests",
maxConcurrent: -1,
numRequests: 50,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cl := NewConcurrencyLimiter()
token := &auth.APIToken{
Name: "unlimited-token",
MaxConcurrent: tt.maxConcurrent,
}
var allowed atomic.Int64
var wg sync.WaitGroup
gate := make(chan struct{})
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
allowed.Add(1)
<-gate
w.WriteHeader(http.StatusOK)
}))
for i := 0; i < tt.numRequests; i++ {
wg.Add(1)
go func() {
defer wg.Done()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
}()
}
// Give goroutines time to enter the handler.
time.Sleep(100 * time.Millisecond)
close(gate)
wg.Wait()
if int(allowed.Load()) != tt.numRequests {
t.Errorf("allowed = %d, want %d (zero/negative maxConcurrent should be unlimited)", allowed.Load(), tt.numRequests)
}
})
}
}
func TestConcurrencyLimiter_NoToken(t *testing.T) {
cl := NewConcurrencyLimiter()
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
// No API token in context.
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
}
}
func TestConcurrencyLimiter_PerTokenIsolation(t *testing.T) {
cl := NewConcurrencyLimiter()
tokenA := &auth.APIToken{
Name: "token-a",
MaxConcurrent: 1,
}
tokenB := &auth.APIToken{
Name: "token-b",
MaxConcurrent: 1,
}
gateA := make(chan struct{})
var wg sync.WaitGroup
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tok := getAPIToken(r.Context())
if tok.Name == "token-a" {
<-gateA // Block token A's request.
}
w.WriteHeader(http.StatusOK)
}))
// Start a request for token A that blocks.
wg.Add(1)
go func() {
defer wg.Done()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), tokenA)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
}()
// Give token A's goroutine time to enter handler.
time.Sleep(50 * time.Millisecond)
// Token B should not be affected by token A's in-flight request.
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), tokenB)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("token-b status = %d, want %d (should not be affected by token-a)", rec.Code, http.StatusOK)
}
close(gateA)
wg.Wait()
}

View file

@ -4,16 +4,11 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"sort"
"strings"
"time"
"github.com/go-chi/chi/v5/middleware"
"llm-gateway/internal/auth"
"llm-gateway/internal/cache"
"llm-gateway/internal/config"
@ -52,7 +47,6 @@ type Handler struct {
metrics *metrics.Metrics
cfg *config.Config
healthTracker *provider.HealthTracker
debugLogger *storage.DebugLogger
}
func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler {
@ -66,10 +60,6 @@ func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cac
}
}
func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) {
h.debugLogger = dl
}
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
if err != nil {
@ -94,53 +84,31 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
return
}
// Filter healthy routes (circuit breaker)
routes = h.filterHealthyRoutes(routes)
tokenName := getTokenName(r.Context())
requestID := middleware.GetReqID(r.Context())
// Check cache for non-streaming requests
if !req.Stream && h.cache != nil {
if cached, err := h.cache.Get(r.Context(), req.Model, body); err == nil && cached != nil {
h.logRequest(requestID, tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true)
if h.metrics != nil {
h.metrics.RecordCacheHit()
}
h.logRequest(tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true)
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Cache", "HIT")
w.Header().Set("X-Request-ID", requestID)
w.Write(cached)
return
}
if h.metrics != nil {
h.metrics.RecordCacheMiss()
}
}
if req.Stream {
h.handleStream(w, r, &req, routes, tokenName, requestID)
h.handleStream(w, r, &req, routes, tokenName)
return
}
h.handleNonStream(w, r, &req, routes, tokenName, body, requestID)
h.handleNonStream(w, r, &req, routes, tokenName, body)
}
func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string) {
func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte) {
var lastErr error
for i, route := range routes {
// Retry backoff between attempts (not before first attempt)
if i > 0 {
backoff := backoffDuration(i, h.cfg.Retry)
select {
case <-time.After(backoff):
case <-r.Context().Done():
writeError(w, http.StatusGatewayTimeout, "request cancelled")
return
}
}
for _, route := range routes {
start := time.Now()
resp, err := route.Provider.ChatCompletion(r.Context(), route.ProviderModel, req)
latency := time.Since(start).Milliseconds()
@ -148,19 +116,19 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
if err != nil {
var pe *provider.ProviderError
if errors.As(err, &pe) && !pe.IsRetryable() {
// Client error — don't retry
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
w.Header().Set("X-Request-ID", requestID)
writeErrorRaw(w, pe.StatusCode, pe.Body)
return
}
lastErr = err
log.Printf("Provider %s failed for %s: %v", route.Provider.Name(), req.Model, err)
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
@ -171,6 +139,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
h.healthTracker.Record(route.Provider.Name(), latency, nil)
}
// Compute cost
inputTokens, outputTokens := 0, 0
if resp.Usage != nil {
inputTokens = resp.Usage.PromptTokens
@ -179,8 +148,9 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false)
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false)
// Override model name in response to match the requested model
resp.Model = req.Model
respBytes, err := json.Marshal(resp)
@ -189,84 +159,27 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
return
}
// Cache the response
if h.cache != nil {
h.cache.Set(r.Context(), req.Model, rawBody, respBytes)
}
// Debug logging
if h.debugLogger != nil && h.debugLogger.IsEnabled() {
reqBody := string(rawBody)
respBody := string(respBytes)
if h.cfg.Debug.MaxBodyBytes > 0 {
if len(reqBody) > h.cfg.Debug.MaxBodyBytes {
reqBody = reqBody[:h.cfg.Debug.MaxBodyBytes]
}
if len(respBody) > h.cfg.Debug.MaxBodyBytes {
respBody = respBody[:h.cfg.Debug.MaxBodyBytes]
}
}
h.debugLogger.Log(storage.DebugLogEntry{
RequestID: requestID,
TokenName: tokenName,
Model: req.Model,
Provider: route.Provider.Name(),
RequestBody: reqBody,
ResponseBody: respBody,
RequestHeaders: formatHeaders(r.Header),
ResponseStatus: http.StatusOK,
})
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Cache", "MISS")
w.Header().Set("X-Request-ID", requestID)
w.Write(respBytes)
return
}
// All providers failed
if lastErr != nil {
w.Header().Set("X-Request-ID", requestID)
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
} else {
w.Header().Set("X-Request-ID", requestID)
writeError(w, http.StatusBadGateway, "all providers failed")
}
}
// filterHealthyRoutes removes providers with open circuit breakers.
// If all are filtered out, returns original routes as fallback.
func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route {
if h.healthTracker == nil {
return routes
}
var healthy []provider.Route
for _, r := range routes {
if h.healthTracker.IsAvailable(r.Provider.Name()) {
healthy = append(healthy, r)
}
}
if len(healthy) == 0 {
return routes // all-down fallback
}
return healthy
}
// backoffDuration computes exponential backoff for the given attempt.
func backoffDuration(attempt int, cfg config.RetryConfig) time.Duration {
d := cfg.InitialBackoff
for i := 1; i < attempt; i++ {
d = time.Duration(float64(d) * cfg.Multiplier)
if d > cfg.MaxBackoff {
d = cfg.MaxBackoff
break
}
}
return d
}
func (h *Handler) logRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) {
func (h *Handler) logRequest(tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) {
h.logger.Log(storage.RequestLog{
RequestID: requestID,
Timestamp: time.Now().Unix(),
TokenName: tokenName,
Model: model,
@ -304,23 +217,3 @@ func writeErrorRaw(w http.ResponseWriter, code int, body string) {
w.WriteHeader(code)
w.Write([]byte(body))
}
// formatHeaders serializes HTTP headers to a readable string, sorted by key.
// Sensitive headers (Authorization) are redacted.
func formatHeaders(h http.Header) string {
keys := make([]string, 0, len(h))
for k := range h {
keys = append(keys, k)
}
sort.Strings(keys)
var b strings.Builder
for _, k := range keys {
val := strings.Join(h[k], ", ")
if strings.EqualFold(k, "Authorization") {
val = "[REDACTED]"
}
fmt.Fprintf(&b, "%s: %s\n", k, val)
}
return b.String()
}

View file

@ -1,8 +1,6 @@
package proxy
import (
"fmt"
"math"
"net/http"
"sync"
"time"
@ -42,19 +40,7 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
// Check rate limit
if apiToken.RateLimitRPM > 0 {
allowed, remaining, resetAt := rl.allow(tokenName, apiToken.RateLimitRPM)
// Set rate limit headers on all responses
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", apiToken.RateLimitRPM))
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetAt))
if !allowed {
retryAfter := resetAt - time.Now().Unix()
if retryAfter < 1 {
retryAfter = 1
}
w.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter))
if !rl.allow(tokenName, apiToken.RateLimitRPM) {
writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
return
}
@ -73,7 +59,7 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
})
}
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) {
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
@ -96,27 +82,9 @@ func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int
}
bucket.lastRefill = now
remaining := int(math.Floor(bucket.tokens))
if remaining < 0 {
remaining = 0
}
// Compute reset time: when bucket would be full again
deficit := bucket.maxTokens - bucket.tokens
var resetAt int64
if deficit > 0 && bucket.refillRate > 0 {
resetAt = now.Add(time.Duration(deficit/bucket.refillRate) * time.Second).Unix()
} else {
resetAt = now.Unix()
}
if bucket.tokens < 1 {
return false, 0, resetAt
return false
}
bucket.tokens--
remaining = int(math.Floor(bucket.tokens))
if remaining < 0 {
remaining = 0
}
return true, remaining, resetAt
return true
}

View file

@ -1,374 +0,0 @@
package proxy
import (
"context"
"database/sql"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
_ "modernc.org/sqlite"
"llm-gateway/internal/auth"
"llm-gateway/internal/storage"
)
// newTestDB creates an in-memory SQLite database wrapped in storage.DB.
// It creates the request_logs table needed by TodaySpend.
func newTestDB(t *testing.T) *storage.DB {
t.Helper()
sqlDB, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("opening in-memory sqlite: %v", err)
}
t.Cleanup(func() { sqlDB.Close() })
// Create the minimal table needed for TodaySpend queries.
_, err = sqlDB.Exec(`CREATE TABLE IF NOT EXISTS request_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
token_name TEXT,
cost_usd REAL,
timestamp INTEGER
)`)
if err != nil {
t.Fatalf("creating request_logs table: %v", err)
}
return &storage.DB{DB: sqlDB}
}
// okHandler is a simple handler that writes 200 OK.
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
func TestRateLimiter_Allow(t *testing.T) {
tests := []struct {
name string
rateLimitRPM int
numRequests int
wantAllowed int
wantDenied int
}{
{
name: "allows requests within limit",
rateLimitRPM: 10,
numRequests: 5,
wantAllowed: 5,
wantDenied: 0,
},
{
name: "denies requests over limit",
rateLimitRPM: 3,
numRequests: 6,
wantAllowed: 3,
wantDenied: 3,
},
{
name: "allows exactly up to limit",
rateLimitRPM: 5,
numRequests: 5,
wantAllowed: 5,
wantDenied: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := newTestDB(t)
rl := NewRateLimiter(db)
allowed := 0
denied := 0
for i := 0; i < tt.numRequests; i++ {
ok, _, _ := rl.allow("test-token", tt.rateLimitRPM)
if ok {
allowed++
} else {
denied++
}
}
if allowed != tt.wantAllowed {
t.Errorf("allowed = %d, want %d", allowed, tt.wantAllowed)
}
if denied != tt.wantDenied {
t.Errorf("denied = %d, want %d", denied, tt.wantDenied)
}
})
}
}
func TestRateLimiter_TokenRefillsOverTime(t *testing.T) {
db := newTestDB(t)
rl := NewRateLimiter(db)
rpm := 60 // 1 token per second refill rate
// Exhaust all tokens.
for i := 0; i < rpm; i++ {
ok, _, _ := rl.allow("refill-token", rpm)
if !ok {
t.Fatalf("request %d should have been allowed", i)
}
}
// Next request should be denied.
ok, _, _ := rl.allow("refill-token", rpm)
if ok {
t.Fatal("request should have been denied after exhausting tokens")
}
// Manually advance the bucket's lastRefill to simulate time passing.
rl.mu.Lock()
bucket := rl.buckets["refill-token"]
bucket.lastRefill = bucket.lastRefill.Add(-2 * time.Second)
rl.mu.Unlock()
// After 2 seconds at 1 token/sec, we should have ~2 tokens refilled.
ok, remaining, _ := rl.allow("refill-token", rpm)
if !ok {
t.Fatal("request should have been allowed after token refill")
}
// We consumed 1 of the ~2 refilled tokens, so remaining should be >= 0.
if remaining < 0 {
t.Errorf("remaining = %d, want >= 0", remaining)
}
}
func TestRateLimiter_AllowReturnValues(t *testing.T) {
tests := []struct {
name string
rateLimitRPM int
numRequests int
wantLastAllowed bool
wantLastRemaining int
}{
{
name: "remaining decrements correctly",
rateLimitRPM: 5,
numRequests: 1,
wantLastAllowed: true,
wantLastRemaining: 4,
},
{
name: "remaining is zero at limit",
rateLimitRPM: 3,
numRequests: 3,
wantLastAllowed: true,
wantLastRemaining: 0,
},
{
name: "denied returns zero remaining",
rateLimitRPM: 2,
numRequests: 3,
wantLastAllowed: false,
wantLastRemaining: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := newTestDB(t)
rl := NewRateLimiter(db)
var allowed bool
var remaining int
for i := 0; i < tt.numRequests; i++ {
allowed, remaining, _ = rl.allow("test-token", tt.rateLimitRPM)
}
if allowed != tt.wantLastAllowed {
t.Errorf("allowed = %v, want %v", allowed, tt.wantLastAllowed)
}
if remaining != tt.wantLastRemaining {
t.Errorf("remaining = %d, want %d", remaining, tt.wantLastRemaining)
}
})
}
}
func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) {
tests := []struct {
name string
rateLimitRPM int
numRequests int
wantStatusCode int
wantLimitHeader string
wantRetryAfter bool
}{
{
name: "sets rate limit headers on allowed request",
rateLimitRPM: 10,
numRequests: 1,
wantStatusCode: http.StatusOK,
wantLimitHeader: "10",
wantRetryAfter: false,
},
{
name: "sets Retry-After header on 429",
rateLimitRPM: 2,
numRequests: 3,
wantStatusCode: http.StatusTooManyRequests,
wantLimitHeader: "2",
wantRetryAfter: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := newTestDB(t)
rl := NewRateLimiter(db)
token := &auth.APIToken{
Name: "header-test-token",
RateLimitRPM: tt.rateLimitRPM,
}
handler := rl.Check(okHandler)
var rec *httptest.ResponseRecorder
for i := 0; i < tt.numRequests; i++ {
rec = httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
}
// Check the last response.
if rec.Code != tt.wantStatusCode {
t.Errorf("status code = %d, want %d", rec.Code, tt.wantStatusCode)
}
// X-RateLimit-Limit header.
limitHeader := rec.Header().Get("X-RateLimit-Limit")
if limitHeader != tt.wantLimitHeader {
t.Errorf("X-RateLimit-Limit = %q, want %q", limitHeader, tt.wantLimitHeader)
}
// X-RateLimit-Remaining header must be present and numeric.
remainingHeader := rec.Header().Get("X-RateLimit-Remaining")
if remainingHeader == "" {
t.Error("X-RateLimit-Remaining header is missing")
} else if _, err := strconv.Atoi(remainingHeader); err != nil {
t.Errorf("X-RateLimit-Remaining = %q, not a valid integer", remainingHeader)
}
// X-RateLimit-Reset header must be present and numeric.
resetHeader := rec.Header().Get("X-RateLimit-Reset")
if resetHeader == "" {
t.Error("X-RateLimit-Reset header is missing")
} else if _, err := strconv.ParseInt(resetHeader, 10, 64); err != nil {
t.Errorf("X-RateLimit-Reset = %q, not a valid integer", resetHeader)
}
// Retry-After header.
retryAfter := rec.Header().Get("Retry-After")
if tt.wantRetryAfter && retryAfter == "" {
t.Error("Retry-After header is missing on 429 response")
}
if !tt.wantRetryAfter && retryAfter != "" {
t.Errorf("Retry-After header should not be present, got %q", retryAfter)
}
})
}
}
func TestRateLimiter_CheckMiddleware_NoToken(t *testing.T) {
db := newTestDB(t)
rl := NewRateLimiter(db)
handler := rl.Check(okHandler)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
// No API token in context.
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("status code = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
}
}
func TestRateLimiter_CheckMiddleware_ZeroRPM(t *testing.T) {
db := newTestDB(t)
rl := NewRateLimiter(db)
token := &auth.APIToken{
Name: "unlimited-token",
RateLimitRPM: 0, // zero means unlimited
}
handler := rl.Check(okHandler)
for i := 0; i < 100; i++ {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx := withAPIToken(req.Context(), token)
req = req.WithContext(ctx)
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("request %d: status code = %d, want %d (zero RPM should be unlimited)", i, rec.Code, http.StatusOK)
}
}
}
func TestRateLimiter_PerTokenIsolation(t *testing.T) {
db := newTestDB(t)
rl := NewRateLimiter(db)
rpm := 2
// Exhaust token A.
for i := 0; i < rpm; i++ {
rl.allow("token-a", rpm)
}
ok, _, _ := rl.allow("token-a", rpm)
if ok {
t.Fatal("token-a should be rate limited")
}
// Token B should still have its own bucket.
ok, _, _ = rl.allow("token-b", rpm)
if !ok {
t.Fatal("token-b should not be affected by token-a's rate limit")
}
}
func TestRateLimiter_ResetAtIsFuture(t *testing.T) {
db := newTestDB(t)
rl := NewRateLimiter(db)
// Consume one token so there's a deficit.
_, _, resetAt := rl.allow("reset-token", 10)
now := time.Now().Unix()
if resetAt < now {
t.Errorf("resetAt = %d, want >= %d (should be now or in the future)", resetAt, now)
}
}
func TestRateLimiter_CheckMiddleware_ContextCancelled(t *testing.T) {
db := newTestDB(t)
rl := NewRateLimiter(db)
token := &auth.APIToken{
Name: "ctx-token",
RateLimitRPM: 10,
}
handler := rl.Check(okHandler)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
ctx, cancel := context.WithCancel(req.Context())
ctx = withAPIToken(ctx, token)
cancel() // Cancel immediately.
req = req.WithContext(ctx)
// Should still process (rate limiter does not check context cancellation).
handler.ServeHTTP(rec, req)
// The handler itself may or may not respect cancelled context;
// the key point is no panic occurs.
}

View file

@ -2,7 +2,6 @@ package proxy
import (
"bufio"
"context"
"encoding/json"
"errors"
"log"
@ -13,7 +12,7 @@ import (
"llm-gateway/internal/provider"
)
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) {
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string) {
flusher, ok := w.(http.Flusher)
if !ok {
writeError(w, http.StatusInternalServerError, "streaming not supported")
@ -22,18 +21,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
var lastErr error
for i, route := range routes {
// Retry backoff between attempts
if i > 0 {
backoff := backoffDuration(i, h.cfg.Retry)
select {
case <-time.After(backoff):
case <-r.Context().Done():
writeError(w, http.StatusGatewayTimeout, "request cancelled")
return
}
}
for _, route := range routes {
start := time.Now()
body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req)
@ -42,95 +30,67 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
if errors.As(err, &pe) && !pe.IsRetryable() {
latency := time.Since(start).Milliseconds()
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
w.Header().Set("X-Request-ID", requestID)
writeErrorRaw(w, pe.StatusCode, pe.Body)
return
}
lastErr = err
latency := time.Since(start).Milliseconds()
log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err)
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
continue
}
// Apply streaming timeout
var streamCtx context.Context
var streamCancel context.CancelFunc
if h.cfg.Server.StreamingTimeout > 0 {
streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout)
} else {
streamCtx, streamCancel = context.WithCancel(r.Context())
}
// Stream the response
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
w.Header().Set("X-Request-ID", requestID)
w.WriteHeader(http.StatusOK)
inputTokens, outputTokens := 0, 0
scanner := bufio.NewScanner(body)
scanner.Buffer(make([]byte, 64*1024), 256*1024)
scanDone := make(chan struct{})
go func() {
defer close(scanDone)
for scanner.Scan() {
select {
case <-streamCtx.Done():
return
default:
}
for scanner.Scan() {
line := scanner.Text()
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data != "[DONE]" {
var chunk streamChunk
if json.Unmarshal([]byte(data), &chunk) == nil {
if chunk.Usage != nil {
inputTokens = chunk.Usage.PromptTokens
outputTokens = chunk.Usage.CompletionTokens
}
if chunk.Model != "" {
chunk.Model = req.Model
if rewritten, err := json.Marshal(chunk); err == nil {
line = "data: " + string(rewritten)
}
// Parse usage from the final chunk if available
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data != "[DONE]" {
var chunk streamChunk
if json.Unmarshal([]byte(data), &chunk) == nil {
if chunk.Usage != nil {
inputTokens = chunk.Usage.PromptTokens
outputTokens = chunk.Usage.CompletionTokens
}
// Override model name in chunk
if chunk.Model != "" {
chunk.Model = req.Model
if rewritten, err := json.Marshal(chunk); err == nil {
line = "data: " + string(rewritten)
}
}
}
}
w.Write([]byte(line + "\n"))
flusher.Flush()
}
}()
select {
case <-scanDone:
// Normal completion
case <-streamCtx.Done():
log.Printf("Stream timeout for %s via %s", req.Model, route.Provider.Name())
w.Write([]byte(line + "\n"))
flusher.Flush()
}
body.Close()
streamCancel()
latency := time.Since(start).Milliseconds()
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, nil)
}
@ -138,7 +98,6 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
}
// All providers failed
w.Header().Set("X-Request-ID", requestID)
if lastErr != nil {
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
} else {

View file

@ -1,102 +0,0 @@
package storage
import (
"log"
"time"
)
type AuditEntry struct {
ID int64 `json:"id"`
Timestamp int64 `json:"timestamp"`
UserID int64 `json:"user_id"`
Username string `json:"username"`
Action string `json:"action"`
TargetType string `json:"target_type"`
TargetID string `json:"target_id"`
Details string `json:"details"`
IPAddress string `json:"ip_address"`
RequestID string `json:"request_id"`
}
type AuditLogger struct {
db *DB
}
func NewAuditLogger(db *DB) *AuditLogger {
return &AuditLogger{db: db}
}
func (a *AuditLogger) Log(entry AuditEntry) {
if entry.Timestamp == 0 {
entry.Timestamp = time.Now().Unix()
}
_, err := a.db.Exec(`INSERT INTO audit_log
(timestamp, user_id, username, action, target_type, target_id, details, ip_address, request_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
entry.Timestamp, entry.UserID, entry.Username, entry.Action,
entry.TargetType, entry.TargetID, entry.Details, entry.IPAddress, entry.RequestID,
)
if err != nil {
log.Printf("ERROR: audit log: %v", err)
}
}
type AuditQueryResult struct {
Entries []AuditEntry `json:"entries"`
Page int `json:"page"`
TotalPages int `json:"total_pages"`
Total int `json:"total"`
}
func (a *AuditLogger) Query(since int64, action string, page, limit int) *AuditQueryResult {
if page < 1 {
page = 1
}
if limit <= 0 {
limit = 50
}
offset := (page - 1) * limit
where := "WHERE timestamp >= ?"
args := []any{since}
if action != "" {
where += " AND action = ?"
args = append(args, action)
}
var total int
countArgs := make([]any, len(args))
copy(countArgs, args)
a.db.QueryRow("SELECT COUNT(*) FROM audit_log "+where, countArgs...).Scan(&total)
totalPages := (total + limit - 1) / limit
if totalPages < 1 {
totalPages = 1
}
query := `SELECT id, timestamp, COALESCE(user_id, 0), username, action,
COALESCE(target_type, ''), COALESCE(target_id, ''), COALESCE(details, ''),
COALESCE(ip_address, ''), COALESCE(request_id, '')
FROM audit_log ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?`
args = append(args, limit, offset)
rows, err := a.db.Query(query, args...)
if err != nil {
return &AuditQueryResult{Entries: []AuditEntry{}, Page: page, TotalPages: totalPages, Total: total}
}
defer rows.Close()
var entries []AuditEntry
for rows.Next() {
var e AuditEntry
rows.Scan(&e.ID, &e.Timestamp, &e.UserID, &e.Username, &e.Action,
&e.TargetType, &e.TargetID, &e.Details, &e.IPAddress, &e.RequestID)
entries = append(entries, e)
}
if entries == nil {
entries = []AuditEntry{}
}
return &AuditQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
}

View file

@ -1,250 +0,0 @@
package storage
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strings"
"sync/atomic"
"time"
)
type DebugLogEntry struct {
ID int64 `json:"id"`
RequestID string `json:"request_id"`
Timestamp int64 `json:"timestamp"`
TokenName string `json:"token_name"`
Model string `json:"model"`
Provider string `json:"provider"`
RequestBody string `json:"request_body"`
ResponseBody string `json:"response_body"`
RequestHeaders string `json:"request_headers"`
ResponseStatus int `json:"response_status"`
FilePath string `json:"-"`
}
// debugFile is the JSON structure written to disk.
type debugFile struct {
RequestHeaders string `json:"request_headers"`
RequestBody string `json:"request_body"`
ResponseBody string `json:"response_body"`
}
type DebugLogger struct {
db *DB
enabled atomic.Bool
dataDir string
}
func NewDebugLogger(db *DB, enabled bool, dataDir string) *DebugLogger {
dl := &DebugLogger{db: db, dataDir: dataDir}
dl.enabled.Store(enabled)
return dl
}
func (d *DebugLogger) SetEnabled(v bool) {
d.enabled.Store(v)
}
func (d *DebugLogger) IsEnabled() bool {
return d.enabled.Load()
}
// debugLogDir returns the base directory for debug log files.
func (d *DebugLogger) debugLogDir() string {
return filepath.Join(d.dataDir, "debug-logs")
}
// debugFilePath builds the file path for a debug log entry.
func (d *DebugLogger) debugFilePath(requestID string, ts time.Time) string {
date := ts.Format("2006-01-02")
return filepath.Join(d.debugLogDir(), date, requestID+".json")
}
func (d *DebugLogger) Log(entry DebugLogEntry) {
if !d.IsEnabled() {
return
}
if entry.Timestamp == 0 {
entry.Timestamp = time.Now().Unix()
}
ts := time.Unix(entry.Timestamp, 0)
fp := d.debugFilePath(entry.RequestID, ts)
// Write body file
if err := os.MkdirAll(filepath.Dir(fp), 0755); err != nil {
log.Printf("ERROR: debug log mkdir: %v", err)
return
}
df := debugFile{
RequestHeaders: entry.RequestHeaders,
RequestBody: entry.RequestBody,
ResponseBody: entry.ResponseBody,
}
data, err := json.Marshal(df)
if err != nil {
log.Printf("ERROR: debug log marshal: %v", err)
return
}
if err := os.WriteFile(fp, data, 0644); err != nil {
log.Printf("ERROR: debug log write: %v", err)
return
}
// Insert metadata into DB (no bodies)
_, err = d.db.Exec(`INSERT INTO debug_log
(request_id, timestamp, token_name, model, provider, response_status, file_path)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
entry.RequestID, entry.Timestamp, entry.TokenName, entry.Model,
entry.Provider, entry.ResponseStatus, fp,
)
if err != nil {
log.Printf("ERROR: debug log db insert: %v", err)
}
}
type DebugLogQueryResult struct {
Entries []DebugLogEntry `json:"entries"`
Page int `json:"page"`
TotalPages int `json:"total_pages"`
Total int `json:"total"`
}
// Query returns paginated debug log metadata (no bodies — fast).
func (d *DebugLogger) Query(page, limit int) *DebugLogQueryResult {
if page < 1 {
page = 1
}
if limit <= 0 {
limit = 50
}
offset := (page - 1) * limit
var total int
d.db.QueryRow("SELECT COUNT(*) FROM debug_log").Scan(&total)
totalPages := (total + limit - 1) / limit
if totalPages < 1 {
totalPages = 1
}
rows, err := d.db.Query(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
FROM debug_log ORDER BY timestamp DESC LIMIT ? OFFSET ?`, limit, offset)
if err != nil {
return &DebugLogQueryResult{Entries: []DebugLogEntry{}, Page: page, TotalPages: totalPages, Total: total}
}
defer rows.Close()
var entries []DebugLogEntry
for rows.Next() {
var e DebugLogEntry
rows.Scan(&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
&e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
entries = append(entries, e)
}
if entries == nil {
entries = []DebugLogEntry{}
}
return &DebugLogQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
}
// QueryFull returns paginated debug log entries including request/response bodies read from files.
func (d *DebugLogger) QueryFull(page, limit int) *DebugLogQueryResult {
result := d.Query(page, limit)
for i := range result.Entries {
d.populateFromFile(&result.Entries[i])
}
return result
}
// GetByRequestID returns a single debug log entry with bodies read from file.
func (d *DebugLogger) GetByRequestID(requestID string) *DebugLogEntry {
var e DebugLogEntry
err := d.db.QueryRow(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
FROM debug_log WHERE request_id = ?`, requestID).Scan(
&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
&e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
if err != nil {
return nil
}
d.populateFromFile(&e)
return &e
}
// populateFromFile reads body data from the debug file on disk.
// Falls back to DB columns for pre-migration entries that have no file_path.
func (d *DebugLogger) populateFromFile(e *DebugLogEntry) {
if e.FilePath == "" {
// Legacy entry: try reading bodies from DB columns
d.db.QueryRow(`SELECT COALESCE(request_body, ''), COALESCE(response_body, ''), COALESCE(request_headers, '')
FROM debug_log WHERE id = ?`, e.ID).Scan(&e.RequestBody, &e.ResponseBody, &e.RequestHeaders)
return
}
data, err := os.ReadFile(e.FilePath)
if err != nil {
log.Printf("WARN: debug log read file %s: %v", e.FilePath, err)
return
}
var df debugFile
if err := json.Unmarshal(data, &df); err != nil {
log.Printf("WARN: debug log parse file %s: %v", e.FilePath, err)
return
}
e.RequestHeaders = df.RequestHeaders
e.RequestBody = df.RequestBody
e.ResponseBody = df.ResponseBody
}
// Cleanup removes debug log entries and files older than retentionDays.
func (d *DebugLogger) Cleanup(retentionDays int) error {
cutoff := time.Now().AddDate(0, 0, -retentionDays)
cutoffUnix := cutoff.Unix()
// Delete old DB rows
result, err := d.db.Exec("DELETE FROM debug_log WHERE timestamp < ?", cutoffUnix)
if err != nil {
return fmt.Errorf("delete old debug rows: %w", err)
}
affected, _ := result.RowsAffected()
if affected > 0 {
log.Printf("Cleaned up %d old debug log entries", affected)
}
// Remove old date directories
baseDir := d.debugLogDir()
dirs, err := os.ReadDir(baseDir)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("read debug log dir: %w", err)
}
cutoffDate := cutoff.Format("2006-01-02")
sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() })
for _, dir := range dirs {
if !dir.IsDir() {
continue
}
// Date directories are named YYYY-MM-DD; string comparison works
if strings.Compare(dir.Name(), cutoffDate) < 0 {
dirPath := filepath.Join(baseDir, dir.Name())
if err := os.RemoveAll(dirPath); err != nil {
log.Printf("WARN: failed to remove debug log dir %s: %v", dirPath, err)
} else {
log.Printf("Removed old debug log directory: %s", dir.Name())
}
}
}
return nil
}

View file

@ -6,7 +6,6 @@ import (
)
type RequestLog struct {
RequestID string
Timestamp int64
TokenName string
Model string
@ -94,8 +93,8 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
}
stmt, err := tx.Prepare(`INSERT INTO request_logs
(request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
(timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
log.Printf("ERROR: preparing log statement: %v", err)
tx.Rollback()
@ -113,7 +112,7 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
cached = 1
}
_, err := stmt.Exec(
r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS,
r.Status, r.ErrorMessage, streaming, cached,
)

View file

@ -1,4 +0,0 @@
-- SQLite doesn't support DROP COLUMN in older versions, so we recreate the table
CREATE TABLE api_tokens_backup AS SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens;
DROP TABLE api_tokens;
ALTER TABLE api_tokens_backup RENAME TO api_tokens;

View file

@ -1 +0,0 @@
ALTER TABLE api_tokens ADD COLUMN max_concurrent INTEGER DEFAULT 0;

View file

@ -1 +0,0 @@
DROP INDEX IF EXISTS idx_request_logs_request_id;

View file

@ -1,2 +0,0 @@
ALTER TABLE request_logs ADD COLUMN request_id TEXT DEFAULT '';
CREATE INDEX idx_request_logs_request_id ON request_logs(request_id);

View file

@ -1 +0,0 @@
DROP TABLE IF EXISTS audit_log;

View file

@ -1,14 +0,0 @@
CREATE TABLE audit_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp INTEGER NOT NULL,
user_id INTEGER,
username TEXT NOT NULL DEFAULT '',
action TEXT NOT NULL,
target_type TEXT DEFAULT '',
target_id TEXT DEFAULT '',
details TEXT DEFAULT '',
ip_address TEXT DEFAULT '',
request_id TEXT DEFAULT ''
);
CREATE INDEX idx_audit_timestamp ON audit_log(timestamp);
CREATE INDEX idx_audit_action ON audit_log(action);

View file

@ -1 +0,0 @@
DROP TABLE IF EXISTS debug_log;

View file

@ -1,14 +0,0 @@
CREATE TABLE debug_log (
id INTEGER PRIMARY KEY AUTOINCREMENT,
request_id TEXT NOT NULL,
timestamp INTEGER NOT NULL,
token_name TEXT DEFAULT '',
model TEXT DEFAULT '',
provider TEXT DEFAULT '',
request_body TEXT DEFAULT '',
response_body TEXT DEFAULT '',
request_headers TEXT DEFAULT '',
response_status INTEGER DEFAULT 0
);
CREATE INDEX idx_debug_request_id ON debug_log(request_id);
CREATE INDEX idx_debug_timestamp ON debug_log(timestamp);

View file

@ -1 +0,0 @@
-- no-op: file_path column is harmless to keep

View file

@ -1 +0,0 @@
ALTER TABLE debug_log ADD COLUMN file_path TEXT DEFAULT '';

147
new-api/CHANNELS.md Normal file
View file

@ -0,0 +1,147 @@
# new-api Channel Configuration
After first start, access the new-api web UI at `http://<server>:4000` to configure channels.
Default admin credentials: `root` / `123456`**change immediately**.
## API Token for Open WebUI
Create an API token in new-api's token management. Use this token as `OPENWEBUI_API_KEY` in `.env`.
## Channels to Create
Configure each channel via **Channels > Add Channel** in the web UI.
### 1. DeepInfra (Priority 1)
| Field | Value |
|---|---|
| Name | DeepInfra |
| Type | OpenAI |
| Base URL | `https://api.deepinfra.com/v1/openai` |
| Key | `$DEEPINFRA_API_KEY` |
| Priority | 1 |
| Models | See model mapping below |
### 2. SiliconFlow (Priority 2)
| Field | Value |
|---|---|
| Name | SiliconFlow |
| Type | OpenAI |
| Base URL | `https://api.siliconflow.com/v1` |
| Key | `$SILICONFLOW_API_KEY` |
| Priority | 2 |
| Models | See model mapping below |
### 3. OpenRouter (Priority 3)
| Field | Value |
|---|---|
| Name | OpenRouter |
| Type | OpenAI |
| Base URL | `https://openrouter.ai/api/v1` |
| Key | `$OPENROUTER_API_KEY` |
| Priority | 3 |
| Models | See model mapping below |
### 4. Groq (Priority 1)
| Field | Value |
|---|---|
| Name | Groq |
| Type | OpenAI |
| Base URL | `https://api.groq.com/openai/v1` |
| Key | `$GROQ_API_KEY` |
| Priority | 1 |
| Models | `llama-3.3-70b` |
### 5. Cerebras (Priority 1)
| Field | Value |
|---|---|
| Name | Cerebras |
| Type | OpenAI |
| Base URL | `https://api.cerebras.ai/v1` |
| Key | `$CEREBRAS_API_KEY` |
| Priority | 1 |
| Models | `llama-3.3-70b-cerebras` |
## Model Mapping per Channel
new-api uses model aliasing: the "model name" is what clients see, the "actual model" is what's sent to the provider.
### DeepInfra Models
| Client Model Name | Actual Provider Model |
|---|---|
| `deepseek-v3.2` | `deepseek-ai/DeepSeek-V3.2` |
| `deepseek-r1` | `deepseek-ai/DeepSeek-R1` |
| `gpt-oss` | `openai/gpt-oss-120b` |
| `gpt-oss-20b` | `openai/gpt-oss-20b` |
| `nemotron-super` | `nvidia/Llama-3.3-Nemotron-Super-49B-v1.5` |
| `nemotron-nano` | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` |
| `devstral` | `mistralai/Devstral-Small-2505` |
| `glm-4.6` | `zai-org/GLM-4.6` |
| `glm-4.7` | `zai-org/GLM-4.7` |
| `glm-5` | `zai-org/GLM-5` |
| `kimi-k2` | `moonshotai/Kimi-K2-Instruct-0905` |
| `kimi-k2.5` | `moonshotai/Kimi-K2.5` |
| `deepseek-v3-free` | `deepseek-ai/DeepSeek-V3` |
### SiliconFlow Models
| Client Model Name | Actual Provider Model |
|---|---|
| `deepseek-v3.2` | `deepseek-ai/DeepSeek-V3.2` |
| `glm-4.7` | `THUDM/GLM-4-32B-0414` |
| `kimi-k2` | `moonshotai/Kimi-K2-Instruct-0905` |
| `qwen3-coder` | `Qwen/Qwen3-Coder-480B-A35B-Instruct` |
| `qwen3-coder-30b` | `Qwen/Qwen3-Coder-30B-A3B-Instruct` |
### OpenRouter Models
| Client Model Name | Actual Provider Model |
|---|---|
| `deepseek-v3.2` | `deepseek/deepseek-chat-v3-0324` |
| `deepseek-v3-free` | `deepseek/deepseek-chat-v3-0324:free` |
| `kimi-k2.5` | `moonshotai/kimi-k2.5` |
| `minimax-m2.5` | `minimax/minimax-m2.5` |
| `gpt-4.1-mini` | `openai/gpt-4.1-mini` |
| `gpt-4.1` | `openai/gpt-4.1` |
| `gemini-3-flash-preview` | `google/gemini-3-flash-preview` |
| `gemini-2.5-pro` | `google/gemini-2.5-pro-preview` |
| `claude-sonnet` | `anthropic/claude-sonnet-4` |
| `trinity-large-preview` | `arcee-ai/trinity-large-preview` |
### Groq Models
| Client Model Name | Actual Provider Model |
|---|---|
| `llama-3.3-70b` | `llama-3.3-70b-versatile` |
### Cerebras Models
| Client Model Name | Actual Provider Model |
|---|---|
| `llama-3.3-70b-cerebras` | `llama-3.3-70b` |
## Fallback Behavior
new-api handles fallbacks via priority levels:
- When a model exists on multiple channels, the highest priority (lowest number) channel is tried first
- If it fails, it automatically falls back to the next priority level
For example, `deepseek-v3.2` exists on:
1. DeepInfra (priority 1) — tried first
2. SiliconFlow (priority 2) — fallback
3. OpenRouter (priority 3) — last resort
## Grafana Setup
After first start, access Grafana at `http://<server>:3001`:
1. Login with `admin` / `$GRAFANA_ADMIN_PASSWORD`
2. Add data source: **Prometheus** with URL `http://victoriametrics:8428`
3. Import dashboards:
- Node Exporter Full: dashboard ID `1860`
- Redis: dashboard ID `763`

218
new-api/init-channels.sh Executable file
View file

@ -0,0 +1,218 @@
#!/usr/bin/env bash
# Configures new-api channels and API token via the admin API.
# Run once after first boot: ./new-api/init-channels.sh
#
# Requires these env vars (or .env file in project root):
# NEW_API_PASSWORD - admin password
# DEEPINFRA_API_KEY
# SILICONFLOW_API_KEY
# OPENROUTER_API_KEY
# GROQ_API_KEY
# CEREBRAS_API_KEY
#
# Optional:
# NEW_API_USERNAME - admin username (default: root)
# NEW_API_BASE - API base URL (default: http://localhost:4000)
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
ENV_FILE="${SCRIPT_DIR}/../.env"
# Load .env if present
if [[ -f "$ENV_FILE" ]]; then
set -a
# shellcheck disable=SC1090
source "$ENV_FILE"
set +a
fi
API_BASE="${NEW_API_BASE:-http://localhost:4000}"
USERNAME="${NEW_API_USERNAME:-root}"
PASSWORD="${NEW_API_PASSWORD:?Set NEW_API_PASSWORD to the admin password}"
COOKIE_JAR=$(mktemp)
USER_ID=""
trap 'rm -f "$COOKIE_JAR"' EXIT
# ── Login and get user ID ───────────────────────────────
login() {
echo "Logging in as ${USERNAME}..."
local resp
resp=$(curl -s -c "$COOKIE_JAR" "${API_BASE}/api/user/login" \
-H "Content-Type: application/json" \
-d "$(python3 -c "
import json, sys
print(json.dumps({'username': sys.argv[1], 'password': sys.argv[2]}))
" "$USERNAME" "$PASSWORD")")
local success
success=$(echo "$resp" | python3 -c "import sys,json; print(json.load(sys.stdin).get('success', False))")
if [[ "$success" != "True" ]]; then
echo "ERROR: Login failed: ${resp}"
exit 1
fi
USER_ID=$(echo "$resp" | python3 -c "import sys,json; print(json.load(sys.stdin)['data']['id'])")
echo " Logged in (user ID: ${USER_ID})."
}
# ── API call helper (cookie + New-Api-User header) ─────
api_call() {
local endpoint="$1"
shift
curl -s -b "$COOKIE_JAR" \
-H "New-Api-User: ${USER_ID}" \
-H "Content-Type: application/json" \
"${API_BASE}${endpoint}" "$@"
}
# ── Create channel ──────────────────────────────────────
create_channel() {
local name="$1" type="$2" key="$3" base_url="$4" priority="$5" models="$6" model_mapping="$7"
echo "Creating channel: ${name} (priority ${priority})..."
local payload
payload=$(python3 -c "
import json, sys
print(json.dumps({
'type': int(sys.argv[1]),
'name': sys.argv[2],
'key': sys.argv[3],
'base_url': sys.argv[4],
'models': sys.argv[5],
'model_mapping': sys.argv[6],
'priority': int(sys.argv[7]),
'status': 1,
'group': 'default',
'weight': 1,
'auto_ban': 1
}))
" "$type" "$name" "$key" "$base_url" "$models" "$model_mapping" "$priority")
local resp success
resp=$(api_call "/api/channel/" -d "$payload")
success=$(echo "$resp" | python3 -c "import sys,json; print(json.load(sys.stdin).get('success', False))")
if [[ "$success" == "True" ]]; then
echo " OK"
else
echo " FAILED: ${resp}" | head -c 500
echo
fi
}
# ── Wait for new-api ────────────────────────────────────
echo "Waiting for new-api at ${API_BASE}..."
for i in $(seq 1 30); do
if curl -sf "${API_BASE}/" > /dev/null 2>&1; then
echo "new-api is ready."
break
fi
if [[ "$i" == "30" ]]; then
echo "ERROR: new-api did not become ready in time."
exit 1
fi
sleep 2
done
# ── Login ───────────────────────────────────────────────
login
# ── Generate system access token for future use ─────────
echo ""
echo "Generating system access token..."
ACCESS_TOKEN_RESP=$(api_call "/api/user/token")
ACCESS_TOKEN=$(echo "$ACCESS_TOKEN_RESP" | python3 -c "
import sys, json
data = json.load(sys.stdin)
if data.get('success'):
print(data.get('data', ''))
else:
print('')
" 2>/dev/null || echo "")
if [[ -n "$ACCESS_TOKEN" ]]; then
echo " Access token: ${ACCESS_TOKEN}"
echo " Save as NEW_API_ACCESS_TOKEN in .env for future API use."
echo " Usage: -H 'Authorization: Bearer ${ACCESS_TOKEN}' -H 'New-Api-User: ${USER_ID}'"
else
echo " Could not generate access token (non-critical, using session)."
fi
# ── Channels ────────────────────────────────────────────
create_channel "DeepInfra" 1 \
"${DEEPINFRA_API_KEY:?}" \
"https://api.deepinfra.com/v1/openai" \
1 \
"deepseek-v3.2,deepseek-r1,gpt-oss,gpt-oss-20b,nemotron-super,nemotron-nano,devstral,glm-4.6,glm-4.7,glm-5,kimi-k2,kimi-k2.5" \
'{"deepseek-v3.2":"deepseek-ai/DeepSeek-V3.2","deepseek-r1":"deepseek-ai/DeepSeek-R1","gpt-oss":"openai/gpt-oss-120b","gpt-oss-20b":"openai/gpt-oss-20b","nemotron-super":"nvidia/Llama-3.3-Nemotron-Super-49B-v1.5","nemotron-nano":"nvidia/NVIDIA-Nemotron-Nano-9B-v2","devstral":"mistralai/Devstral-Small-2505","glm-4.6":"zai-org/GLM-4.6","glm-4.7":"zai-org/GLM-4.7","glm-5":"zai-org/GLM-5","kimi-k2":"moonshotai/Kimi-K2-Instruct-0905","kimi-k2.5":"moonshotai/Kimi-K2.5"}'
create_channel "SiliconFlow" 1 \
"${SILICONFLOW_API_KEY:?}" \
"https://api.siliconflow.com/v1" \
2 \
"deepseek-v3.2,glm-4.7,kimi-k2,qwen3-coder,qwen3-coder-30b" \
'{"deepseek-v3.2":"deepseek-ai/DeepSeek-V3.2","glm-4.7":"THUDM/GLM-4-32B-0414","kimi-k2":"moonshotai/Kimi-K2-Instruct-0905","qwen3-coder":"Qwen/Qwen3-Coder-480B-A35B-Instruct","qwen3-coder-30b":"Qwen/Qwen3-Coder-30B-A3B-Instruct"}'
create_channel "OpenRouter" 1 \
"${OPENROUTER_API_KEY:?}" \
"https://openrouter.ai/api/v1" \
3 \
"deepseek-v3.2,deepseek-v3-free,kimi-k2.5,minimax-m2.5,gpt-4.1-mini,gpt-4.1,gemini-3-flash-preview,gemini-2.5-pro,claude-sonnet,trinity-large-preview" \
'{"deepseek-v3.2":"deepseek/deepseek-chat-v3-0324","deepseek-v3-free":"deepseek/deepseek-chat-v3-0324:free","kimi-k2.5":"moonshotai/kimi-k2.5","minimax-m2.5":"minimax/minimax-m2.5","gpt-4.1-mini":"openai/gpt-4.1-mini","gpt-4.1":"openai/gpt-4.1","gemini-3-flash-preview":"google/gemini-3-flash-preview","gemini-2.5-pro":"google/gemini-2.5-pro-preview","claude-sonnet":"anthropic/claude-sonnet-4","trinity-large-preview":"arcee-ai/trinity-large-preview"}'
create_channel "Groq" 1 \
"${GROQ_API_KEY:?}" \
"https://api.groq.com/openai/v1" \
1 \
"llama-3.3-70b" \
'{"llama-3.3-70b":"llama-3.3-70b-versatile"}'
create_channel "Cerebras" 1 \
"${CEREBRAS_API_KEY:?}" \
"https://api.cerebras.ai/v1" \
1 \
"llama-3.3-70b-cerebras" \
'{"llama-3.3-70b-cerebras":"llama-3.3-70b"}'
# ── Create API token for Open WebUI ─────────────────────
echo ""
echo "Creating API token for Open WebUI..."
TOKEN_RESP=$(api_call "/api/token/" -d "$(python3 -c "
import json
print(json.dumps({
'name': 'open-webui',
'remain_quota': 0,
'unlimited_quota': True
}))
")")
TOKEN_KEY=$(echo "$TOKEN_RESP" | python3 -c "
import sys, json
data = json.load(sys.stdin)
if data.get('success'):
print(data['data']['key'])
else:
print('FAILED: ' + data.get('message', 'unknown error'))
" 2>/dev/null || echo "FAILED: could not parse response")
echo ""
echo "══════════════════════════════════════"
echo "Setup complete!"
echo ""
if [[ "$TOKEN_KEY" != FAILED* ]]; then
echo "Open WebUI API key: ${TOKEN_KEY}"
echo "Set OPENWEBUI_API_KEY=${TOKEN_KEY} in your .env"
echo ""
echo "Test:"
echo " curl ${API_BASE}/v1/chat/completions \\"
echo " -H 'Authorization: Bearer ${TOKEN_KEY}' \\"
echo " -H 'Content-Type: application/json' \\"
echo " -d '{\"model\":\"deepseek-v3.2\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}'"
else
echo "Token creation: ${TOKEN_KEY}"
echo "Create a token manually in the new-api UI."
fi
echo "══════════════════════════════════════"