Compare commits
3 commits
f23a7c14c0
...
28a694744d
| Author | SHA1 | Date | |
|---|---|---|---|
| 28a694744d | |||
| 90adf6f3a8 | |||
| c17ffa41f5 |
53 changed files with 4825 additions and 710 deletions
|
|
@ -12,7 +12,7 @@ ADMIN_USERNAME=admin
|
|||
ADMIN_PASSWORD=change-me-min-8-chars
|
||||
# Static API tokens (seeded on startup, leave empty to skip)
|
||||
OPENWEBUI_API_KEY=sk-...
|
||||
PERSONAL_API_KEY=sk-...
|
||||
OPENCODE_API_KEY=sk-...
|
||||
# Provider API keys
|
||||
OPENROUTER_API_KEY=sk-or-...
|
||||
SILICONFLOW_API_KEY=sk-...
|
||||
|
|
|
|||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -3,5 +3,8 @@
|
|||
# Environment secrets
|
||||
.env
|
||||
|
||||
# Host-mounted data directories
|
||||
data/
|
||||
|
||||
# SearXNG runtime state
|
||||
searxng/uwsgi.ini
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ services:
|
|||
ports:
|
||||
- "0.0.0.0:4000:3000"
|
||||
volumes:
|
||||
- llm-gateway-data:/data
|
||||
- ./data/llm-gateway:/data
|
||||
- ./llm-gateway.yaml:/etc/llm-gateway/config.yaml:ro
|
||||
environment:
|
||||
- SESSION_SECRET=${SESSION_SECRET}
|
||||
|
|
@ -161,7 +161,6 @@ services:
|
|||
volumes:
|
||||
valkey-data:
|
||||
chromadb-data:
|
||||
llm-gateway-data:
|
||||
open-webui-data:
|
||||
tailscale-state:
|
||||
victoriametrics-data:
|
||||
|
|
|
|||
|
|
@ -1,233 +0,0 @@
|
|||
model_list:
|
||||
# ═══════════════════════════════════════════════
|
||||
# TIER 1: Free providers (try first)
|
||||
# ═══════════════════════════════════════════════
|
||||
|
||||
# --- Groq (free tier, very fast) ---
|
||||
- model_name: llama-3.3-70b
|
||||
litellm_params:
|
||||
model: groq/llama-3.3-70b-versatile
|
||||
api_key: os.environ/GROQ_API_KEY
|
||||
|
||||
# --- Cerebras (free tier, very fast) ---
|
||||
- model_name: llama-3.3-70b-cerebras
|
||||
litellm_params:
|
||||
model: cerebras/llama-3.3-70b
|
||||
api_key: os.environ/CEREBRAS_API_KEY
|
||||
|
||||
# --- OpenRouter free models ---
|
||||
- model_name: deepseek-v3-free
|
||||
litellm_params:
|
||||
model: openrouter/deepseek/deepseek-chat-v3-0324:free
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
# ═══════════════════════════════════════════════
|
||||
# TIER 2: DeepSeek V3.2 (cheapest first)
|
||||
# ═══════════════════════════════════════════════
|
||||
|
||||
# DeepSeek V3.2 via DeepInfra ($0.26 in / $0.38 out per M)
|
||||
- model_name: deepseek-v3.2
|
||||
litellm_params:
|
||||
model: deepinfra/deepseek-ai/DeepSeek-V3.2
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# DeepSeek V3.2 fallback via SiliconFlow ($0.27 in / $0.42 out per M)
|
||||
- model_name: deepseek-v3.2
|
||||
litellm_params:
|
||||
model: openai/deepseek-ai/DeepSeek-V3.2
|
||||
api_base: https://api.siliconflow.com/v1
|
||||
api_key: os.environ/SILICONFLOW_API_KEY
|
||||
|
||||
# ═══════════════════════════════════════════════
|
||||
# TIER 3: Ultra-cheap DeepInfra models
|
||||
# ═══════════════════════════════════════════════
|
||||
|
||||
# GPT-OSS-120B — OpenAI open-weight MoE ($0.05 in / $0.24 out per M)
|
||||
- model_name: gpt-oss
|
||||
litellm_params:
|
||||
model: deepinfra/openai/gpt-oss-120b
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# GPT-OSS-20B — lower latency variant ($0.04 in / $0.16 out per M)
|
||||
- model_name: gpt-oss-20b
|
||||
litellm_params:
|
||||
model: deepinfra/openai/gpt-oss-20b
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# Nemotron Super 49B — near-flagship quality ($0.10 in / $0.40 out per M)
|
||||
- model_name: nemotron-super
|
||||
litellm_params:
|
||||
model: deepinfra/nvidia/Llama-3.3-Nemotron-Super-49B-v1.5
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# Nemotron Nano 9B — dirt cheap for simple tasks ($0.04 in / $0.16 out per M)
|
||||
- model_name: nemotron-nano
|
||||
litellm_params:
|
||||
model: deepinfra/nvidia/NVIDIA-Nemotron-Nano-9B-v2
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# ═══════════════════════════════════════════════
|
||||
# TIER 4: Other DeepInfra models
|
||||
# ═══════════════════════════════════════════════
|
||||
|
||||
- model_name: deepseek-r1
|
||||
litellm_params:
|
||||
model: deepinfra/deepseek-ai/DeepSeek-R1
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
- model_name: devstral
|
||||
litellm_params:
|
||||
model: deepinfra/mistralai/Devstral-Small-2505
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# ═══════════════════════════════════════════════
|
||||
# TIER 5: GLM models (cheapest first)
|
||||
# ═══════════════════════════════════════════════
|
||||
|
||||
# GLM-4.6 via DeepInfra ($0.60 in / $1.90 out per M)
|
||||
- model_name: glm-4.6
|
||||
litellm_params:
|
||||
model: deepinfra/zai-org/GLM-4.6
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# GLM-4.7 via DeepInfra ($0.40 in / $1.75 out per M)
|
||||
- model_name: glm-4.7
|
||||
litellm_params:
|
||||
model: deepinfra/zai-org/GLM-4.7
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# GLM-4.7 fallback via SiliconFlow
|
||||
- model_name: glm-4.7
|
||||
litellm_params:
|
||||
model: openai/THUDM/GLM-4-32B-0414
|
||||
api_base: https://api.siliconflow.com/v1
|
||||
api_key: os.environ/SILICONFLOW_API_KEY
|
||||
|
||||
# GLM-5 via DeepInfra ($0.80 in / $2.56 out per M)
|
||||
- model_name: glm-5
|
||||
litellm_params:
|
||||
model: deepinfra/zai-org/GLM-5
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# ═══════════════════════════════════════════════
|
||||
# TIER 6: Kimi K2 (cheapest first)
|
||||
# ═══════════════════════════════════════════════
|
||||
|
||||
# Kimi K2 via DeepInfra ($0.50 in / $2.00 out per M)
|
||||
- model_name: kimi-k2
|
||||
litellm_params:
|
||||
model: deepinfra/moonshotai/Kimi-K2-Instruct-0905
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# Kimi K2 fallback via SiliconFlow ($0.58 in / $2.29 out per M)
|
||||
- model_name: kimi-k2
|
||||
litellm_params:
|
||||
model: openai/moonshotai/Kimi-K2-Instruct-0905
|
||||
api_base: https://api.siliconflow.com/v1
|
||||
api_key: os.environ/SILICONFLOW_API_KEY
|
||||
|
||||
# ═══════════════════════════════════════════════
|
||||
# TIER 7: SiliconFlow (Qwen)
|
||||
# ═══════════════════════════════════════════════
|
||||
|
||||
# Qwen3 Coder 480B MoE via SiliconFlow ($1.14 in / $2.28 out per M)
|
||||
- model_name: qwen3-coder
|
||||
litellm_params:
|
||||
model: openai/Qwen/Qwen3-Coder-480B-A35B-Instruct
|
||||
api_base: https://api.siliconflow.com/v1
|
||||
api_key: os.environ/SILICONFLOW_API_KEY
|
||||
|
||||
# Qwen3 Coder 30B — cheaper alternative for simpler tasks
|
||||
- model_name: qwen3-coder-30b
|
||||
litellm_params:
|
||||
model: openai/Qwen/Qwen3-Coder-30B-A3B-Instruct
|
||||
api_base: https://api.siliconflow.com/v1
|
||||
api_key: os.environ/SILICONFLOW_API_KEY
|
||||
|
||||
# ═══════════════════════════════════════════════
|
||||
# TIER 8: OpenRouter (most expensive, widest selection)
|
||||
# ═══════════════════════════════════════════════
|
||||
|
||||
# Kimi K2.5 — DeepInfra is cheapest ($0.45 in / $2.25 out per M)
|
||||
- model_name: kimi-k2.5
|
||||
litellm_params:
|
||||
model: deepinfra/moonshotai/Kimi-K2.5
|
||||
api_key: os.environ/DEEPINFRA_API_KEY
|
||||
|
||||
# Kimi K2.5 fallback via OpenRouter
|
||||
- model_name: kimi-k2.5
|
||||
litellm_params:
|
||||
model: openrouter/moonshotai/kimi-k2.5
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
- model_name: minimax-m2.5
|
||||
litellm_params:
|
||||
model: openrouter/minimax/minimax-m2.5
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
- model_name: gpt-4.1-mini
|
||||
litellm_params:
|
||||
model: openrouter/openai/gpt-4.1-mini
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
- model_name: gemini-3-flash-preview
|
||||
litellm_params:
|
||||
model: openrouter/google/gemini-3-flash-preview
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
- model_name: trinity-large-preview
|
||||
litellm_params:
|
||||
model: openrouter/arcee-ai/trinity-large-preview
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
# --- OpenRouter premium models ---
|
||||
- model_name: gemini-2.5-pro
|
||||
litellm_params:
|
||||
model: openrouter/google/gemini-2.5-pro-preview
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
- model_name: claude-sonnet
|
||||
litellm_params:
|
||||
model: openrouter/anthropic/claude-sonnet-4
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
- model_name: gpt-4.1
|
||||
litellm_params:
|
||||
model: openrouter/openai/gpt-4.1
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
# DeepSeek V3.2 last-resort fallback via OpenRouter
|
||||
- model_name: deepseek-v3.2
|
||||
litellm_params:
|
||||
model: openrouter/deepseek/deepseek-chat-v3-0324
|
||||
api_key: os.environ/OPENROUTER_API_KEY
|
||||
|
||||
general_settings:
|
||||
master_key: os.environ/LITELLM_MASTER_KEY
|
||||
|
||||
# ── Model group fallbacks (when model misbehaves mid-stream) ──
|
||||
fallbacks:
|
||||
- deepseek-v3.2: [gpt-oss, kimi-k2]
|
||||
- gpt-oss: [deepseek-v3.2, nemotron-super]
|
||||
- kimi-k2: [deepseek-v3.2, gpt-oss]
|
||||
- kimi-k2.5: [deepseek-v3.2, gpt-oss]
|
||||
- glm-4.7: [deepseek-v3.2, gpt-oss]
|
||||
- glm-5: [deepseek-v3.2, gpt-oss]
|
||||
|
||||
litellm_settings:
|
||||
drop_params: true
|
||||
set_verbose: false
|
||||
num_retries: 2
|
||||
request_timeout: 600
|
||||
|
||||
# ── Response caching via Valkey (reuses SearXNG's instance) ──
|
||||
cache: true
|
||||
cache_params:
|
||||
type: redis
|
||||
host: valkey
|
||||
port: 6379
|
||||
ttl: 3600
|
||||
|
||||
# ── Budget limit: $3/day to prevent surprise bills ──
|
||||
# max_budget: 3.0
|
||||
# budget_duration: "1d"
|
||||
|
|
@ -25,6 +25,12 @@ database:
|
|||
path: "/data/gateway.db"
|
||||
retention_days: 90
|
||||
|
||||
debug:
|
||||
enabled: true
|
||||
retention_days: 90
|
||||
# data_dir: "/data" # defaults to directory of database.path
|
||||
# max_body_bytes: 0 # 0 = unlimited (save full bodies)
|
||||
|
||||
cache:
|
||||
enabled: true
|
||||
address: "valkey:6379"
|
||||
|
|
|
|||
3
llm-gateway/.gitignore
vendored
3
llm-gateway/.gitignore
vendored
|
|
@ -8,6 +8,9 @@ llm-gateway
|
|||
*.db-wal
|
||||
*.db-shm
|
||||
|
||||
# Debug log files
|
||||
debug-logs/
|
||||
|
||||
# Local config
|
||||
configs/config.local.yaml
|
||||
|
||||
|
|
|
|||
|
|
@ -7,11 +7,13 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
gocors "github.com/go-chi/cors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
"llm-gateway/internal/auth"
|
||||
|
|
@ -91,7 +93,7 @@ func main() {
|
|||
log.Printf("Registered %d models", len(cfg.Models))
|
||||
|
||||
// Provider health tracker
|
||||
healthTracker := provider.NewHealthTracker(5 * time.Minute)
|
||||
healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker)
|
||||
|
||||
// Auth store (static tokens checked in-memory, not seeded to DB)
|
||||
var staticTokens []auth.StaticToken
|
||||
|
|
@ -102,6 +104,7 @@ func main() {
|
|||
Key: t.Key,
|
||||
RateLimitRPM: t.RateLimitRPM,
|
||||
DailyBudgetUSD: t.DailyBudgetUSD,
|
||||
MaxConcurrent: t.MaxConcurrent,
|
||||
})
|
||||
log.Printf("Loaded static token: %s", t.Name)
|
||||
}
|
||||
|
|
@ -110,6 +113,17 @@ func main() {
|
|||
authMiddleware := auth.NewMiddleware(authStore)
|
||||
authHandlers := auth.NewHandlers(authStore, cfg.Server.SessionSecret)
|
||||
|
||||
// Audit logger
|
||||
auditLogger := storage.NewAuditLogger(db)
|
||||
authHandlers.SetAuditLogger(auditLogger)
|
||||
|
||||
// Debug logger
|
||||
debugDataDir := cfg.Debug.DataDir
|
||||
if debugDataDir == "" {
|
||||
debugDataDir = filepath.Dir(cfg.Database.Path)
|
||||
}
|
||||
debugLogger := storage.NewDebugLogger(db, cfg.Debug.Enabled, debugDataDir)
|
||||
|
||||
// Seed default admin
|
||||
seedDefaultAdmin(cfg, authStore)
|
||||
|
||||
|
|
@ -118,22 +132,43 @@ func main() {
|
|||
|
||||
// Handlers
|
||||
proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker)
|
||||
proxyHandler.SetDebugLogger(debugLogger)
|
||||
modelsHandler := proxy.NewModelsHandler(registry)
|
||||
proxyAuth := proxy.NewAuthMiddleware(authStore)
|
||||
rateLimiter := proxy.NewRateLimiter(db)
|
||||
concurrencyLimiter := proxy.NewConcurrencyLimiter()
|
||||
statsAPI := dashboard.NewStatsAPI(db, authStore)
|
||||
statsAPI.SetHealthTracker(healthTracker)
|
||||
statsAPI.SetAuditLogger(auditLogger)
|
||||
statsAPI.SetDebugLogger(debugLogger)
|
||||
if c != nil {
|
||||
statsAPI.SetCache(c)
|
||||
}
|
||||
dash := dashboard.NewDashboard(authStore, statsAPI)
|
||||
dash.SetRegistry(registry)
|
||||
dash.SetAuditLogger(auditLogger)
|
||||
dash.SetDebugLogger(debugLogger)
|
||||
if c != nil {
|
||||
dash.SetCache(c)
|
||||
}
|
||||
|
||||
// Export handler
|
||||
exportHandler := dashboard.NewExportHandler(db, authStore)
|
||||
|
||||
// Router
|
||||
r := chi.NewRouter()
|
||||
|
||||
// CORS (before other middleware)
|
||||
if cfg.CORS.Enabled {
|
||||
r.Use(gocors.Handler(gocors.Options{
|
||||
AllowedOrigins: cfg.CORS.AllowedOrigins,
|
||||
AllowedMethods: cfg.CORS.AllowedMethods,
|
||||
AllowedHeaders: cfg.CORS.AllowedHeaders,
|
||||
MaxAge: cfg.CORS.MaxAge,
|
||||
AllowCredentials: true,
|
||||
}))
|
||||
}
|
||||
|
||||
r.Use(middleware.RealIP)
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.RequestID)
|
||||
|
|
@ -159,6 +194,7 @@ func main() {
|
|||
r.Group(func(r chi.Router) {
|
||||
r.Use(proxyAuth.Authenticate)
|
||||
r.Use(rateLimiter.Check)
|
||||
r.Use(concurrencyLimiter.Check)
|
||||
r.Post("/v1/chat/completions", proxyHandler.ChatCompletions)
|
||||
r.Get("/v1/models", modelsHandler.ListModels)
|
||||
})
|
||||
|
|
@ -192,6 +228,8 @@ func main() {
|
|||
r.Group(func(r chi.Router) {
|
||||
r.Use(authMiddleware.RequireAdmin)
|
||||
r.Get("/users", dash.UsersPage)
|
||||
r.Get("/audit", dash.AuditPage)
|
||||
r.Get("/debug", dash.DebugPage)
|
||||
})
|
||||
|
||||
// Auth API
|
||||
|
|
@ -224,16 +262,29 @@ func main() {
|
|||
r.Get("/api/stats/provider-health", statsAPI.ProviderHealthHandler)
|
||||
r.Get("/api/stats/cache", statsAPI.CacheStats)
|
||||
|
||||
// Admin-only: user management
|
||||
// Data export
|
||||
r.Get("/api/export/logs", exportHandler.ExportLogs)
|
||||
r.Get("/api/export/stats", exportHandler.ExportStats)
|
||||
|
||||
// Admin-only: user management, audit, debug
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(authMiddleware.RequireAdmin)
|
||||
r.Get("/api/auth/users", authHandlers.ListUsers)
|
||||
r.Post("/api/auth/users", authHandlers.CreateUser)
|
||||
r.Delete("/api/auth/users/{id}", authHandlers.DeleteUser)
|
||||
|
||||
// Audit log
|
||||
r.Get("/api/stats/audit", statsAPI.AuditLogs)
|
||||
|
||||
// Debug logging
|
||||
r.Post("/api/debug/toggle", statsAPI.DebugToggle)
|
||||
r.Get("/api/debug/status", statsAPI.DebugStatus)
|
||||
r.Get("/api/debug/logs", statsAPI.DebugLogs)
|
||||
r.Get("/api/debug/logs/{requestID}", statsAPI.DebugLogByRequestID)
|
||||
})
|
||||
})
|
||||
|
||||
// Periodic session cleanup
|
||||
// Periodic session cleanup and debug log cleanup
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
|
@ -241,6 +292,9 @@ func main() {
|
|||
if err := authStore.CleanExpiredSessions(); err != nil {
|
||||
log.Printf("WARNING: session cleanup failed: %v", err)
|
||||
}
|
||||
if err := debugLogger.Cleanup(cfg.Debug.RetentionDays); err != nil {
|
||||
log.Printf("WARNING: debug log cleanup failed: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
@ -253,6 +307,45 @@ func main() {
|
|||
IdleTimeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
// Config hot-reload via SIGHUP
|
||||
config.WatchReload(*configPath, func(newCfg *config.Config) {
|
||||
// Reload registry (models, providers, routes)
|
||||
if err := registry.Reload(newCfg); err != nil {
|
||||
log.Printf("ERROR: registry reload failed: %v", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Reloaded %d models", len(newCfg.Models))
|
||||
|
||||
// Reload pricing
|
||||
for i, m := range newCfg.Models {
|
||||
for j, rt := range m.Routes {
|
||||
if rt.Pricing.Input == 0 && rt.Pricing.Output == 0 {
|
||||
pricingLookup.FillMissing(rt.Provider, rt.Model,
|
||||
&newCfg.Models[i].Routes[j].Pricing.Input,
|
||||
&newCfg.Models[i].Routes[j].Pricing.Output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reload static tokens
|
||||
var newStaticTokens []auth.StaticToken
|
||||
for _, t := range newCfg.Tokens {
|
||||
if t.Key != "" {
|
||||
newStaticTokens = append(newStaticTokens, auth.StaticToken{
|
||||
Name: t.Name,
|
||||
Key: t.Key,
|
||||
RateLimitRPM: t.RateLimitRPM,
|
||||
DailyBudgetUSD: t.DailyBudgetUSD,
|
||||
MaxConcurrent: t.MaxConcurrent,
|
||||
})
|
||||
}
|
||||
}
|
||||
authStore.SetStaticTokens(newStaticTokens)
|
||||
|
||||
// Update config pointer for retry/debug/etc
|
||||
cfg = newCfg
|
||||
})
|
||||
|
||||
// Graceful shutdown
|
||||
done := make(chan os.Signal, 1)
|
||||
signal.Notify(done, os.Interrupt, syscall.SIGTERM)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ require (
|
|||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/go-chi/cors v1.2.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
|
|||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
||||
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
||||
github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
|
||||
github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
||||
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
|
||||
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
|
|
|
|||
|
|
@ -13,12 +13,15 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"llm-gateway/internal/storage"
|
||||
)
|
||||
|
||||
type Handlers struct {
|
||||
store *Store
|
||||
sessionSecret string
|
||||
loginLimiter *loginRateLimiter
|
||||
auditLogger *storage.AuditLogger
|
||||
}
|
||||
|
||||
func NewHandlers(store *Store, sessionSecret string) *Handlers {
|
||||
|
|
@ -29,6 +32,36 @@ func NewHandlers(store *Store, sessionSecret string) *Handlers {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Handlers) SetAuditLogger(al *storage.AuditLogger) {
|
||||
h.auditLogger = al
|
||||
}
|
||||
|
||||
func (h *Handlers) audit(r *http.Request, action, targetType, targetID, details string) {
|
||||
if h.auditLogger == nil {
|
||||
return
|
||||
}
|
||||
user := UserFromContext(r.Context())
|
||||
var userID int64
|
||||
var username string
|
||||
if user != nil {
|
||||
userID = user.ID
|
||||
username = user.Username
|
||||
}
|
||||
ip := r.RemoteAddr
|
||||
if fwd := r.Header.Get("X-Real-IP"); fwd != "" {
|
||||
ip = fwd
|
||||
}
|
||||
h.auditLogger.Log(storage.AuditEntry{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Action: action,
|
||||
TargetType: targetType,
|
||||
TargetID: targetID,
|
||||
Details: details,
|
||||
IPAddress: ip,
|
||||
})
|
||||
}
|
||||
|
||||
// Login brute-force protection
|
||||
type loginRateLimiter struct {
|
||||
mu sync.Mutex
|
||||
|
|
@ -126,6 +159,8 @@ func (h *Handlers) Setup(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "auth.setup", "user", fmt.Sprintf("%d", user.ID), "initial setup")
|
||||
|
||||
h.setSessionCookie(w, sessionID)
|
||||
writeJSON(w, map[string]any{
|
||||
"user": map[string]any{
|
||||
|
|
@ -187,6 +222,8 @@ func (h *Handlers) Login(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "auth.login", "user", fmt.Sprintf("%d", user.ID), user.Username)
|
||||
|
||||
h.setSessionCookie(w, sessionID)
|
||||
writeJSON(w, map[string]any{
|
||||
"require_totp": false,
|
||||
|
|
@ -256,6 +293,8 @@ func (h *Handlers) LoginTOTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
h.audit(r, "auth.logout", "", "", "")
|
||||
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err == nil {
|
||||
h.store.DeleteSession(cookie.Value)
|
||||
|
|
@ -347,6 +386,7 @@ func (h *Handlers) TOTPVerify(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "totp.enable", "user", fmt.Sprintf("%d", user.ID), "")
|
||||
writeJSON(w, map[string]string{"status": "totp_enabled"})
|
||||
}
|
||||
|
||||
|
|
@ -362,6 +402,7 @@ func (h *Handlers) TOTPDisable(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "totp.disable", "user", fmt.Sprintf("%d", user.ID), "")
|
||||
writeJSON(w, map[string]string{"status": "totp_disabled"})
|
||||
}
|
||||
|
||||
|
|
@ -420,6 +461,7 @@ func (h *Handlers) CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "user.create", "user", fmt.Sprintf("%d", user.ID), user.Username)
|
||||
writeJSON(w, map[string]any{
|
||||
"id": user.ID,
|
||||
"username": user.Username,
|
||||
|
|
@ -447,6 +489,7 @@ func (h *Handlers) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "user.delete", "user", idStr, "")
|
||||
writeJSON(w, map[string]string{"status": "deleted"})
|
||||
}
|
||||
|
||||
|
|
@ -507,6 +550,7 @@ func (h *Handlers) CreateToken(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "token.create", "token", fmt.Sprintf("%d", token.ID), req.Name)
|
||||
writeJSON(w, map[string]any{
|
||||
"key": plainKey,
|
||||
"token": token,
|
||||
|
|
@ -545,6 +589,7 @@ func (h *Handlers) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "token.delete", "token", idStr, "")
|
||||
writeJSON(w, map[string]string{"status": "deleted"})
|
||||
}
|
||||
|
||||
|
|
@ -587,6 +632,7 @@ func (h *Handlers) ChangePassword(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "password.change", "user", fmt.Sprintf("%d", user.ID), "")
|
||||
writeJSON(w, map[string]string{"status": "password_updated"})
|
||||
}
|
||||
|
||||
|
|
@ -621,6 +667,7 @@ func (h *Handlers) ChangeUsername(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
h.audit(r, "username.change", "user", fmt.Sprintf("%d", user.ID), req.NewUsername)
|
||||
writeJSON(w, map[string]string{"status": "username_updated"})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ type APIToken struct {
|
|||
UserID int64 `json:"user_id"`
|
||||
RateLimitRPM int `json:"rate_limit_rpm"`
|
||||
DailyBudgetUSD float64 `json:"daily_budget_usd"`
|
||||
MaxConcurrent int `json:"max_concurrent"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
LastUsedAt int64 `json:"last_used_at"`
|
||||
}
|
||||
|
|
@ -48,6 +49,7 @@ type StaticToken struct {
|
|||
Key string
|
||||
RateLimitRPM int
|
||||
DailyBudgetUSD float64
|
||||
MaxConcurrent int
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
|
|
@ -59,6 +61,11 @@ func NewStore(db *sql.DB, staticTokens []StaticToken) *Store {
|
|||
return &Store{db: db, staticTokens: staticTokens}
|
||||
}
|
||||
|
||||
// SetStaticTokens updates the static tokens list (used for config hot-reload).
|
||||
func (s *Store) SetStaticTokens(tokens []StaticToken) {
|
||||
s.staticTokens = tokens
|
||||
}
|
||||
|
||||
func (s *Store) HasAnyUser() bool {
|
||||
var count int
|
||||
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
||||
|
|
@ -287,6 +294,7 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
|
|||
KeyPrefix: prefix,
|
||||
RateLimitRPM: st.RateLimitRPM,
|
||||
DailyBudgetUSD: st.DailyBudgetUSD,
|
||||
MaxConcurrent: st.MaxConcurrent,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
|
@ -297,9 +305,9 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
|
|||
|
||||
var t APIToken
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
|
||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
|
||||
keyHash,
|
||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt)
|
||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -320,6 +328,7 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
|
|||
KeyPrefix: prefix,
|
||||
RateLimitRPM: st.RateLimitRPM,
|
||||
DailyBudgetUSD: st.DailyBudgetUSD,
|
||||
MaxConcurrent: st.MaxConcurrent,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -327,19 +336,18 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
|
|||
var rows *sql.Rows
|
||||
var err error
|
||||
if userID == 0 {
|
||||
// Admin: list all
|
||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens ORDER BY id")
|
||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id")
|
||||
} else {
|
||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
|
||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
|
||||
}
|
||||
if err != nil {
|
||||
return tokens, nil // return static tokens even if DB query fails
|
||||
return tokens, nil
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var t APIToken
|
||||
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt); err != nil {
|
||||
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil {
|
||||
return tokens, nil
|
||||
}
|
||||
tokens = append(tokens, t)
|
||||
|
|
@ -355,9 +363,9 @@ func (s *Store) DeleteAPIToken(id int64) error {
|
|||
func (s *Store) GetAPIToken(id int64) (*APIToken, error) {
|
||||
var t APIToken
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE id = ?",
|
||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?",
|
||||
id,
|
||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt)
|
||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
300
llm-gateway/internal/auth/store_test.go
Normal file
300
llm-gateway/internal/auth/store_test.go
Normal file
|
|
@ -0,0 +1,300 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
db, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("opening test db: %v", err)
|
||||
}
|
||||
|
||||
// Create tables
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
email TEXT DEFAULT '',
|
||||
password_hash TEXT NOT NULL,
|
||||
is_admin INTEGER DEFAULT 0,
|
||||
totp_secret TEXT DEFAULT '',
|
||||
totp_enabled INTEGER DEFAULT 0,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
expires_at INTEGER NOT NULL
|
||||
);
|
||||
CREATE TABLE api_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
key_hash TEXT NOT NULL,
|
||||
key_prefix TEXT NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
rate_limit_rpm INTEGER DEFAULT 0,
|
||||
daily_budget_usd REAL DEFAULT 0,
|
||||
max_concurrent INTEGER DEFAULT 0,
|
||||
created_at INTEGER NOT NULL,
|
||||
last_used_at INTEGER DEFAULT 0
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("creating tables: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
store := NewStore(db, nil)
|
||||
|
||||
user, err := store.CreateUser("alice", "password123", true)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateUser: %v", err)
|
||||
}
|
||||
if user.Username != "alice" {
|
||||
t.Errorf("expected username 'alice', got '%s'", user.Username)
|
||||
}
|
||||
if !user.IsAdmin {
|
||||
t.Error("expected admin user")
|
||||
}
|
||||
if user.ID == 0 {
|
||||
t.Error("expected non-zero ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByUsername(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
store := NewStore(db, nil)
|
||||
|
||||
store.CreateUser("bob", "password123", false)
|
||||
|
||||
user, err := store.GetUserByUsername("bob")
|
||||
if err != nil {
|
||||
t.Fatalf("GetUserByUsername: %v", err)
|
||||
}
|
||||
if user.Username != "bob" {
|
||||
t.Errorf("expected 'bob', got '%s'", user.Username)
|
||||
}
|
||||
|
||||
_, err = store.GetUserByUsername("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent user")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPassword(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
store := NewStore(db, nil)
|
||||
|
||||
store.CreateUser("charlie", "correctpassword", false)
|
||||
user, _ := store.GetUserByUsername("charlie")
|
||||
|
||||
if !store.CheckPassword(user, "correctpassword") {
|
||||
t.Error("correct password should match")
|
||||
}
|
||||
if store.CheckPassword(user, "wrongpassword") {
|
||||
t.Error("wrong password should not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdatePassword(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
store := NewStore(db, nil)
|
||||
|
||||
user, _ := store.CreateUser("dave", "oldpass12", false)
|
||||
|
||||
if err := store.UpdatePassword(user.ID, "newpass12"); err != nil {
|
||||
t.Fatalf("UpdatePassword: %v", err)
|
||||
}
|
||||
|
||||
user, _ = store.GetUserByUsername("dave")
|
||||
if store.CheckPassword(user, "oldpass12") {
|
||||
t.Error("old password should not work")
|
||||
}
|
||||
if !store.CheckPassword(user, "newpass12") {
|
||||
t.Error("new password should work")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
store := NewStore(db, nil)
|
||||
|
||||
user1, _ := store.CreateUser("admin1", "password1234", true)
|
||||
user2, _ := store.CreateUser("user2", "password1234", false)
|
||||
|
||||
// Can delete non-admin
|
||||
if err := store.DeleteUser(user2.ID); err != nil {
|
||||
t.Fatalf("DeleteUser: %v", err)
|
||||
}
|
||||
|
||||
// Cannot delete last admin
|
||||
if err := store.DeleteUser(user1.ID); err == nil {
|
||||
t.Error("should not be able to delete last admin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasAnyUser(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
store := NewStore(db, nil)
|
||||
|
||||
if store.HasAnyUser() {
|
||||
t.Error("should have no users initially")
|
||||
}
|
||||
|
||||
store.CreateUser("first", "password1234", true)
|
||||
|
||||
if !store.HasAnyUser() {
|
||||
t.Error("should have users after creation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCRUD(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
store := NewStore(db, nil)
|
||||
|
||||
user, _ := store.CreateUser("sessuser", "password1234", false)
|
||||
|
||||
sessionID, err := store.CreateSession(user.ID, 1*time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
sess, err := store.GetSession(sessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if sess.UserID != user.ID {
|
||||
t.Errorf("expected user ID %d, got %d", user.ID, sess.UserID)
|
||||
}
|
||||
|
||||
if err := store.DeleteSession(sessionID); err != nil {
|
||||
t.Fatalf("DeleteSession: %v", err)
|
||||
}
|
||||
|
||||
_, err = store.GetSession(sessionID)
|
||||
if err == nil {
|
||||
t.Error("session should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticTokenLookup(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
staticTokens := []StaticToken{
|
||||
{Name: "test-token", Key: "sk-test-key-12345678", RateLimitRPM: 60, DailyBudgetUSD: 10.0, MaxConcurrent: 5},
|
||||
}
|
||||
store := NewStore(db, staticTokens)
|
||||
|
||||
token, err := store.LookupAPIToken("sk-test-key-12345678")
|
||||
if err != nil {
|
||||
t.Fatalf("LookupAPIToken: %v", err)
|
||||
}
|
||||
if token.Name != "test-token" {
|
||||
t.Errorf("expected 'test-token', got '%s'", token.Name)
|
||||
}
|
||||
if token.ID != -1 {
|
||||
t.Errorf("static token should have ID -1, got %d", token.ID)
|
||||
}
|
||||
if token.RateLimitRPM != 60 {
|
||||
t.Errorf("expected RPM 60, got %d", token.RateLimitRPM)
|
||||
}
|
||||
if token.MaxConcurrent != 5 {
|
||||
t.Errorf("expected max_concurrent 5, got %d", token.MaxConcurrent)
|
||||
}
|
||||
|
||||
// Non-existent token
|
||||
_, err = store.LookupAPIToken("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("should error on nonexistent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBTokenCRUD(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
store := NewStore(db, nil)
|
||||
|
||||
user, _ := store.CreateUser("tokenuser", "password1234", false)
|
||||
|
||||
plainKey, token, err := store.CreateAPIToken(user.ID, "my-token", 100, 5.0)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAPIToken: %v", err)
|
||||
}
|
||||
if plainKey == "" {
|
||||
t.Error("plain key should not be empty")
|
||||
}
|
||||
if token.Name != "my-token" {
|
||||
t.Errorf("expected 'my-token', got '%s'", token.Name)
|
||||
}
|
||||
|
||||
// Lookup by key
|
||||
found, err := store.LookupAPIToken(plainKey)
|
||||
if err != nil {
|
||||
t.Fatalf("LookupAPIToken: %v", err)
|
||||
}
|
||||
if found.Name != "my-token" {
|
||||
t.Errorf("expected 'my-token', got '%s'", found.Name)
|
||||
}
|
||||
|
||||
// List tokens
|
||||
tokens, err := store.ListAPITokens(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAPITokens: %v", err)
|
||||
}
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("expected 1 token, got %d", len(tokens))
|
||||
}
|
||||
|
||||
// Delete
|
||||
if err := store.DeleteAPIToken(token.ID); err != nil {
|
||||
t.Fatalf("DeleteAPIToken: %v", err)
|
||||
}
|
||||
|
||||
_, err = store.LookupAPIToken(plainKey)
|
||||
if err == nil {
|
||||
t.Error("token should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetStaticTokens(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
store := NewStore(db, nil)
|
||||
|
||||
_, err := store.LookupAPIToken("key1")
|
||||
if err == nil {
|
||||
t.Error("should not find token before setting")
|
||||
}
|
||||
|
||||
store.SetStaticTokens([]StaticToken{
|
||||
{Name: "new-token", Key: "key1"},
|
||||
})
|
||||
|
||||
token, err := store.LookupAPIToken("key1")
|
||||
if err != nil {
|
||||
t.Fatalf("after SetStaticTokens: %v", err)
|
||||
}
|
||||
if token.Name != "new-token" {
|
||||
t.Errorf("expected 'new-token', got '%s'", token.Name)
|
||||
}
|
||||
}
|
||||
112
llm-gateway/internal/cache/cache_test.go
vendored
Normal file
112
llm-gateway/internal/cache/cache_test.go
vendored
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCacheKey_Deterministic(t *testing.T) {
|
||||
c := &Cache{}
|
||||
|
||||
model := "gpt-4"
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
key1 := c.cacheKey(model, body)
|
||||
key2 := c.cacheKey(model, body)
|
||||
|
||||
if key1 != key2 {
|
||||
t.Errorf("cache key not deterministic: %s != %s", key1, key2)
|
||||
}
|
||||
|
||||
if key1 == "" {
|
||||
t.Error("cache key is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheKey_DifferentInputs(t *testing.T) {
|
||||
c := &Cache{}
|
||||
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
||||
|
||||
key1 := c.cacheKey("gpt-4", body)
|
||||
key2 := c.cacheKey("gpt-3.5", body)
|
||||
|
||||
if key1 == key2 {
|
||||
t.Error("different models should produce different cache keys")
|
||||
}
|
||||
|
||||
key3 := c.cacheKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"world"}]}`))
|
||||
if key1 == key3 {
|
||||
t.Error("different bodies should produce different cache keys")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheKey_HasPrefix(t *testing.T) {
|
||||
c := &Cache{}
|
||||
key := c.cacheKey("gpt-4", []byte("test"))
|
||||
|
||||
if len(key) < 7 || key[:7] != "llm-gw:" {
|
||||
t.Errorf("cache key should start with 'llm-gw:', got: %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseInfoInt(t *testing.T) {
|
||||
info := "keyspace_hits:42\nkeyspace_misses:10\n"
|
||||
|
||||
hits := parseInfoInt(info, "keyspace_hits")
|
||||
if hits != 42 {
|
||||
t.Errorf("expected 42, got %d", hits)
|
||||
}
|
||||
|
||||
misses := parseInfoInt(info, "keyspace_misses")
|
||||
if misses != 10 {
|
||||
t.Errorf("expected 10, got %d", misses)
|
||||
}
|
||||
|
||||
unknown := parseInfoInt(info, "nonexistent")
|
||||
if unknown != 0 {
|
||||
t.Errorf("expected 0 for unknown key, got %d", unknown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseInfoString(t *testing.T) {
|
||||
info := "used_memory_human:1.5M\r\nother:value\r\n"
|
||||
|
||||
mem := parseInfoString(info, "used_memory_human")
|
||||
if mem != "1.5M" {
|
||||
t.Errorf("expected '1.5M', got '%s'", mem)
|
||||
}
|
||||
|
||||
unknown := parseInfoString(info, "nonexistent")
|
||||
if unknown != "" {
|
||||
t.Errorf("expected empty for unknown key, got '%s'", unknown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKeyspaceKeys(t *testing.T) {
|
||||
info := "# Keyspace\ndb0:keys=123,expires=45,avg_ttl=6789\n"
|
||||
|
||||
keys := parseKeyspaceKeys(info)
|
||||
if keys != 123 {
|
||||
t.Errorf("expected 123, got %d", keys)
|
||||
}
|
||||
|
||||
empty := parseKeyspaceKeys("# Keyspace\n")
|
||||
if empty != 0 {
|
||||
t.Errorf("expected 0 for empty keyspace, got %d", empty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitLines(t *testing.T) {
|
||||
lines := splitLines("a\nb\nc")
|
||||
if len(lines) != 3 {
|
||||
t.Errorf("expected 3 lines, got %d", len(lines))
|
||||
}
|
||||
if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
|
||||
t.Errorf("unexpected lines: %v", lines)
|
||||
}
|
||||
|
||||
single := splitLines("hello")
|
||||
if len(single) != 1 || single[0] != "hello" {
|
||||
t.Errorf("single line: %v", single)
|
||||
}
|
||||
}
|
||||
|
|
@ -12,13 +12,17 @@ import (
|
|||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Cache CacheConfig `yaml:"cache"`
|
||||
Pricing PricingLookupConfig `yaml:"pricing_lookup"`
|
||||
Providers []ProviderConfig `yaml:"providers"`
|
||||
Models []ModelConfig `yaml:"models"`
|
||||
Tokens []TokenConfig `yaml:"tokens"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Cache CacheConfig `yaml:"cache"`
|
||||
Pricing PricingLookupConfig `yaml:"pricing_lookup"`
|
||||
CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"`
|
||||
Retry RetryConfig `yaml:"retry"`
|
||||
Debug DebugConfig `yaml:"debug"`
|
||||
CORS CORSConfig `yaml:"cors"`
|
||||
Providers []ProviderConfig `yaml:"providers"`
|
||||
Models []ModelConfig `yaml:"models"`
|
||||
Tokens []TokenConfig `yaml:"tokens"`
|
||||
}
|
||||
|
||||
type PricingLookupConfig struct {
|
||||
|
|
@ -36,14 +40,44 @@ type TokenConfig struct {
|
|||
Key string `yaml:"key"`
|
||||
RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited
|
||||
DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited
|
||||
MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Listen string `yaml:"listen"`
|
||||
RequestTimeout time.Duration `yaml:"request_timeout"`
|
||||
MaxRequestBodyMB int `yaml:"max_request_body_mb"`
|
||||
SessionSecret string `yaml:"session_secret"`
|
||||
DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
|
||||
Listen string `yaml:"listen"`
|
||||
RequestTimeout time.Duration `yaml:"request_timeout"`
|
||||
StreamingTimeout time.Duration `yaml:"streaming_timeout"`
|
||||
MaxRequestBodyMB int `yaml:"max_request_body_mb"`
|
||||
SessionSecret string `yaml:"session_secret"`
|
||||
DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
|
||||
}
|
||||
|
||||
type CircuitBreakerConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
ErrorThreshold float64 `yaml:"error_threshold"`
|
||||
MinRequests int `yaml:"min_requests"`
|
||||
CooldownDuration time.Duration `yaml:"cooldown_duration"`
|
||||
}
|
||||
|
||||
type RetryConfig struct {
|
||||
InitialBackoff time.Duration `yaml:"initial_backoff"`
|
||||
MaxBackoff time.Duration `yaml:"max_backoff"`
|
||||
Multiplier float64 `yaml:"multiplier"`
|
||||
}
|
||||
|
||||
type DebugConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies)
|
||||
RetentionDays int `yaml:"retention_days"`
|
||||
DataDir string `yaml:"data_dir"`
|
||||
}
|
||||
|
||||
type CORSConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
||||
AllowedMethods []string `yaml:"allowed_methods"`
|
||||
AllowedHeaders []string `yaml:"allowed_headers"`
|
||||
MaxAge int `yaml:"max_age"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
|
|
@ -66,8 +100,10 @@ type ProviderConfig struct {
|
|||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Routes []RouteConfig `yaml:"routes"`
|
||||
Name string `yaml:"name"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Routes []RouteConfig `yaml:"routes"`
|
||||
LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost
|
||||
}
|
||||
|
||||
type RouteConfig struct {
|
||||
|
|
@ -128,6 +164,43 @@ func (c *Config) validate() error {
|
|||
c.Pricing.RefreshInterval = 6 * time.Hour
|
||||
}
|
||||
|
||||
// Server defaults
|
||||
if c.Server.StreamingTimeout == 0 {
|
||||
c.Server.StreamingTimeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
// Circuit breaker defaults
|
||||
if c.CircuitBreaker.ErrorThreshold == 0 {
|
||||
c.CircuitBreaker.ErrorThreshold = 0.5
|
||||
}
|
||||
if c.CircuitBreaker.MinRequests == 0 {
|
||||
c.CircuitBreaker.MinRequests = 5
|
||||
}
|
||||
if c.CircuitBreaker.CooldownDuration == 0 {
|
||||
c.CircuitBreaker.CooldownDuration = 30 * time.Second
|
||||
}
|
||||
|
||||
// Retry defaults
|
||||
if c.Retry.InitialBackoff == 0 {
|
||||
c.Retry.InitialBackoff = 100 * time.Millisecond
|
||||
}
|
||||
if c.Retry.MaxBackoff == 0 {
|
||||
c.Retry.MaxBackoff = 5 * time.Second
|
||||
}
|
||||
if c.Retry.Multiplier == 0 {
|
||||
c.Retry.Multiplier = 2.0
|
||||
}
|
||||
|
||||
// Debug defaults
|
||||
if c.Debug.RetentionDays == 0 {
|
||||
c.Debug.RetentionDays = 90
|
||||
}
|
||||
|
||||
// CORS defaults
|
||||
if c.CORS.MaxAge == 0 {
|
||||
c.CORS.MaxAge = 300
|
||||
}
|
||||
|
||||
if len(c.Providers) == 0 {
|
||||
return fmt.Errorf("at least one provider is required")
|
||||
}
|
||||
|
|
@ -160,6 +233,12 @@ func (c *Config) validate() error {
|
|||
return fmt.Errorf("duplicate model name: %s", m.Name)
|
||||
}
|
||||
modelNames[m.Name] = true
|
||||
for _, alias := range m.Aliases {
|
||||
if modelNames[alias] {
|
||||
return fmt.Errorf("model alias %s conflicts with existing model or alias", alias)
|
||||
}
|
||||
modelNames[alias] = true
|
||||
}
|
||||
if len(m.Routes) == 0 {
|
||||
return fmt.Errorf("model %s: at least one route is required", m.Name)
|
||||
}
|
||||
|
|
|
|||
738
llm-gateway/internal/config/config_test.go
Normal file
738
llm-gateway/internal/config/config_test.go
Normal file
|
|
@ -0,0 +1,738 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// writeConfigFile creates a temporary YAML config file and returns its path.
|
||||
func writeConfigFile(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
f, err := os.CreateTemp(t.TempDir(), "config-*.yaml")
|
||||
if err != nil {
|
||||
t.Fatalf("creating temp file: %v", err)
|
||||
}
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
f.Close()
|
||||
t.Fatalf("writing temp file: %v", err)
|
||||
}
|
||||
f.Close()
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
// minimalValidConfig returns a minimal valid YAML config string.
|
||||
func minimalValidConfig() string {
|
||||
return `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-test-key
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
`
|
||||
}
|
||||
|
||||
func TestLoad_ValidConfig(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
server:
|
||||
listen: "127.0.0.1:8080"
|
||||
request_timeout: 60s
|
||||
streaming_timeout: 120s
|
||||
max_request_body_mb: 5
|
||||
session_secret: "test-secret-1234567890abcdef1234567890abcdef"
|
||||
|
||||
database:
|
||||
path: "/tmp/test.db"
|
||||
retention_days: 30
|
||||
|
||||
pricing_lookup:
|
||||
url: "https://pricing.example.com"
|
||||
refresh_interval: 1h
|
||||
|
||||
circuit_breaker:
|
||||
enabled: true
|
||||
error_threshold: 0.3
|
||||
min_requests: 10
|
||||
cooldown_duration: 60s
|
||||
|
||||
retry:
|
||||
initial_backoff: 200ms
|
||||
max_backoff: 10s
|
||||
multiplier: 3.0
|
||||
|
||||
debug:
|
||||
enabled: true
|
||||
max_body_bytes: 65536
|
||||
retention_days: 60
|
||||
|
||||
cors:
|
||||
enabled: true
|
||||
allowed_origins:
|
||||
- "https://example.com"
|
||||
allowed_methods:
|
||||
- GET
|
||||
- POST
|
||||
allowed_headers:
|
||||
- Authorization
|
||||
max_age: 600
|
||||
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-test-key
|
||||
priority: 2
|
||||
timeout: 60s
|
||||
- name: anthropic
|
||||
base_url: https://api.anthropic.com/v1
|
||||
api_key: sk-ant-test
|
||||
priority: 1
|
||||
timeout: 30s
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
aliases:
|
||||
- gpt4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
pricing:
|
||||
input: 30.0
|
||||
output: 60.0
|
||||
load_balancing: first
|
||||
- name: claude-3
|
||||
routes:
|
||||
- provider: anthropic
|
||||
model: claude-3-opus-20240229
|
||||
|
||||
tokens:
|
||||
- name: test-token
|
||||
key: tok-abc123
|
||||
rate_limit_rpm: 100
|
||||
daily_budget_usd: 10.0
|
||||
max_concurrent: 5
|
||||
`)
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() returned error: %v", err)
|
||||
}
|
||||
|
||||
// Server
|
||||
if cfg.Server.Listen != "127.0.0.1:8080" {
|
||||
t.Errorf("Listen = %q, want %q", cfg.Server.Listen, "127.0.0.1:8080")
|
||||
}
|
||||
if cfg.Server.RequestTimeout != 60*time.Second {
|
||||
t.Errorf("RequestTimeout = %v, want %v", cfg.Server.RequestTimeout, 60*time.Second)
|
||||
}
|
||||
if cfg.Server.StreamingTimeout != 120*time.Second {
|
||||
t.Errorf("StreamingTimeout = %v, want %v", cfg.Server.StreamingTimeout, 120*time.Second)
|
||||
}
|
||||
if cfg.Server.MaxRequestBodyMB != 5 {
|
||||
t.Errorf("MaxRequestBodyMB = %d, want %d", cfg.Server.MaxRequestBodyMB, 5)
|
||||
}
|
||||
if cfg.Server.SessionSecret != "test-secret-1234567890abcdef1234567890abcdef" {
|
||||
t.Errorf("SessionSecret = %q, want %q", cfg.Server.SessionSecret, "test-secret-1234567890abcdef1234567890abcdef")
|
||||
}
|
||||
|
||||
// Database
|
||||
if cfg.Database.Path != "/tmp/test.db" {
|
||||
t.Errorf("Database.Path = %q, want %q", cfg.Database.Path, "/tmp/test.db")
|
||||
}
|
||||
if cfg.Database.RetentionDays != 30 {
|
||||
t.Errorf("Database.RetentionDays = %d, want %d", cfg.Database.RetentionDays, 30)
|
||||
}
|
||||
|
||||
// Pricing
|
||||
if cfg.Pricing.URL != "https://pricing.example.com" {
|
||||
t.Errorf("Pricing.URL = %q, want %q", cfg.Pricing.URL, "https://pricing.example.com")
|
||||
}
|
||||
if cfg.Pricing.RefreshInterval != 1*time.Hour {
|
||||
t.Errorf("Pricing.RefreshInterval = %v, want %v", cfg.Pricing.RefreshInterval, 1*time.Hour)
|
||||
}
|
||||
|
||||
// Circuit breaker
|
||||
if !cfg.CircuitBreaker.Enabled {
|
||||
t.Error("CircuitBreaker.Enabled = false, want true")
|
||||
}
|
||||
if cfg.CircuitBreaker.ErrorThreshold != 0.3 {
|
||||
t.Errorf("CircuitBreaker.ErrorThreshold = %v, want %v", cfg.CircuitBreaker.ErrorThreshold, 0.3)
|
||||
}
|
||||
if cfg.CircuitBreaker.MinRequests != 10 {
|
||||
t.Errorf("CircuitBreaker.MinRequests = %d, want %d", cfg.CircuitBreaker.MinRequests, 10)
|
||||
}
|
||||
if cfg.CircuitBreaker.CooldownDuration != 60*time.Second {
|
||||
t.Errorf("CircuitBreaker.CooldownDuration = %v, want %v", cfg.CircuitBreaker.CooldownDuration, 60*time.Second)
|
||||
}
|
||||
|
||||
// Retry
|
||||
if cfg.Retry.InitialBackoff != 200*time.Millisecond {
|
||||
t.Errorf("Retry.InitialBackoff = %v, want %v", cfg.Retry.InitialBackoff, 200*time.Millisecond)
|
||||
}
|
||||
if cfg.Retry.MaxBackoff != 10*time.Second {
|
||||
t.Errorf("Retry.MaxBackoff = %v, want %v", cfg.Retry.MaxBackoff, 10*time.Second)
|
||||
}
|
||||
if cfg.Retry.Multiplier != 3.0 {
|
||||
t.Errorf("Retry.Multiplier = %v, want %v", cfg.Retry.Multiplier, 3.0)
|
||||
}
|
||||
|
||||
// Debug
|
||||
if !cfg.Debug.Enabled {
|
||||
t.Error("Debug.Enabled = false, want true")
|
||||
}
|
||||
if cfg.Debug.MaxBodyBytes != 65536 {
|
||||
t.Errorf("Debug.MaxBodyBytes = %d, want %d", cfg.Debug.MaxBodyBytes, 65536)
|
||||
}
|
||||
if cfg.Debug.RetentionDays != 60 {
|
||||
t.Errorf("Debug.RetentionDays = %d, want %d", cfg.Debug.RetentionDays, 60)
|
||||
}
|
||||
|
||||
// CORS
|
||||
if !cfg.CORS.Enabled {
|
||||
t.Error("CORS.Enabled = false, want true")
|
||||
}
|
||||
if cfg.CORS.MaxAge != 600 {
|
||||
t.Errorf("CORS.MaxAge = %d, want %d", cfg.CORS.MaxAge, 600)
|
||||
}
|
||||
|
||||
// Providers
|
||||
if len(cfg.Providers) != 2 {
|
||||
t.Fatalf("len(Providers) = %d, want 2", len(cfg.Providers))
|
||||
}
|
||||
if cfg.Providers[0].Name != "openai" {
|
||||
t.Errorf("Providers[0].Name = %q, want %q", cfg.Providers[0].Name, "openai")
|
||||
}
|
||||
if cfg.Providers[0].Timeout != 60*time.Second {
|
||||
t.Errorf("Providers[0].Timeout = %v, want %v", cfg.Providers[0].Timeout, 60*time.Second)
|
||||
}
|
||||
|
||||
// Models
|
||||
if len(cfg.Models) != 2 {
|
||||
t.Fatalf("len(Models) = %d, want 2", len(cfg.Models))
|
||||
}
|
||||
if cfg.Models[0].LoadBalancing != "first" {
|
||||
t.Errorf("Models[0].LoadBalancing = %q, want %q", cfg.Models[0].LoadBalancing, "first")
|
||||
}
|
||||
if len(cfg.Models[0].Aliases) != 1 || cfg.Models[0].Aliases[0] != "gpt4" {
|
||||
t.Errorf("Models[0].Aliases = %v, want [gpt4]", cfg.Models[0].Aliases)
|
||||
}
|
||||
if cfg.Models[0].Routes[0].Pricing.Input != 30.0 {
|
||||
t.Errorf("Models[0].Routes[0].Pricing.Input = %v, want 30.0", cfg.Models[0].Routes[0].Pricing.Input)
|
||||
}
|
||||
|
||||
// Tokens
|
||||
if len(cfg.Tokens) != 1 {
|
||||
t.Fatalf("len(Tokens) = %d, want 1", len(cfg.Tokens))
|
||||
}
|
||||
if cfg.Tokens[0].Name != "test-token" {
|
||||
t.Errorf("Tokens[0].Name = %q, want %q", cfg.Tokens[0].Name, "test-token")
|
||||
}
|
||||
if cfg.Tokens[0].RateLimitRPM != 100 {
|
||||
t.Errorf("Tokens[0].RateLimitRPM = %d, want 100", cfg.Tokens[0].RateLimitRPM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_Defaults(t *testing.T) {
|
||||
path := writeConfigFile(t, minimalValidConfig())
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() returned error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
got any
|
||||
want any
|
||||
}{
|
||||
// Server defaults
|
||||
{"Server.Listen", cfg.Server.Listen, "0.0.0.0:3000"},
|
||||
{"Server.RequestTimeout", cfg.Server.RequestTimeout, 300 * time.Second},
|
||||
{"Server.StreamingTimeout", cfg.Server.StreamingTimeout, 5 * time.Minute},
|
||||
{"Server.MaxRequestBodyMB", cfg.Server.MaxRequestBodyMB, 10},
|
||||
|
||||
// Database defaults
|
||||
{"Database.Path", cfg.Database.Path, "gateway.db"},
|
||||
{"Database.RetentionDays", cfg.Database.RetentionDays, 90},
|
||||
|
||||
// Pricing defaults
|
||||
{"Pricing.RefreshInterval", cfg.Pricing.RefreshInterval, 6 * time.Hour},
|
||||
|
||||
// Circuit breaker defaults
|
||||
{"CircuitBreaker.ErrorThreshold", cfg.CircuitBreaker.ErrorThreshold, 0.5},
|
||||
{"CircuitBreaker.MinRequests", cfg.CircuitBreaker.MinRequests, 5},
|
||||
{"CircuitBreaker.CooldownDuration", cfg.CircuitBreaker.CooldownDuration, 30 * time.Second},
|
||||
|
||||
// Retry defaults
|
||||
{"Retry.InitialBackoff", cfg.Retry.InitialBackoff, 100 * time.Millisecond},
|
||||
{"Retry.MaxBackoff", cfg.Retry.MaxBackoff, 5 * time.Second},
|
||||
{"Retry.Multiplier", cfg.Retry.Multiplier, 2.0},
|
||||
|
||||
// Debug defaults
|
||||
{"Debug.MaxBodyBytes", cfg.Debug.MaxBodyBytes, 0},
|
||||
{"Debug.RetentionDays", cfg.Debug.RetentionDays, 90},
|
||||
|
||||
// CORS defaults
|
||||
{"CORS.MaxAge", cfg.CORS.MaxAge, 300},
|
||||
|
||||
// Provider defaults
|
||||
{"Providers[0].Timeout", cfg.Providers[0].Timeout, 120 * time.Second},
|
||||
{"Providers[0].Priority", cfg.Providers[0].Priority, 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Compare using formatted strings to handle different numeric types
|
||||
gotStr := formatValue(tt.got)
|
||||
wantStr := formatValue(tt.want)
|
||||
if gotStr != wantStr {
|
||||
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// SessionSecret should be auto-generated (non-empty, 64 hex chars)
|
||||
if cfg.Server.SessionSecret == "" {
|
||||
t.Error("SessionSecret should be auto-generated when empty")
|
||||
}
|
||||
if len(cfg.Server.SessionSecret) != 64 {
|
||||
t.Errorf("SessionSecret length = %d, want 64 hex chars", len(cfg.Server.SessionSecret))
|
||||
}
|
||||
}
|
||||
|
||||
func formatValue(v any) string {
|
||||
switch val := v.(type) {
|
||||
case time.Duration:
|
||||
return val.String()
|
||||
case float64:
|
||||
return fmt.Sprintf("%g", val)
|
||||
case int:
|
||||
return fmt.Sprintf("%d", val)
|
||||
case string:
|
||||
return val
|
||||
default:
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_FileNotFound(t *testing.T) {
|
||||
_, err := Load(filepath.Join(t.TempDir(), "nonexistent.yaml"))
|
||||
if err == nil {
|
||||
t.Fatal("Load() should return error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_InvalidYAML(t *testing.T) {
|
||||
path := writeConfigFile(t, `{{{invalid yaml`)
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("Load() should return error for invalid YAML")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_DuplicateProviderNames(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key1
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v2
|
||||
api_key: sk-key2
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
`)
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("Load() should return error for duplicate provider names")
|
||||
}
|
||||
wantSubstr := "duplicate provider name: openai"
|
||||
if !strings.Contains(err.Error(), wantSubstr) {
|
||||
t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_DuplicateModelNames(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key1
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4-turbo
|
||||
`)
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("Load() should return error for duplicate model names")
|
||||
}
|
||||
wantSubstr := "duplicate model name: gpt-4"
|
||||
if !strings.Contains(err.Error(), wantSubstr) {
|
||||
t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_AliasConflicts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "alias conflicts with model name",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key1
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
- name: claude-3
|
||||
aliases:
|
||||
- gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: claude-3
|
||||
`,
|
||||
wantErr: "model alias gpt-4 conflicts with existing model or alias",
|
||||
},
|
||||
{
|
||||
name: "alias conflicts with another alias",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key1
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
aliases:
|
||||
- fast-model
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
- name: claude-3
|
||||
aliases:
|
||||
- fast-model
|
||||
routes:
|
||||
- provider: openai
|
||||
model: claude-3
|
||||
`,
|
||||
wantErr: "model alias fast-model conflicts with existing model or alias",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := writeConfigFile(t, tt.config)
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("Load() should return error for alias conflicts")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_MissingRequiredFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "no providers",
|
||||
config: `models: [{name: test, routes: [{provider: x, model: y}]}]`,
|
||||
wantErr: "at least one provider is required",
|
||||
},
|
||||
{
|
||||
name: "no models",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key
|
||||
`,
|
||||
wantErr: "at least one model is required",
|
||||
},
|
||||
{
|
||||
name: "provider missing name",
|
||||
config: `
|
||||
providers:
|
||||
- base_url: https://api.openai.com/v1
|
||||
api_key: sk-key
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
`,
|
||||
wantErr: "provider 0: name, base_url, and api_key are required",
|
||||
},
|
||||
{
|
||||
name: "provider missing base_url",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
api_key: sk-key
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
`,
|
||||
wantErr: "provider 0: name, base_url, and api_key are required",
|
||||
},
|
||||
{
|
||||
name: "provider missing api_key",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
`,
|
||||
wantErr: "provider 0: name, base_url, and api_key are required",
|
||||
},
|
||||
{
|
||||
name: "model missing name",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key
|
||||
|
||||
models:
|
||||
- routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
`,
|
||||
wantErr: "model 0: name is required",
|
||||
},
|
||||
{
|
||||
name: "model missing routes",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
`,
|
||||
wantErr: "model gpt-4: at least one route is required",
|
||||
},
|
||||
{
|
||||
name: "route missing provider",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- model: gpt-4
|
||||
`,
|
||||
wantErr: "model gpt-4 route 0: provider and model are required",
|
||||
},
|
||||
{
|
||||
name: "route missing model",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
`,
|
||||
wantErr: "model gpt-4 route 0: provider and model are required",
|
||||
},
|
||||
{
|
||||
name: "route references unknown provider",
|
||||
config: `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-key
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: anthropic
|
||||
model: gpt-4
|
||||
`,
|
||||
wantErr: "model gpt-4 route 0: unknown provider anthropic",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := writeConfigFile(t, tt.config)
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatalf("Load() should return error, want %q", tt.wantErr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderByName(t *testing.T) {
|
||||
path := writeConfigFile(t, `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: sk-openai
|
||||
- name: anthropic
|
||||
base_url: https://api.anthropic.com/v1
|
||||
api_key: sk-anthropic
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
`)
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() returned error: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
lookup string
|
||||
wantNil bool
|
||||
wantName string
|
||||
}{
|
||||
{"existing provider openai", "openai", false, "openai"},
|
||||
{"existing provider anthropic", "anthropic", false, "anthropic"},
|
||||
{"nonexistent provider", "google", true, ""},
|
||||
{"empty name", "", true, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := cfg.ProviderByName(tt.lookup)
|
||||
if tt.wantNil {
|
||||
if p != nil {
|
||||
t.Errorf("ProviderByName(%q) = %v, want nil", tt.lookup, p)
|
||||
}
|
||||
} else {
|
||||
if p == nil {
|
||||
t.Fatalf("ProviderByName(%q) = nil, want provider", tt.lookup)
|
||||
}
|
||||
if p.Name != tt.wantName {
|
||||
t.Errorf("ProviderByName(%q).Name = %q, want %q", tt.lookup, p.Name, tt.wantName)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Verify returned pointer refers to the actual config entry
|
||||
p := cfg.ProviderByName("openai")
|
||||
if p.APIKey != "sk-openai" {
|
||||
t.Errorf("ProviderByName(openai).APIKey = %q, want %q", p.APIKey, "sk-openai")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_EnvironmentVariableExpansion(t *testing.T) {
|
||||
t.Setenv("TEST_API_KEY", "sk-from-env")
|
||||
t.Setenv("TEST_BASE_URL", "https://env.example.com/v1")
|
||||
t.Setenv("TEST_PROVIDER_NAME", "env-provider")
|
||||
|
||||
path := writeConfigFile(t, `
|
||||
providers:
|
||||
- name: $TEST_PROVIDER_NAME
|
||||
base_url: ${TEST_BASE_URL}
|
||||
api_key: ${TEST_API_KEY}
|
||||
|
||||
models:
|
||||
- name: test-model
|
||||
routes:
|
||||
- provider: env-provider
|
||||
model: gpt-4
|
||||
`)
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() returned error: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Providers[0].Name != "env-provider" {
|
||||
t.Errorf("Provider.Name = %q, want %q", cfg.Providers[0].Name, "env-provider")
|
||||
}
|
||||
if cfg.Providers[0].BaseURL != "https://env.example.com/v1" {
|
||||
t.Errorf("Provider.BaseURL = %q, want %q", cfg.Providers[0].BaseURL, "https://env.example.com/v1")
|
||||
}
|
||||
if cfg.Providers[0].APIKey != "sk-from-env" {
|
||||
t.Errorf("Provider.APIKey = %q, want %q", cfg.Providers[0].APIKey, "sk-from-env")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoad_UnsetEnvVarExpandsToEmpty(t *testing.T) {
|
||||
// Ensure the variable is not set
|
||||
t.Setenv("TEST_UNSET_VAR", "")
|
||||
os.Unsetenv("TEST_UNSET_VAR")
|
||||
|
||||
path := writeConfigFile(t, `
|
||||
providers:
|
||||
- name: openai
|
||||
base_url: https://api.openai.com/v1
|
||||
api_key: ${TEST_UNSET_VAR}
|
||||
|
||||
models:
|
||||
- name: gpt-4
|
||||
routes:
|
||||
- provider: openai
|
||||
model: gpt-4
|
||||
`)
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("Load() should return error when env var expands to empty required field")
|
||||
}
|
||||
// api_key will be empty, so validation should catch it
|
||||
if !strings.Contains(err.Error(), "api_key are required") {
|
||||
t.Errorf("error = %q, want to contain api_key validation message", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
27
llm-gateway/internal/config/watcher.go
Normal file
27
llm-gateway/internal/config/watcher.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// WatchReload listens for SIGHUP and calls the callback with the new config.
|
||||
func WatchReload(configPath string, callback func(*Config)) {
|
||||
sighup := make(chan os.Signal, 1)
|
||||
signal.Notify(sighup, syscall.SIGHUP)
|
||||
|
||||
go func() {
|
||||
for range sighup {
|
||||
log.Println("SIGHUP received, reloading config...")
|
||||
newCfg, err := Load(configPath)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: config reload failed: %v", err)
|
||||
continue
|
||||
}
|
||||
callback(newCfg)
|
||||
log.Println("Config reloaded successfully")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
@ -7,6 +7,8 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"llm-gateway/internal/auth"
|
||||
"llm-gateway/internal/cache"
|
||||
"llm-gateway/internal/provider"
|
||||
|
|
@ -58,6 +60,7 @@ type TokenUsageStats struct {
|
|||
|
||||
// RequestLogEntry represents a single request log row.
|
||||
type RequestLogEntry struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
TokenName string `json:"token_name"`
|
||||
Model string `json:"model"`
|
||||
|
|
@ -104,6 +107,8 @@ type StatsAPI struct {
|
|||
authStore *auth.Store
|
||||
healthTracker *provider.HealthTracker
|
||||
cache *cache.Cache
|
||||
auditLogger *storage.AuditLogger
|
||||
debugLogger *storage.DebugLogger
|
||||
}
|
||||
|
||||
func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI {
|
||||
|
|
@ -120,6 +125,16 @@ func (s *StatsAPI) SetCache(c *cache.Cache) {
|
|||
s.cache = c
|
||||
}
|
||||
|
||||
// SetAuditLogger sets the audit logger.
|
||||
func (s *StatsAPI) SetAuditLogger(al *storage.AuditLogger) {
|
||||
s.auditLogger = al
|
||||
}
|
||||
|
||||
// SetDebugLogger sets the debug logger.
|
||||
func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) {
|
||||
s.debugLogger = dl
|
||||
}
|
||||
|
||||
// TokenNamesForUser returns the token names that belong to the user.
|
||||
// Admins get nil (no filter), non-admins get their token names.
|
||||
func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string {
|
||||
|
|
@ -325,7 +340,7 @@ func (s *StatsAPI) GetLogs(tokenNames []string, page int, model, token, status s
|
|||
}
|
||||
|
||||
// Get page
|
||||
query := `SELECT timestamp, token_name, model, provider, provider_model,
|
||||
query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model,
|
||||
input_tokens, output_tokens, cost_usd, latency_ms, status,
|
||||
COALESCE(error_message, ''), streaming, cached
|
||||
FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?`
|
||||
|
|
@ -341,7 +356,7 @@ func (s *StatsAPI) GetLogs(tokenNames []string, page int, model, token, status s
|
|||
for rows.Next() {
|
||||
var l RequestLogEntry
|
||||
var streaming, cached int
|
||||
rows.Scan(&l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
|
||||
rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
|
||||
&l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status,
|
||||
&l.ErrorMessage, &streaming, &cached)
|
||||
l.Streaming = streaming == 1
|
||||
|
|
@ -624,6 +639,79 @@ func (s *StatsAPI) CacheStats(w http.ResponseWriter, r *http.Request) {
|
|||
writeJSON(w, stats)
|
||||
}
|
||||
|
||||
// AuditLogs serves the audit log API (admin-only).
|
||||
func (s *StatsAPI) AuditLogs(w http.ResponseWriter, r *http.Request) {
|
||||
if s.auditLogger == nil {
|
||||
writeJSON(w, map[string]any{"entries": []any{}, "total": 0})
|
||||
return
|
||||
}
|
||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||
action := r.URL.Query().Get("action")
|
||||
since := time.Now().AddDate(0, 0, -30).Unix()
|
||||
if sinceStr := r.URL.Query().Get("since"); sinceStr != "" {
|
||||
if s, err := strconv.ParseInt(sinceStr, 10, 64); err == nil {
|
||||
since = s
|
||||
}
|
||||
}
|
||||
result := s.auditLogger.Query(since, action, page, 50)
|
||||
writeJSON(w, result)
|
||||
}
|
||||
|
||||
// DebugToggle enables/disables debug logging at runtime.
|
||||
func (s *StatsAPI) DebugToggle(w http.ResponseWriter, r *http.Request) {
|
||||
if s.debugLogger == nil {
|
||||
writeJSON(w, map[string]any{"error": "debug logger not configured"})
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
writeJSON(w, map[string]string{"error": "invalid JSON"})
|
||||
return
|
||||
}
|
||||
s.debugLogger.SetEnabled(req.Enabled)
|
||||
writeJSON(w, map[string]any{"enabled": s.debugLogger.IsEnabled()})
|
||||
}
|
||||
|
||||
// DebugStatus returns whether debug logging is enabled.
|
||||
func (s *StatsAPI) DebugStatus(w http.ResponseWriter, r *http.Request) {
|
||||
enabled := false
|
||||
if s.debugLogger != nil {
|
||||
enabled = s.debugLogger.IsEnabled()
|
||||
}
|
||||
writeJSON(w, map[string]any{"enabled": enabled})
|
||||
}
|
||||
|
||||
// DebugLogs serves paginated debug log entries.
|
||||
func (s *StatsAPI) DebugLogs(w http.ResponseWriter, r *http.Request) {
|
||||
if s.debugLogger == nil {
|
||||
writeJSON(w, map[string]any{"entries": []any{}, "total": 0})
|
||||
return
|
||||
}
|
||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||
result := s.debugLogger.Query(page, 50)
|
||||
writeJSON(w, result)
|
||||
}
|
||||
|
||||
// DebugLogByRequestID serves a single debug log entry by request ID.
|
||||
func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) {
|
||||
if s.debugLogger == nil {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
writeJSON(w, map[string]string{"error": "debug logger not configured"})
|
||||
return
|
||||
}
|
||||
requestID := chi.URLParam(r, "requestID")
|
||||
entry := s.debugLogger.GetByRequestID(requestID)
|
||||
if entry == nil {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
writeJSON(w, map[string]string{"error": "not found"})
|
||||
return
|
||||
}
|
||||
writeJSON(w, entry)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(v)
|
||||
|
|
|
|||
297
llm-gateway/internal/dashboard/export.go
Normal file
297
llm-gateway/internal/dashboard/export.go
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
package dashboard
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"llm-gateway/internal/auth"
|
||||
"llm-gateway/internal/storage"
|
||||
)
|
||||
|
||||
type ExportHandler struct {
|
||||
db *storage.DB
|
||||
authStore *auth.Store
|
||||
}
|
||||
|
||||
func NewExportHandler(db *storage.DB, authStore *auth.Store) *ExportHandler {
|
||||
return &ExportHandler{db: db, authStore: authStore}
|
||||
}
|
||||
|
||||
// ExportLogs exports request logs as CSV or JSON.
|
||||
func (e *ExportHandler) ExportLogs(w http.ResponseWriter, r *http.Request) {
|
||||
format := r.URL.Query().Get("format")
|
||||
if format == "" {
|
||||
format = "json"
|
||||
}
|
||||
|
||||
// Build query
|
||||
where := "WHERE 1=1"
|
||||
var args []any
|
||||
|
||||
if from := r.URL.Query().Get("from"); from != "" {
|
||||
if ts, err := strconv.ParseInt(from, 10, 64); err == nil {
|
||||
where += " AND timestamp >= ?"
|
||||
args = append(args, ts)
|
||||
}
|
||||
}
|
||||
if to := r.URL.Query().Get("to"); to != "" {
|
||||
if ts, err := strconv.ParseInt(to, 10, 64); err == nil {
|
||||
where += " AND timestamp <= ?"
|
||||
args = append(args, ts)
|
||||
}
|
||||
}
|
||||
if model := r.URL.Query().Get("model"); model != "" {
|
||||
where += " AND model = ?"
|
||||
args = append(args, model)
|
||||
}
|
||||
if token := r.URL.Query().Get("token"); token != "" {
|
||||
where += " AND token_name = ?"
|
||||
args = append(args, token)
|
||||
}
|
||||
if status := r.URL.Query().Get("status"); status != "" {
|
||||
where += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
|
||||
// Token filtering for non-admins
|
||||
user := auth.UserFromContext(r.Context())
|
||||
if user != nil && !user.IsAdmin {
|
||||
tokens, err := e.authStore.ListAPITokens(user.ID)
|
||||
if err != nil || len(tokens) == 0 {
|
||||
where += " AND 1=0"
|
||||
} else {
|
||||
where += " AND token_name IN ("
|
||||
for i, t := range tokens {
|
||||
if i > 0 {
|
||||
where += ","
|
||||
}
|
||||
where += "?"
|
||||
args = append(args, t.Name)
|
||||
}
|
||||
where += ")"
|
||||
}
|
||||
}
|
||||
|
||||
query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model,
|
||||
input_tokens, output_tokens, cost_usd, latency_ms, status,
|
||||
COALESCE(error_message, ''), streaming, cached
|
||||
FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT 100000`
|
||||
|
||||
rows, err := e.db.Query(query, args...)
|
||||
if err != nil {
|
||||
http.Error(w, "query failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type logRow struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
TokenName string `json:"token_name"`
|
||||
Model string `json:"model"`
|
||||
Provider string `json:"provider"`
|
||||
ProviderModel string `json:"provider_model"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CostUSD float64 `json:"cost_usd"`
|
||||
LatencyMS int64 `json:"latency_ms"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
Streaming bool `json:"streaming"`
|
||||
Cached bool `json:"cached"`
|
||||
}
|
||||
|
||||
var results []logRow
|
||||
for rows.Next() {
|
||||
var l logRow
|
||||
var streaming, cached int
|
||||
rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
|
||||
&l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status,
|
||||
&l.ErrorMessage, &streaming, &cached)
|
||||
l.Streaming = streaming == 1
|
||||
l.Cached = cached == 1
|
||||
results = append(results, l)
|
||||
}
|
||||
|
||||
now := time.Now().Format("20060102-150405")
|
||||
|
||||
switch format {
|
||||
case "csv":
|
||||
w.Header().Set("Content-Type", "text/csv")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.csv", now))
|
||||
writer := csv.NewWriter(w)
|
||||
writer.Write([]string{"request_id", "timestamp", "token_name", "model", "provider", "provider_model",
|
||||
"input_tokens", "output_tokens", "cost_usd", "latency_ms", "status", "error_message", "streaming", "cached"})
|
||||
for _, l := range results {
|
||||
writer.Write([]string{
|
||||
l.RequestID,
|
||||
strconv.FormatInt(l.Timestamp, 10),
|
||||
l.TokenName, l.Model, l.Provider, l.ProviderModel,
|
||||
strconv.Itoa(l.InputTokens), strconv.Itoa(l.OutputTokens),
|
||||
fmt.Sprintf("%.8f", l.CostUSD),
|
||||
strconv.FormatInt(l.LatencyMS, 10),
|
||||
l.Status, l.ErrorMessage,
|
||||
strconv.FormatBool(l.Streaming), strconv.FormatBool(l.Cached),
|
||||
})
|
||||
}
|
||||
writer.Flush()
|
||||
default:
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.json", now))
|
||||
json.NewEncoder(w).Encode(results)
|
||||
}
|
||||
}
|
||||
|
||||
// ExportStats exports aggregated stats as CSV or JSON.
|
||||
func (e *ExportHandler) ExportStats(w http.ResponseWriter, r *http.Request) {
|
||||
format := r.URL.Query().Get("format")
|
||||
if format == "" {
|
||||
format = "json"
|
||||
}
|
||||
statsType := r.URL.Query().Get("type")
|
||||
if statsType == "" {
|
||||
statsType = "summary"
|
||||
}
|
||||
|
||||
now := time.Now().Format("20060102-150405")
|
||||
since := time.Now().AddDate(0, -1, 0).Unix()
|
||||
|
||||
switch statsType {
|
||||
case "models":
|
||||
rows, err := e.db.Query(`SELECT model, COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
|
||||
COALESCE(SUM(cost_usd), 0), COALESCE(AVG(latency_ms), 0)
|
||||
FROM request_logs WHERE timestamp >= ? GROUP BY model ORDER BY requests DESC`, since)
|
||||
if err != nil {
|
||||
http.Error(w, "query failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type modelRow struct {
|
||||
Model string `json:"model"`
|
||||
Requests int `json:"requests"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CostUSD float64 `json:"cost_usd"`
|
||||
AvgLatencyMS float64 `json:"avg_latency_ms"`
|
||||
}
|
||||
var results []modelRow
|
||||
for rows.Next() {
|
||||
var m modelRow
|
||||
rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS)
|
||||
results = append(results, m)
|
||||
}
|
||||
|
||||
if format == "csv" {
|
||||
w.Header().Set("Content-Type", "text/csv")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.csv", now))
|
||||
writer := csv.NewWriter(w)
|
||||
writer.Write([]string{"model", "requests", "input_tokens", "output_tokens", "cost_usd", "avg_latency_ms"})
|
||||
for _, m := range results {
|
||||
writer.Write([]string{m.Model, strconv.Itoa(m.Requests), strconv.Itoa(m.InputTokens),
|
||||
strconv.Itoa(m.OutputTokens), fmt.Sprintf("%.8f", m.CostUSD), fmt.Sprintf("%.2f", m.AvgLatencyMS)})
|
||||
}
|
||||
writer.Flush()
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.json", now))
|
||||
json.NewEncoder(w).Encode(results)
|
||||
}
|
||||
|
||||
case "providers":
|
||||
rows, err := e.db.Query(`SELECT provider, COUNT(*) as requests,
|
||||
COALESCE(SUM(CASE WHEN status='success' THEN 1 ELSE 0 END), 0),
|
||||
COALESCE(SUM(CASE WHEN status='error' THEN 1 ELSE 0 END), 0),
|
||||
COALESCE(AVG(latency_ms), 0), COALESCE(SUM(cost_usd), 0)
|
||||
FROM request_logs WHERE timestamp >= ? GROUP BY provider ORDER BY requests DESC`, since)
|
||||
if err != nil {
|
||||
http.Error(w, "query failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type providerRow struct {
|
||||
Provider string `json:"provider"`
|
||||
Requests int `json:"requests"`
|
||||
Successes int `json:"successes"`
|
||||
Errors int `json:"errors"`
|
||||
AvgLatencyMS float64 `json:"avg_latency_ms"`
|
||||
CostUSD float64 `json:"cost_usd"`
|
||||
}
|
||||
var results []providerRow
|
||||
for rows.Next() {
|
||||
var p providerRow
|
||||
rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD)
|
||||
results = append(results, p)
|
||||
}
|
||||
|
||||
if format == "csv" {
|
||||
w.Header().Set("Content-Type", "text/csv")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.csv", now))
|
||||
writer := csv.NewWriter(w)
|
||||
writer.Write([]string{"provider", "requests", "successes", "errors", "avg_latency_ms", "cost_usd"})
|
||||
for _, p := range results {
|
||||
writer.Write([]string{p.Provider, strconv.Itoa(p.Requests), strconv.Itoa(p.Successes),
|
||||
strconv.Itoa(p.Errors), fmt.Sprintf("%.2f", p.AvgLatencyMS), fmt.Sprintf("%.8f", p.CostUSD)})
|
||||
}
|
||||
writer.Flush()
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.json", now))
|
||||
json.NewEncoder(w).Encode(results)
|
||||
}
|
||||
|
||||
case "tokens":
|
||||
rows, err := e.db.Query(`SELECT token_name, COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
|
||||
COALESCE(SUM(cost_usd), 0)
|
||||
FROM request_logs WHERE timestamp >= ? GROUP BY token_name ORDER BY requests DESC`, since)
|
||||
if err != nil {
|
||||
http.Error(w, "query failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type tokenRow struct {
|
||||
TokenName string `json:"token_name"`
|
||||
Requests int `json:"requests"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CostUSD float64 `json:"cost_usd"`
|
||||
}
|
||||
var results []tokenRow
|
||||
for rows.Next() {
|
||||
var t tokenRow
|
||||
rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD)
|
||||
results = append(results, t)
|
||||
}
|
||||
|
||||
if format == "csv" {
|
||||
w.Header().Set("Content-Type", "text/csv")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.csv", now))
|
||||
writer := csv.NewWriter(w)
|
||||
writer.Write([]string{"token_name", "requests", "input_tokens", "output_tokens", "cost_usd"})
|
||||
for _, t := range results {
|
||||
writer.Write([]string{t.TokenName, strconv.Itoa(t.Requests), strconv.Itoa(t.InputTokens),
|
||||
strconv.Itoa(t.OutputTokens), fmt.Sprintf("%.8f", t.CostUSD)})
|
||||
}
|
||||
writer.Flush()
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.json", now))
|
||||
json.NewEncoder(w).Encode(results)
|
||||
}
|
||||
|
||||
default: // summary
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-summary-%s.json", now))
|
||||
statsAPI := NewStatsAPI(e.db, e.authStore)
|
||||
result := statsAPI.GetSummary(nil)
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
}
|
||||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"llm-gateway/internal/auth"
|
||||
"llm-gateway/internal/cache"
|
||||
"llm-gateway/internal/provider"
|
||||
"llm-gateway/internal/storage"
|
||||
)
|
||||
|
||||
//go:embed templates/*.html templates/partials/*.html
|
||||
|
|
@ -125,15 +126,24 @@ type PageData struct {
|
|||
FilterStatus string
|
||||
// Models routing page data
|
||||
ModelRoutes []provider.ModelRouteInfo
|
||||
// Audit page data
|
||||
AuditResult *storage.AuditQueryResult
|
||||
AuditFilterActions []string
|
||||
FilterAction string
|
||||
// Debug page data
|
||||
DebugResult *storage.DebugLogQueryResult
|
||||
DebugEnabled bool
|
||||
}
|
||||
|
||||
// Dashboard serves the HTMX-based dashboard pages.
|
||||
type Dashboard struct {
|
||||
templates *template.Template
|
||||
authStore *auth.Store
|
||||
statsAPI *StatsAPI
|
||||
registry *provider.Registry
|
||||
cache *cache.Cache
|
||||
templates *template.Template
|
||||
authStore *auth.Store
|
||||
statsAPI *StatsAPI
|
||||
registry *provider.Registry
|
||||
cache *cache.Cache
|
||||
auditLogger *storage.AuditLogger
|
||||
debugLogger *storage.DebugLogger
|
||||
}
|
||||
|
||||
// NewDashboard creates a new Dashboard handler.
|
||||
|
|
@ -162,6 +172,16 @@ func (d *Dashboard) SetCache(c *cache.Cache) {
|
|||
d.cache = c
|
||||
}
|
||||
|
||||
// SetAuditLogger sets the audit logger for the audit page.
|
||||
func (d *Dashboard) SetAuditLogger(al *storage.AuditLogger) {
|
||||
d.auditLogger = al
|
||||
}
|
||||
|
||||
// SetDebugLogger sets the debug logger for the debug page.
|
||||
func (d *Dashboard) SetDebugLogger(dl *storage.DebugLogger) {
|
||||
d.debugLogger = dl
|
||||
}
|
||||
|
||||
// LoginPage serves the login page.
|
||||
func (d *Dashboard) LoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
if !d.authStore.HasAnyUser() {
|
||||
|
|
@ -298,6 +318,62 @@ func (d *Dashboard) UsersPage(w http.ResponseWriter, r *http.Request) {
|
|||
})
|
||||
}
|
||||
|
||||
// AuditPage serves the audit log view (admin only).
|
||||
func (d *Dashboard) AuditPage(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.UserFromContext(r.Context())
|
||||
|
||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
action := r.URL.Query().Get("action")
|
||||
since := time.Now().AddDate(0, 0, -30).Unix()
|
||||
|
||||
var auditResult *storage.AuditQueryResult
|
||||
if d.auditLogger != nil {
|
||||
auditResult = d.auditLogger.Query(since, action, page, 50)
|
||||
} else {
|
||||
auditResult = &storage.AuditQueryResult{Entries: []storage.AuditEntry{}, Page: 1, TotalPages: 1}
|
||||
}
|
||||
|
||||
// Common audit action types for the filter dropdown
|
||||
actions := []string{"login", "logout", "create_user", "delete_user", "create_token", "delete_token", "change_password", "setup_totp", "disable_totp"}
|
||||
|
||||
d.renderDashboardPage(w, r, "partials/audit.html", PageData{
|
||||
ActivePage: "audit",
|
||||
User: user,
|
||||
AuditResult: auditResult,
|
||||
AuditFilterActions: actions,
|
||||
FilterAction: action,
|
||||
})
|
||||
}
|
||||
|
||||
// DebugPage serves the debug logging view (admin only).
|
||||
func (d *Dashboard) DebugPage(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.UserFromContext(r.Context())
|
||||
|
||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
var debugResult *storage.DebugLogQueryResult
|
||||
debugEnabled := false
|
||||
if d.debugLogger != nil {
|
||||
debugResult = d.debugLogger.QueryFull(page, 50)
|
||||
debugEnabled = d.debugLogger.IsEnabled()
|
||||
} else {
|
||||
debugResult = &storage.DebugLogQueryResult{Entries: []storage.DebugLogEntry{}, Page: 1, TotalPages: 1}
|
||||
}
|
||||
|
||||
d.renderDashboardPage(w, r, "partials/debug.html", PageData{
|
||||
ActivePage: "debug",
|
||||
User: user,
|
||||
DebugResult: debugResult,
|
||||
DebugEnabled: debugEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
// SettingsPage serves the settings view.
|
||||
func (d *Dashboard) SettingsPage(w http.ResponseWriter, r *http.Request) {
|
||||
user := auth.UserFromContext(r.Context())
|
||||
|
|
|
|||
|
|
@ -160,6 +160,24 @@
|
|||
.badge-error { background: var(--accent-red-bg); color: var(--accent-red); }
|
||||
.badge-cached { background: var(--accent-blue-bg); color: var(--accent-blue); }
|
||||
.badge-priority { background: var(--bg-tertiary); color: var(--text-secondary); }
|
||||
.badge-open { background: var(--accent-red-bg); color: var(--accent-red); }
|
||||
.badge-half-open { background: var(--accent-yellow-bg); color: var(--accent-yellow); }
|
||||
|
||||
/* Toggle switch */
|
||||
.toggle-switch { position: relative; display: inline-block; width: 44px; height: 24px; }
|
||||
.toggle-switch input { opacity: 0; width: 0; height: 0; }
|
||||
.toggle-slider { position: absolute; cursor: pointer; top: 0; left: 0; right: 0; bottom: 0; background: var(--bg-tertiary); border-radius: 24px; transition: 0.2s; }
|
||||
.toggle-slider:before { content: ""; position: absolute; height: 18px; width: 18px; left: 3px; bottom: 3px; background: var(--text-secondary); border-radius: 50%; transition: 0.2s; }
|
||||
.toggle-switch input:checked + .toggle-slider { background: var(--accent-blue); }
|
||||
.toggle-switch input:checked + .toggle-slider:before { transform: translateX(20px); background: #fff; }
|
||||
|
||||
/* Code block for debug bodies */
|
||||
.code-block { background: var(--bg-primary); border: 1px solid var(--border-color); border-radius: 6px; padding: 12px; font-family: monospace; font-size: 0.8rem; white-space: pre-wrap; word-break: break-all; max-height: 300px; overflow-y: auto; }
|
||||
|
||||
/* Export button inline */
|
||||
.export-links { display: inline-flex; gap: 6px; margin-left: 12px; }
|
||||
.export-links a { font-size: 0.7rem; color: var(--text-muted); text-decoration: none; padding: 2px 6px; border: 1px solid var(--border-color); border-radius: 4px; }
|
||||
.export-links a:hover { color: var(--text-primary); border-color: var(--text-muted); }
|
||||
|
||||
.page-header { display: flex; align-items: center; gap: 12px; margin-bottom: 20px; }
|
||||
.page-header h1 { font-size: 1.3rem; color: var(--text-heading); }
|
||||
|
|
@ -268,6 +286,8 @@ window.matchMedia('(prefers-color-scheme: light)').addEventListener('change', fu
|
|||
<a href="/tokens" hx-get="/tokens" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "tokens"}}class="active"{{end}}>API Tokens</a>
|
||||
{{if .User.IsAdmin}}
|
||||
<a href="/users" hx-get="/users" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "users"}}class="active"{{end}}>Users</a>
|
||||
<a href="/audit" hx-get="/audit" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "audit"}}class="active"{{end}}>Audit Log</a>
|
||||
<a href="/debug" hx-get="/debug" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "debug"}}class="active"{{end}}>Debug</a>
|
||||
{{end}}
|
||||
<a href="/settings" hx-get="/settings" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "settings"}}class="active"{{end}}>Settings</a>
|
||||
</nav>
|
||||
|
|
|
|||
83
llm-gateway/internal/dashboard/templates/partials/audit.html
Normal file
83
llm-gateway/internal/dashboard/templates/partials/audit.html
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
{{define "content"}}
|
||||
<div class="page-header">
|
||||
<h1>Audit Log</h1>
|
||||
<span style="font-size:0.85rem;color:var(--text-muted)">{{.AuditResult.Total}} total</span>
|
||||
</div>
|
||||
|
||||
<div class="filter-bar">
|
||||
<select id="filter-action" onchange="applyAuditFilter()">
|
||||
<option value="">All Actions</option>
|
||||
{{range .AuditFilterActions}}<option value="{{.}}" {{if eq . $.FilterAction}}selected{{end}}>{{.}}</option>{{end}}
|
||||
</select>
|
||||
<button class="btn btn-sm btn-outline" onclick="clearAuditFilter()">Clear</button>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Time</th>
|
||||
<th>User</th>
|
||||
<th>Action</th>
|
||||
<th>Target</th>
|
||||
<th>Details</th>
|
||||
<th>IP</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .AuditResult.Entries}}
|
||||
<tr>
|
||||
<td>{{formatTimeDetail .Timestamp}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td><span class="badge badge-priority">{{.Action}}</span></td>
|
||||
<td>{{if .TargetType}}{{.TargetType}}{{if .TargetID}}/{{.TargetID}}{{end}}{{else}}-{{end}}</td>
|
||||
<td style="max-width:300px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="{{.Details}}">{{if .Details}}{{.Details}}{{else}}-{{end}}</td>
|
||||
<td>{{if .IPAddress}}{{.IPAddress}}{{else}}-{{end}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
{{if not .AuditResult.Entries}}
|
||||
<tr><td colspan="6" style="text-align:center;color:var(--text-muted);padding:24px;">No audit log entries</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
{{if gt .AuditResult.TotalPages 1}}
|
||||
<div class="pagination">
|
||||
<button {{if le .AuditResult.Page 1}}disabled{{end}} onclick="goToAuditPage(1)">First</button>
|
||||
<button {{if le .AuditResult.Page 1}}disabled{{end}} onclick="goToAuditPage({{subInt .AuditResult.Page 1}})">Prev</button>
|
||||
{{$page := .AuditResult.Page}}
|
||||
{{$total := .AuditResult.TotalPages}}
|
||||
{{range seq (paginationStart $page $total) (paginationEnd $page $total)}}
|
||||
<button class="{{if eq . $page}}active{{end}}" onclick="goToAuditPage({{.}})">{{.}}</button>
|
||||
{{end}}
|
||||
<button {{if ge .AuditResult.Page .AuditResult.TotalPages}}disabled{{end}} onclick="goToAuditPage({{addInt .AuditResult.Page 1}})">Next</button>
|
||||
<button {{if ge .AuditResult.Page .AuditResult.TotalPages}}disabled{{end}} onclick="goToAuditPage({{.AuditResult.TotalPages}})">Last</button>
|
||||
<span class="page-info">Page {{.AuditResult.Page}} of {{.AuditResult.TotalPages}}</span>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function buildAuditURL(page) {
|
||||
var params = [];
|
||||
var action = document.getElementById('filter-action').value;
|
||||
if (action) params.push('action=' + encodeURIComponent(action));
|
||||
if (page > 1) params.push('page=' + page);
|
||||
return '/audit' + (params.length ? '?' + params.join('&') : '');
|
||||
}
|
||||
function applyAuditFilter() {
|
||||
var url = buildAuditURL(1);
|
||||
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
|
||||
history.pushState({}, '', url);
|
||||
}
|
||||
function goToAuditPage(page) {
|
||||
var url = buildAuditURL(page);
|
||||
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
|
||||
history.pushState({}, '', url);
|
||||
}
|
||||
function clearAuditFilter() {
|
||||
document.getElementById('filter-action').value = '';
|
||||
applyAuditFilter();
|
||||
}
|
||||
</script>
|
||||
{{end}}
|
||||
|
|
@ -25,6 +25,8 @@
|
|||
<div class="health-item">
|
||||
<span class="provider-name">{{.Provider}}</span>
|
||||
<span class="badge badge-{{.Status}}">{{.Status}}</span>
|
||||
{{if eq .CircuitState "open"}}<span class="badge badge-open">circuit open</span>{{end}}
|
||||
{{if eq .CircuitState "half-open"}}<span class="badge badge-half-open">half-open</span>{{end}}
|
||||
<span style="font-size:0.75rem;color:var(--text-muted)">{{printf "%.0f" .AvgLatency}}ms avg | {{formatPct .ErrorRate}} errors</span>
|
||||
</div>
|
||||
{{end}}
|
||||
|
|
@ -71,7 +73,7 @@
|
|||
|
||||
{{if .Models}}
|
||||
<div class="section">
|
||||
<h2>Models</h2>
|
||||
<h2>Models<span class="export-links"><a href="/api/export/stats?format=csv&type=models" target="_blank">CSV</a><a href="/api/export/stats?format=json&type=models" target="_blank">JSON</a></span></h2>
|
||||
<table>
|
||||
<thead><tr><th>Model</th><th>Requests</th><th>Tokens (in/out)</th><th>Cost</th><th>Avg Latency</th></tr></thead>
|
||||
<tbody>
|
||||
|
|
@ -91,7 +93,7 @@
|
|||
|
||||
{{if .Providers}}
|
||||
<div class="section">
|
||||
<h2>Providers</h2>
|
||||
<h2>Providers<span class="export-links"><a href="/api/export/stats?format=csv&type=providers" target="_blank">CSV</a><a href="/api/export/stats?format=json&type=providers" target="_blank">JSON</a></span></h2>
|
||||
<table>
|
||||
<thead><tr><th>Provider</th><th>Requests</th><th>Success</th><th>Errors</th><th>Avg Latency</th><th>Cost</th></tr></thead>
|
||||
<tbody>
|
||||
|
|
@ -112,7 +114,7 @@
|
|||
|
||||
{{if .TokenStats}}
|
||||
<div class="section">
|
||||
<h2>API Token Usage</h2>
|
||||
<h2>API Token Usage<span class="export-links"><a href="/api/export/stats?format=csv&type=tokens" target="_blank">CSV</a><a href="/api/export/stats?format=json&type=tokens" target="_blank">JSON</a></span></h2>
|
||||
<table>
|
||||
<thead><tr><th>Token</th><th>Requests</th><th>Tokens (in/out)</th><th>Cost</th></tr></thead>
|
||||
<tbody>
|
||||
|
|
|
|||
100
llm-gateway/internal/dashboard/templates/partials/debug.html
Normal file
100
llm-gateway/internal/dashboard/templates/partials/debug.html
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
{{define "content"}}
|
||||
<div class="page-header">
|
||||
<h1>Debug Logging</h1>
|
||||
<span style="font-size:0.85rem;color:var(--text-muted)">{{.DebugResult.Total}} entries</span>
|
||||
</div>
|
||||
|
||||
<div class="section" style="display:flex;align-items:center;gap:16px;padding:12px 16px;">
|
||||
<span style="font-size:0.9rem;font-weight:600;">Debug Mode</span>
|
||||
<label class="toggle-switch">
|
||||
<input type="checkbox" id="debug-toggle" {{if .DebugEnabled}}checked{{end}} onchange="toggleDebug(this.checked)">
|
||||
<span class="toggle-slider"></span>
|
||||
</label>
|
||||
<span id="debug-status" style="font-size:0.8rem;color:var(--text-muted)">{{if .DebugEnabled}}Enabled — requests are being logged{{else}}Disabled{{end}}</span>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<th>Time</th>
|
||||
<th>Request ID</th>
|
||||
<th>Token</th>
|
||||
<th>Model</th>
|
||||
<th>Provider</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range $i, $entry := .DebugResult.Entries}}
|
||||
<tr class="expandable" onclick="toggleDebugExpand('debug-expand-{{$i}}')">
|
||||
<td style="width:20px;text-align:center;color:var(--text-muted)">▶</td>
|
||||
<td>{{formatTimeDetail $entry.Timestamp}}</td>
|
||||
<td><code style="font-size:0.75rem">{{$entry.RequestID}}</code></td>
|
||||
<td>{{$entry.TokenName}}</td>
|
||||
<td>{{$entry.Model}}</td>
|
||||
<td>{{$entry.Provider}}</td>
|
||||
<td>
|
||||
{{if and (ge $entry.ResponseStatus 200) (lt $entry.ResponseStatus 300)}}<span class="badge badge-success">{{$entry.ResponseStatus}}</span>
|
||||
{{else if ge $entry.ResponseStatus 400}}<span class="badge badge-error">{{$entry.ResponseStatus}}</span>
|
||||
{{else}}<span class="badge">{{$entry.ResponseStatus}}</span>{{end}}
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="7" style="padding:0;">
|
||||
<div id="debug-expand-{{$i}}" class="expand-content">
|
||||
<div style="margin-bottom:8px"><strong>Request Headers:</strong></div>
|
||||
<div class="code-block">{{if $entry.RequestHeaders}}{{$entry.RequestHeaders}}{{else}}(none){{end}}</div>
|
||||
<div style="margin:8px 0"><strong>Request Body:</strong></div>
|
||||
<div class="code-block">{{if $entry.RequestBody}}{{$entry.RequestBody}}{{else}}(none){{end}}</div>
|
||||
<div style="margin:8px 0"><strong>Response Body:</strong></div>
|
||||
<div class="code-block">{{if $entry.ResponseBody}}{{$entry.ResponseBody}}{{else}}(none){{end}}</div>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
{{if not .DebugResult.Entries}}
|
||||
<tr><td colspan="7" style="text-align:center;color:var(--text-muted);padding:24px;">No debug log entries</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
{{if gt .DebugResult.TotalPages 1}}
|
||||
<div class="pagination">
|
||||
<button {{if le .DebugResult.Page 1}}disabled{{end}} onclick="goToDebugPage(1)">First</button>
|
||||
<button {{if le .DebugResult.Page 1}}disabled{{end}} onclick="goToDebugPage({{subInt .DebugResult.Page 1}})">Prev</button>
|
||||
{{$page := .DebugResult.Page}}
|
||||
{{$total := .DebugResult.TotalPages}}
|
||||
{{range seq (paginationStart $page $total) (paginationEnd $page $total)}}
|
||||
<button class="{{if eq . $page}}active{{end}}" onclick="goToDebugPage({{.}})">{{.}}</button>
|
||||
{{end}}
|
||||
<button {{if ge .DebugResult.Page .DebugResult.TotalPages}}disabled{{end}} onclick="goToDebugPage({{addInt .DebugResult.Page 1}})">Next</button>
|
||||
<button {{if ge .DebugResult.Page .DebugResult.TotalPages}}disabled{{end}} onclick="goToDebugPage({{.DebugResult.TotalPages}})">Last</button>
|
||||
<span class="page-info">Page {{.DebugResult.Page}} of {{.DebugResult.TotalPages}}</span>
|
||||
</div>
|
||||
{{end}}
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function toggleDebug(enabled) {
|
||||
fetch('/api/debug/toggle', {
|
||||
method: 'POST',
|
||||
credentials: 'same-origin',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({enabled: enabled})
|
||||
}).then(function() {
|
||||
htmx.ajax('GET', '/debug', {target: '#content', swap: 'innerHTML'});
|
||||
});
|
||||
}
|
||||
function toggleDebugExpand(id) {
|
||||
var el = document.getElementById(id);
|
||||
if (el) el.classList.toggle('show');
|
||||
}
|
||||
function goToDebugPage(page) {
|
||||
var url = '/debug' + (page > 1 ? '?page=' + page : '');
|
||||
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
|
||||
history.pushState({}, '', url);
|
||||
}
|
||||
</script>
|
||||
{{end}}
|
||||
|
|
@ -20,6 +20,9 @@
|
|||
<option value="cached" {{if eq .FilterStatus "cached"}}selected{{end}}>Cached</option>
|
||||
</select>
|
||||
<button class="btn btn-sm btn-outline" onclick="clearLogsFilter()">Clear</button>
|
||||
<span style="margin-left:auto"></span>
|
||||
<button class="btn btn-sm btn-outline" onclick="exportLogs('csv')">Export CSV</button>
|
||||
<button class="btn btn-sm btn-outline" onclick="exportLogs('json')">Export JSON</button>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
|
|
@ -116,5 +119,15 @@ function toggleExpand(id) {
|
|||
var el = document.getElementById(id);
|
||||
if (el) el.classList.toggle('show');
|
||||
}
|
||||
function exportLogs(format) {
|
||||
var params = ['format=' + format];
|
||||
var model = document.getElementById('filter-model').value;
|
||||
var token = document.getElementById('filter-token').value;
|
||||
var status = document.getElementById('filter-status').value;
|
||||
if (model) params.push('model=' + encodeURIComponent(model));
|
||||
if (token) params.push('token=' + encodeURIComponent(token));
|
||||
if (status) params.push('status=' + encodeURIComponent(status));
|
||||
window.open('/api/export/logs?' + params.join('&'), '_blank');
|
||||
}
|
||||
</script>
|
||||
{{end}}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ type Metrics struct {
|
|||
requestDuration *prometheus.HistogramVec
|
||||
tokensTotal *prometheus.CounterVec
|
||||
costTotal *prometheus.CounterVec
|
||||
cacheHits prometheus.Counter
|
||||
cacheMisses prometheus.Counter
|
||||
}
|
||||
|
||||
func New() *Metrics {
|
||||
|
|
@ -34,6 +36,16 @@ func New() *Metrics {
|
|||
Name: "llm_gateway_cost_usd_total",
|
||||
Help: "Total cost in USD",
|
||||
}, []string{"model", "provider", "token_name"}),
|
||||
|
||||
cacheHits: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "llm_gateway_cache_hits_total",
|
||||
Help: "Total number of cache hits",
|
||||
}),
|
||||
|
||||
cacheMisses: promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "llm_gateway_cache_misses_total",
|
||||
Help: "Total number of cache misses",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -51,3 +63,11 @@ func (m *Metrics) RecordRequest(model, providerName, tokenName, status string, l
|
|||
m.costTotal.WithLabelValues(model, providerName, tokenName).Add(cost)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordCacheHit() {
|
||||
m.cacheHits.Inc()
|
||||
}
|
||||
|
||||
func (m *Metrics) RecordCacheMiss() {
|
||||
m.cacheMisses.Inc()
|
||||
}
|
||||
|
|
|
|||
144
llm-gateway/internal/provider/balancer.go
Normal file
144
llm-gateway/internal/provider/balancer.go
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// LoadBalancer reorders routes for load distribution.
|
||||
type LoadBalancer interface {
|
||||
Reorder(routes []Route) []Route
|
||||
}
|
||||
|
||||
// NewLoadBalancer creates a load balancer by strategy name.
|
||||
func NewLoadBalancer(strategy string) LoadBalancer {
|
||||
switch strategy {
|
||||
case "round-robin":
|
||||
return &RoundRobinBalancer{}
|
||||
case "random":
|
||||
return &RandomBalancer{}
|
||||
case "least-cost":
|
||||
return &LeastCostBalancer{}
|
||||
default:
|
||||
return &FirstBalancer{}
|
||||
}
|
||||
}
|
||||
|
||||
// FirstBalancer is a no-op that preserves original order.
|
||||
type FirstBalancer struct{}
|
||||
|
||||
func (b *FirstBalancer) Reorder(routes []Route) []Route {
|
||||
return routes
|
||||
}
|
||||
|
||||
// RoundRobinBalancer rotates routes within same-priority groups.
|
||||
type RoundRobinBalancer struct {
|
||||
counter atomic.Uint64
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) Reorder(routes []Route) []Route {
|
||||
if len(routes) <= 1 {
|
||||
return routes
|
||||
}
|
||||
|
||||
result := make([]Route, len(routes))
|
||||
copy(result, routes)
|
||||
|
||||
// Group by priority and rotate within each group
|
||||
groups := groupByPriority(result)
|
||||
idx := 0
|
||||
count := b.counter.Add(1)
|
||||
for _, group := range groups {
|
||||
if len(group) > 1 {
|
||||
offset := int(count) % len(group)
|
||||
for j := 0; j < len(group); j++ {
|
||||
result[idx] = group[(j+offset)%len(group)]
|
||||
idx++
|
||||
}
|
||||
} else {
|
||||
result[idx] = group[0]
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// RandomBalancer shuffles routes within same-priority groups.
|
||||
type RandomBalancer struct{}
|
||||
|
||||
func (b *RandomBalancer) Reorder(routes []Route) []Route {
|
||||
if len(routes) <= 1 {
|
||||
return routes
|
||||
}
|
||||
|
||||
result := make([]Route, len(routes))
|
||||
copy(result, routes)
|
||||
|
||||
groups := groupByPriority(result)
|
||||
idx := 0
|
||||
for _, group := range groups {
|
||||
rand.Shuffle(len(group), func(i, j int) {
|
||||
group[i], group[j] = group[j], group[i]
|
||||
})
|
||||
for _, r := range group {
|
||||
result[idx] = r
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// LeastCostBalancer sorts by price within same-priority groups.
|
||||
type LeastCostBalancer struct{}
|
||||
|
||||
func (b *LeastCostBalancer) Reorder(routes []Route) []Route {
|
||||
if len(routes) <= 1 {
|
||||
return routes
|
||||
}
|
||||
|
||||
result := make([]Route, len(routes))
|
||||
copy(result, routes)
|
||||
|
||||
groups := groupByPriority(result)
|
||||
idx := 0
|
||||
for _, group := range groups {
|
||||
sort.Slice(group, func(i, j int) bool {
|
||||
costI := group[i].InputPrice + group[i].OutputPrice
|
||||
costJ := group[j].InputPrice + group[j].OutputPrice
|
||||
return costI < costJ
|
||||
})
|
||||
for _, r := range group {
|
||||
result[idx] = r
|
||||
idx++
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// groupByPriority splits routes into groups of same priority, preserving order.
|
||||
func groupByPriority(routes []Route) [][]Route {
|
||||
if len(routes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var groups [][]Route
|
||||
currentPriority := routes[0].Priority
|
||||
currentGroup := []Route{routes[0]}
|
||||
|
||||
for i := 1; i < len(routes); i++ {
|
||||
if routes[i].Priority == currentPriority {
|
||||
currentGroup = append(currentGroup, routes[i])
|
||||
} else {
|
||||
groups = append(groups, currentGroup)
|
||||
currentPriority = routes[i].Priority
|
||||
currentGroup = []Route{routes[i]}
|
||||
}
|
||||
}
|
||||
groups = append(groups, currentGroup)
|
||||
|
||||
return groups
|
||||
}
|
||||
294
llm-gateway/internal/provider/balancer_test.go
Normal file
294
llm-gateway/internal/provider/balancer_test.go
Normal file
|
|
@ -0,0 +1,294 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type routeSpec struct {
|
||||
name string
|
||||
priority int
|
||||
input float64
|
||||
output float64
|
||||
}
|
||||
|
||||
func makeRoutes(specs ...routeSpec) []Route {
|
||||
routes := make([]Route, len(specs))
|
||||
for i, s := range specs {
|
||||
routes[i] = Route{
|
||||
Provider: &mockProvider{name: s.name},
|
||||
ProviderModel: s.name + "-model",
|
||||
Priority: s.priority,
|
||||
InputPrice: s.input,
|
||||
OutputPrice: s.output,
|
||||
}
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
func routeNames(routes []Route) []string {
|
||||
names := make([]string, len(routes))
|
||||
for i, r := range routes {
|
||||
names[i] = r.Provider.Name()
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func TestFirstBalancer_PreservesOrder(t *testing.T) {
|
||||
routes := makeRoutes(
|
||||
routeSpec{"a", 1, 1.0, 1.0},
|
||||
routeSpec{"b", 1, 2.0, 2.0},
|
||||
routeSpec{"c", 1, 3.0, 3.0},
|
||||
)
|
||||
|
||||
b := &FirstBalancer{}
|
||||
result := b.Reorder(routes)
|
||||
|
||||
names := routeNames(result)
|
||||
if names[0] != "a" || names[1] != "b" || names[2] != "c" {
|
||||
t.Fatalf("expected [a b c], got %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinBalancer_RotatesWithinPriorityGroup(t *testing.T) {
|
||||
routes := makeRoutes(
|
||||
routeSpec{"a", 1, 1.0, 1.0},
|
||||
routeSpec{"b", 1, 1.0, 1.0},
|
||||
routeSpec{"c", 1, 1.0, 1.0},
|
||||
)
|
||||
|
||||
b := &RoundRobinBalancer{}
|
||||
|
||||
// Collect the first element from multiple calls
|
||||
seen := make(map[string]bool)
|
||||
for i := 0; i < 6; i++ {
|
||||
result := b.Reorder(routes)
|
||||
seen[result[0].Provider.Name()] = true
|
||||
}
|
||||
|
||||
// All routes should have appeared as first at some point
|
||||
for _, name := range []string{"a", "b", "c"} {
|
||||
if !seen[name] {
|
||||
t.Errorf("expected %q to appear as first element in rotation", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinBalancer_PreservesPriorityOrder(t *testing.T) {
|
||||
routes := makeRoutes(
|
||||
routeSpec{"a", 1, 1.0, 1.0},
|
||||
routeSpec{"b", 1, 1.0, 1.0},
|
||||
routeSpec{"c", 2, 1.0, 1.0},
|
||||
)
|
||||
|
||||
b := &RoundRobinBalancer{}
|
||||
|
||||
// Priority 2 route should always be last
|
||||
for i := 0; i < 5; i++ {
|
||||
result := b.Reorder(routes)
|
||||
if result[2].Provider.Name() != "c" {
|
||||
t.Fatalf("expected priority-2 route 'c' at the end, got %q", result[2].Provider.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomBalancer_AllRoutesPresent(t *testing.T) {
|
||||
routes := makeRoutes(
|
||||
routeSpec{"a", 1, 1.0, 1.0},
|
||||
routeSpec{"b", 1, 1.0, 1.0},
|
||||
routeSpec{"c", 1, 1.0, 1.0},
|
||||
)
|
||||
|
||||
b := &RandomBalancer{}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
result := b.Reorder(routes)
|
||||
if len(result) != 3 {
|
||||
t.Fatalf("expected 3 routes, got %d", len(result))
|
||||
}
|
||||
|
||||
names := make(map[string]bool)
|
||||
for _, r := range result {
|
||||
names[r.Provider.Name()] = true
|
||||
}
|
||||
for _, want := range []string{"a", "b", "c"} {
|
||||
if !names[want] {
|
||||
t.Errorf("missing route %q in result", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomBalancer_PreservesPriorityOrder(t *testing.T) {
|
||||
routes := makeRoutes(
|
||||
routeSpec{"a", 1, 1.0, 1.0},
|
||||
routeSpec{"b", 1, 1.0, 1.0},
|
||||
routeSpec{"c", 2, 1.0, 1.0},
|
||||
)
|
||||
|
||||
b := &RandomBalancer{}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
result := b.Reorder(routes)
|
||||
if result[2].Provider.Name() != "c" {
|
||||
t.Fatalf("expected priority-2 route 'c' last, got %q", result[2].Provider.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastCostBalancer_SortsByCost(t *testing.T) {
|
||||
routes := makeRoutes(
|
||||
routeSpec{"expensive", 1, 10.0, 10.0},
|
||||
routeSpec{"cheap", 1, 1.0, 1.0},
|
||||
routeSpec{"medium", 1, 5.0, 5.0},
|
||||
)
|
||||
|
||||
b := &LeastCostBalancer{}
|
||||
result := b.Reorder(routes)
|
||||
|
||||
names := routeNames(result)
|
||||
expected := []string{"cheap", "medium", "expensive"}
|
||||
for i, want := range expected {
|
||||
if names[i] != want {
|
||||
t.Errorf("position %d: got %q, want %q", i, names[i], want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLeastCostBalancer_PreservesPriorityOrder(t *testing.T) {
|
||||
routes := makeRoutes(
|
||||
routeSpec{"expensive-p1", 1, 10.0, 10.0},
|
||||
routeSpec{"cheap-p1", 1, 1.0, 1.0},
|
||||
routeSpec{"cheap-p2", 2, 0.5, 0.5},
|
||||
)
|
||||
|
||||
b := &LeastCostBalancer{}
|
||||
result := b.Reorder(routes)
|
||||
|
||||
names := routeNames(result)
|
||||
// Within priority 1, cheap should come first; priority 2 always last
|
||||
if names[0] != "cheap-p1" {
|
||||
t.Errorf("expected cheap-p1 first, got %q", names[0])
|
||||
}
|
||||
if names[1] != "expensive-p1" {
|
||||
t.Errorf("expected expensive-p1 second, got %q", names[1])
|
||||
}
|
||||
if names[2] != "cheap-p2" {
|
||||
t.Errorf("expected cheap-p2 last, got %q", names[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroupByPriority(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
priorities []int
|
||||
wantGroups [][]int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
priorities: nil,
|
||||
wantGroups: nil,
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
priorities: []int{1},
|
||||
wantGroups: [][]int{{1}},
|
||||
},
|
||||
{
|
||||
name: "all same",
|
||||
priorities: []int{1, 1, 1},
|
||||
wantGroups: [][]int{{1, 1, 1}},
|
||||
},
|
||||
{
|
||||
name: "two groups",
|
||||
priorities: []int{1, 1, 2, 2},
|
||||
wantGroups: [][]int{{1, 1}, {2, 2}},
|
||||
},
|
||||
{
|
||||
name: "three groups",
|
||||
priorities: []int{1, 2, 2, 3},
|
||||
wantGroups: [][]int{{1}, {2, 2}, {3}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var routes []Route
|
||||
for _, p := range tt.priorities {
|
||||
routes = append(routes, Route{Priority: p})
|
||||
}
|
||||
|
||||
groups := groupByPriority(routes)
|
||||
|
||||
if tt.wantGroups == nil {
|
||||
if groups != nil {
|
||||
t.Fatalf("expected nil groups, got %v", groups)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(groups) != len(tt.wantGroups) {
|
||||
t.Fatalf("expected %d groups, got %d", len(tt.wantGroups), len(groups))
|
||||
}
|
||||
|
||||
for i, wg := range tt.wantGroups {
|
||||
if len(groups[i]) != len(wg) {
|
||||
t.Errorf("group %d: expected %d routes, got %d", i, len(wg), len(groups[i]))
|
||||
continue
|
||||
}
|
||||
for j, wp := range wg {
|
||||
if groups[i][j].Priority != wp {
|
||||
t.Errorf("group %d, route %d: expected priority %d, got %d", i, j, wp, groups[i][j].Priority)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBalancer_SingleRoute(t *testing.T) {
|
||||
routes := makeRoutes(routeSpec{"only", 1, 1.0, 1.0})
|
||||
|
||||
balancers := []struct {
|
||||
name string
|
||||
balancer LoadBalancer
|
||||
}{
|
||||
{"first", &FirstBalancer{}},
|
||||
{"round-robin", &RoundRobinBalancer{}},
|
||||
{"random", &RandomBalancer{}},
|
||||
{"least-cost", &LeastCostBalancer{}},
|
||||
}
|
||||
|
||||
for _, bb := range balancers {
|
||||
t.Run(bb.name, func(t *testing.T) {
|
||||
result := bb.balancer.Reorder(routes)
|
||||
if len(result) != 1 || result[0].Provider.Name() != "only" {
|
||||
t.Fatalf("expected single route 'only', got %v", routeNames(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLoadBalancer(t *testing.T) {
|
||||
tests := []struct {
|
||||
strategy string
|
||||
wantType string
|
||||
}{
|
||||
{"round-robin", "*provider.RoundRobinBalancer"},
|
||||
{"random", "*provider.RandomBalancer"},
|
||||
{"least-cost", "*provider.LeastCostBalancer"},
|
||||
{"first", "*provider.FirstBalancer"},
|
||||
{"unknown", "*provider.FirstBalancer"},
|
||||
{"", "*provider.FirstBalancer"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.strategy, func(t *testing.T) {
|
||||
b := NewLoadBalancer(tt.strategy)
|
||||
got := fmt.Sprintf("%T", b)
|
||||
if got != tt.wantType {
|
||||
t.Errorf("NewLoadBalancer(%q) = %s, want %s", tt.strategy, got, tt.wantType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -3,8 +3,39 @@ package provider
|
|||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"llm-gateway/internal/config"
|
||||
)
|
||||
|
||||
// CircuitState represents the state of a circuit breaker.
|
||||
type CircuitState int
|
||||
|
||||
const (
|
||||
CircuitClosed CircuitState = iota // normal operation
|
||||
CircuitOpen // blocking requests
|
||||
CircuitHalfOpen // testing with probe request
|
||||
)
|
||||
|
||||
func (s CircuitState) String() string {
|
||||
switch s {
|
||||
case CircuitClosed:
|
||||
return "closed"
|
||||
case CircuitOpen:
|
||||
return "open"
|
||||
case CircuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderCircuit tracks circuit breaker state for a single provider.
|
||||
type ProviderCircuit struct {
|
||||
State CircuitState
|
||||
OpenedAt time.Time
|
||||
LastProbe time.Time
|
||||
}
|
||||
|
||||
// HealthEvent represents a single request outcome for a provider.
|
||||
type HealthEvent struct {
|
||||
Timestamp time.Time
|
||||
|
|
@ -15,12 +46,13 @@ type HealthEvent struct {
|
|||
|
||||
// ProviderHealth is the computed health status for a provider.
|
||||
type ProviderHealth struct {
|
||||
Provider string `json:"provider"`
|
||||
Status string `json:"status"` // healthy, degraded, down
|
||||
ErrorRate float64 `json:"error_rate"`
|
||||
AvgLatency float64 `json:"avg_latency_ms"`
|
||||
Total int `json:"total"`
|
||||
Errors int `json:"errors"`
|
||||
Provider string `json:"provider"`
|
||||
Status string `json:"status"` // healthy, degraded, down
|
||||
ErrorRate float64 `json:"error_rate"`
|
||||
AvgLatency float64 `json:"avg_latency_ms"`
|
||||
Total int `json:"total"`
|
||||
Errors int `json:"errors"`
|
||||
CircuitState string `json:"circuit_state"`
|
||||
}
|
||||
|
||||
// HealthTracker tracks per-provider health using a sliding window.
|
||||
|
|
@ -28,20 +60,52 @@ type HealthTracker struct {
|
|||
mu sync.RWMutex
|
||||
windows map[string][]HealthEvent
|
||||
windowDu time.Duration
|
||||
circuits map[string]*ProviderCircuit
|
||||
cbConfig config.CircuitBreakerConfig
|
||||
}
|
||||
|
||||
// NewHealthTracker creates a health tracker with the given window duration.
|
||||
func NewHealthTracker(window time.Duration) *HealthTracker {
|
||||
func NewHealthTracker(window time.Duration, cbCfg config.CircuitBreakerConfig) *HealthTracker {
|
||||
if window == 0 {
|
||||
window = 5 * time.Minute
|
||||
}
|
||||
return &HealthTracker{
|
||||
windows: make(map[string][]HealthEvent),
|
||||
circuits: make(map[string]*ProviderCircuit),
|
||||
windowDu: window,
|
||||
cbConfig: cbCfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Record adds a health event for a provider.
|
||||
// IsAvailable returns true if the provider's circuit breaker allows requests.
|
||||
func (h *HealthTracker) IsAvailable(provider string) bool {
|
||||
if !h.cbConfig.Enabled {
|
||||
return true
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
circuit, ok := h.circuits[provider]
|
||||
if !ok {
|
||||
return true // no circuit = closed = available
|
||||
}
|
||||
|
||||
switch circuit.State {
|
||||
case CircuitOpen:
|
||||
// Check if cooldown has elapsed -> transition to half-open
|
||||
if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
|
||||
return true // will transition to half-open on next record
|
||||
}
|
||||
return false
|
||||
case CircuitHalfOpen:
|
||||
return true // allow probe
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Record adds a health event for a provider and evaluates circuit transitions.
|
||||
func (h *HealthTracker) Record(provider string, latencyMS int64, err error) {
|
||||
event := HealthEvent{
|
||||
Timestamp: time.Now(),
|
||||
|
|
@ -57,6 +121,69 @@ func (h *HealthTracker) Record(provider string, latencyMS int64, err error) {
|
|||
|
||||
h.windows[provider] = append(h.windows[provider], event)
|
||||
h.prune(provider)
|
||||
|
||||
if h.cbConfig.Enabled {
|
||||
h.evaluateCircuit(provider, err)
|
||||
}
|
||||
}
|
||||
|
||||
// evaluateCircuit transitions circuit breaker state. Must be called with lock held.
|
||||
func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
|
||||
circuit, ok := h.circuits[providerName]
|
||||
if !ok {
|
||||
circuit = &ProviderCircuit{State: CircuitClosed}
|
||||
h.circuits[providerName] = circuit
|
||||
}
|
||||
|
||||
switch circuit.State {
|
||||
case CircuitClosed:
|
||||
// Check if error threshold exceeded
|
||||
errorRate, total := h.errorRateUnlocked(providerName)
|
||||
if total >= h.cbConfig.MinRequests && errorRate >= h.cbConfig.ErrorThreshold {
|
||||
circuit.State = CircuitOpen
|
||||
circuit.OpenedAt = time.Now()
|
||||
}
|
||||
case CircuitOpen:
|
||||
// Check if cooldown elapsed -> half-open
|
||||
if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
|
||||
circuit.State = CircuitHalfOpen
|
||||
circuit.LastProbe = time.Now()
|
||||
// Evaluate the probe result immediately
|
||||
if lastErr == nil {
|
||||
circuit.State = CircuitClosed
|
||||
} else {
|
||||
circuit.State = CircuitOpen
|
||||
circuit.OpenedAt = time.Now()
|
||||
}
|
||||
}
|
||||
case CircuitHalfOpen:
|
||||
if lastErr == nil {
|
||||
circuit.State = CircuitClosed
|
||||
} else {
|
||||
circuit.State = CircuitOpen
|
||||
circuit.OpenedAt = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// errorRateUnlocked computes error rate within window. Must be called with lock held.
|
||||
func (h *HealthTracker) errorRateUnlocked(provider string) (float64, int) {
|
||||
cutoff := time.Now().Add(-h.windowDu)
|
||||
events := h.windows[provider]
|
||||
var total, errors int
|
||||
for _, e := range events {
|
||||
if e.Timestamp.Before(cutoff) {
|
||||
continue
|
||||
}
|
||||
total++
|
||||
if e.IsError {
|
||||
errors++
|
||||
}
|
||||
}
|
||||
if total == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
return float64(errors) / float64(total), total
|
||||
}
|
||||
|
||||
// Status returns computed health for all tracked providers.
|
||||
|
|
@ -94,13 +221,19 @@ func (h *HealthTracker) Status() []ProviderHealth {
|
|||
status = "degraded"
|
||||
}
|
||||
|
||||
circuitState := "closed"
|
||||
if circuit, ok := h.circuits[provider]; ok {
|
||||
circuitState = circuit.State.String()
|
||||
}
|
||||
|
||||
results = append(results, ProviderHealth{
|
||||
Provider: provider,
|
||||
Status: status,
|
||||
ErrorRate: errorRate,
|
||||
AvgLatency: float64(totalLatency) / float64(total),
|
||||
Total: total,
|
||||
Errors: errors,
|
||||
Provider: provider,
|
||||
Status: status,
|
||||
ErrorRate: errorRate,
|
||||
AvgLatency: float64(totalLatency) / float64(total),
|
||||
Total: total,
|
||||
Errors: errors,
|
||||
CircuitState: circuitState,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
345
llm-gateway/internal/provider/health_test.go
Normal file
345
llm-gateway/internal/provider/health_test.go
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"llm-gateway/internal/config"
|
||||
)
|
||||
|
||||
func newTestTracker(window time.Duration, cb config.CircuitBreakerConfig) *HealthTracker {
|
||||
return NewHealthTracker(window, cb)
|
||||
}
|
||||
|
||||
func defaultCBConfig() config.CircuitBreakerConfig {
|
||||
return config.CircuitBreakerConfig{
|
||||
Enabled: true,
|
||||
ErrorThreshold: 0.5,
|
||||
MinRequests: 3,
|
||||
CooldownDuration: 100 * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_Record(t *testing.T) {
|
||||
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
|
||||
|
||||
ht.Record("provA", 100, nil)
|
||||
ht.Record("provA", 200, errors.New("fail"))
|
||||
ht.Record("provB", 50, nil)
|
||||
|
||||
ht.mu.RLock()
|
||||
defer ht.mu.RUnlock()
|
||||
|
||||
if len(ht.windows["provA"]) != 2 {
|
||||
t.Fatalf("expected 2 events for provA, got %d", len(ht.windows["provA"]))
|
||||
}
|
||||
if len(ht.windows["provB"]) != 1 {
|
||||
t.Fatalf("expected 1 event for provB, got %d", len(ht.windows["provB"]))
|
||||
}
|
||||
|
||||
// Verify event fields
|
||||
ev := ht.windows["provA"][1]
|
||||
if !ev.IsError || ev.ErrorMsg != "fail" || ev.LatencyMS != 200 {
|
||||
t.Fatalf("unexpected event fields: %+v", ev)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_Status(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
successCount int
|
||||
errorCount int
|
||||
wantStatus string
|
||||
wantErrorRate float64
|
||||
wantTotal int
|
||||
wantErrors int
|
||||
}{
|
||||
{
|
||||
name: "healthy - no errors",
|
||||
successCount: 10,
|
||||
errorCount: 0,
|
||||
wantStatus: "healthy",
|
||||
wantErrorRate: 0.0,
|
||||
wantTotal: 10,
|
||||
wantErrors: 0,
|
||||
},
|
||||
{
|
||||
name: "healthy - below 10% errors",
|
||||
successCount: 19,
|
||||
errorCount: 1,
|
||||
wantStatus: "healthy",
|
||||
wantErrorRate: 0.05,
|
||||
wantTotal: 20,
|
||||
wantErrors: 1,
|
||||
},
|
||||
{
|
||||
name: "degraded - 20% errors",
|
||||
successCount: 8,
|
||||
errorCount: 2,
|
||||
wantStatus: "degraded",
|
||||
wantErrorRate: 0.2,
|
||||
wantTotal: 10,
|
||||
wantErrors: 2,
|
||||
},
|
||||
{
|
||||
name: "degraded - exactly 10% errors",
|
||||
successCount: 9,
|
||||
errorCount: 1,
|
||||
wantStatus: "degraded",
|
||||
wantErrorRate: 0.1,
|
||||
wantTotal: 10,
|
||||
wantErrors: 1,
|
||||
},
|
||||
{
|
||||
name: "down - 50% errors",
|
||||
successCount: 5,
|
||||
errorCount: 5,
|
||||
wantStatus: "down",
|
||||
wantErrorRate: 0.5,
|
||||
wantTotal: 10,
|
||||
wantErrors: 5,
|
||||
},
|
||||
{
|
||||
name: "down - all errors",
|
||||
successCount: 0,
|
||||
errorCount: 5,
|
||||
wantStatus: "down",
|
||||
wantErrorRate: 1.0,
|
||||
wantTotal: 5,
|
||||
wantErrors: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
|
||||
|
||||
for i := 0; i < tt.successCount; i++ {
|
||||
ht.Record("prov", 100, nil)
|
||||
}
|
||||
for i := 0; i < tt.errorCount; i++ {
|
||||
ht.Record("prov", 100, errors.New("err"))
|
||||
}
|
||||
|
||||
statuses := ht.Status()
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected 1 status, got %d", len(statuses))
|
||||
}
|
||||
|
||||
s := statuses[0]
|
||||
if s.Status != tt.wantStatus {
|
||||
t.Errorf("status = %q, want %q", s.Status, tt.wantStatus)
|
||||
}
|
||||
if s.Total != tt.wantTotal {
|
||||
t.Errorf("total = %d, want %d", s.Total, tt.wantTotal)
|
||||
}
|
||||
if s.Errors != tt.wantErrors {
|
||||
t.Errorf("errors = %d, want %d", s.Errors, tt.wantErrors)
|
||||
}
|
||||
// Allow small float tolerance
|
||||
if diff := s.ErrorRate - tt.wantErrorRate; diff > 0.001 || diff < -0.001 {
|
||||
t.Errorf("error_rate = %f, want %f", s.ErrorRate, tt.wantErrorRate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_CircuitBreaker_ClosedToOpen(t *testing.T) {
|
||||
cb := defaultCBConfig()
|
||||
cb.MinRequests = 3
|
||||
cb.ErrorThreshold = 0.5
|
||||
|
||||
ht := newTestTracker(5*time.Minute, cb)
|
||||
|
||||
// Record errors to exceed threshold (3 errors out of 3 = 100% > 50%)
|
||||
ht.Record("prov", 100, errors.New("err"))
|
||||
ht.Record("prov", 100, errors.New("err"))
|
||||
ht.Record("prov", 100, errors.New("err"))
|
||||
|
||||
ht.mu.RLock()
|
||||
state := ht.circuits["prov"].State
|
||||
ht.mu.RUnlock()
|
||||
|
||||
if state != CircuitOpen {
|
||||
t.Fatalf("expected CircuitOpen, got %s", state)
|
||||
}
|
||||
|
||||
if ht.IsAvailable("prov") {
|
||||
t.Fatal("expected IsAvailable=false when circuit is open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_CircuitBreaker_OpenToHalfOpenOnCooldown(t *testing.T) {
|
||||
cb := defaultCBConfig()
|
||||
cb.CooldownDuration = 50 * time.Millisecond
|
||||
|
||||
ht := newTestTracker(5*time.Minute, cb)
|
||||
|
||||
// Trip the circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
ht.Record("prov", 100, errors.New("err"))
|
||||
}
|
||||
|
||||
if ht.IsAvailable("prov") {
|
||||
t.Fatal("expected circuit open, IsAvailable should be false")
|
||||
}
|
||||
|
||||
// Wait for cooldown
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// After cooldown, IsAvailable should return true (will transition to half-open)
|
||||
if !ht.IsAvailable("prov") {
|
||||
t.Fatal("expected IsAvailable=true after cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_CircuitBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
||||
cb := defaultCBConfig()
|
||||
cb.CooldownDuration = 10 * time.Millisecond
|
||||
|
||||
ht := newTestTracker(5*time.Minute, cb)
|
||||
|
||||
// Trip the circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
ht.Record("prov", 100, errors.New("err"))
|
||||
}
|
||||
|
||||
// Wait for cooldown so next Record transitions through Open->HalfOpen
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// A successful record should transition: Open -> HalfOpen -> Closed
|
||||
ht.Record("prov", 100, nil)
|
||||
|
||||
ht.mu.RLock()
|
||||
state := ht.circuits["prov"].State
|
||||
ht.mu.RUnlock()
|
||||
|
||||
if state != CircuitClosed {
|
||||
t.Fatalf("expected CircuitClosed after success in half-open, got %s", state)
|
||||
}
|
||||
|
||||
if !ht.IsAvailable("prov") {
|
||||
t.Fatal("expected IsAvailable=true after circuit closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_CircuitBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
||||
cb := defaultCBConfig()
|
||||
cb.CooldownDuration = 10 * time.Millisecond
|
||||
|
||||
ht := newTestTracker(5*time.Minute, cb)
|
||||
|
||||
// Trip the circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
ht.Record("prov", 100, errors.New("err"))
|
||||
}
|
||||
|
||||
// Wait for cooldown
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// A failed record should transition: Open -> HalfOpen -> Open
|
||||
ht.Record("prov", 100, errors.New("still failing"))
|
||||
|
||||
ht.mu.RLock()
|
||||
state := ht.circuits["prov"].State
|
||||
ht.mu.RUnlock()
|
||||
|
||||
if state != CircuitOpen {
|
||||
t.Fatalf("expected CircuitOpen after failure in half-open, got %s", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_IsAvailable_NoCircuitBreaker(t *testing.T) {
|
||||
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{Enabled: false})
|
||||
|
||||
// Even with errors, IsAvailable should return true when CB is disabled
|
||||
for i := 0; i < 10; i++ {
|
||||
ht.Record("prov", 100, errors.New("err"))
|
||||
}
|
||||
|
||||
if !ht.IsAvailable("prov") {
|
||||
t.Fatal("expected IsAvailable=true when circuit breaker disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_IsAvailable_UnknownProvider(t *testing.T) {
|
||||
ht := newTestTracker(5*time.Minute, defaultCBConfig())
|
||||
|
||||
if !ht.IsAvailable("unknown") {
|
||||
t.Fatal("expected IsAvailable=true for unknown provider (no circuit)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_WindowPruning(t *testing.T) {
|
||||
// Use a tiny window so events expire quickly
|
||||
ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
|
||||
|
||||
ht.Record("prov", 100, nil)
|
||||
ht.Record("prov", 200, nil)
|
||||
|
||||
// Wait for events to expire
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// Record a new event to trigger pruning
|
||||
ht.Record("prov", 300, nil)
|
||||
|
||||
ht.mu.RLock()
|
||||
count := len(ht.windows["prov"])
|
||||
ht.mu.RUnlock()
|
||||
|
||||
if count != 1 {
|
||||
t.Fatalf("expected 1 event after pruning, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_Status_EmptyAfterPruning(t *testing.T) {
|
||||
ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
|
||||
|
||||
ht.Record("prov", 100, nil)
|
||||
|
||||
// Wait for events to expire
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
statuses := ht.Status()
|
||||
if len(statuses) != 0 {
|
||||
t.Fatalf("expected 0 statuses after window expiry, got %d", len(statuses))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_Status_AvgLatency(t *testing.T) {
|
||||
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
|
||||
|
||||
ht.Record("prov", 100, nil)
|
||||
ht.Record("prov", 200, nil)
|
||||
ht.Record("prov", 300, nil)
|
||||
|
||||
statuses := ht.Status()
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected 1 status, got %d", len(statuses))
|
||||
}
|
||||
|
||||
want := 200.0
|
||||
if diff := statuses[0].AvgLatency - want; diff > 0.001 || diff < -0.001 {
|
||||
t.Errorf("avg_latency = %f, want %f", statuses[0].AvgLatency, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthTracker_Status_CircuitStateReported(t *testing.T) {
|
||||
cb := defaultCBConfig()
|
||||
ht := newTestTracker(5*time.Minute, cb)
|
||||
|
||||
// Trip the circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
ht.Record("prov", 100, errors.New("err"))
|
||||
}
|
||||
|
||||
statuses := ht.Status()
|
||||
if len(statuses) != 1 {
|
||||
t.Fatalf("expected 1 status, got %d", len(statuses))
|
||||
}
|
||||
|
||||
if statuses[0].CircuitState != "open" {
|
||||
t.Errorf("circuit_state = %q, want %q", statuses[0].CircuitState, "open")
|
||||
}
|
||||
}
|
||||
|
|
@ -111,6 +111,12 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string,
|
|||
func (p *OpenAIProvider) setHeaders(req *http.Request) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
// Forward request ID if present in context
|
||||
if reqID := req.Context().Value("requestID"); reqID != nil {
|
||||
if id, ok := reqID.(string); ok && id != "" {
|
||||
req.Header.Set("X-Request-ID", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProviderError represents a non-200 response from a provider.
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package provider
|
|||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"llm-gateway/internal/config"
|
||||
)
|
||||
|
|
@ -18,26 +19,40 @@ type Route struct {
|
|||
|
||||
// Registry maps model names to provider routes.
|
||||
type Registry struct {
|
||||
routes map[string][]Route
|
||||
order []string // preserves config order
|
||||
mu sync.RWMutex
|
||||
routes map[string][]Route
|
||||
balancers map[string]LoadBalancer
|
||||
aliases map[string]string // alias -> canonical name
|
||||
order []string // preserves config order (canonical names only)
|
||||
}
|
||||
|
||||
func NewRegistry(cfg *config.Config) (*Registry, error) {
|
||||
r := &Registry{}
|
||||
if err := r.buildFromConfig(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Registry) buildFromConfig(cfg *config.Config) error {
|
||||
// Build providers
|
||||
providers := make(map[string]Provider)
|
||||
for _, pc := range cfg.Providers {
|
||||
providers[pc.Name] = NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKey, pc.Timeout)
|
||||
}
|
||||
|
||||
// Build routes (preserving config order)
|
||||
// Build routes
|
||||
routes := make(map[string][]Route)
|
||||
balancers := make(map[string]LoadBalancer)
|
||||
aliases := make(map[string]string)
|
||||
order := make([]string, 0, len(cfg.Models))
|
||||
|
||||
for _, mc := range cfg.Models {
|
||||
var modelRoutes []Route
|
||||
for _, rc := range mc.Routes {
|
||||
p, ok := providers[rc.Provider]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider)
|
||||
return fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider)
|
||||
}
|
||||
pc := cfg.ProviderByName(rc.Provider)
|
||||
priority := pc.Priority
|
||||
|
|
@ -55,20 +70,69 @@ func NewRegistry(cfg *config.Config) (*Registry, error) {
|
|||
})
|
||||
routes[mc.Name] = modelRoutes
|
||||
order = append(order, mc.Name)
|
||||
|
||||
// Load balancer
|
||||
balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing)
|
||||
|
||||
// Register aliases
|
||||
for _, alias := range mc.Aliases {
|
||||
aliases[alias] = mc.Name
|
||||
}
|
||||
}
|
||||
|
||||
return &Registry{routes: routes, order: order}, nil
|
||||
r.mu.Lock()
|
||||
r.routes = routes
|
||||
r.balancers = balancers
|
||||
r.aliases = aliases
|
||||
r.order = order
|
||||
r.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lookup returns the routes for a model name.
|
||||
// Reload rebuilds routes from new config. Used for hot-reload.
|
||||
func (r *Registry) Reload(cfg *config.Config) error {
|
||||
return r.buildFromConfig(cfg)
|
||||
}
|
||||
|
||||
// Lookup returns the routes for a model name (resolving aliases).
|
||||
func (r *Registry) Lookup(model string) ([]Route, bool) {
|
||||
routes, ok := r.routes[model]
|
||||
return routes, ok
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Resolve alias
|
||||
canonical := model
|
||||
if alias, ok := r.aliases[model]; ok {
|
||||
canonical = alias
|
||||
}
|
||||
|
||||
routes, ok := r.routes[canonical]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Apply load balancer
|
||||
if balancer, ok := r.balancers[canonical]; ok {
|
||||
routes = balancer.Reorder(routes)
|
||||
}
|
||||
|
||||
return routes, true
|
||||
}
|
||||
|
||||
// ModelNames returns all registered model names in config order.
|
||||
// ModelNames returns all registered model names in config order (including aliases).
|
||||
func (r *Registry) ModelNames() []string {
|
||||
return r.order
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var names []string
|
||||
for _, name := range r.order {
|
||||
names = append(names, name)
|
||||
}
|
||||
// Add aliases
|
||||
for alias := range r.aliases {
|
||||
names = append(names, alias)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// RouteInfo exposes route details for dashboard display.
|
||||
|
|
@ -82,16 +146,29 @@ type RouteInfo struct {
|
|||
|
||||
// ModelRouteInfo exposes a model and its routes for dashboard display.
|
||||
type ModelRouteInfo struct {
|
||||
Name string `json:"name"`
|
||||
Routes []RouteInfo `json:"routes"`
|
||||
Name string `json:"name"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
Routes []RouteInfo `json:"routes"`
|
||||
}
|
||||
|
||||
// AllRoutes returns all models and their routes in config order.
|
||||
func (r *Registry) AllRoutes() []ModelRouteInfo {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Build reverse alias map
|
||||
modelAliases := make(map[string][]string)
|
||||
for alias, canonical := range r.aliases {
|
||||
modelAliases[canonical] = append(modelAliases[canonical], alias)
|
||||
}
|
||||
|
||||
results := make([]ModelRouteInfo, 0, len(r.order))
|
||||
for _, name := range r.order {
|
||||
routes := r.routes[name]
|
||||
info := ModelRouteInfo{Name: name}
|
||||
info := ModelRouteInfo{
|
||||
Name: name,
|
||||
Aliases: modelAliases[name],
|
||||
}
|
||||
for _, rt := range routes {
|
||||
info.Routes = append(info.Routes, RouteInfo{
|
||||
ProviderName: rt.Provider.Name(),
|
||||
|
|
|
|||
282
llm-gateway/internal/provider/registry_test.go
Normal file
282
llm-gateway/internal/provider/registry_test.go
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"llm-gateway/internal/config"
|
||||
)
|
||||
|
||||
// mockProvider implements the Provider interface for testing.
|
||||
type mockProvider struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string { return m.name }
|
||||
|
||||
func (m *mockProvider) ChatCompletion(_ context.Context, _ string, _ *ChatRequest) (*ChatResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *ChatRequest) (io.ReadCloser, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// newTestRegistry builds a Registry directly without going through config parsing.
|
||||
func newTestRegistry(models []testModel) *Registry {
|
||||
r := &Registry{
|
||||
routes: make(map[string][]Route),
|
||||
balancers: make(map[string]LoadBalancer),
|
||||
aliases: make(map[string]string),
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
r.routes[m.name] = m.routes
|
||||
r.balancers[m.name] = &FirstBalancer{}
|
||||
r.order = append(r.order, m.name)
|
||||
for _, alias := range m.aliases {
|
||||
r.aliases[alias] = m.name
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
type testModel struct {
|
||||
name string
|
||||
aliases []string
|
||||
routes []Route
|
||||
}
|
||||
|
||||
func TestRegistry_Lookup_Canonical(t *testing.T) {
|
||||
reg := newTestRegistry([]testModel{
|
||||
{
|
||||
name: "gpt-4",
|
||||
routes: []Route{
|
||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
routes, ok := reg.Lookup("gpt-4")
|
||||
if !ok {
|
||||
t.Fatal("expected Lookup to find gpt-4")
|
||||
}
|
||||
if len(routes) != 1 {
|
||||
t.Fatalf("expected 1 route, got %d", len(routes))
|
||||
}
|
||||
if routes[0].Provider.Name() != "openai" {
|
||||
t.Errorf("expected provider 'openai', got %q", routes[0].Provider.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Lookup_Alias(t *testing.T) {
|
||||
reg := newTestRegistry([]testModel{
|
||||
{
|
||||
name: "gpt-4",
|
||||
aliases: []string{"gpt4", "gpt-4-latest"},
|
||||
routes: []Route{
|
||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
found bool
|
||||
}{
|
||||
{"canonical", "gpt-4", true},
|
||||
{"alias1", "gpt4", true},
|
||||
{"alias2", "gpt-4-latest", true},
|
||||
{"unknown", "gpt-5", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
routes, ok := reg.Lookup(tt.model)
|
||||
if ok != tt.found {
|
||||
t.Fatalf("Lookup(%q) found=%v, want %v", tt.model, ok, tt.found)
|
||||
}
|
||||
if tt.found && len(routes) != 1 {
|
||||
t.Fatalf("expected 1 route, got %d", len(routes))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_ModelNames_IncludesAliases(t *testing.T) {
|
||||
reg := newTestRegistry([]testModel{
|
||||
{
|
||||
name: "gpt-4",
|
||||
aliases: []string{"gpt4"},
|
||||
routes: []Route{
|
||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "claude-3",
|
||||
routes: []Route{
|
||||
{Provider: &mockProvider{name: "anthropic"}, ProviderModel: "claude-3", Priority: 1},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
names := reg.ModelNames()
|
||||
|
||||
want := map[string]bool{"gpt-4": true, "gpt4": true, "claude-3": true}
|
||||
got := make(map[string]bool)
|
||||
for _, n := range names {
|
||||
got[n] = true
|
||||
}
|
||||
|
||||
for name := range want {
|
||||
if !got[name] {
|
||||
t.Errorf("expected %q in ModelNames, not found", name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(names) != len(want) {
|
||||
t.Errorf("expected %d names, got %d: %v", len(want), len(names), names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AllRoutes_ShowsAliases(t *testing.T) {
|
||||
reg := newTestRegistry([]testModel{
|
||||
{
|
||||
name: "gpt-4",
|
||||
aliases: []string{"gpt4", "gpt-4-latest"},
|
||||
routes: []Route{
|
||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
||||
{Provider: &mockProvider{name: "azure"}, ProviderModel: "gpt-4", Priority: 2},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
allRoutes := reg.AllRoutes()
|
||||
if len(allRoutes) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(allRoutes))
|
||||
}
|
||||
|
||||
m := allRoutes[0]
|
||||
if m.Name != "gpt-4" {
|
||||
t.Errorf("expected name 'gpt-4', got %q", m.Name)
|
||||
}
|
||||
|
||||
aliasSet := make(map[string]bool)
|
||||
for _, a := range m.Aliases {
|
||||
aliasSet[a] = true
|
||||
}
|
||||
if !aliasSet["gpt4"] || !aliasSet["gpt-4-latest"] {
|
||||
t.Errorf("expected aliases [gpt4, gpt-4-latest], got %v", m.Aliases)
|
||||
}
|
||||
|
||||
if len(m.Routes) != 2 {
|
||||
t.Fatalf("expected 2 routes, got %d", len(m.Routes))
|
||||
}
|
||||
if m.Routes[0].ProviderName != "openai" {
|
||||
t.Errorf("expected first route provider 'openai', got %q", m.Routes[0].ProviderName)
|
||||
}
|
||||
if m.Routes[1].ProviderName != "azure" {
|
||||
t.Errorf("expected second route provider 'azure', got %q", m.Routes[1].ProviderName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_AllRoutes_ConfigOrder(t *testing.T) {
|
||||
reg := newTestRegistry([]testModel{
|
||||
{
|
||||
name: "model-b",
|
||||
routes: []Route{
|
||||
{Provider: &mockProvider{name: "prov"}, ProviderModel: "b", Priority: 1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model-a",
|
||||
routes: []Route{
|
||||
{Provider: &mockProvider{name: "prov"}, ProviderModel: "a", Priority: 1},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
allRoutes := reg.AllRoutes()
|
||||
if len(allRoutes) != 2 {
|
||||
t.Fatalf("expected 2 models, got %d", len(allRoutes))
|
||||
}
|
||||
if allRoutes[0].Name != "model-b" {
|
||||
t.Errorf("expected first model 'model-b', got %q", allRoutes[0].Name)
|
||||
}
|
||||
if allRoutes[1].Name != "model-a" {
|
||||
t.Errorf("expected second model 'model-a', got %q", allRoutes[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_PrioritySorting(t *testing.T) {
|
||||
reg := newTestRegistry([]testModel{
|
||||
{
|
||||
name: "multi-provider",
|
||||
routes: []Route{
|
||||
{Provider: &mockProvider{name: "low-priority"}, ProviderModel: "m", Priority: 3},
|
||||
{Provider: &mockProvider{name: "high-priority"}, ProviderModel: "m", Priority: 1},
|
||||
{Provider: &mockProvider{name: "mid-priority"}, ProviderModel: "m", Priority: 2},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// Note: routes are stored as given (sorting happens during buildFromConfig).
|
||||
// For this test we verify AllRoutes returns them in stored order.
|
||||
allRoutes := reg.AllRoutes()
|
||||
if len(allRoutes) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(allRoutes))
|
||||
}
|
||||
|
||||
routes := allRoutes[0].Routes
|
||||
if len(routes) != 3 {
|
||||
t.Fatalf("expected 3 routes, got %d", len(routes))
|
||||
}
|
||||
|
||||
// Verify the priorities are present
|
||||
priorities := make(map[int]bool)
|
||||
for _, r := range routes {
|
||||
priorities[r.Priority] = true
|
||||
}
|
||||
for _, p := range []int{1, 2, 3} {
|
||||
if !priorities[p] {
|
||||
t.Errorf("expected priority %d in routes", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_NewRegistry_UnknownProvider(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Models: []config.ModelConfig{
|
||||
{
|
||||
Name: "test-model",
|
||||
Routes: []config.RouteConfig{
|
||||
{Provider: "nonexistent", Model: "m"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := NewRegistry(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown provider, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Lookup_NotFound(t *testing.T) {
|
||||
reg := newTestRegistry([]testModel{
|
||||
{
|
||||
name: "gpt-4",
|
||||
routes: []Route{
|
||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
_, ok := reg.Lookup("nonexistent")
|
||||
if ok {
|
||||
t.Fatal("expected Lookup to return false for nonexistent model")
|
||||
}
|
||||
}
|
||||
51
llm-gateway/internal/proxy/concurrency.go
Normal file
51
llm-gateway/internal/proxy/concurrency.go
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ConcurrencyLimiter enforces per-token concurrent request limits.
|
||||
type ConcurrencyLimiter struct {
|
||||
mu sync.Mutex
|
||||
counters map[string]*atomic.Int64
|
||||
}
|
||||
|
||||
func NewConcurrencyLimiter() *ConcurrencyLimiter {
|
||||
return &ConcurrencyLimiter{
|
||||
counters: make(map[string]*atomic.Int64),
|
||||
}
|
||||
}
|
||||
|
||||
func (cl *ConcurrencyLimiter) getCounter(tokenName string) *atomic.Int64 {
|
||||
cl.mu.Lock()
|
||||
defer cl.mu.Unlock()
|
||||
c, ok := cl.counters[tokenName]
|
||||
if !ok {
|
||||
c = &atomic.Int64{}
|
||||
cl.counters[tokenName] = c
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (cl *ConcurrencyLimiter) Check(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
apiToken := getAPIToken(r.Context())
|
||||
if apiToken == nil || apiToken.MaxConcurrent <= 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
counter := cl.getCounter(apiToken.Name)
|
||||
current := counter.Add(1)
|
||||
defer counter.Add(-1)
|
||||
|
||||
if current > int64(apiToken.MaxConcurrent) {
|
||||
writeError(w, http.StatusTooManyRequests, "concurrent request limit exceeded")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
317
llm-gateway/internal/proxy/concurrency_test.go
Normal file
317
llm-gateway/internal/proxy/concurrency_test.go
Normal file
|
|
@ -0,0 +1,317 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"llm-gateway/internal/auth"
|
||||
)
|
||||
|
||||
func TestConcurrencyLimiter_AllowsWithinLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
maxConcurrent int
|
||||
numRequests int
|
||||
wantAllowed int
|
||||
}{
|
||||
{
|
||||
name: "single request within limit",
|
||||
maxConcurrent: 5,
|
||||
numRequests: 1,
|
||||
wantAllowed: 1,
|
||||
},
|
||||
{
|
||||
name: "all requests within limit",
|
||||
maxConcurrent: 5,
|
||||
numRequests: 5,
|
||||
wantAllowed: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cl := NewConcurrencyLimiter()
|
||||
|
||||
token := &auth.APIToken{
|
||||
Name: "conc-token",
|
||||
MaxConcurrent: tt.maxConcurrent,
|
||||
}
|
||||
|
||||
var allowed atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
// Use a channel to hold all goroutines inside the handler simultaneously.
|
||||
gate := make(chan struct{})
|
||||
|
||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
allowed.Add(1)
|
||||
<-gate // Block until released.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
for i := 0; i < tt.numRequests; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx := withAPIToken(req.Context(), token)
|
||||
req = req.WithContext(ctx)
|
||||
handler.ServeHTTP(rec, req)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for goroutines to enter the handler.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
close(gate)
|
||||
wg.Wait()
|
||||
|
||||
if int(allowed.Load()) != tt.wantAllowed {
|
||||
t.Errorf("allowed = %d, want %d", allowed.Load(), tt.wantAllowed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrencyLimiter_DeniesOverLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
maxConcurrent int
|
||||
numRequests int
|
||||
wantDenied int
|
||||
}{
|
||||
{
|
||||
name: "one over limit",
|
||||
maxConcurrent: 2,
|
||||
numRequests: 3,
|
||||
wantDenied: 1,
|
||||
},
|
||||
{
|
||||
name: "many over limit",
|
||||
maxConcurrent: 1,
|
||||
numRequests: 5,
|
||||
wantDenied: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cl := NewConcurrencyLimiter()
|
||||
|
||||
token := &auth.APIToken{
|
||||
Name: "conc-token",
|
||||
MaxConcurrent: tt.maxConcurrent,
|
||||
}
|
||||
|
||||
var denied atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
gate := make(chan struct{})
|
||||
|
||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-gate
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
results := make([]int, tt.numRequests)
|
||||
for i := 0; i < tt.numRequests; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx := withAPIToken(req.Context(), token)
|
||||
req = req.WithContext(ctx)
|
||||
handler.ServeHTTP(rec, req)
|
||||
results[idx] = rec.Code
|
||||
if rec.Code == http.StatusTooManyRequests {
|
||||
denied.Add(1)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for goroutines to reach the handler or be rejected.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
close(gate)
|
||||
wg.Wait()
|
||||
|
||||
if int(denied.Load()) != tt.wantDenied {
|
||||
t.Errorf("denied = %d, want %d", denied.Load(), tt.wantDenied)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrencyLimiter_CounterDecrementsAfterCompletion(t *testing.T) {
|
||||
cl := NewConcurrencyLimiter()
|
||||
|
||||
token := &auth.APIToken{
|
||||
Name: "decrement-token",
|
||||
MaxConcurrent: 1,
|
||||
}
|
||||
|
||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request should succeed and complete, decrementing the counter.
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx := withAPIToken(req.Context(), token)
|
||||
req = req.WithContext(ctx)
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("first request: status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Counter should have decremented. A second request should also succeed.
|
||||
rec2 := httptest.NewRecorder()
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx2 := withAPIToken(req2.Context(), token)
|
||||
req2 = req2.WithContext(ctx2)
|
||||
handler.ServeHTTP(rec2, req2)
|
||||
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Errorf("second request after first completed: status = %d, want %d", rec2.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Verify the internal counter is back to 0.
|
||||
counter := cl.getCounter(token.Name)
|
||||
val := counter.Load()
|
||||
if val != 0 {
|
||||
t.Errorf("counter = %d, want 0 after all requests completed", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrencyLimiter_ZeroMaxConcurrentMeansUnlimited(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
maxConcurrent int
|
||||
numRequests int
|
||||
}{
|
||||
{
|
||||
name: "zero allows unlimited concurrent requests",
|
||||
maxConcurrent: 0,
|
||||
numRequests: 50,
|
||||
},
|
||||
{
|
||||
name: "negative allows unlimited concurrent requests",
|
||||
maxConcurrent: -1,
|
||||
numRequests: 50,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cl := NewConcurrencyLimiter()
|
||||
|
||||
token := &auth.APIToken{
|
||||
Name: "unlimited-token",
|
||||
MaxConcurrent: tt.maxConcurrent,
|
||||
}
|
||||
|
||||
var allowed atomic.Int64
|
||||
var wg sync.WaitGroup
|
||||
gate := make(chan struct{})
|
||||
|
||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
allowed.Add(1)
|
||||
<-gate
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
for i := 0; i < tt.numRequests; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx := withAPIToken(req.Context(), token)
|
||||
req = req.WithContext(ctx)
|
||||
handler.ServeHTTP(rec, req)
|
||||
}()
|
||||
}
|
||||
|
||||
// Give goroutines time to enter the handler.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(gate)
|
||||
wg.Wait()
|
||||
|
||||
if int(allowed.Load()) != tt.numRequests {
|
||||
t.Errorf("allowed = %d, want %d (zero/negative maxConcurrent should be unlimited)", allowed.Load(), tt.numRequests)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrencyLimiter_NoToken(t *testing.T) {
|
||||
cl := NewConcurrencyLimiter()
|
||||
|
||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// No API token in context.
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrencyLimiter_PerTokenIsolation(t *testing.T) {
|
||||
cl := NewConcurrencyLimiter()
|
||||
|
||||
tokenA := &auth.APIToken{
|
||||
Name: "token-a",
|
||||
MaxConcurrent: 1,
|
||||
}
|
||||
tokenB := &auth.APIToken{
|
||||
Name: "token-b",
|
||||
MaxConcurrent: 1,
|
||||
}
|
||||
|
||||
gateA := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
|
||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tok := getAPIToken(r.Context())
|
||||
if tok.Name == "token-a" {
|
||||
<-gateA // Block token A's request.
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Start a request for token A that blocks.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx := withAPIToken(req.Context(), tokenA)
|
||||
req = req.WithContext(ctx)
|
||||
handler.ServeHTTP(rec, req)
|
||||
}()
|
||||
|
||||
// Give token A's goroutine time to enter handler.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Token B should not be affected by token A's in-flight request.
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx := withAPIToken(req.Context(), tokenB)
|
||||
req = req.WithContext(ctx)
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("token-b status = %d, want %d (should not be affected by token-a)", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
close(gateA)
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
@ -4,11 +4,16 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
"llm-gateway/internal/auth"
|
||||
"llm-gateway/internal/cache"
|
||||
"llm-gateway/internal/config"
|
||||
|
|
@ -47,6 +52,7 @@ type Handler struct {
|
|||
metrics *metrics.Metrics
|
||||
cfg *config.Config
|
||||
healthTracker *provider.HealthTracker
|
||||
debugLogger *storage.DebugLogger
|
||||
}
|
||||
|
||||
func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler {
|
||||
|
|
@ -60,6 +66,10 @@ func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cac
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) {
|
||||
h.debugLogger = dl
|
||||
}
|
||||
|
||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
|
||||
if err != nil {
|
||||
|
|
@ -84,31 +94,53 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
// Filter healthy routes (circuit breaker)
|
||||
routes = h.filterHealthyRoutes(routes)
|
||||
|
||||
tokenName := getTokenName(r.Context())
|
||||
requestID := middleware.GetReqID(r.Context())
|
||||
|
||||
// Check cache for non-streaming requests
|
||||
if !req.Stream && h.cache != nil {
|
||||
if cached, err := h.cache.Get(r.Context(), req.Model, body); err == nil && cached != nil {
|
||||
h.logRequest(tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true)
|
||||
h.logRequest(requestID, tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true)
|
||||
if h.metrics != nil {
|
||||
h.metrics.RecordCacheHit()
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Cache", "HIT")
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
w.Write(cached)
|
||||
return
|
||||
}
|
||||
if h.metrics != nil {
|
||||
h.metrics.RecordCacheMiss()
|
||||
}
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
h.handleStream(w, r, &req, routes, tokenName)
|
||||
h.handleStream(w, r, &req, routes, tokenName, requestID)
|
||||
return
|
||||
}
|
||||
|
||||
h.handleNonStream(w, r, &req, routes, tokenName, body)
|
||||
h.handleNonStream(w, r, &req, routes, tokenName, body, requestID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte) {
|
||||
func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string) {
|
||||
var lastErr error
|
||||
|
||||
for _, route := range routes {
|
||||
for i, route := range routes {
|
||||
// Retry backoff between attempts (not before first attempt)
|
||||
if i > 0 {
|
||||
backoff := backoffDuration(i, h.cfg.Retry)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-r.Context().Done():
|
||||
writeError(w, http.StatusGatewayTimeout, "request cancelled")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := route.Provider.ChatCompletion(r.Context(), route.ProviderModel, req)
|
||||
latency := time.Since(start).Milliseconds()
|
||||
|
|
@ -116,19 +148,19 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
|
|||
if err != nil {
|
||||
var pe *provider.ProviderError
|
||||
if errors.As(err, &pe) && !pe.IsRetryable() {
|
||||
// Client error — don't retry
|
||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
||||
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
|
||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
|
||||
if h.healthTracker != nil {
|
||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
||||
}
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
writeErrorRaw(w, pe.StatusCode, pe.Body)
|
||||
return
|
||||
}
|
||||
lastErr = err
|
||||
log.Printf("Provider %s failed for %s: %v", route.Provider.Name(), req.Model, err)
|
||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
||||
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
|
||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
|
||||
if h.healthTracker != nil {
|
||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
||||
}
|
||||
|
|
@ -139,7 +171,6 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
|
|||
h.healthTracker.Record(route.Provider.Name(), latency, nil)
|
||||
}
|
||||
|
||||
// Compute cost
|
||||
inputTokens, outputTokens := 0, 0
|
||||
if resp.Usage != nil {
|
||||
inputTokens = resp.Usage.PromptTokens
|
||||
|
|
@ -148,9 +179,8 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
|
|||
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
|
||||
|
||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
|
||||
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false)
|
||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false)
|
||||
|
||||
// Override model name in response to match the requested model
|
||||
resp.Model = req.Model
|
||||
|
||||
respBytes, err := json.Marshal(resp)
|
||||
|
|
@ -159,27 +189,84 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
|
|||
return
|
||||
}
|
||||
|
||||
// Cache the response
|
||||
if h.cache != nil {
|
||||
h.cache.Set(r.Context(), req.Model, rawBody, respBytes)
|
||||
}
|
||||
|
||||
// Debug logging
|
||||
if h.debugLogger != nil && h.debugLogger.IsEnabled() {
|
||||
reqBody := string(rawBody)
|
||||
respBody := string(respBytes)
|
||||
if h.cfg.Debug.MaxBodyBytes > 0 {
|
||||
if len(reqBody) > h.cfg.Debug.MaxBodyBytes {
|
||||
reqBody = reqBody[:h.cfg.Debug.MaxBodyBytes]
|
||||
}
|
||||
if len(respBody) > h.cfg.Debug.MaxBodyBytes {
|
||||
respBody = respBody[:h.cfg.Debug.MaxBodyBytes]
|
||||
}
|
||||
}
|
||||
h.debugLogger.Log(storage.DebugLogEntry{
|
||||
RequestID: requestID,
|
||||
TokenName: tokenName,
|
||||
Model: req.Model,
|
||||
Provider: route.Provider.Name(),
|
||||
RequestBody: reqBody,
|
||||
ResponseBody: respBody,
|
||||
RequestHeaders: formatHeaders(r.Header),
|
||||
ResponseStatus: http.StatusOK,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Cache", "MISS")
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
w.Write(respBytes)
|
||||
return
|
||||
}
|
||||
|
||||
// All providers failed
|
||||
if lastErr != nil {
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
|
||||
} else {
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
writeError(w, http.StatusBadGateway, "all providers failed")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) logRequest(tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) {
|
||||
// filterHealthyRoutes removes providers with open circuit breakers.
|
||||
// If all are filtered out, returns original routes as fallback.
|
||||
func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route {
|
||||
if h.healthTracker == nil {
|
||||
return routes
|
||||
}
|
||||
var healthy []provider.Route
|
||||
for _, r := range routes {
|
||||
if h.healthTracker.IsAvailable(r.Provider.Name()) {
|
||||
healthy = append(healthy, r)
|
||||
}
|
||||
}
|
||||
if len(healthy) == 0 {
|
||||
return routes // all-down fallback
|
||||
}
|
||||
return healthy
|
||||
}
|
||||
|
||||
// backoffDuration computes exponential backoff for the given attempt.
|
||||
func backoffDuration(attempt int, cfg config.RetryConfig) time.Duration {
|
||||
d := cfg.InitialBackoff
|
||||
for i := 1; i < attempt; i++ {
|
||||
d = time.Duration(float64(d) * cfg.Multiplier)
|
||||
if d > cfg.MaxBackoff {
|
||||
d = cfg.MaxBackoff
|
||||
break
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func (h *Handler) logRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) {
|
||||
h.logger.Log(storage.RequestLog{
|
||||
RequestID: requestID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
TokenName: tokenName,
|
||||
Model: model,
|
||||
|
|
@ -217,3 +304,23 @@ func writeErrorRaw(w http.ResponseWriter, code int, body string) {
|
|||
w.WriteHeader(code)
|
||||
w.Write([]byte(body))
|
||||
}
|
||||
|
||||
// formatHeaders serializes HTTP headers to a readable string, sorted by key.
|
||||
// Sensitive headers (Authorization) are redacted.
|
||||
func formatHeaders(h http.Header) string {
|
||||
keys := make([]string, 0, len(h))
|
||||
for k := range h {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
var b strings.Builder
|
||||
for _, k := range keys {
|
||||
val := strings.Join(h[k], ", ")
|
||||
if strings.EqualFold(k, "Authorization") {
|
||||
val = "[REDACTED]"
|
||||
}
|
||||
fmt.Fprintf(&b, "%s: %s\n", k, val)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -40,7 +42,19 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
|
|||
|
||||
// Check rate limit
|
||||
if apiToken.RateLimitRPM > 0 {
|
||||
if !rl.allow(tokenName, apiToken.RateLimitRPM) {
|
||||
allowed, remaining, resetAt := rl.allow(tokenName, apiToken.RateLimitRPM)
|
||||
|
||||
// Set rate limit headers on all responses
|
||||
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", apiToken.RateLimitRPM))
|
||||
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
||||
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetAt))
|
||||
|
||||
if !allowed {
|
||||
retryAfter := resetAt - time.Now().Unix()
|
||||
if retryAfter < 1 {
|
||||
retryAfter = 1
|
||||
}
|
||||
w.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter))
|
||||
writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
||||
return
|
||||
}
|
||||
|
|
@ -59,7 +73,7 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) bool {
|
||||
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
|
|
@ -82,9 +96,27 @@ func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) bool {
|
|||
}
|
||||
bucket.lastRefill = now
|
||||
|
||||
remaining := int(math.Floor(bucket.tokens))
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
// Compute reset time: when bucket would be full again
|
||||
deficit := bucket.maxTokens - bucket.tokens
|
||||
var resetAt int64
|
||||
if deficit > 0 && bucket.refillRate > 0 {
|
||||
resetAt = now.Add(time.Duration(deficit/bucket.refillRate) * time.Second).Unix()
|
||||
} else {
|
||||
resetAt = now.Unix()
|
||||
}
|
||||
|
||||
if bucket.tokens < 1 {
|
||||
return false
|
||||
return false, 0, resetAt
|
||||
}
|
||||
bucket.tokens--
|
||||
return true
|
||||
remaining = int(math.Floor(bucket.tokens))
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
return true, remaining, resetAt
|
||||
}
|
||||
|
|
|
|||
374
llm-gateway/internal/proxy/ratelimit_test.go
Normal file
374
llm-gateway/internal/proxy/ratelimit_test.go
Normal file
|
|
@ -0,0 +1,374 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"llm-gateway/internal/auth"
|
||||
"llm-gateway/internal/storage"
|
||||
)
|
||||
|
||||
// newTestDB creates an in-memory SQLite database wrapped in storage.DB.
|
||||
// It creates the request_logs table needed by TodaySpend.
|
||||
func newTestDB(t *testing.T) *storage.DB {
|
||||
t.Helper()
|
||||
sqlDB, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("opening in-memory sqlite: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { sqlDB.Close() })
|
||||
|
||||
// Create the minimal table needed for TodaySpend queries.
|
||||
_, err = sqlDB.Exec(`CREATE TABLE IF NOT EXISTS request_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
token_name TEXT,
|
||||
cost_usd REAL,
|
||||
timestamp INTEGER
|
||||
)`)
|
||||
if err != nil {
|
||||
t.Fatalf("creating request_logs table: %v", err)
|
||||
}
|
||||
return &storage.DB{DB: sqlDB}
|
||||
}
|
||||
|
||||
// okHandler is a simple handler that writes 200 OK.
|
||||
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
func TestRateLimiter_Allow(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rateLimitRPM int
|
||||
numRequests int
|
||||
wantAllowed int
|
||||
wantDenied int
|
||||
}{
|
||||
{
|
||||
name: "allows requests within limit",
|
||||
rateLimitRPM: 10,
|
||||
numRequests: 5,
|
||||
wantAllowed: 5,
|
||||
wantDenied: 0,
|
||||
},
|
||||
{
|
||||
name: "denies requests over limit",
|
||||
rateLimitRPM: 3,
|
||||
numRequests: 6,
|
||||
wantAllowed: 3,
|
||||
wantDenied: 3,
|
||||
},
|
||||
{
|
||||
name: "allows exactly up to limit",
|
||||
rateLimitRPM: 5,
|
||||
numRequests: 5,
|
||||
wantAllowed: 5,
|
||||
wantDenied: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
rl := NewRateLimiter(db)
|
||||
|
||||
allowed := 0
|
||||
denied := 0
|
||||
for i := 0; i < tt.numRequests; i++ {
|
||||
ok, _, _ := rl.allow("test-token", tt.rateLimitRPM)
|
||||
if ok {
|
||||
allowed++
|
||||
} else {
|
||||
denied++
|
||||
}
|
||||
}
|
||||
|
||||
if allowed != tt.wantAllowed {
|
||||
t.Errorf("allowed = %d, want %d", allowed, tt.wantAllowed)
|
||||
}
|
||||
if denied != tt.wantDenied {
|
||||
t.Errorf("denied = %d, want %d", denied, tt.wantDenied)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_TokenRefillsOverTime(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
rl := NewRateLimiter(db)
|
||||
|
||||
rpm := 60 // 1 token per second refill rate
|
||||
|
||||
// Exhaust all tokens.
|
||||
for i := 0; i < rpm; i++ {
|
||||
ok, _, _ := rl.allow("refill-token", rpm)
|
||||
if !ok {
|
||||
t.Fatalf("request %d should have been allowed", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Next request should be denied.
|
||||
ok, _, _ := rl.allow("refill-token", rpm)
|
||||
if ok {
|
||||
t.Fatal("request should have been denied after exhausting tokens")
|
||||
}
|
||||
|
||||
// Manually advance the bucket's lastRefill to simulate time passing.
|
||||
rl.mu.Lock()
|
||||
bucket := rl.buckets["refill-token"]
|
||||
bucket.lastRefill = bucket.lastRefill.Add(-2 * time.Second)
|
||||
rl.mu.Unlock()
|
||||
|
||||
// After 2 seconds at 1 token/sec, we should have ~2 tokens refilled.
|
||||
ok, remaining, _ := rl.allow("refill-token", rpm)
|
||||
if !ok {
|
||||
t.Fatal("request should have been allowed after token refill")
|
||||
}
|
||||
// We consumed 1 of the ~2 refilled tokens, so remaining should be >= 0.
|
||||
if remaining < 0 {
|
||||
t.Errorf("remaining = %d, want >= 0", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_AllowReturnValues(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rateLimitRPM int
|
||||
numRequests int
|
||||
wantLastAllowed bool
|
||||
wantLastRemaining int
|
||||
}{
|
||||
{
|
||||
name: "remaining decrements correctly",
|
||||
rateLimitRPM: 5,
|
||||
numRequests: 1,
|
||||
wantLastAllowed: true,
|
||||
wantLastRemaining: 4,
|
||||
},
|
||||
{
|
||||
name: "remaining is zero at limit",
|
||||
rateLimitRPM: 3,
|
||||
numRequests: 3,
|
||||
wantLastAllowed: true,
|
||||
wantLastRemaining: 0,
|
||||
},
|
||||
{
|
||||
name: "denied returns zero remaining",
|
||||
rateLimitRPM: 2,
|
||||
numRequests: 3,
|
||||
wantLastAllowed: false,
|
||||
wantLastRemaining: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
rl := NewRateLimiter(db)
|
||||
|
||||
var allowed bool
|
||||
var remaining int
|
||||
for i := 0; i < tt.numRequests; i++ {
|
||||
allowed, remaining, _ = rl.allow("test-token", tt.rateLimitRPM)
|
||||
}
|
||||
|
||||
if allowed != tt.wantLastAllowed {
|
||||
t.Errorf("allowed = %v, want %v", allowed, tt.wantLastAllowed)
|
||||
}
|
||||
if remaining != tt.wantLastRemaining {
|
||||
t.Errorf("remaining = %d, want %d", remaining, tt.wantLastRemaining)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rateLimitRPM int
|
||||
numRequests int
|
||||
wantStatusCode int
|
||||
wantLimitHeader string
|
||||
wantRetryAfter bool
|
||||
}{
|
||||
{
|
||||
name: "sets rate limit headers on allowed request",
|
||||
rateLimitRPM: 10,
|
||||
numRequests: 1,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantLimitHeader: "10",
|
||||
wantRetryAfter: false,
|
||||
},
|
||||
{
|
||||
name: "sets Retry-After header on 429",
|
||||
rateLimitRPM: 2,
|
||||
numRequests: 3,
|
||||
wantStatusCode: http.StatusTooManyRequests,
|
||||
wantLimitHeader: "2",
|
||||
wantRetryAfter: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
rl := NewRateLimiter(db)
|
||||
|
||||
token := &auth.APIToken{
|
||||
Name: "header-test-token",
|
||||
RateLimitRPM: tt.rateLimitRPM,
|
||||
}
|
||||
|
||||
handler := rl.Check(okHandler)
|
||||
|
||||
var rec *httptest.ResponseRecorder
|
||||
for i := 0; i < tt.numRequests; i++ {
|
||||
rec = httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx := withAPIToken(req.Context(), token)
|
||||
req = req.WithContext(ctx)
|
||||
handler.ServeHTTP(rec, req)
|
||||
}
|
||||
|
||||
// Check the last response.
|
||||
if rec.Code != tt.wantStatusCode {
|
||||
t.Errorf("status code = %d, want %d", rec.Code, tt.wantStatusCode)
|
||||
}
|
||||
|
||||
// X-RateLimit-Limit header.
|
||||
limitHeader := rec.Header().Get("X-RateLimit-Limit")
|
||||
if limitHeader != tt.wantLimitHeader {
|
||||
t.Errorf("X-RateLimit-Limit = %q, want %q", limitHeader, tt.wantLimitHeader)
|
||||
}
|
||||
|
||||
// X-RateLimit-Remaining header must be present and numeric.
|
||||
remainingHeader := rec.Header().Get("X-RateLimit-Remaining")
|
||||
if remainingHeader == "" {
|
||||
t.Error("X-RateLimit-Remaining header is missing")
|
||||
} else if _, err := strconv.Atoi(remainingHeader); err != nil {
|
||||
t.Errorf("X-RateLimit-Remaining = %q, not a valid integer", remainingHeader)
|
||||
}
|
||||
|
||||
// X-RateLimit-Reset header must be present and numeric.
|
||||
resetHeader := rec.Header().Get("X-RateLimit-Reset")
|
||||
if resetHeader == "" {
|
||||
t.Error("X-RateLimit-Reset header is missing")
|
||||
} else if _, err := strconv.ParseInt(resetHeader, 10, 64); err != nil {
|
||||
t.Errorf("X-RateLimit-Reset = %q, not a valid integer", resetHeader)
|
||||
}
|
||||
|
||||
// Retry-After header.
|
||||
retryAfter := rec.Header().Get("Retry-After")
|
||||
if tt.wantRetryAfter && retryAfter == "" {
|
||||
t.Error("Retry-After header is missing on 429 response")
|
||||
}
|
||||
if !tt.wantRetryAfter && retryAfter != "" {
|
||||
t.Errorf("Retry-After header should not be present, got %q", retryAfter)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_CheckMiddleware_NoToken(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
rl := NewRateLimiter(db)
|
||||
|
||||
handler := rl.Check(okHandler)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// No API token in context.
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status code = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_CheckMiddleware_ZeroRPM(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
rl := NewRateLimiter(db)
|
||||
|
||||
token := &auth.APIToken{
|
||||
Name: "unlimited-token",
|
||||
RateLimitRPM: 0, // zero means unlimited
|
||||
}
|
||||
|
||||
handler := rl.Check(okHandler)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx := withAPIToken(req.Context(), token)
|
||||
req = req.WithContext(ctx)
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("request %d: status code = %d, want %d (zero RPM should be unlimited)", i, rec.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_PerTokenIsolation(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
rl := NewRateLimiter(db)
|
||||
|
||||
rpm := 2
|
||||
|
||||
// Exhaust token A.
|
||||
for i := 0; i < rpm; i++ {
|
||||
rl.allow("token-a", rpm)
|
||||
}
|
||||
ok, _, _ := rl.allow("token-a", rpm)
|
||||
if ok {
|
||||
t.Fatal("token-a should be rate limited")
|
||||
}
|
||||
|
||||
// Token B should still have its own bucket.
|
||||
ok, _, _ = rl.allow("token-b", rpm)
|
||||
if !ok {
|
||||
t.Fatal("token-b should not be affected by token-a's rate limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_ResetAtIsFuture(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
rl := NewRateLimiter(db)
|
||||
|
||||
// Consume one token so there's a deficit.
|
||||
_, _, resetAt := rl.allow("reset-token", 10)
|
||||
now := time.Now().Unix()
|
||||
|
||||
if resetAt < now {
|
||||
t.Errorf("resetAt = %d, want >= %d (should be now or in the future)", resetAt, now)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_CheckMiddleware_ContextCancelled(t *testing.T) {
|
||||
db := newTestDB(t)
|
||||
rl := NewRateLimiter(db)
|
||||
|
||||
token := &auth.APIToken{
|
||||
Name: "ctx-token",
|
||||
RateLimitRPM: 10,
|
||||
}
|
||||
|
||||
handler := rl.Check(okHandler)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
ctx, cancel := context.WithCancel(req.Context())
|
||||
ctx = withAPIToken(ctx, token)
|
||||
cancel() // Cancel immediately.
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// Should still process (rate limiter does not check context cancellation).
|
||||
handler.ServeHTTP(rec, req)
|
||||
// The handler itself may or may not respect cancelled context;
|
||||
// the key point is no panic occurs.
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@ package proxy
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
|
|
@ -12,7 +13,7 @@ import (
|
|||
"llm-gateway/internal/provider"
|
||||
)
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string) {
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
writeError(w, http.StatusInternalServerError, "streaming not supported")
|
||||
|
|
@ -21,7 +22,18 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
|
|||
|
||||
var lastErr error
|
||||
|
||||
for _, route := range routes {
|
||||
for i, route := range routes {
|
||||
// Retry backoff between attempts
|
||||
if i > 0 {
|
||||
backoff := backoffDuration(i, h.cfg.Retry)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-r.Context().Done():
|
||||
writeError(w, http.StatusGatewayTimeout, "request cancelled")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req)
|
||||
|
||||
|
|
@ -30,67 +42,95 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
|
|||
if errors.As(err, &pe) && !pe.IsRetryable() {
|
||||
latency := time.Since(start).Milliseconds()
|
||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
||||
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
||||
if h.healthTracker != nil {
|
||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
||||
}
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
writeErrorRaw(w, pe.StatusCode, pe.Body)
|
||||
return
|
||||
}
|
||||
lastErr = err
|
||||
latency := time.Since(start).Milliseconds()
|
||||
log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err)
|
||||
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
||||
if h.healthTracker != nil {
|
||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply streaming timeout
|
||||
var streamCtx context.Context
|
||||
var streamCancel context.CancelFunc
|
||||
if h.cfg.Server.StreamingTimeout > 0 {
|
||||
streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout)
|
||||
} else {
|
||||
streamCtx, streamCancel = context.WithCancel(r.Context())
|
||||
}
|
||||
|
||||
// Stream the response
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
inputTokens, outputTokens := 0, 0
|
||||
scanner := bufio.NewScanner(body)
|
||||
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
scanDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(scanDone)
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-streamCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Parse usage from the final chunk if available
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data != "[DONE]" {
|
||||
var chunk streamChunk
|
||||
if json.Unmarshal([]byte(data), &chunk) == nil {
|
||||
if chunk.Usage != nil {
|
||||
inputTokens = chunk.Usage.PromptTokens
|
||||
outputTokens = chunk.Usage.CompletionTokens
|
||||
}
|
||||
// Override model name in chunk
|
||||
if chunk.Model != "" {
|
||||
chunk.Model = req.Model
|
||||
if rewritten, err := json.Marshal(chunk); err == nil {
|
||||
line = "data: " + string(rewritten)
|
||||
line := scanner.Text()
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data != "[DONE]" {
|
||||
var chunk streamChunk
|
||||
if json.Unmarshal([]byte(data), &chunk) == nil {
|
||||
if chunk.Usage != nil {
|
||||
inputTokens = chunk.Usage.PromptTokens
|
||||
outputTokens = chunk.Usage.CompletionTokens
|
||||
}
|
||||
if chunk.Model != "" {
|
||||
chunk.Model = req.Model
|
||||
if rewritten, err := json.Marshal(chunk); err == nil {
|
||||
line = "data: " + string(rewritten)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Write([]byte(line + "\n"))
|
||||
flusher.Flush()
|
||||
w.Write([]byte(line + "\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-scanDone:
|
||||
// Normal completion
|
||||
case <-streamCtx.Done():
|
||||
log.Printf("Stream timeout for %s via %s", req.Model, route.Provider.Name())
|
||||
}
|
||||
|
||||
body.Close()
|
||||
streamCancel()
|
||||
|
||||
latency := time.Since(start).Milliseconds()
|
||||
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
|
||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
|
||||
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
|
||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
|
||||
if h.healthTracker != nil {
|
||||
h.healthTracker.Record(route.Provider.Name(), latency, nil)
|
||||
}
|
||||
|
|
@ -98,6 +138,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
|
|||
}
|
||||
|
||||
// All providers failed
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
if lastErr != nil {
|
||||
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
|
||||
} else {
|
||||
|
|
|
|||
102
llm-gateway/internal/storage/audit.go
Normal file
102
llm-gateway/internal/storage/audit.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuditEntry struct {
|
||||
ID int64 `json:"id"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Action string `json:"action"`
|
||||
TargetType string `json:"target_type"`
|
||||
TargetID string `json:"target_id"`
|
||||
Details string `json:"details"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AuditLogger struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
func NewAuditLogger(db *DB) *AuditLogger {
|
||||
return &AuditLogger{db: db}
|
||||
}
|
||||
|
||||
func (a *AuditLogger) Log(entry AuditEntry) {
|
||||
if entry.Timestamp == 0 {
|
||||
entry.Timestamp = time.Now().Unix()
|
||||
}
|
||||
_, err := a.db.Exec(`INSERT INTO audit_log
|
||||
(timestamp, user_id, username, action, target_type, target_id, details, ip_address, request_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
entry.Timestamp, entry.UserID, entry.Username, entry.Action,
|
||||
entry.TargetType, entry.TargetID, entry.Details, entry.IPAddress, entry.RequestID,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: audit log: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type AuditQueryResult struct {
|
||||
Entries []AuditEntry `json:"entries"`
|
||||
Page int `json:"page"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
func (a *AuditLogger) Query(since int64, action string, page, limit int) *AuditQueryResult {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
offset := (page - 1) * limit
|
||||
|
||||
where := "WHERE timestamp >= ?"
|
||||
args := []any{since}
|
||||
|
||||
if action != "" {
|
||||
where += " AND action = ?"
|
||||
args = append(args, action)
|
||||
}
|
||||
|
||||
var total int
|
||||
countArgs := make([]any, len(args))
|
||||
copy(countArgs, args)
|
||||
a.db.QueryRow("SELECT COUNT(*) FROM audit_log "+where, countArgs...).Scan(&total)
|
||||
|
||||
totalPages := (total + limit - 1) / limit
|
||||
if totalPages < 1 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
query := `SELECT id, timestamp, COALESCE(user_id, 0), username, action,
|
||||
COALESCE(target_type, ''), COALESCE(target_id, ''), COALESCE(details, ''),
|
||||
COALESCE(ip_address, ''), COALESCE(request_id, '')
|
||||
FROM audit_log ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?`
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := a.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return &AuditQueryResult{Entries: []AuditEntry{}, Page: page, TotalPages: totalPages, Total: total}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []AuditEntry
|
||||
for rows.Next() {
|
||||
var e AuditEntry
|
||||
rows.Scan(&e.ID, &e.Timestamp, &e.UserID, &e.Username, &e.Action,
|
||||
&e.TargetType, &e.TargetID, &e.Details, &e.IPAddress, &e.RequestID)
|
||||
entries = append(entries, e)
|
||||
}
|
||||
if entries == nil {
|
||||
entries = []AuditEntry{}
|
||||
}
|
||||
|
||||
return &AuditQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
|
||||
}
|
||||
250
llm-gateway/internal/storage/debuglog.go
Normal file
250
llm-gateway/internal/storage/debuglog.go
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DebugLogEntry struct {
|
||||
ID int64 `json:"id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
TokenName string `json:"token_name"`
|
||||
Model string `json:"model"`
|
||||
Provider string `json:"provider"`
|
||||
RequestBody string `json:"request_body"`
|
||||
ResponseBody string `json:"response_body"`
|
||||
RequestHeaders string `json:"request_headers"`
|
||||
ResponseStatus int `json:"response_status"`
|
||||
FilePath string `json:"-"`
|
||||
}
|
||||
|
||||
// debugFile is the JSON structure written to disk.
|
||||
type debugFile struct {
|
||||
RequestHeaders string `json:"request_headers"`
|
||||
RequestBody string `json:"request_body"`
|
||||
ResponseBody string `json:"response_body"`
|
||||
}
|
||||
|
||||
type DebugLogger struct {
|
||||
db *DB
|
||||
enabled atomic.Bool
|
||||
dataDir string
|
||||
}
|
||||
|
||||
func NewDebugLogger(db *DB, enabled bool, dataDir string) *DebugLogger {
|
||||
dl := &DebugLogger{db: db, dataDir: dataDir}
|
||||
dl.enabled.Store(enabled)
|
||||
return dl
|
||||
}
|
||||
|
||||
func (d *DebugLogger) SetEnabled(v bool) {
|
||||
d.enabled.Store(v)
|
||||
}
|
||||
|
||||
func (d *DebugLogger) IsEnabled() bool {
|
||||
return d.enabled.Load()
|
||||
}
|
||||
|
||||
// debugLogDir returns the base directory for debug log files.
|
||||
func (d *DebugLogger) debugLogDir() string {
|
||||
return filepath.Join(d.dataDir, "debug-logs")
|
||||
}
|
||||
|
||||
// debugFilePath builds the file path for a debug log entry.
|
||||
func (d *DebugLogger) debugFilePath(requestID string, ts time.Time) string {
|
||||
date := ts.Format("2006-01-02")
|
||||
return filepath.Join(d.debugLogDir(), date, requestID+".json")
|
||||
}
|
||||
|
||||
func (d *DebugLogger) Log(entry DebugLogEntry) {
|
||||
if !d.IsEnabled() {
|
||||
return
|
||||
}
|
||||
if entry.Timestamp == 0 {
|
||||
entry.Timestamp = time.Now().Unix()
|
||||
}
|
||||
|
||||
ts := time.Unix(entry.Timestamp, 0)
|
||||
fp := d.debugFilePath(entry.RequestID, ts)
|
||||
|
||||
// Write body file
|
||||
if err := os.MkdirAll(filepath.Dir(fp), 0755); err != nil {
|
||||
log.Printf("ERROR: debug log mkdir: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
df := debugFile{
|
||||
RequestHeaders: entry.RequestHeaders,
|
||||
RequestBody: entry.RequestBody,
|
||||
ResponseBody: entry.ResponseBody,
|
||||
}
|
||||
data, err := json.Marshal(df)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: debug log marshal: %v", err)
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(fp, data, 0644); err != nil {
|
||||
log.Printf("ERROR: debug log write: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Insert metadata into DB (no bodies)
|
||||
_, err = d.db.Exec(`INSERT INTO debug_log
|
||||
(request_id, timestamp, token_name, model, provider, response_status, file_path)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
entry.RequestID, entry.Timestamp, entry.TokenName, entry.Model,
|
||||
entry.Provider, entry.ResponseStatus, fp,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: debug log db insert: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type DebugLogQueryResult struct {
|
||||
Entries []DebugLogEntry `json:"entries"`
|
||||
Page int `json:"page"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// Query returns paginated debug log metadata (no bodies — fast).
|
||||
func (d *DebugLogger) Query(page, limit int) *DebugLogQueryResult {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
offset := (page - 1) * limit
|
||||
|
||||
var total int
|
||||
d.db.QueryRow("SELECT COUNT(*) FROM debug_log").Scan(&total)
|
||||
|
||||
totalPages := (total + limit - 1) / limit
|
||||
if totalPages < 1 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
rows, err := d.db.Query(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
|
||||
COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
|
||||
FROM debug_log ORDER BY timestamp DESC LIMIT ? OFFSET ?`, limit, offset)
|
||||
if err != nil {
|
||||
return &DebugLogQueryResult{Entries: []DebugLogEntry{}, Page: page, TotalPages: totalPages, Total: total}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var entries []DebugLogEntry
|
||||
for rows.Next() {
|
||||
var e DebugLogEntry
|
||||
rows.Scan(&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
|
||||
&e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
|
||||
entries = append(entries, e)
|
||||
}
|
||||
if entries == nil {
|
||||
entries = []DebugLogEntry{}
|
||||
}
|
||||
|
||||
return &DebugLogQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
|
||||
}
|
||||
|
||||
// QueryFull returns paginated debug log entries including request/response bodies read from files.
|
||||
func (d *DebugLogger) QueryFull(page, limit int) *DebugLogQueryResult {
|
||||
result := d.Query(page, limit)
|
||||
for i := range result.Entries {
|
||||
d.populateFromFile(&result.Entries[i])
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetByRequestID returns a single debug log entry with bodies read from file.
|
||||
func (d *DebugLogger) GetByRequestID(requestID string) *DebugLogEntry {
|
||||
var e DebugLogEntry
|
||||
err := d.db.QueryRow(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
|
||||
COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
|
||||
FROM debug_log WHERE request_id = ?`, requestID).Scan(
|
||||
&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
|
||||
&e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
d.populateFromFile(&e)
|
||||
return &e
|
||||
}
|
||||
|
||||
// populateFromFile reads body data from the debug file on disk.
|
||||
// Falls back to DB columns for pre-migration entries that have no file_path.
|
||||
func (d *DebugLogger) populateFromFile(e *DebugLogEntry) {
|
||||
if e.FilePath == "" {
|
||||
// Legacy entry: try reading bodies from DB columns
|
||||
d.db.QueryRow(`SELECT COALESCE(request_body, ''), COALESCE(response_body, ''), COALESCE(request_headers, '')
|
||||
FROM debug_log WHERE id = ?`, e.ID).Scan(&e.RequestBody, &e.ResponseBody, &e.RequestHeaders)
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(e.FilePath)
|
||||
if err != nil {
|
||||
log.Printf("WARN: debug log read file %s: %v", e.FilePath, err)
|
||||
return
|
||||
}
|
||||
var df debugFile
|
||||
if err := json.Unmarshal(data, &df); err != nil {
|
||||
log.Printf("WARN: debug log parse file %s: %v", e.FilePath, err)
|
||||
return
|
||||
}
|
||||
e.RequestHeaders = df.RequestHeaders
|
||||
e.RequestBody = df.RequestBody
|
||||
e.ResponseBody = df.ResponseBody
|
||||
}
|
||||
|
||||
// Cleanup removes debug log entries and files older than retentionDays.
|
||||
func (d *DebugLogger) Cleanup(retentionDays int) error {
|
||||
cutoff := time.Now().AddDate(0, 0, -retentionDays)
|
||||
cutoffUnix := cutoff.Unix()
|
||||
|
||||
// Delete old DB rows
|
||||
result, err := d.db.Exec("DELETE FROM debug_log WHERE timestamp < ?", cutoffUnix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete old debug rows: %w", err)
|
||||
}
|
||||
affected, _ := result.RowsAffected()
|
||||
if affected > 0 {
|
||||
log.Printf("Cleaned up %d old debug log entries", affected)
|
||||
}
|
||||
|
||||
// Remove old date directories
|
||||
baseDir := d.debugLogDir()
|
||||
dirs, err := os.ReadDir(baseDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read debug log dir: %w", err)
|
||||
}
|
||||
|
||||
cutoffDate := cutoff.Format("2006-01-02")
|
||||
sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() })
|
||||
|
||||
for _, dir := range dirs {
|
||||
if !dir.IsDir() {
|
||||
continue
|
||||
}
|
||||
// Date directories are named YYYY-MM-DD; string comparison works
|
||||
if strings.Compare(dir.Name(), cutoffDate) < 0 {
|
||||
dirPath := filepath.Join(baseDir, dir.Name())
|
||||
if err := os.RemoveAll(dirPath); err != nil {
|
||||
log.Printf("WARN: failed to remove debug log dir %s: %v", dirPath, err)
|
||||
} else {
|
||||
log.Printf("Removed old debug log directory: %s", dir.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ import (
|
|||
)
|
||||
|
||||
type RequestLog struct {
|
||||
RequestID string
|
||||
Timestamp int64
|
||||
TokenName string
|
||||
Model string
|
||||
|
|
@ -93,8 +94,8 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
|
|||
}
|
||||
|
||||
stmt, err := tx.Prepare(`INSERT INTO request_logs
|
||||
(timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
(request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: preparing log statement: %v", err)
|
||||
tx.Rollback()
|
||||
|
|
@ -112,7 +113,7 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
|
|||
cached = 1
|
||||
}
|
||||
_, err := stmt.Exec(
|
||||
r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
|
||||
r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
|
||||
r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS,
|
||||
r.Status, r.ErrorMessage, streaming, cached,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
-- SQLite doesn't support DROP COLUMN in older versions, so we recreate the table
|
||||
CREATE TABLE api_tokens_backup AS SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens;
|
||||
DROP TABLE api_tokens;
|
||||
ALTER TABLE api_tokens_backup RENAME TO api_tokens;
|
||||
|
|
@ -0,0 +1 @@
|
|||
ALTER TABLE api_tokens ADD COLUMN max_concurrent INTEGER DEFAULT 0;
|
||||
|
|
@ -0,0 +1 @@
|
|||
DROP INDEX IF EXISTS idx_request_logs_request_id;
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
ALTER TABLE request_logs ADD COLUMN request_id TEXT DEFAULT '';
|
||||
CREATE INDEX idx_request_logs_request_id ON request_logs(request_id);
|
||||
|
|
@ -0,0 +1 @@
|
|||
DROP TABLE IF EXISTS audit_log;
|
||||
14
llm-gateway/internal/storage/migrations/006_audit_log.up.sql
Normal file
14
llm-gateway/internal/storage/migrations/006_audit_log.up.sql
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
CREATE TABLE audit_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp INTEGER NOT NULL,
|
||||
user_id INTEGER,
|
||||
username TEXT NOT NULL DEFAULT '',
|
||||
action TEXT NOT NULL,
|
||||
target_type TEXT DEFAULT '',
|
||||
target_id TEXT DEFAULT '',
|
||||
details TEXT DEFAULT '',
|
||||
ip_address TEXT DEFAULT '',
|
||||
request_id TEXT DEFAULT ''
|
||||
);
|
||||
CREATE INDEX idx_audit_timestamp ON audit_log(timestamp);
|
||||
CREATE INDEX idx_audit_action ON audit_log(action);
|
||||
|
|
@ -0,0 +1 @@
|
|||
DROP TABLE IF EXISTS debug_log;
|
||||
14
llm-gateway/internal/storage/migrations/007_debug_log.up.sql
Normal file
14
llm-gateway/internal/storage/migrations/007_debug_log.up.sql
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
CREATE TABLE debug_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
request_id TEXT NOT NULL,
|
||||
timestamp INTEGER NOT NULL,
|
||||
token_name TEXT DEFAULT '',
|
||||
model TEXT DEFAULT '',
|
||||
provider TEXT DEFAULT '',
|
||||
request_body TEXT DEFAULT '',
|
||||
response_body TEXT DEFAULT '',
|
||||
request_headers TEXT DEFAULT '',
|
||||
response_status INTEGER DEFAULT 0
|
||||
);
|
||||
CREATE INDEX idx_debug_request_id ON debug_log(request_id);
|
||||
CREATE INDEX idx_debug_timestamp ON debug_log(timestamp);
|
||||
|
|
@ -0,0 +1 @@
|
|||
-- no-op: file_path column is harmless to keep
|
||||
|
|
@ -0,0 +1 @@
|
|||
ALTER TABLE debug_log ADD COLUMN file_path TEXT DEFAULT '';
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
# new-api Channel Configuration
|
||||
|
||||
After first start, access the new-api web UI at `http://<server>:4000` to configure channels.
|
||||
|
||||
Default admin credentials: `root` / `123456` — **change immediately**.
|
||||
|
||||
## API Token for Open WebUI
|
||||
|
||||
Create an API token in new-api's token management. Use this token as `OPENWEBUI_API_KEY` in `.env`.
|
||||
|
||||
## Channels to Create
|
||||
|
||||
Configure each channel via **Channels > Add Channel** in the web UI.
|
||||
|
||||
### 1. DeepInfra (Priority 1)
|
||||
|
||||
| Field | Value |
|
||||
|---|---|
|
||||
| Name | DeepInfra |
|
||||
| Type | OpenAI |
|
||||
| Base URL | `https://api.deepinfra.com/v1/openai` |
|
||||
| Key | `$DEEPINFRA_API_KEY` |
|
||||
| Priority | 1 |
|
||||
| Models | See model mapping below |
|
||||
|
||||
### 2. SiliconFlow (Priority 2)
|
||||
|
||||
| Field | Value |
|
||||
|---|---|
|
||||
| Name | SiliconFlow |
|
||||
| Type | OpenAI |
|
||||
| Base URL | `https://api.siliconflow.com/v1` |
|
||||
| Key | `$SILICONFLOW_API_KEY` |
|
||||
| Priority | 2 |
|
||||
| Models | See model mapping below |
|
||||
|
||||
### 3. OpenRouter (Priority 3)
|
||||
|
||||
| Field | Value |
|
||||
|---|---|
|
||||
| Name | OpenRouter |
|
||||
| Type | OpenAI |
|
||||
| Base URL | `https://openrouter.ai/api/v1` |
|
||||
| Key | `$OPENROUTER_API_KEY` |
|
||||
| Priority | 3 |
|
||||
| Models | See model mapping below |
|
||||
|
||||
### 4. Groq (Priority 1)
|
||||
|
||||
| Field | Value |
|
||||
|---|---|
|
||||
| Name | Groq |
|
||||
| Type | OpenAI |
|
||||
| Base URL | `https://api.groq.com/openai/v1` |
|
||||
| Key | `$GROQ_API_KEY` |
|
||||
| Priority | 1 |
|
||||
| Models | `llama-3.3-70b` |
|
||||
|
||||
### 5. Cerebras (Priority 1)
|
||||
|
||||
| Field | Value |
|
||||
|---|---|
|
||||
| Name | Cerebras |
|
||||
| Type | OpenAI |
|
||||
| Base URL | `https://api.cerebras.ai/v1` |
|
||||
| Key | `$CEREBRAS_API_KEY` |
|
||||
| Priority | 1 |
|
||||
| Models | `llama-3.3-70b-cerebras` |
|
||||
|
||||
## Model Mapping per Channel
|
||||
|
||||
new-api uses model aliasing: the "model name" is what clients see, the "actual model" is what's sent to the provider.
|
||||
|
||||
### DeepInfra Models
|
||||
|
||||
| Client Model Name | Actual Provider Model |
|
||||
|---|---|
|
||||
| `deepseek-v3.2` | `deepseek-ai/DeepSeek-V3.2` |
|
||||
| `deepseek-r1` | `deepseek-ai/DeepSeek-R1` |
|
||||
| `gpt-oss` | `openai/gpt-oss-120b` |
|
||||
| `gpt-oss-20b` | `openai/gpt-oss-20b` |
|
||||
| `nemotron-super` | `nvidia/Llama-3.3-Nemotron-Super-49B-v1.5` |
|
||||
| `nemotron-nano` | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` |
|
||||
| `devstral` | `mistralai/Devstral-Small-2505` |
|
||||
| `glm-4.6` | `zai-org/GLM-4.6` |
|
||||
| `glm-4.7` | `zai-org/GLM-4.7` |
|
||||
| `glm-5` | `zai-org/GLM-5` |
|
||||
| `kimi-k2` | `moonshotai/Kimi-K2-Instruct-0905` |
|
||||
| `kimi-k2.5` | `moonshotai/Kimi-K2.5` |
|
||||
| `deepseek-v3-free` | `deepseek-ai/DeepSeek-V3` |
|
||||
|
||||
### SiliconFlow Models
|
||||
|
||||
| Client Model Name | Actual Provider Model |
|
||||
|---|---|
|
||||
| `deepseek-v3.2` | `deepseek-ai/DeepSeek-V3.2` |
|
||||
| `glm-4.7` | `THUDM/GLM-4-32B-0414` |
|
||||
| `kimi-k2` | `moonshotai/Kimi-K2-Instruct-0905` |
|
||||
| `qwen3-coder` | `Qwen/Qwen3-Coder-480B-A35B-Instruct` |
|
||||
| `qwen3-coder-30b` | `Qwen/Qwen3-Coder-30B-A3B-Instruct` |
|
||||
|
||||
### OpenRouter Models
|
||||
|
||||
| Client Model Name | Actual Provider Model |
|
||||
|---|---|
|
||||
| `deepseek-v3.2` | `deepseek/deepseek-chat-v3-0324` |
|
||||
| `deepseek-v3-free` | `deepseek/deepseek-chat-v3-0324:free` |
|
||||
| `kimi-k2.5` | `moonshotai/kimi-k2.5` |
|
||||
| `minimax-m2.5` | `minimax/minimax-m2.5` |
|
||||
| `gpt-4.1-mini` | `openai/gpt-4.1-mini` |
|
||||
| `gpt-4.1` | `openai/gpt-4.1` |
|
||||
| `gemini-3-flash-preview` | `google/gemini-3-flash-preview` |
|
||||
| `gemini-2.5-pro` | `google/gemini-2.5-pro-preview` |
|
||||
| `claude-sonnet` | `anthropic/claude-sonnet-4` |
|
||||
| `trinity-large-preview` | `arcee-ai/trinity-large-preview` |
|
||||
|
||||
### Groq Models
|
||||
|
||||
| Client Model Name | Actual Provider Model |
|
||||
|---|---|
|
||||
| `llama-3.3-70b` | `llama-3.3-70b-versatile` |
|
||||
|
||||
### Cerebras Models
|
||||
|
||||
| Client Model Name | Actual Provider Model |
|
||||
|---|---|
|
||||
| `llama-3.3-70b-cerebras` | `llama-3.3-70b` |
|
||||
|
||||
## Fallback Behavior
|
||||
|
||||
new-api handles fallbacks via priority levels:
|
||||
- When a model exists on multiple channels, the highest priority (lowest number) channel is tried first
|
||||
- If it fails, it automatically falls back to the next priority level
|
||||
|
||||
For example, `deepseek-v3.2` exists on:
|
||||
1. DeepInfra (priority 1) — tried first
|
||||
2. SiliconFlow (priority 2) — fallback
|
||||
3. OpenRouter (priority 3) — last resort
|
||||
|
||||
## Grafana Setup
|
||||
|
||||
After first start, access Grafana at `http://<server>:3001`:
|
||||
1. Login with `admin` / `$GRAFANA_ADMIN_PASSWORD`
|
||||
2. Add data source: **Prometheus** with URL `http://victoriametrics:8428`
|
||||
3. Import dashboards:
|
||||
- Node Exporter Full: dashboard ID `1860`
|
||||
- Redis: dashboard ID `763`
|
||||
|
|
@ -1,218 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
# Configures new-api channels and API token via the admin API.
|
||||
# Run once after first boot: ./new-api/init-channels.sh
|
||||
#
|
||||
# Requires these env vars (or .env file in project root):
|
||||
# NEW_API_PASSWORD - admin password
|
||||
# DEEPINFRA_API_KEY
|
||||
# SILICONFLOW_API_KEY
|
||||
# OPENROUTER_API_KEY
|
||||
# GROQ_API_KEY
|
||||
# CEREBRAS_API_KEY
|
||||
#
|
||||
# Optional:
|
||||
# NEW_API_USERNAME - admin username (default: root)
|
||||
# NEW_API_BASE - API base URL (default: http://localhost:4000)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
ENV_FILE="${SCRIPT_DIR}/../.env"
|
||||
|
||||
# Load .env if present
|
||||
if [[ -f "$ENV_FILE" ]]; then
|
||||
set -a
|
||||
# shellcheck disable=SC1090
|
||||
source "$ENV_FILE"
|
||||
set +a
|
||||
fi
|
||||
|
||||
API_BASE="${NEW_API_BASE:-http://localhost:4000}"
|
||||
USERNAME="${NEW_API_USERNAME:-root}"
|
||||
PASSWORD="${NEW_API_PASSWORD:?Set NEW_API_PASSWORD to the admin password}"
|
||||
COOKIE_JAR=$(mktemp)
|
||||
USER_ID=""
|
||||
trap 'rm -f "$COOKIE_JAR"' EXIT
|
||||
|
||||
# ── Login and get user ID ───────────────────────────────
|
||||
login() {
|
||||
echo "Logging in as ${USERNAME}..."
|
||||
local resp
|
||||
resp=$(curl -s -c "$COOKIE_JAR" "${API_BASE}/api/user/login" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$(python3 -c "
|
||||
import json, sys
|
||||
print(json.dumps({'username': sys.argv[1], 'password': sys.argv[2]}))
|
||||
" "$USERNAME" "$PASSWORD")")
|
||||
|
||||
local success
|
||||
success=$(echo "$resp" | python3 -c "import sys,json; print(json.load(sys.stdin).get('success', False))")
|
||||
if [[ "$success" != "True" ]]; then
|
||||
echo "ERROR: Login failed: ${resp}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
USER_ID=$(echo "$resp" | python3 -c "import sys,json; print(json.load(sys.stdin)['data']['id'])")
|
||||
echo " Logged in (user ID: ${USER_ID})."
|
||||
}
|
||||
|
||||
# ── API call helper (cookie + New-Api-User header) ─────
|
||||
api_call() {
|
||||
local endpoint="$1"
|
||||
shift
|
||||
curl -s -b "$COOKIE_JAR" \
|
||||
-H "New-Api-User: ${USER_ID}" \
|
||||
-H "Content-Type: application/json" \
|
||||
"${API_BASE}${endpoint}" "$@"
|
||||
}
|
||||
|
||||
# ── Create channel ──────────────────────────────────────
|
||||
create_channel() {
|
||||
local name="$1" type="$2" key="$3" base_url="$4" priority="$5" models="$6" model_mapping="$7"
|
||||
|
||||
echo "Creating channel: ${name} (priority ${priority})..."
|
||||
|
||||
local payload
|
||||
payload=$(python3 -c "
|
||||
import json, sys
|
||||
print(json.dumps({
|
||||
'type': int(sys.argv[1]),
|
||||
'name': sys.argv[2],
|
||||
'key': sys.argv[3],
|
||||
'base_url': sys.argv[4],
|
||||
'models': sys.argv[5],
|
||||
'model_mapping': sys.argv[6],
|
||||
'priority': int(sys.argv[7]),
|
||||
'status': 1,
|
||||
'group': 'default',
|
||||
'weight': 1,
|
||||
'auto_ban': 1
|
||||
}))
|
||||
" "$type" "$name" "$key" "$base_url" "$models" "$model_mapping" "$priority")
|
||||
|
||||
local resp success
|
||||
resp=$(api_call "/api/channel/" -d "$payload")
|
||||
|
||||
success=$(echo "$resp" | python3 -c "import sys,json; print(json.load(sys.stdin).get('success', False))")
|
||||
if [[ "$success" == "True" ]]; then
|
||||
echo " OK"
|
||||
else
|
||||
echo " FAILED: ${resp}" | head -c 500
|
||||
echo
|
||||
fi
|
||||
}
|
||||
|
||||
# ── Wait for new-api ────────────────────────────────────
|
||||
echo "Waiting for new-api at ${API_BASE}..."
|
||||
for i in $(seq 1 30); do
|
||||
if curl -sf "${API_BASE}/" > /dev/null 2>&1; then
|
||||
echo "new-api is ready."
|
||||
break
|
||||
fi
|
||||
if [[ "$i" == "30" ]]; then
|
||||
echo "ERROR: new-api did not become ready in time."
|
||||
exit 1
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
|
||||
# ── Login ───────────────────────────────────────────────
|
||||
login
|
||||
|
||||
# ── Generate system access token for future use ─────────
|
||||
echo ""
|
||||
echo "Generating system access token..."
|
||||
ACCESS_TOKEN_RESP=$(api_call "/api/user/token")
|
||||
ACCESS_TOKEN=$(echo "$ACCESS_TOKEN_RESP" | python3 -c "
|
||||
import sys, json
|
||||
data = json.load(sys.stdin)
|
||||
if data.get('success'):
|
||||
print(data.get('data', ''))
|
||||
else:
|
||||
print('')
|
||||
" 2>/dev/null || echo "")
|
||||
|
||||
if [[ -n "$ACCESS_TOKEN" ]]; then
|
||||
echo " Access token: ${ACCESS_TOKEN}"
|
||||
echo " Save as NEW_API_ACCESS_TOKEN in .env for future API use."
|
||||
echo " Usage: -H 'Authorization: Bearer ${ACCESS_TOKEN}' -H 'New-Api-User: ${USER_ID}'"
|
||||
else
|
||||
echo " Could not generate access token (non-critical, using session)."
|
||||
fi
|
||||
|
||||
# ── Channels ────────────────────────────────────────────
|
||||
|
||||
create_channel "DeepInfra" 1 \
|
||||
"${DEEPINFRA_API_KEY:?}" \
|
||||
"https://api.deepinfra.com/v1/openai" \
|
||||
1 \
|
||||
"deepseek-v3.2,deepseek-r1,gpt-oss,gpt-oss-20b,nemotron-super,nemotron-nano,devstral,glm-4.6,glm-4.7,glm-5,kimi-k2,kimi-k2.5" \
|
||||
'{"deepseek-v3.2":"deepseek-ai/DeepSeek-V3.2","deepseek-r1":"deepseek-ai/DeepSeek-R1","gpt-oss":"openai/gpt-oss-120b","gpt-oss-20b":"openai/gpt-oss-20b","nemotron-super":"nvidia/Llama-3.3-Nemotron-Super-49B-v1.5","nemotron-nano":"nvidia/NVIDIA-Nemotron-Nano-9B-v2","devstral":"mistralai/Devstral-Small-2505","glm-4.6":"zai-org/GLM-4.6","glm-4.7":"zai-org/GLM-4.7","glm-5":"zai-org/GLM-5","kimi-k2":"moonshotai/Kimi-K2-Instruct-0905","kimi-k2.5":"moonshotai/Kimi-K2.5"}'
|
||||
|
||||
create_channel "SiliconFlow" 1 \
|
||||
"${SILICONFLOW_API_KEY:?}" \
|
||||
"https://api.siliconflow.com/v1" \
|
||||
2 \
|
||||
"deepseek-v3.2,glm-4.7,kimi-k2,qwen3-coder,qwen3-coder-30b" \
|
||||
'{"deepseek-v3.2":"deepseek-ai/DeepSeek-V3.2","glm-4.7":"THUDM/GLM-4-32B-0414","kimi-k2":"moonshotai/Kimi-K2-Instruct-0905","qwen3-coder":"Qwen/Qwen3-Coder-480B-A35B-Instruct","qwen3-coder-30b":"Qwen/Qwen3-Coder-30B-A3B-Instruct"}'
|
||||
|
||||
create_channel "OpenRouter" 1 \
|
||||
"${OPENROUTER_API_KEY:?}" \
|
||||
"https://openrouter.ai/api/v1" \
|
||||
3 \
|
||||
"deepseek-v3.2,deepseek-v3-free,kimi-k2.5,minimax-m2.5,gpt-4.1-mini,gpt-4.1,gemini-3-flash-preview,gemini-2.5-pro,claude-sonnet,trinity-large-preview" \
|
||||
'{"deepseek-v3.2":"deepseek/deepseek-chat-v3-0324","deepseek-v3-free":"deepseek/deepseek-chat-v3-0324:free","kimi-k2.5":"moonshotai/kimi-k2.5","minimax-m2.5":"minimax/minimax-m2.5","gpt-4.1-mini":"openai/gpt-4.1-mini","gpt-4.1":"openai/gpt-4.1","gemini-3-flash-preview":"google/gemini-3-flash-preview","gemini-2.5-pro":"google/gemini-2.5-pro-preview","claude-sonnet":"anthropic/claude-sonnet-4","trinity-large-preview":"arcee-ai/trinity-large-preview"}'
|
||||
|
||||
create_channel "Groq" 1 \
|
||||
"${GROQ_API_KEY:?}" \
|
||||
"https://api.groq.com/openai/v1" \
|
||||
1 \
|
||||
"llama-3.3-70b" \
|
||||
'{"llama-3.3-70b":"llama-3.3-70b-versatile"}'
|
||||
|
||||
create_channel "Cerebras" 1 \
|
||||
"${CEREBRAS_API_KEY:?}" \
|
||||
"https://api.cerebras.ai/v1" \
|
||||
1 \
|
||||
"llama-3.3-70b-cerebras" \
|
||||
'{"llama-3.3-70b-cerebras":"llama-3.3-70b"}'
|
||||
|
||||
# ── Create API token for Open WebUI ─────────────────────
|
||||
echo ""
|
||||
echo "Creating API token for Open WebUI..."
|
||||
TOKEN_RESP=$(api_call "/api/token/" -d "$(python3 -c "
|
||||
import json
|
||||
print(json.dumps({
|
||||
'name': 'open-webui',
|
||||
'remain_quota': 0,
|
||||
'unlimited_quota': True
|
||||
}))
|
||||
")")
|
||||
|
||||
TOKEN_KEY=$(echo "$TOKEN_RESP" | python3 -c "
|
||||
import sys, json
|
||||
data = json.load(sys.stdin)
|
||||
if data.get('success'):
|
||||
print(data['data']['key'])
|
||||
else:
|
||||
print('FAILED: ' + data.get('message', 'unknown error'))
|
||||
" 2>/dev/null || echo "FAILED: could not parse response")
|
||||
|
||||
echo ""
|
||||
echo "══════════════════════════════════════"
|
||||
echo "Setup complete!"
|
||||
echo ""
|
||||
if [[ "$TOKEN_KEY" != FAILED* ]]; then
|
||||
echo "Open WebUI API key: ${TOKEN_KEY}"
|
||||
echo "Set OPENWEBUI_API_KEY=${TOKEN_KEY} in your .env"
|
||||
echo ""
|
||||
echo "Test:"
|
||||
echo " curl ${API_BASE}/v1/chat/completions \\"
|
||||
echo " -H 'Authorization: Bearer ${TOKEN_KEY}' \\"
|
||||
echo " -H 'Content-Type: application/json' \\"
|
||||
echo " -d '{\"model\":\"deepseek-v3.2\",\"messages\":[{\"role\":\"user\",\"content\":\"hi\"}]}'"
|
||||
else
|
||||
echo "Token creation: ${TOKEN_KEY}"
|
||||
echo "Create a token manually in the new-api UI."
|
||||
fi
|
||||
echo "══════════════════════════════════════"
|
||||
Loading…
Reference in a new issue