diff --git a/llm-gateway.yaml b/llm-gateway.yaml deleted file mode 100644 index 4025d81..0000000 --- a/llm-gateway.yaml +++ /dev/null @@ -1,372 +0,0 @@ -server: - listen: "0.0.0.0:3000" - request_timeout: 300s - max_request_body_mb: 10 - session_secret: "${SESSION_SECRET}" - default_admin: - username: "${ADMIN_USERNAME}" - password: "${ADMIN_PASSWORD}" - -tokens: - - name: "open-webui" - key: "${OPENWEBUI_API_KEY}" - rate_limit_rpm: 0 # unlimited - daily_budget_usd: 0 - - name: "opencode" - key: "${OPENCODE_API_KEY}" - rate_limit_rpm: 0 # unlimited - daily_budget_usd: 0 - -pricing_lookup: - # url: "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json" # default - refresh_interval: 6h - -database: - path: "/data/gateway.db" - retention_days: 90 - -debug: - enabled: true - retention_days: 90 - # data_dir: "/data" # defaults to directory of database.path - # max_body_bytes: 0 # 0 = unlimited (save full bodies) - -cache: - enabled: true - address: "valkey:6379" - ttl: 3600 - -providers: - - name: deepinfra - base_url: "https://api.deepinfra.com/v1/openai" - api_key: "${DEEPINFRA_API_KEY}" - priority: 1 - timeout: 120s - - name: siliconflow - base_url: "https://api.siliconflow.com/v1" - api_key: "${SILICONFLOW_API_KEY}" - priority: 2 - timeout: 120s - - name: openrouter - base_url: "https://openrouter.ai/api/v1" - api_key: "${OPENROUTER_API_KEY}" - priority: 3 - timeout: 120s - - name: groq - base_url: "https://api.groq.com/openai/v1" - api_key: "${GROQ_API_KEY}" - priority: 1 - timeout: 120s - - name: cerebras - base_url: "https://api.cerebras.ai/v1" - api_key: "${CEREBRAS_API_KEY}" - priority: 1 - timeout: 120s - -models: - # ═══ TIER 1: Free (OpenRouter free models, $0) ═══ - # NOTE: Commented out — free models are heavily rate-limited upstream. - # Uncomment if you want best-effort free access. - # - name: "llama-3.3-70b-free" - # routes: - # - provider: openrouter - # model: "meta-llama/llama-3.3-70b-instruct:free" - # - name: "deepseek-r1-free" - # routes: - # - provider: openrouter - # model: "deepseek/deepseek-r1-0528:free" - # - name: "gpt-oss-free" - # routes: - # - provider: openrouter - # model: "openai/gpt-oss-120b:free" - # - name: "gpt-oss-20b-free" - # routes: - # - provider: openrouter - # model: "openai/gpt-oss-20b:free" - # - name: "qwen3-coder-free" - # routes: - # - provider: openrouter - # model: "qwen/qwen3-coder:free" - # - name: "qwen3-235b-free" - # routes: - # - provider: openrouter - # model: "qwen/qwen3-235b-a22b-thinking-2507" - # - name: "glm-4.5-air-free" - # routes: - # - provider: openrouter - # model: "z-ai/glm-4.5-air:free" - # - name: "nemotron-nano-free" - # routes: - # - provider: openrouter - # model: "nvidia/nemotron-nano-9b-v2:free" - # - name: "trinity-large-free" - # routes: - # - provider: openrouter - # model: "arcee-ai/trinity-large-preview:free" - # - name: "mistral-small-free" - # routes: - # - provider: openrouter - # model: "mistralai/mistral-small-3.1-24b-instruct:free" - # - name: "gemma-3-27b-free" - # routes: - # - provider: openrouter - # model: "google/gemma-3-27b-it:free" - # - name: "step-3.5-flash-free" - # routes: - # - provider: openrouter - # model: "stepfun/step-3.5-flash:free" - - # ═══ TIER 2: Low cost (Groq, Cerebras — free tier with rate limits) ═══ - - name: "llama-3.1-8b" - routes: - - provider: groq - model: "llama-3.1-8b-instant" - pricing: { input: 0.05, output: 0.08 } - - provider: cerebras - model: "llama3.1-8b" - pricing: { input: 0.10, output: 0.10 } - - provider: deepinfra - model: "meta-llama/Meta-Llama-3.1-8B-Instruct" - pricing: { input: 0.03, output: 0.05 } - - - name: "llama-3.3-70b" - routes: - - provider: deepinfra - model: "meta-llama/Llama-3.3-70B-Instruct-Turbo" - pricing: { input: 0.23, output: 0.40 } - - provider: groq - model: "llama-3.3-70b-versatile" - pricing: { input: 0.59, output: 0.79 } - - provider: cerebras - model: "llama-3.3-70b" - pricing: { input: 0.85, output: 1.20 } - - - name: "gpt-oss" - routes: - - provider: groq - model: "openai/gpt-oss-120b" - pricing: { input: 0.15, output: 0.60 } - - provider: cerebras - model: "gpt-oss-120b" - pricing: { input: 0.35, output: 0.75 } - - provider: deepinfra - model: "openai/gpt-oss-120b" - pricing: { input: 0.05, output: 0.24 } - - - name: "gpt-oss-20b" - routes: - - provider: groq - model: "openai/gpt-oss-20b" - pricing: { input: 0.075, output: 0.30 } - - provider: deepinfra - model: "openai/gpt-oss-20b" - pricing: { input: 0.04, output: 0.16 } - - - name: "llama-4-scout" - routes: - - provider: groq - model: "meta-llama/llama-4-scout-17b-16e-instruct" - pricing: { input: 0.11, output: 0.34 } - - - name: "llama-4-maverick" - routes: - - provider: groq - model: "meta-llama/llama-4-maverick-17b-128e-instruct" - pricing: { input: 0.20, output: 0.60 } - - - name: "qwen3-32b" - routes: - - provider: groq - model: "qwen/qwen3-32b" - pricing: { input: 0.29, output: 0.59 } - - provider: cerebras - model: "qwen-3-32b" - - # ═══ TIER 3: DeepSeek V3.2 (cheapest flagship) ═══ - - name: "deepseek-v3.2" - routes: - - provider: deepinfra - model: "deepseek-ai/DeepSeek-V3.2" - pricing: { input: 0.26, output: 0.38 } - - provider: siliconflow - model: "deepseek-ai/DeepSeek-V3.2" - pricing: { input: 0.27, output: 0.42 } - - provider: openrouter - model: "deepseek/deepseek-chat-v3-0324" - pricing: { input: 0.30, output: 0.88 } - - # ═══ TIER 4: Ultra-cheap DeepInfra ═══ - - name: "nemotron-super" - routes: - - provider: deepinfra - model: "nvidia/Llama-3.3-Nemotron-Super-49B-v1.5" - pricing: { input: 0.10, output: 0.40 } - - - name: "nemotron-nano" - routes: - - provider: deepinfra - model: "nvidia/NVIDIA-Nemotron-Nano-9B-v2" - pricing: { input: 0.04, output: 0.16 } - - # ═══ TIER 5: DeepSeek R1 & reasoning ═══ - - name: "deepseek-r1" - routes: - - provider: deepinfra - model: "deepseek-ai/DeepSeek-R1-0528" - - provider: openrouter - model: "deepseek/deepseek-r1" - - - name: "deepseek-r1-distill-llama-70b" - routes: - - provider: deepinfra - model: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B" - - - name: "devstral-small" - routes: - - provider: openrouter - model: "mistralai/devstral-small" - - - name: "devstral-medium" - routes: - - provider: openrouter - model: "mistralai/devstral-medium" - - # ═══ TIER 6: GLM ═══ - - name: "glm-4.6" - routes: - - provider: deepinfra - model: "zai-org/GLM-4.6" - pricing: { input: 0.60, output: 1.90 } - - - name: "glm-4.7" - routes: - - provider: deepinfra - model: "zai-org/GLM-4.7" - pricing: { input: 0.40, output: 1.75 } - - provider: cerebras - model: "zai-glm-4.7" - pricing: { input: 2.25, output: 2.75 } - - provider: siliconflow - model: "THUDM/GLM-4-32B-0414" - - - name: "glm-5" - routes: - - provider: deepinfra - model: "zai-org/GLM-5" - pricing: { input: 0.80, output: 2.56 } - - # ═══ TIER 7: Kimi ═══ - - name: "kimi-k2" - routes: - - provider: groq - model: "moonshotai/kimi-k2-instruct-0905" - pricing: { input: 1.00, output: 3.00 } - - provider: deepinfra - model: "moonshotai/Kimi-K2-Instruct-0905" - pricing: { input: 0.50, output: 2.00 } - - provider: siliconflow - model: "moonshotai/Kimi-K2-Instruct-0905" - pricing: { input: 0.58, output: 2.29 } - - - name: "kimi-k2.5" - routes: - - provider: deepinfra - model: "moonshotai/Kimi-K2.5" - pricing: { input: 0.45, output: 2.25 } - - provider: openrouter - model: "moonshotai/kimi-k2.5" - - # ═══ TIER 8: SiliconFlow (Qwen) ═══ - - name: "qwen3-coder" - routes: - - provider: siliconflow - model: "Qwen/Qwen3-Coder-480B-A35B-Instruct" - pricing: { input: 1.14, output: 2.28 } - - - name: "qwen3-coder-30b" - routes: - - provider: siliconflow - model: "Qwen/Qwen3-Coder-30B-A3B-Instruct" - - # ═══ TIER 9: OpenRouter premium (paid) ═══ - - name: "minimax-m2.5" - routes: - - provider: openrouter - model: "minimax/minimax-m2.5" - - - name: "gpt-4.1-mini" - routes: - - provider: openrouter - model: "openai/gpt-4.1-mini" - - - name: "gpt-4.1" - routes: - - provider: openrouter - model: "openai/gpt-4.1" - - - name: "gemini-3-flash-preview" - routes: - - provider: openrouter - model: "google/gemini-3-flash-preview" - - - name: "gemini-2.5-pro" - routes: - - provider: openrouter - model: "google/gemini-2.5-pro-preview" - - # ═══ TIER 10: Vision / Multimodal ═══ - - name: "gemma-3-4b" - routes: - - provider: openrouter - model: "google/gemma-3-4b-it" - pricing: { input: 0.017, output: 0.068 } - - provider: deepinfra - model: "google/gemma-3-4b-it" - pricing: { input: 0.04, output: 0.08 } - - - name: "gemma-3-12b" - routes: - - provider: openrouter - model: "google/gemma-3-12b-it" - pricing: { input: 0.03, output: 0.10 } - - provider: deepinfra - model: "google/gemma-3-12b-it" - pricing: { input: 0.04, output: 0.13 } - - - name: "gemma-3-27b" - routes: - - provider: openrouter - model: "google/gemma-3-27b-it" - pricing: { input: 0.04, output: 0.15 } - - provider: deepinfra - model: "google/gemma-3-27b-it" - pricing: { input: 0.08, output: 0.16 } - - - name: "qwen3-vl-8b" - routes: - - provider: openrouter - model: "qwen/qwen3-vl-8b-instruct" - pricing: { input: 0.08, output: 0.50 } - - provider: deepinfra - model: "Qwen/Qwen3-VL-8B-Instruct" - pricing: { input: 0.18, output: 0.69 } - - - name: "qwen3-vl-32b" - routes: - - provider: openrouter - model: "qwen/qwen3-vl-32b-instruct" - pricing: { input: 0.104, output: 0.416 } - - - name: "qwen2.5-vl-32b" - routes: - - provider: openrouter - model: "qwen/qwen2.5-vl-32b-instruct" - pricing: { input: 0.05, output: 0.22 } - - provider: deepinfra - model: "Qwen/Qwen2.5-VL-32B-Instruct" - pricing: { input: 0.20, output: 0.60 } - - - name: "claude-sonnet" - routes: - - provider: openrouter - model: "anthropic/claude-sonnet-4" diff --git a/llm-gateway/.env.example b/llm-gateway/.env.example deleted file mode 100644 index a1acce2..0000000 --- a/llm-gateway/.env.example +++ /dev/null @@ -1,19 +0,0 @@ -# LLM Gateway Environment Variables - -# Session secret (required for persistent sessions) -SESSION_SECRET=change-me-to-a-random-string - -# Default admin (created on first run if no users exist) -ADMIN_USERNAME=admin -ADMIN_PASSWORD=change-me-min-8-chars - -# Static API tokens (seeded on startup) -OPENWEBUI_API_KEY=sk-your-openwebui-key -PERSONAL_API_KEY=sk-your-personal-key - -# Provider API keys -DEEPINFRA_API_KEY= -SILICONFLOW_API_KEY= -OPENROUTER_API_KEY= -GROQ_API_KEY= -CEREBRAS_API_KEY= diff --git a/llm-gateway/.gitignore b/llm-gateway/.gitignore deleted file mode 100644 index 759cb20..0000000 --- a/llm-gateway/.gitignore +++ /dev/null @@ -1,18 +0,0 @@ -# Binaries -gateway -llm-gateway - -# Database -*.db -*.db-journal -*.db-wal -*.db-shm - -# Debug log files -debug-logs/ - -# Local config -configs/config.local.yaml - -# Environment -.env diff --git a/llm-gateway/Dockerfile b/llm-gateway/Dockerfile deleted file mode 100644 index 9fdb103..0000000 --- a/llm-gateway/Dockerfile +++ /dev/null @@ -1,15 +0,0 @@ -FROM golang:1.24-alpine AS builder -WORKDIR /src -COPY go.mod go.sum ./ -RUN go mod download -COPY . . -RUN CGO_ENABLED=0 go build -ldflags="-s -w" -o /llm-gateway ./cmd/gateway - -FROM alpine:3.19 -RUN apk add --no-cache ca-certificates tzdata -COPY --from=builder /llm-gateway /usr/local/bin/llm-gateway -RUN mkdir -p /data -VOLUME /data -EXPOSE 3000 -ENTRYPOINT ["llm-gateway"] -CMD ["-config", "/etc/llm-gateway/config.yaml"] diff --git a/llm-gateway/Makefile b/llm-gateway/Makefile deleted file mode 100644 index f042edf..0000000 --- a/llm-gateway/Makefile +++ /dev/null @@ -1,16 +0,0 @@ -.PHONY: build run clean docker - -BINARY=llm-gateway -VERSION=$(shell git describe --tags --always --dirty 2>/dev/null || echo dev) - -build: - go build -ldflags="-s -w -X main.version=$(VERSION)" -o $(BINARY) ./cmd/gateway - -run: build - ./$(BINARY) -config configs/config.yaml - -clean: - rm -f $(BINARY) - -docker: - docker build -t llm-gateway:latest . diff --git a/llm-gateway/cmd/gateway/main.go b/llm-gateway/cmd/gateway/main.go deleted file mode 100644 index c8bbd06..0000000 --- a/llm-gateway/cmd/gateway/main.go +++ /dev/null @@ -1,434 +0,0 @@ -package main - -import ( - "context" - "flag" - "log" - "net/http" - "os" - "os/signal" - "path/filepath" - "syscall" - "time" - - "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" - gocors "github.com/go-chi/cors" - "github.com/prometheus/client_golang/prometheus/promhttp" - - "llm-gateway/internal/auth" - "llm-gateway/internal/cache" - "llm-gateway/internal/config" - "llm-gateway/internal/dashboard" - "llm-gateway/internal/metrics" - "llm-gateway/internal/pricing" - "llm-gateway/internal/provider" - "llm-gateway/internal/proxy" - "llm-gateway/internal/storage" - "llm-gateway/internal/webhook" -) - -var version = "dev" - -func main() { - configPath := flag.String("config", "configs/config.yaml", "path to config file") - flag.Parse() - - log.Printf("llm-gateway %s starting", version) - - cfg, err := config.Load(*configPath) - if err != nil { - log.Fatalf("Failed to load config: %v", err) - } - - // Pricing lookup (fetches from URL, refreshes periodically) - pricingLookup := pricing.NewLookup(cfg.Pricing.URL, cfg.Pricing.RefreshInterval) - defer pricingLookup.Close() - - // Auto-fill missing pricing from fetched data - for i, m := range cfg.Models { - for j, r := range m.Routes { - if r.Pricing.Input == 0 && r.Pricing.Output == 0 { - if pricingLookup.FillMissing(r.Provider, r.Model, &cfg.Models[i].Routes[j].Pricing.Input, &cfg.Models[i].Routes[j].Pricing.Output) { - log.Printf("Auto-filled pricing for %s via %s: $%.2f/$%.2f per 1M tokens", - m.Name, r.Provider, cfg.Models[i].Routes[j].Pricing.Input, cfg.Models[i].Routes[j].Pricing.Output) - } - } - } - } - - // Database - db, err := storage.Open(cfg.Database.Path) - if err != nil { - log.Fatalf("Failed to open database: %v", err) - } - defer db.Close() - - if err := db.CleanupOldRecords(cfg.Database.RetentionDays); err != nil { - log.Printf("WARNING: retention cleanup failed: %v", err) - } - - asyncLogger := storage.NewAsyncLogger(db, 1000) - defer asyncLogger.Close() - - // SSE broker for real-time dashboard updates - sseBroker := dashboard.NewSSEBroker() - asyncLogger.OnFlush = sseBroker.Notify - - // Cache (optional) - var c *cache.Cache - if cfg.Cache.Enabled { - c, err = cache.New(cfg.Cache.Address, cfg.Cache.TTL) - if err != nil { - log.Printf("WARNING: cache disabled: %v", err) - } else { - log.Printf("Cache connected to %s", cfg.Cache.Address) - } - } - - // Provider registry - registry, err := provider.NewRegistry(cfg) - if err != nil { - log.Fatalf("Failed to build provider registry: %v", err) - } - log.Printf("Registered %d models", len(cfg.Models)) - - // Provider health tracker - healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker) - - // Webhook notifier - var notifier *webhook.Notifier - if len(cfg.Webhooks) > 0 { - notifier = webhook.NewNotifier(cfg.Webhooks) - defer notifier.Close() - log.Printf("Webhooks configured: %d endpoints", len(cfg.Webhooks)) - - // Wire health tracker state changes to webhook - healthTracker.OnStateChange = func(providerName string, from, to provider.CircuitState) { - eventType := webhook.EventCircuitBreakerOpen - if to == provider.CircuitClosed { - eventType = webhook.EventCircuitBreakerClosed - } - notifier.Notify(webhook.Event{ - Type: eventType, - Data: map[string]any{ - "provider": providerName, - "from": from.String(), - "to": to.String(), - }, - }) - } - } - - // Auth store (static tokens checked in-memory, not seeded to DB) - var staticTokens []auth.StaticToken - for _, t := range cfg.Tokens { - if t.Key != "" { - staticTokens = append(staticTokens, auth.StaticToken{ - Name: t.Name, - Key: t.Key, - RateLimitRPM: t.RateLimitRPM, - DailyBudgetUSD: t.DailyBudgetUSD, - MonthlyBudgetUSD: t.MonthlyBudgetUSD, - MaxConcurrent: t.MaxConcurrent, - }) - log.Printf("Loaded static token: %s", t.Name) - } - } - authStore := auth.NewStore(db.DB, staticTokens) - authMiddleware := auth.NewMiddleware(authStore) - authHandlers := auth.NewHandlers(authStore, cfg.Server.SessionSecret) - - // Audit logger - auditLogger := storage.NewAuditLogger(db) - auditLogger.OnWrite = sseBroker.Notify - authHandlers.SetAuditLogger(auditLogger) - - // Debug logger - debugDataDir := cfg.Debug.DataDir - if debugDataDir == "" { - debugDataDir = filepath.Dir(cfg.Database.Path) - } - debugLogger := storage.NewDebugLogger(db, cfg.Debug.Enabled, debugDataDir) - debugLogger.OnWrite = sseBroker.Notify - - // Seed default admin - seedDefaultAdmin(cfg, authStore) - - // Metrics - m := metrics.New() - - // Handlers - proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker) - proxyHandler.SetDebugLogger(debugLogger) - - // Request deduplication - if cfg.Dedup.Enabled { - dedup := proxy.NewDeduplicator(cfg.Dedup.Window) - defer dedup.Close() - proxyHandler.SetDeduplicator(dedup) - log.Printf("Request deduplication enabled (window: %v)", cfg.Dedup.Window) - } - - modelsHandler := proxy.NewModelsHandler(registry, healthTracker, cfg) - proxyAuth := proxy.NewAuthMiddleware(authStore) - rateLimiter := proxy.NewRateLimiter(db) - if notifier != nil { - rateLimiter.SetNotifier(notifier) - } - concurrencyLimiter := proxy.NewConcurrencyLimiter() - statsAPI := dashboard.NewStatsAPI(db, authStore) - statsAPI.SetHealthTracker(healthTracker) - statsAPI.SetAuditLogger(auditLogger) - statsAPI.SetDebugLogger(debugLogger) - statsAPI.SetConfigPath(*configPath) - if c != nil { - statsAPI.SetCache(c) - } - dash := dashboard.NewDashboard(authStore, statsAPI) - dash.SetRegistry(registry) - dash.SetAuditLogger(auditLogger) - dash.SetDebugLogger(debugLogger) - if c != nil { - dash.SetCache(c) - } - - // Export handler - exportHandler := dashboard.NewExportHandler(db, authStore) - - // Router - r := chi.NewRouter() - - // CORS (before other middleware) - if cfg.CORS.Enabled { - r.Use(gocors.Handler(gocors.Options{ - AllowedOrigins: cfg.CORS.AllowedOrigins, - AllowedMethods: cfg.CORS.AllowedMethods, - AllowedHeaders: cfg.CORS.AllowedHeaders, - MaxAge: cfg.CORS.MaxAge, - AllowCredentials: true, - })) - } - - r.Use(middleware.RealIP) - r.Use(middleware.Recoverer) - r.Use(middleware.RequestID) - - // Health & metrics (public) - r.Get("/health", func(w http.ResponseWriter, r *http.Request) { - if err := db.Ping(); err != nil { - http.Error(w, "database unhealthy", http.StatusServiceUnavailable) - return - } - if c != nil { - if err := c.Ping(r.Context()); err != nil { - http.Error(w, "cache unhealthy", http.StatusServiceUnavailable) - return - } - } - w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) - }) - r.Handle("/metrics", promhttp.Handler()) - - // OpenAI-compatible API (API token auth via Bearer header) - r.Group(func(r chi.Router) { - r.Use(proxyAuth.Authenticate) - r.Use(rateLimiter.Check) - r.Use(concurrencyLimiter.Check) - r.Post("/v1/chat/completions", proxyHandler.ChatCompletions) - r.Post("/v1/embeddings", proxyHandler.Embeddings) - r.Get("/v1/models", modelsHandler.ListModels) - }) - - // Auth pages (public) - r.Get("/login", dash.LoginPage) - r.Get("/setup", dash.SetupPage) - - // Auth API endpoints (public) - r.Post("/api/auth/login", authHandlers.Login) - r.Post("/api/auth/setup", authHandlers.Setup) - r.Post("/api/auth/login/totp", authHandlers.LoginTOTP) - - // Favicon (prevent 401 noise in browser console) - r.Get("/favicon.ico", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNoContent) - }) - - // Root redirect - r.Get("/", func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "/dashboard", http.StatusFound) - }) - - // Authenticated pages and API - r.Group(func(r chi.Router) { - r.Use(authMiddleware.RequireAuth) - - // Dashboard pages (HTMX) - r.Get("/dashboard", dash.DashboardPage) - r.Get("/logs", dash.LogsPage) - r.Get("/models", dash.ModelsPage) - r.Get("/tokens", dash.TokensPage) - r.Get("/settings", dash.SettingsPage) - - // Admin-only pages - r.Group(func(r chi.Router) { - r.Use(authMiddleware.RequireAdmin) - r.Get("/users", dash.UsersPage) - r.Get("/audit", dash.AuditPage) - r.Get("/debug", dash.DebugPage) - }) - - // Auth API - r.Post("/api/auth/logout", authHandlers.Logout) - r.Get("/api/auth/me", authHandlers.Me) - r.Put("/api/auth/me/password", authHandlers.ChangePassword) - r.Put("/api/auth/me/username", authHandlers.ChangeUsername) - r.Put("/api/auth/me/email", authHandlers.ChangeEmail) - r.Post("/api/auth/totp/setup", authHandlers.TOTPSetup) - r.Post("/api/auth/totp/verify", authHandlers.TOTPVerify) - r.Delete("/api/auth/totp", authHandlers.TOTPDisable) - - // API token management - r.Get("/api/tokens", authHandlers.ListTokens) - r.Post("/api/tokens", authHandlers.CreateToken) - r.Delete("/api/tokens/{id}", authHandlers.DeleteToken) - - // SSE events - r.Get("/api/events", sseBroker.ServeHTTP) - - // Dashboard stats - r.Get("/api/stats/summary", statsAPI.Summary) - r.Get("/api/stats/models", statsAPI.Models) - r.Get("/api/stats/providers", statsAPI.Providers) - r.Get("/api/stats/tokens", statsAPI.Tokens) - r.Get("/api/stats/timeseries", statsAPI.Timeseries) - r.Get("/api/stats/logs", statsAPI.Logs) - r.Get("/api/stats/latency", statsAPI.Latency) - r.Get("/api/stats/cost-breakdown", statsAPI.CostBreakdown) - r.Get("/api/stats/provider-health", statsAPI.ProviderHealthHandler) - r.Get("/api/stats/cache", statsAPI.CacheStats) - - // Data export - r.Get("/api/export/logs", exportHandler.ExportLogs) - r.Get("/api/export/stats", exportHandler.ExportStats) - - // Admin-only: user management, audit, debug, config validation - r.Group(func(r chi.Router) { - r.Use(authMiddleware.RequireAdmin) - r.Get("/api/auth/users", authHandlers.ListUsers) - r.Post("/api/auth/users", authHandlers.CreateUser) - r.Delete("/api/auth/users/{id}", authHandlers.DeleteUser) - - // Audit log - r.Get("/api/stats/audit", statsAPI.AuditLogs) - - // Config validation - r.Get("/api/config/validate", statsAPI.ValidateConfig) - - // Debug logging - r.Post("/api/debug/toggle", statsAPI.DebugToggle) - r.Get("/api/debug/status", statsAPI.DebugStatus) - r.Get("/api/debug/logs", statsAPI.DebugLogs) - r.Get("/api/debug/logs/{requestID}", statsAPI.DebugLogByRequestID) - }) - }) - - // Periodic session cleanup and debug log cleanup - go func() { - ticker := time.NewTicker(1 * time.Hour) - defer ticker.Stop() - for range ticker.C { - if err := authStore.CleanExpiredSessions(); err != nil { - log.Printf("WARNING: session cleanup failed: %v", err) - } - if err := debugLogger.Cleanup(cfg.Debug.RetentionDays); err != nil { - log.Printf("WARNING: debug log cleanup failed: %v", err) - } - } - }() - - // Server - srv := &http.Server{ - Addr: cfg.Server.Listen, - Handler: r, - ReadTimeout: 30 * time.Second, - WriteTimeout: cfg.Server.RequestTimeout + 10*time.Second, - IdleTimeout: 120 * time.Second, - } - - // Config hot-reload via SIGHUP - config.WatchReload(*configPath, func(newCfg *config.Config) { - // Reload registry (models, providers, routes) - if err := registry.Reload(newCfg); err != nil { - log.Printf("ERROR: registry reload failed: %v", err) - return - } - log.Printf("Reloaded %d models", len(newCfg.Models)) - - // Reload pricing - for i, m := range newCfg.Models { - for j, rt := range m.Routes { - if rt.Pricing.Input == 0 && rt.Pricing.Output == 0 { - pricingLookup.FillMissing(rt.Provider, rt.Model, - &newCfg.Models[i].Routes[j].Pricing.Input, - &newCfg.Models[i].Routes[j].Pricing.Output) - } - } - } - - // Reload static tokens - var newStaticTokens []auth.StaticToken - for _, t := range newCfg.Tokens { - if t.Key != "" { - newStaticTokens = append(newStaticTokens, auth.StaticToken{ - Name: t.Name, - Key: t.Key, - RateLimitRPM: t.RateLimitRPM, - DailyBudgetUSD: t.DailyBudgetUSD, - MonthlyBudgetUSD: t.MonthlyBudgetUSD, - MaxConcurrent: t.MaxConcurrent, - }) - } - } - authStore.SetStaticTokens(newStaticTokens) - - // Update config pointer for retry/debug/etc - cfg = newCfg - }) - - // Graceful shutdown - done := make(chan os.Signal, 1) - signal.Notify(done, os.Interrupt, syscall.SIGTERM) - - go func() { - log.Printf("Listening on %s", cfg.Server.Listen) - if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("Server failed: %v", err) - } - }() - - <-done - log.Println("Shutting down...") - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - srv.Shutdown(ctx) - - log.Println("Stopped") -} - -// seedDefaultAdmin creates the default admin user if no users exist. -func seedDefaultAdmin(cfg *config.Config, authStore *auth.Store) { - if !authStore.HasAnyUser() { - da := cfg.Server.DefaultAdmin - if da.Username != "" && da.Password != "" { - user, err := authStore.CreateUser(da.Username, da.Password, true) - if err != nil { - log.Printf("WARNING: failed to create default admin: %v", err) - } else { - log.Printf("Created default admin user: %s (id=%d)", user.Username, user.ID) - } - } - } -} diff --git a/llm-gateway/configs/config.yaml b/llm-gateway/configs/config.yaml deleted file mode 100644 index 91f1913..0000000 --- a/llm-gateway/configs/config.yaml +++ /dev/null @@ -1,140 +0,0 @@ -server: - listen: "0.0.0.0:3000" - request_timeout: 300s - max_request_body_mb: 10 - session_secret: "${SESSION_SECRET}" - default_admin: - username: "${ADMIN_USERNAME}" - password: "${ADMIN_PASSWORD}" - -tokens: - - name: "open-webui" - key: "${OPENWEBUI_API_KEY}" - rate_limit_rpm: 0 # unlimited - daily_budget_usd: 5.0 - - name: "rayandrew" - key: "${PERSONAL_API_KEY}" - rate_limit_rpm: 0 # unlimited - daily_budget_usd: 10.0 - -pricing_lookup: - # url: "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json" # default - refresh_interval: 6h - -database: - path: "/data/gateway.db" - retention_days: 90 - -cache: - enabled: true - address: "valkey:6379" - ttl: 3600 - -providers: - - name: deepinfra - base_url: "https://api.deepinfra.com/v1/openai" - api_key: "${DEEPINFRA_API_KEY}" - priority: 1 - timeout: 120s - - name: siliconflow - base_url: "https://api.siliconflow.com/v1" - api_key: "${SILICONFLOW_API_KEY}" - priority: 2 - timeout: 120s - - name: openrouter - base_url: "https://openrouter.ai/api/v1" - api_key: "${OPENROUTER_API_KEY}" - priority: 3 - timeout: 120s - - name: groq - base_url: "https://api.groq.com/openai/v1" - api_key: "${GROQ_API_KEY}" - priority: 1 - timeout: 120s - - name: cerebras - base_url: "https://api.cerebras.ai/v1" - api_key: "${CEREBRAS_API_KEY}" - priority: 1 - timeout: 120s - -models: - - name: "deepseek-v3.2" - routes: - - provider: deepinfra - model: "deepseek-ai/DeepSeek-V3.2" - pricing: { input: 0.26, output: 0.38 } - - provider: siliconflow - model: "deepseek-ai/DeepSeek-V3.2" - pricing: { input: 0.27, output: 0.42 } - - provider: openrouter - model: "deepseek/deepseek-chat-v3-0324" - pricing: { input: 0.30, output: 0.88 } - - - name: "llama-3.3-70b" - routes: - - provider: groq - model: "llama-3.3-70b-versatile" - pricing: { input: 0, output: 0 } - - provider: deepinfra - model: "meta-llama/Llama-3.3-70B-Instruct" - pricing: { input: 0.23, output: 0.40 } - - - name: "llama-3.1-8b" - routes: - - provider: groq - model: "llama-3.1-8b-instant" - pricing: { input: 0, output: 0 } - - provider: cerebras - model: "llama-3.1-8b" - pricing: { input: 0, output: 0 } - - provider: deepinfra - model: "meta-llama/Meta-Llama-3.1-8B-Instruct" - pricing: { input: 0.03, output: 0.05 } - - - name: "qwen-2.5-72b" - routes: - - provider: groq - model: "qwen-2.5-72b" - pricing: { input: 0, output: 0 } - - provider: deepinfra - model: "Qwen/Qwen2.5-72B-Instruct" - pricing: { input: 0.23, output: 0.40 } - - - name: "qwen-2.5-coder-32b" - routes: - - provider: groq - model: "qwen-2.5-coder-32b" - pricing: { input: 0, output: 0 } - - provider: deepinfra - model: "Qwen/Qwen2.5-Coder-32B-Instruct" - pricing: { input: 0.07, output: 0.16 } - - - name: "gemma-2-9b" - routes: - - provider: groq - model: "gemma2-9b-it" - pricing: { input: 0, output: 0 } - - - name: "deepseek-r1" - routes: - - provider: deepinfra - model: "deepseek-ai/DeepSeek-R1" - pricing: { input: 0.40, output: 1.60 } - - provider: openrouter - model: "deepseek/deepseek-r1" - pricing: { input: 0.55, output: 2.19 } - - - name: "deepseek-r1-distill-llama-70b" - routes: - - provider: groq - model: "deepseek-r1-distill-llama-70b" - pricing: { input: 0, output: 0 } - - provider: deepinfra - model: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B" - pricing: { input: 0.23, output: 0.69 } - - - name: "deepseek-r1-distill-qwen-32b" - routes: - - provider: deepinfra - model: "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" - pricing: { input: 0.07, output: 0.16 } diff --git a/llm-gateway/go.mod b/llm-gateway/go.mod deleted file mode 100644 index 08c7f3b..0000000 --- a/llm-gateway/go.mod +++ /dev/null @@ -1,39 +0,0 @@ -module llm-gateway - -go 1.24.0 - -require ( - github.com/go-chi/chi/v5 v5.2.5 - github.com/go-chi/cors v1.2.2 - github.com/golang-migrate/migrate/v4 v4.19.1 - github.com/pquerna/otp v1.5.0 - github.com/prometheus/client_golang v1.23.2 - github.com/redis/go-redis/v9 v9.17.3 - golang.org/x/crypto v0.48.0 - gopkg.in/yaml.v3 v3.0.1 - modernc.org/sqlite v1.45.0 -) - -require ( - github.com/beorn7/perks v1.0.1 // indirect - github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dustin/go-humanize v1.0.1 // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/kr/text v0.2.0 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect - github.com/ncruces/go-strftime v1.0.0 // indirect - github.com/prometheus/client_model v0.6.2 // indirect - github.com/prometheus/common v0.66.1 // indirect - github.com/prometheus/procfs v0.16.1 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - go.yaml.in/yaml/v2 v2.4.2 // indirect - golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/sys v0.41.0 // indirect - google.golang.org/protobuf v1.36.8 // indirect - modernc.org/libc v1.67.6 // indirect - modernc.org/mathutil v1.7.1 // indirect - modernc.org/memory v1.11.0 // indirect -) diff --git a/llm-gateway/go.sum b/llm-gateway/go.sum deleted file mode 100644 index 853ad13..0000000 --- a/llm-gateway/go.sum +++ /dev/null @@ -1,123 +0,0 @@ -github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= -github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= -github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= -github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= -github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= -github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= -github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= -github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= -github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= -github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= -github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= -github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= -github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= -github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= -github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= -github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= -github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= -github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= -github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= -github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= -github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= -github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= -github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= -github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= -github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= -github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= -github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= -github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= -github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4= -github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= -github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= -go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= -go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= -golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= -golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= -golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= -golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= -google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= -google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= -modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= -modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= -modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= -modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= -modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= -modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= -modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= -modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= -modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= -modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= -modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= -modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= -modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= -modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= -modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= -modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= -modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= -modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= -modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= -modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= -modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= -modernc.org/sqlite v1.45.0 h1:r51cSGzKpbptxnby+EIIz5fop4VuE4qFoVEjNvWoObs= -modernc.org/sqlite v1.45.0/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= -modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= -modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= -modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= -modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/llm-gateway/internal/auth/handlers.go b/llm-gateway/internal/auth/handlers.go deleted file mode 100644 index 7585587..0000000 --- a/llm-gateway/internal/auth/handlers.go +++ /dev/null @@ -1,1067 +0,0 @@ -package auth - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "html/template" - "net/http" - "strconv" - "strings" - "sync" - "time" - - "github.com/go-chi/chi/v5" - - "llm-gateway/internal/storage" -) - -type Handlers struct { - store *Store - sessionSecret string - loginLimiter *loginRateLimiter - auditLogger *storage.AuditLogger -} - -func NewHandlers(store *Store, sessionSecret string) *Handlers { - return &Handlers{ - store: store, - sessionSecret: sessionSecret, - loginLimiter: newLoginRateLimiter(), - } -} - -func (h *Handlers) SetAuditLogger(al *storage.AuditLogger) { - h.auditLogger = al -} - -func (h *Handlers) audit(r *http.Request, action, targetType, targetID, details string) { - if h.auditLogger == nil { - return - } - user := UserFromContext(r.Context()) - var userID int64 - var username string - if user != nil { - userID = user.ID - username = user.Username - } - ip := r.RemoteAddr - if fwd := r.Header.Get("X-Real-IP"); fwd != "" { - ip = fwd - } - h.auditLogger.Log(storage.AuditEntry{ - UserID: userID, - Username: username, - Action: action, - TargetType: targetType, - TargetID: targetID, - Details: details, - IPAddress: ip, - }) -} - -// Login brute-force protection -type loginRateLimiter struct { - mu sync.Mutex - attempts map[string][]time.Time -} - -func newLoginRateLimiter() *loginRateLimiter { - return &loginRateLimiter{attempts: make(map[string][]time.Time)} -} - -func (l *loginRateLimiter) allow(ip string) bool { - l.mu.Lock() - defer l.mu.Unlock() - - now := time.Now() - cutoff := now.Add(-1 * time.Minute) - - // Clean old entries - recent := l.attempts[ip][:0] - for _, t := range l.attempts[ip] { - if t.After(cutoff) { - recent = append(recent, t) - } - } - l.attempts[ip] = recent - - if len(recent) >= 5 { - return false - } - l.attempts[ip] = append(l.attempts[ip], now) - return true -} - -func (h *Handlers) Status(w http.ResponseWriter, r *http.Request) { - initialized := h.store.HasAnyUser() - - resp := map[string]any{ - "initialized": initialized, - "logged_in": false, - } - - cookie, err := r.Cookie(sessionCookieName) - if err == nil && cookie.Value != "" { - sess, err := h.store.GetSession(cookie.Value) - if err == nil { - user, err := h.store.GetUserByID(sess.UserID) - if err == nil { - resp["logged_in"] = true - resp["user"] = map[string]any{ - "id": user.ID, - "username": user.Username, - "is_admin": user.IsAdmin, - "totp_enabled": user.TOTPEnabled, - } - } - } - } - - writeJSON(w, resp) -} - -func (h *Handlers) Setup(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - - if h.store.HasAnyUser() { - if htmx { - writeHTMXError(w, "already initialized") - return - } - writeError(w, http.StatusBadRequest, "already initialized") - return - } - - var req struct { - Username string `json:"username"` - Password string `json:"password"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - if htmx { - writeHTMXError(w, "invalid request") - return - } - writeError(w, http.StatusBadRequest, "invalid JSON") - return - } - if req.Username == "" || req.Password == "" { - if htmx { - writeHTMXError(w, "username and password required") - return - } - writeError(w, http.StatusBadRequest, "username and password required") - return - } - if len(req.Password) < 8 { - if htmx { - writeHTMXError(w, "password must be at least 8 characters") - return - } - writeError(w, http.StatusBadRequest, "password must be at least 8 characters") - return - } - - user, err := h.store.CreateUser(req.Username, req.Password, true) - if err != nil { - msg := "failed to create user: " + err.Error() - if htmx { - writeHTMXError(w, msg) - return - } - writeError(w, http.StatusInternalServerError, msg) - return - } - - // Auto-login - sessionID, err := h.store.CreateSession(user.ID, 7*24*time.Hour) - if err != nil { - if htmx { - writeHTMXError(w, "failed to create session") - return - } - writeError(w, http.StatusInternalServerError, "failed to create session") - return - } - - h.audit(r, "auth.setup", "user", fmt.Sprintf("%d", user.ID), "initial setup") - - h.setSessionCookie(w, sessionID) - if htmx { - writeHTMXRedirect(w, "/dashboard") - return - } - writeJSON(w, map[string]any{ - "user": map[string]any{ - "id": user.ID, - "username": user.Username, - "is_admin": user.IsAdmin, - }, - }) -} - -func (h *Handlers) Login(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - - ip := r.RemoteAddr - if fwd := r.Header.Get("X-Real-IP"); fwd != "" { - ip = fwd - } - if !h.loginLimiter.allow(ip) { - if htmx { - writeHTMXError(w, "too many login attempts, try again in a minute") - return - } - writeError(w, http.StatusTooManyRequests, "too many login attempts, try again in a minute") - return - } - - var req struct { - Username string `json:"username"` - Password string `json:"password"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - if htmx { - writeHTMXError(w, "invalid request") - return - } - writeError(w, http.StatusBadRequest, "invalid JSON") - return - } - - user, err := h.store.GetUserByUsername(req.Username) - if err != nil { - if htmx { - writeHTMXError(w, "invalid credentials") - return - } - writeError(w, http.StatusUnauthorized, "invalid credentials") - return - } - - if !h.store.CheckPassword(user, req.Password) { - if htmx { - writeHTMXError(w, "invalid credentials") - return - } - writeError(w, http.StatusUnauthorized, "invalid credentials") - return - } - - if user.TOTPEnabled { - // Set pending cookie for TOTP step - pending := h.signPendingToken(user.ID) - http.SetCookie(w, &http.Cookie{ - Name: "llmgw_pending", - Value: pending, - Path: "/", - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - MaxAge: 300, // 5 minutes - }) - if htmx { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - fmt.Fprint(w, `
-
- - -
- -
`) - return - } - writeJSON(w, map[string]any{"require_totp": true}) - return - } - - sessionID, err := h.store.CreateSession(user.ID, 7*24*time.Hour) - if err != nil { - if htmx { - writeHTMXError(w, "failed to create session") - return - } - writeError(w, http.StatusInternalServerError, "failed to create session") - return - } - - h.audit(r, "auth.login", "user", fmt.Sprintf("%d", user.ID), user.Username) - - h.setSessionCookie(w, sessionID) - if htmx { - writeHTMXRedirect(w, "/dashboard") - return - } - writeJSON(w, map[string]any{ - "require_totp": false, - "user": map[string]any{ - "id": user.ID, - "username": user.Username, - "is_admin": user.IsAdmin, - }, - }) -} - -func (h *Handlers) LoginTOTP(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - - cookie, err := r.Cookie("llmgw_pending") - if err != nil || cookie.Value == "" { - if htmx { - writeHTMXError(w, "no pending login") - return - } - writeError(w, http.StatusBadRequest, "no pending login") - return - } - - userID, err := h.verifyPendingToken(cookie.Value) - if err != nil { - if htmx { - writeHTMXError(w, "invalid or expired pending login") - return - } - writeError(w, http.StatusBadRequest, "invalid or expired pending login") - return - } - - var req struct { - Code string `json:"code"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - if htmx { - writeHTMXError(w, "invalid request") - return - } - writeError(w, http.StatusBadRequest, "invalid JSON") - return - } - - user, err := h.store.GetUserByID(userID) - if err != nil { - if htmx { - writeHTMXError(w, "user not found") - return - } - writeError(w, http.StatusBadRequest, "user not found") - return - } - - if !ValidateTOTPCode(user.TOTPSecret, req.Code) { - if htmx { - writeHTMXError(w, "invalid TOTP code") - return - } - writeError(w, http.StatusUnauthorized, "invalid TOTP code") - return - } - - // Clear pending cookie - http.SetCookie(w, &http.Cookie{ - Name: "llmgw_pending", - Value: "", - Path: "/", - HttpOnly: true, - MaxAge: -1, - }) - - sessionID, err := h.store.CreateSession(user.ID, 7*24*time.Hour) - if err != nil { - if htmx { - writeHTMXError(w, "failed to create session") - return - } - writeError(w, http.StatusInternalServerError, "failed to create session") - return - } - - h.setSessionCookie(w, sessionID) - if htmx { - writeHTMXRedirect(w, "/dashboard") - return - } - writeJSON(w, map[string]any{ - "user": map[string]any{ - "id": user.ID, - "username": user.Username, - "is_admin": user.IsAdmin, - }, - }) -} - -func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) { - h.audit(r, "auth.logout", "", "", "") - - cookie, err := r.Cookie(sessionCookieName) - if err == nil { - h.store.DeleteSession(cookie.Value) - } - - http.SetCookie(w, &http.Cookie{ - Name: sessionCookieName, - Value: "", - Path: "/", - HttpOnly: true, - MaxAge: -1, - }) - - if isHTMX(r) { - writeHTMXRedirect(w, "/login") - return - } - writeJSON(w, map[string]string{"status": "ok"}) -} - -func (h *Handlers) Me(w http.ResponseWriter, r *http.Request) { - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - writeJSON(w, map[string]any{ - "id": user.ID, - "username": user.Username, - "is_admin": user.IsAdmin, - "totp_enabled": user.TOTPEnabled, - }) -} - -func (h *Handlers) TOTPSetup(w http.ResponseWriter, r *http.Request) { - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - - key, err := GenerateTOTPKey(user.Username) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to generate TOTP key") - return - } - - if err := h.store.SetTOTPSecret(user.ID, key.Secret()); err != nil { - writeError(w, http.StatusInternalServerError, "failed to save TOTP secret") - return - } - - writeJSON(w, map[string]string{ - "secret": key.Secret(), - "uri": key.URL(), - }) -} - -func (h *Handlers) TOTPVerify(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - - var req struct { - Code string `json:"code"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - if htmx { - writeHTMXError(w, "invalid request") - return - } - writeError(w, http.StatusBadRequest, "invalid JSON") - return - } - - // Re-fetch user to get latest TOTP secret - user, err := h.store.GetUserByID(user.ID) - if err != nil { - if htmx { - writeHTMXError(w, "failed to fetch user") - return - } - writeError(w, http.StatusInternalServerError, "failed to fetch user") - return - } - - if user.TOTPSecret == "" { - if htmx { - writeHTMXError(w, "TOTP not set up yet") - return - } - writeError(w, http.StatusBadRequest, "TOTP not set up, call /api/auth/totp/setup first") - return - } - - if !ValidateTOTPCode(user.TOTPSecret, req.Code) { - if htmx { - writeHTMXError(w, "invalid TOTP code") - return - } - writeError(w, http.StatusBadRequest, "invalid TOTP code") - return - } - - if err := h.store.EnableTOTP(user.ID); err != nil { - if htmx { - writeHTMXError(w, "failed to enable TOTP") - return - } - writeError(w, http.StatusInternalServerError, "failed to enable TOTP") - return - } - - h.audit(r, "totp.enable", "user", fmt.Sprintf("%d", user.ID), "") - if htmx { - // Trigger settings page reload to show updated TOTP status - w.Header().Set("HX-Trigger", "settingsRefresh") - writeHTMXSuccess(w, "Two-factor authentication enabled") - return - } - writeJSON(w, map[string]string{"status": "totp_enabled"}) -} - -func (h *Handlers) TOTPDisable(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - - if err := h.store.DisableTOTP(user.ID); err != nil { - if htmx { - writeHTMXError(w, "failed to disable TOTP") - return - } - writeError(w, http.StatusInternalServerError, "failed to disable TOTP") - return - } - - h.audit(r, "totp.disable", "user", fmt.Sprintf("%d", user.ID), "") - if htmx { - w.Header().Set("HX-Trigger", "settingsRefresh") - writeHTMXSuccess(w, "Two-factor authentication disabled") - return - } - writeJSON(w, map[string]string{"status": "totp_disabled"}) -} - -// User management (admin only) - -func (h *Handlers) ListUsers(w http.ResponseWriter, r *http.Request) { - users, err := h.store.ListUsers() - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list users") - return - } - - // Strip sensitive fields - type safeUser struct { - ID int64 `json:"id"` - Username string `json:"username"` - IsAdmin bool `json:"is_admin"` - TOTPEnabled bool `json:"totp_enabled"` - CreatedAt int64 `json:"created_at"` - } - result := make([]safeUser, len(users)) - for i, u := range users { - result[i] = safeUser{ - ID: u.ID, - Username: u.Username, - IsAdmin: u.IsAdmin, - TOTPEnabled: u.TOTPEnabled, - CreatedAt: u.CreatedAt, - } - } - writeJSON(w, result) -} - -func (h *Handlers) CreateUser(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - var req struct { - Username string `json:"username"` - Password string `json:"password"` - IsAdmin bool `json:"is_admin"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - if htmx { - writeHTMXError(w, "invalid request") - return - } - writeError(w, http.StatusBadRequest, "invalid JSON") - return - } - if req.Username == "" || req.Password == "" { - if htmx { - writeHTMXError(w, "username and password required") - return - } - writeError(w, http.StatusBadRequest, "username and password required") - return - } - if len(req.Password) < 8 { - if htmx { - writeHTMXError(w, "password must be at least 8 characters") - return - } - writeError(w, http.StatusBadRequest, "password must be at least 8 characters") - return - } - - user, err := h.store.CreateUser(req.Username, req.Password, req.IsAdmin) - if err != nil { - msg := "failed to create user: " + err.Error() - if htmx { - writeHTMXError(w, msg) - return - } - writeError(w, http.StatusInternalServerError, msg) - return - } - - h.audit(r, "user.create", "user", fmt.Sprintf("%d", user.ID), user.Username) - if htmx { - writeHTMXRefresh(w) - return - } - writeJSON(w, map[string]any{ - "id": user.ID, - "username": user.Username, - "is_admin": user.IsAdmin, - }) -} - -func (h *Handlers) DeleteUser(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - idStr := chi.URLParam(r, "id") - id, err := strconv.ParseInt(idStr, 10, 64) - if err != nil { - if htmx { - writeHTMXError(w, "invalid user ID") - return - } - writeError(w, http.StatusBadRequest, "invalid user ID") - return - } - - // Prevent deleting yourself - user := UserFromContext(r.Context()) - if user != nil && user.ID == id { - if htmx { - writeHTMXError(w, "cannot delete yourself") - return - } - writeError(w, http.StatusBadRequest, "cannot delete yourself") - return - } - - if err := h.store.DeleteUser(id); err != nil { - if htmx { - writeHTMXError(w, err.Error()) - return - } - writeError(w, http.StatusBadRequest, err.Error()) - return - } - - h.audit(r, "user.delete", "user", idStr, "") - if htmx { - writeHTMXRefresh(w) - return - } - writeJSON(w, map[string]string{"status": "deleted"}) -} - -// API Token management - -func (h *Handlers) ListTokens(w http.ResponseWriter, r *http.Request) { - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - - var userID int64 - if !user.IsAdmin { - userID = user.ID - } - // userID=0 means list all (admin) - - tokens, err := h.store.ListAPITokens(userID) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to list tokens") - return - } - if tokens == nil { - tokens = []APIToken{} - } - writeJSON(w, tokens) -} - -func (h *Handlers) CreateToken(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - - var req struct { - Name string `json:"name"` - RateLimitRPM int `json:"rate_limit_rpm"` - DailyBudgetUSD float64 `json:"daily_budget_usd"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - if htmx { - writeHTMXError(w, "invalid request") - return - } - writeError(w, http.StatusBadRequest, "invalid JSON") - return - } - if req.Name == "" { - if htmx { - writeHTMXError(w, "name is required") - return - } - writeError(w, http.StatusBadRequest, "name is required") - return - } - // RateLimitRPM: 0 = unlimited, negative treated as 0 - if req.RateLimitRPM < 0 { - req.RateLimitRPM = 0 - } - - plainKey, token, err := h.store.CreateAPIToken(user.ID, req.Name, req.RateLimitRPM, req.DailyBudgetUSD) - if err != nil { - msg := "failed to create token: " + err.Error() - if htmx { - writeHTMXError(w, msg) - return - } - writeError(w, http.StatusInternalServerError, msg) - return - } - - h.audit(r, "token.create", "token", fmt.Sprintf("%d", token.ID), req.Name) - if htmx { - // Return HTML with the key display and trigger page refresh - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Header().Set("HX-Trigger", "tokenCreated") - escaped := template.HTMLEscapeString(plainKey) - fmt.Fprintf(w, `
Token created! Copy the key below — it won't be shown again.
-
%s
`, escaped) - return - } - writeJSON(w, map[string]any{ - "key": plainKey, - "token": token, - }) -} - -func (h *Handlers) DeleteToken(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - - idStr := chi.URLParam(r, "id") - id, err := strconv.ParseInt(idStr, 10, 64) - if err != nil { - if htmx { - writeHTMXError(w, "invalid token ID") - return - } - writeError(w, http.StatusBadRequest, "invalid token ID") - return - } - - // Non-admin can only delete own tokens - if !user.IsAdmin { - token, err := h.store.GetAPIToken(id) - if err != nil { - if htmx { - writeHTMXError(w, "token not found") - return - } - writeError(w, http.StatusNotFound, "token not found") - return - } - if token.UserID != user.ID { - if htmx { - writeHTMXError(w, "not your token") - return - } - writeError(w, http.StatusForbidden, "not your token") - return - } - } - - if err := h.store.DeleteAPIToken(id); err != nil { - if htmx { - writeHTMXError(w, "failed to delete token") - return - } - writeError(w, http.StatusInternalServerError, "failed to delete token") - return - } - - h.audit(r, "token.delete", "token", idStr, "") - if htmx { - writeHTMXRefresh(w) - return - } - writeJSON(w, map[string]string{"status": "deleted"}) -} - -// Self-service endpoints - -func (h *Handlers) ChangePassword(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - - var req struct { - CurrentPassword string `json:"current_password"` - NewPassword string `json:"new_password"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - if htmx { - writeHTMXError(w, "invalid request") - return - } - writeError(w, http.StatusBadRequest, "invalid JSON") - return - } - if req.NewPassword == "" || len(req.NewPassword) < 8 { - if htmx { - writeHTMXError(w, "new password must be at least 8 characters") - return - } - writeError(w, http.StatusBadRequest, "new password must be at least 8 characters") - return - } - - // Re-fetch user to get password hash - user, err := h.store.GetUserByID(user.ID) - if err != nil { - if htmx { - writeHTMXError(w, "failed to fetch user") - return - } - writeError(w, http.StatusInternalServerError, "failed to fetch user") - return - } - - if !h.store.CheckPassword(user, req.CurrentPassword) { - if htmx { - writeHTMXError(w, "current password is incorrect") - return - } - writeError(w, http.StatusUnauthorized, "current password is incorrect") - return - } - - if err := h.store.UpdatePassword(user.ID, req.NewPassword); err != nil { - if htmx { - writeHTMXError(w, "failed to update password") - return - } - writeError(w, http.StatusInternalServerError, "failed to update password") - return - } - - h.audit(r, "password.change", "user", fmt.Sprintf("%d", user.ID), "") - if htmx { - writeHTMXSuccess(w, "Password updated") - return - } - writeJSON(w, map[string]string{"status": "password_updated"}) -} - -func (h *Handlers) ChangeUsername(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - - var req struct { - NewUsername string `json:"new_username"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - if htmx { - writeHTMXError(w, "invalid request") - return - } - writeError(w, http.StatusBadRequest, "invalid JSON") - return - } - if req.NewUsername == "" { - if htmx { - writeHTMXError(w, "username is required") - return - } - writeError(w, http.StatusBadRequest, "username is required") - return - } - - // Check uniqueness - existing, err := h.store.GetUserByUsername(req.NewUsername) - if err == nil && existing.ID != user.ID { - if htmx { - writeHTMXError(w, "username already taken") - return - } - writeError(w, http.StatusConflict, "username already taken") - return - } - - if err := h.store.UpdateUsername(user.ID, req.NewUsername); err != nil { - if htmx { - writeHTMXError(w, "failed to update username") - return - } - writeError(w, http.StatusInternalServerError, "failed to update username") - return - } - - h.audit(r, "username.change", "user", fmt.Sprintf("%d", user.ID), req.NewUsername) - if htmx { - writeHTMXSuccess(w, "Username updated") - return - } - writeJSON(w, map[string]string{"status": "username_updated"}) -} - -func (h *Handlers) ChangeEmail(w http.ResponseWriter, r *http.Request) { - htmx := isHTMX(r) - user := UserFromContext(r.Context()) - if user == nil { - writeError(w, http.StatusUnauthorized, "not authenticated") - return - } - - var req struct { - Email string `json:"email"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - if htmx { - writeHTMXError(w, "invalid request") - return - } - writeError(w, http.StatusBadRequest, "invalid JSON") - return - } - - if err := h.store.UpdateEmail(user.ID, req.Email); err != nil { - if htmx { - writeHTMXError(w, "failed to update email") - return - } - writeError(w, http.StatusInternalServerError, "failed to update email") - return - } - - if htmx { - writeHTMXSuccess(w, "Email updated") - return - } - writeJSON(w, map[string]string{"status": "email_updated"}) -} - -// Helpers - -func (h *Handlers) setSessionCookie(w http.ResponseWriter, sessionID string) { - http.SetCookie(w, &http.Cookie{ - Name: sessionCookieName, - Value: sessionID, - Path: "/", - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - MaxAge: sessionTTLDays * 24 * 60 * 60, - }) -} - -func (h *Handlers) signPendingToken(userID int64) string { - data := fmt.Sprintf("%d:%d", userID, time.Now().Unix()) - mac := hmac.New(sha256.New, []byte(h.sessionSecret)) - mac.Write([]byte(data)) - sig := hex.EncodeToString(mac.Sum(nil)) - return data + ":" + sig -} - -func (h *Handlers) verifyPendingToken(token string) (int64, error) { - parts := strings.SplitN(token, ":", 3) - if len(parts) != 3 { - return 0, fmt.Errorf("invalid format") - } - - userID, err := strconv.ParseInt(parts[0], 10, 64) - if err != nil { - return 0, fmt.Errorf("invalid user ID") - } - - ts, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return 0, fmt.Errorf("invalid timestamp") - } - - // Check expiry (5 minutes) - if time.Now().Unix()-ts > 300 { - return 0, fmt.Errorf("expired") - } - - // Verify HMAC - data := parts[0] + ":" + parts[1] - mac := hmac.New(sha256.New, []byte(h.sessionSecret)) - mac.Write([]byte(data)) - expectedSig := hex.EncodeToString(mac.Sum(nil)) - - if !hmac.Equal([]byte(parts[2]), []byte(expectedSig)) { - return 0, fmt.Errorf("invalid signature") - } - - return userID, nil -} - -func isHTMX(r *http.Request) bool { - return r.Header.Get("HX-Request") == "true" -} - -func writeHTMXError(w http.ResponseWriter, msg string) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - fmt.Fprintf(w, `
%s
`, template.HTMLEscapeString(msg)) -} - -func writeHTMXSuccess(w http.ResponseWriter, msg string) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - fmt.Fprintf(w, `
%s
`, template.HTMLEscapeString(msg)) -} - -func writeHTMXRedirect(w http.ResponseWriter, url string) { - w.Header().Set("HX-Redirect", url) - w.WriteHeader(http.StatusOK) -} - -func writeHTMXRefresh(w http.ResponseWriter) { - w.Header().Set("HX-Refresh", "true") - w.WriteHeader(http.StatusOK) -} - -func writeJSON(w http.ResponseWriter, v any) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(v) -} - -func writeError(w http.ResponseWriter, code int, msg string) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - json.NewEncoder(w).Encode(map[string]string{"error": msg}) -} diff --git a/llm-gateway/internal/auth/middleware.go b/llm-gateway/internal/auth/middleware.go deleted file mode 100644 index bceadf1..0000000 --- a/llm-gateway/internal/auth/middleware.go +++ /dev/null @@ -1,83 +0,0 @@ -package auth - -import ( - "context" - "encoding/json" - "net/http" - "strings" -) - -type contextKey string - -const userContextKey contextKey = "auth_user" - -const ( - sessionCookieName = "llmgw_session" - sessionTTLDays = 7 -) - -type Middleware struct { - store *Store -} - -func NewMiddleware(store *Store) *Middleware { - return &Middleware{store: store} -} - -func UserFromContext(ctx context.Context) *User { - u, _ := ctx.Value(userContextKey).(*User) - return u -} - -func (m *Middleware) RequireAuth(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - cookie, err := r.Cookie(sessionCookieName) - if err != nil || cookie.Value == "" { - m.unauthorized(w, r) - return - } - - sess, err := m.store.GetSession(cookie.Value) - if err != nil { - m.unauthorized(w, r) - return - } - - user, err := m.store.GetUserByID(sess.UserID) - if err != nil { - m.unauthorized(w, r) - return - } - - ctx := context.WithValue(r.Context(), userContextKey, user) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -func (m *Middleware) RequireAdmin(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - user := UserFromContext(r.Context()) - if user == nil || !user.IsAdmin { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusForbidden) - json.NewEncoder(w).Encode(map[string]string{"error": "admin access required"}) - return - } - next.ServeHTTP(w, r) - }) -} - -func (m *Middleware) unauthorized(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("HX-Request") == "true" { - w.Header().Set("HX-Redirect", "/login") - w.WriteHeader(http.StatusUnauthorized) - return - } - if strings.HasPrefix(r.URL.Path, "/api/") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusUnauthorized) - json.NewEncoder(w).Encode(map[string]string{"error": "authentication required"}) - return - } - http.Redirect(w, r, "/login", http.StatusFound) -} diff --git a/llm-gateway/internal/auth/store.go b/llm-gateway/internal/auth/store.go deleted file mode 100644 index 4f62ff7..0000000 --- a/llm-gateway/internal/auth/store.go +++ /dev/null @@ -1,391 +0,0 @@ -package auth - -import ( - "crypto/rand" - "crypto/sha256" - "database/sql" - "encoding/hex" - "fmt" - "time" - - "golang.org/x/crypto/bcrypt" -) - -type User struct { - ID int64 `json:"id"` - Username string `json:"username"` - Email string `json:"email"` - PasswordHash string `json:"-"` - IsAdmin bool `json:"is_admin"` - TOTPSecret string `json:"-"` - TOTPEnabled bool `json:"totp_enabled"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` -} - -type Session struct { - ID string - UserID int64 - CreatedAt int64 - ExpiresAt int64 -} - -type APIToken struct { - ID int64 `json:"id"` - Name string `json:"name"` - KeyPrefix string `json:"key_prefix"` - KeyHash string `json:"-"` - UserID int64 `json:"user_id"` - RateLimitRPM int `json:"rate_limit_rpm"` - DailyBudgetUSD float64 `json:"daily_budget_usd"` - MonthlyBudgetUSD float64 `json:"monthly_budget_usd"` - MaxConcurrent int `json:"max_concurrent"` - CreatedAt int64 `json:"created_at"` - LastUsedAt int64 `json:"last_used_at"` -} - -// StaticToken represents a token defined in config (checked in-memory, never stored in DB). -type StaticToken struct { - Name string - Key string - RateLimitRPM int - DailyBudgetUSD float64 - MonthlyBudgetUSD float64 - MaxConcurrent int -} - -type Store struct { - db *sql.DB - staticTokens []StaticToken -} - -func NewStore(db *sql.DB, staticTokens []StaticToken) *Store { - return &Store{db: db, staticTokens: staticTokens} -} - -// SetStaticTokens updates the static tokens list (used for config hot-reload). -func (s *Store) SetStaticTokens(tokens []StaticToken) { - s.staticTokens = tokens -} - -func (s *Store) HasAnyUser() bool { - var count int - s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) - return count > 0 -} - -func (s *Store) CreateUser(username, password string, isAdmin bool) (*User, error) { - hash, err := bcrypt.GenerateFromPassword([]byte(password), 12) - if err != nil { - return nil, fmt.Errorf("hashing password: %w", err) - } - - now := time.Now().Unix() - adminInt := 0 - if isAdmin { - adminInt = 1 - } - - result, err := s.db.Exec( - "INSERT INTO users (username, password_hash, is_admin, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", - username, string(hash), adminInt, now, now, - ) - if err != nil { - return nil, fmt.Errorf("creating user: %w", err) - } - - id, _ := result.LastInsertId() - return &User{ - ID: id, - Username: username, - IsAdmin: isAdmin, - CreatedAt: now, - UpdatedAt: now, - }, nil -} - -func (s *Store) GetUserByUsername(username string) (*User, error) { - return s.scanUser(s.db.QueryRow( - "SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users WHERE username = ?", - username, - )) -} - -func (s *Store) GetUserByID(id int64) (*User, error) { - return s.scanUser(s.db.QueryRow( - "SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users WHERE id = ?", - id, - )) -} - -func (s *Store) scanUser(row *sql.Row) (*User, error) { - var u User - var isAdmin, totpEnabled int - var totpSecret sql.NullString - var email sql.NullString - err := row.Scan(&u.ID, &u.Username, &email, &u.PasswordHash, &isAdmin, &totpSecret, &totpEnabled, &u.CreatedAt, &u.UpdatedAt) - if err != nil { - return nil, err - } - u.Email = email.String - u.IsAdmin = isAdmin == 1 - u.TOTPEnabled = totpEnabled == 1 - u.TOTPSecret = totpSecret.String - return &u, nil -} - -func (s *Store) ListUsers() ([]User, error) { - rows, err := s.db.Query("SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users ORDER BY id") - if err != nil { - return nil, err - } - defer rows.Close() - - var users []User - for rows.Next() { - var u User - var isAdmin, totpEnabled int - var totpSecret sql.NullString - var email sql.NullString - if err := rows.Scan(&u.ID, &u.Username, &email, &u.PasswordHash, &isAdmin, &totpSecret, &totpEnabled, &u.CreatedAt, &u.UpdatedAt); err != nil { - return nil, err - } - u.Email = email.String - u.IsAdmin = isAdmin == 1 - u.TOTPEnabled = totpEnabled == 1 - u.TOTPSecret = totpSecret.String - users = append(users, u) - } - return users, nil -} - -func (s *Store) DeleteUser(id int64) error { - // Prevent deleting the last admin - var adminCount int - s.db.QueryRow("SELECT COUNT(*) FROM users WHERE is_admin = 1").Scan(&adminCount) - - var isAdmin int - s.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", id).Scan(&isAdmin) - if isAdmin == 1 && adminCount <= 1 { - return fmt.Errorf("cannot delete the last admin user") - } - - _, err := s.db.Exec("DELETE FROM users WHERE id = ?", id) - return err -} - -func (s *Store) UpdatePassword(userID int64, newPassword string) error { - hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12) - if err != nil { - return fmt.Errorf("hashing password: %w", err) - } - _, err = s.db.Exec("UPDATE users SET password_hash = ?, updated_at = ? WHERE id = ?", string(hash), time.Now().Unix(), userID) - return err -} - -func (s *Store) CheckPassword(user *User, password string) bool { - return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) == nil -} - -func (s *Store) SetTOTPSecret(userID int64, secret string) error { - _, err := s.db.Exec("UPDATE users SET totp_secret = ?, updated_at = ? WHERE id = ?", secret, time.Now().Unix(), userID) - return err -} - -func (s *Store) EnableTOTP(userID int64) error { - _, err := s.db.Exec("UPDATE users SET totp_enabled = 1, updated_at = ? WHERE id = ?", time.Now().Unix(), userID) - return err -} - -func (s *Store) DisableTOTP(userID int64) error { - _, err := s.db.Exec("UPDATE users SET totp_enabled = 0, totp_secret = '', updated_at = ? WHERE id = ?", time.Now().Unix(), userID) - return err -} - -// Session management - -func (s *Store) CreateSession(userID int64, ttl time.Duration) (string, error) { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - return "", fmt.Errorf("generating session ID: %w", err) - } - id := hex.EncodeToString(b) - now := time.Now().Unix() - expiresAt := time.Now().Add(ttl).Unix() - - _, err := s.db.Exec( - "INSERT INTO sessions (id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)", - id, userID, now, expiresAt, - ) - if err != nil { - return "", fmt.Errorf("creating session: %w", err) - } - return id, nil -} - -func (s *Store) GetSession(sessionID string) (*Session, error) { - var sess Session - err := s.db.QueryRow( - "SELECT id, user_id, created_at, expires_at FROM sessions WHERE id = ? AND expires_at > ?", - sessionID, time.Now().Unix(), - ).Scan(&sess.ID, &sess.UserID, &sess.CreatedAt, &sess.ExpiresAt) - if err != nil { - return nil, err - } - return &sess, nil -} - -func (s *Store) DeleteSession(id string) error { - _, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", id) - return err -} - -func (s *Store) CleanExpiredSessions() error { - _, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now().Unix()) - return err -} - -// API Token management - -func (s *Store) CreateAPIToken(userID int64, name string, rateLimitRPM int, dailyBudgetUSD float64) (string, *APIToken, error) { - // Generate sk- prefixed random key - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - return "", nil, fmt.Errorf("generating token: %w", err) - } - plainKey := "sk-" + hex.EncodeToString(b) - keyPrefix := plainKey[:11] // "sk-" + first 8 hex chars - - hash := sha256.Sum256([]byte(plainKey)) - keyHash := hex.EncodeToString(hash[:]) - - now := time.Now().Unix() - result, err := s.db.Exec( - "INSERT INTO api_tokens (name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", - name, keyHash, keyPrefix, userID, rateLimitRPM, dailyBudgetUSD, now, - ) - if err != nil { - return "", nil, fmt.Errorf("creating API token: %w", err) - } - - id, _ := result.LastInsertId() - token := &APIToken{ - ID: id, - Name: name, - KeyPrefix: keyPrefix, - KeyHash: keyHash, - UserID: userID, - RateLimitRPM: rateLimitRPM, - DailyBudgetUSD: dailyBudgetUSD, - CreatedAt: now, - } - return plainKey, token, nil -} - -func (s *Store) LookupAPIToken(key string) (*APIToken, error) { - // Check static tokens first (from config, never stored in DB) - for _, st := range s.staticTokens { - if st.Key == key { - prefix := st.Key - if len(prefix) > 11 { - prefix = prefix[:11] - } - return &APIToken{ - ID: -1, // sentinel: static token - Name: st.Name, - KeyPrefix: prefix, - RateLimitRPM: st.RateLimitRPM, - DailyBudgetUSD: st.DailyBudgetUSD, - MonthlyBudgetUSD: st.MonthlyBudgetUSD, - MaxConcurrent: st.MaxConcurrent, - }, nil - } - } - - // Fall back to DB tokens - hash := sha256.Sum256([]byte(key)) - keyHash := hex.EncodeToString(hash[:]) - - var t APIToken - err := s.db.QueryRow( - "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?", - keyHash, - ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) - if err != nil { - return nil, err - } - return &t, nil -} - -func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) { - // Include static tokens (shown for all users, not deletable) - var tokens []APIToken - for _, st := range s.staticTokens { - prefix := st.Key - if len(prefix) > 11 { - prefix = prefix[:11] - } - tokens = append(tokens, APIToken{ - ID: -1, // sentinel: static token - Name: st.Name, - KeyPrefix: prefix, - RateLimitRPM: st.RateLimitRPM, - DailyBudgetUSD: st.DailyBudgetUSD, - MonthlyBudgetUSD: st.MonthlyBudgetUSD, - MaxConcurrent: st.MaxConcurrent, - }) - } - - // DB tokens - var rows *sql.Rows - var err error - if userID == 0 { - rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id") - } else { - rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID) - } - if err != nil { - return tokens, nil - } - defer rows.Close() - - for rows.Next() { - var t APIToken - if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil { - return tokens, nil - } - tokens = append(tokens, t) - } - return tokens, nil -} - -func (s *Store) DeleteAPIToken(id int64) error { - _, err := s.db.Exec("DELETE FROM api_tokens WHERE id = ?", id) - return err -} - -func (s *Store) GetAPIToken(id int64) (*APIToken, error) { - var t APIToken - err := s.db.QueryRow( - "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?", - id, - ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) - if err != nil { - return nil, err - } - return &t, nil -} - -func (s *Store) UpdateAPITokenLastUsed(id int64) { - s.db.Exec("UPDATE api_tokens SET last_used_at = ? WHERE id = ?", time.Now().Unix(), id) -} - -func (s *Store) UpdateUsername(userID int64, newUsername string) error { - _, err := s.db.Exec("UPDATE users SET username = ?, updated_at = ? WHERE id = ?", newUsername, time.Now().Unix(), userID) - return err -} - -func (s *Store) UpdateEmail(userID int64, email string) error { - _, err := s.db.Exec("UPDATE users SET email = ?, updated_at = ? WHERE id = ?", email, time.Now().Unix(), userID) - return err -} diff --git a/llm-gateway/internal/auth/store_test.go b/llm-gateway/internal/auth/store_test.go deleted file mode 100644 index d566224..0000000 --- a/llm-gateway/internal/auth/store_test.go +++ /dev/null @@ -1,301 +0,0 @@ -package auth - -import ( - "database/sql" - "testing" - "time" - - _ "modernc.org/sqlite" -) - -func setupTestDB(t *testing.T) *sql.DB { - t.Helper() - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("opening test db: %v", err) - } - - // Create tables - _, err = db.Exec(` - CREATE TABLE users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT UNIQUE NOT NULL, - email TEXT DEFAULT '', - password_hash TEXT NOT NULL, - is_admin INTEGER DEFAULT 0, - totp_secret TEXT DEFAULT '', - totp_enabled INTEGER DEFAULT 0, - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL - ); - CREATE TABLE sessions ( - id TEXT PRIMARY KEY, - user_id INTEGER NOT NULL, - created_at INTEGER NOT NULL, - expires_at INTEGER NOT NULL - ); - CREATE TABLE api_tokens ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - key_hash TEXT NOT NULL, - key_prefix TEXT NOT NULL, - user_id INTEGER NOT NULL, - rate_limit_rpm INTEGER DEFAULT 0, - daily_budget_usd REAL DEFAULT 0, - monthly_budget_usd REAL DEFAULT 0, - max_concurrent INTEGER DEFAULT 0, - created_at INTEGER NOT NULL, - last_used_at INTEGER DEFAULT 0 - ); - `) - if err != nil { - t.Fatalf("creating tables: %v", err) - } - return db -} - -func TestCreateUser(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - store := NewStore(db, nil) - - user, err := store.CreateUser("alice", "password123", true) - if err != nil { - t.Fatalf("CreateUser: %v", err) - } - if user.Username != "alice" { - t.Errorf("expected username 'alice', got '%s'", user.Username) - } - if !user.IsAdmin { - t.Error("expected admin user") - } - if user.ID == 0 { - t.Error("expected non-zero ID") - } -} - -func TestGetUserByUsername(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - store := NewStore(db, nil) - - store.CreateUser("bob", "password123", false) - - user, err := store.GetUserByUsername("bob") - if err != nil { - t.Fatalf("GetUserByUsername: %v", err) - } - if user.Username != "bob" { - t.Errorf("expected 'bob', got '%s'", user.Username) - } - - _, err = store.GetUserByUsername("nonexistent") - if err == nil { - t.Error("expected error for nonexistent user") - } -} - -func TestCheckPassword(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - store := NewStore(db, nil) - - store.CreateUser("charlie", "correctpassword", false) - user, _ := store.GetUserByUsername("charlie") - - if !store.CheckPassword(user, "correctpassword") { - t.Error("correct password should match") - } - if store.CheckPassword(user, "wrongpassword") { - t.Error("wrong password should not match") - } -} - -func TestUpdatePassword(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - store := NewStore(db, nil) - - user, _ := store.CreateUser("dave", "oldpass12", false) - - if err := store.UpdatePassword(user.ID, "newpass12"); err != nil { - t.Fatalf("UpdatePassword: %v", err) - } - - user, _ = store.GetUserByUsername("dave") - if store.CheckPassword(user, "oldpass12") { - t.Error("old password should not work") - } - if !store.CheckPassword(user, "newpass12") { - t.Error("new password should work") - } -} - -func TestDeleteUser(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - store := NewStore(db, nil) - - user1, _ := store.CreateUser("admin1", "password1234", true) - user2, _ := store.CreateUser("user2", "password1234", false) - - // Can delete non-admin - if err := store.DeleteUser(user2.ID); err != nil { - t.Fatalf("DeleteUser: %v", err) - } - - // Cannot delete last admin - if err := store.DeleteUser(user1.ID); err == nil { - t.Error("should not be able to delete last admin") - } -} - -func TestHasAnyUser(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - store := NewStore(db, nil) - - if store.HasAnyUser() { - t.Error("should have no users initially") - } - - store.CreateUser("first", "password1234", true) - - if !store.HasAnyUser() { - t.Error("should have users after creation") - } -} - -func TestSessionCRUD(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - store := NewStore(db, nil) - - user, _ := store.CreateUser("sessuser", "password1234", false) - - sessionID, err := store.CreateSession(user.ID, 1*time.Hour) - if err != nil { - t.Fatalf("CreateSession: %v", err) - } - - sess, err := store.GetSession(sessionID) - if err != nil { - t.Fatalf("GetSession: %v", err) - } - if sess.UserID != user.ID { - t.Errorf("expected user ID %d, got %d", user.ID, sess.UserID) - } - - if err := store.DeleteSession(sessionID); err != nil { - t.Fatalf("DeleteSession: %v", err) - } - - _, err = store.GetSession(sessionID) - if err == nil { - t.Error("session should be deleted") - } -} - -func TestStaticTokenLookup(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - - staticTokens := []StaticToken{ - {Name: "test-token", Key: "sk-test-key-12345678", RateLimitRPM: 60, DailyBudgetUSD: 10.0, MaxConcurrent: 5}, - } - store := NewStore(db, staticTokens) - - token, err := store.LookupAPIToken("sk-test-key-12345678") - if err != nil { - t.Fatalf("LookupAPIToken: %v", err) - } - if token.Name != "test-token" { - t.Errorf("expected 'test-token', got '%s'", token.Name) - } - if token.ID != -1 { - t.Errorf("static token should have ID -1, got %d", token.ID) - } - if token.RateLimitRPM != 60 { - t.Errorf("expected RPM 60, got %d", token.RateLimitRPM) - } - if token.MaxConcurrent != 5 { - t.Errorf("expected max_concurrent 5, got %d", token.MaxConcurrent) - } - - // Non-existent token - _, err = store.LookupAPIToken("nonexistent") - if err == nil { - t.Error("should error on nonexistent token") - } -} - -func TestDBTokenCRUD(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - store := NewStore(db, nil) - - user, _ := store.CreateUser("tokenuser", "password1234", false) - - plainKey, token, err := store.CreateAPIToken(user.ID, "my-token", 100, 5.0) - if err != nil { - t.Fatalf("CreateAPIToken: %v", err) - } - if plainKey == "" { - t.Error("plain key should not be empty") - } - if token.Name != "my-token" { - t.Errorf("expected 'my-token', got '%s'", token.Name) - } - - // Lookup by key - found, err := store.LookupAPIToken(plainKey) - if err != nil { - t.Fatalf("LookupAPIToken: %v", err) - } - if found.Name != "my-token" { - t.Errorf("expected 'my-token', got '%s'", found.Name) - } - - // List tokens - tokens, err := store.ListAPITokens(user.ID) - if err != nil { - t.Fatalf("ListAPITokens: %v", err) - } - if len(tokens) != 1 { - t.Errorf("expected 1 token, got %d", len(tokens)) - } - - // Delete - if err := store.DeleteAPIToken(token.ID); err != nil { - t.Fatalf("DeleteAPIToken: %v", err) - } - - _, err = store.LookupAPIToken(plainKey) - if err == nil { - t.Error("token should be deleted") - } -} - -func TestSetStaticTokens(t *testing.T) { - db := setupTestDB(t) - defer db.Close() - - store := NewStore(db, nil) - - _, err := store.LookupAPIToken("key1") - if err == nil { - t.Error("should not find token before setting") - } - - store.SetStaticTokens([]StaticToken{ - {Name: "new-token", Key: "key1"}, - }) - - token, err := store.LookupAPIToken("key1") - if err != nil { - t.Fatalf("after SetStaticTokens: %v", err) - } - if token.Name != "new-token" { - t.Errorf("expected 'new-token', got '%s'", token.Name) - } -} diff --git a/llm-gateway/internal/auth/totp.go b/llm-gateway/internal/auth/totp.go deleted file mode 100644 index b3092ff..0000000 --- a/llm-gateway/internal/auth/totp.go +++ /dev/null @@ -1,17 +0,0 @@ -package auth - -import ( - "github.com/pquerna/otp" - "github.com/pquerna/otp/totp" -) - -func GenerateTOTPKey(username string) (*otp.Key, error) { - return totp.Generate(totp.GenerateOpts{ - Issuer: "LLM Gateway", - AccountName: username, - }) -} - -func ValidateTOTPCode(secret, code string) bool { - return totp.Validate(code, secret) -} diff --git a/llm-gateway/internal/cache/cache.go b/llm-gateway/internal/cache/cache.go deleted file mode 100644 index 40852c5..0000000 --- a/llm-gateway/internal/cache/cache.go +++ /dev/null @@ -1,178 +0,0 @@ -package cache - -import ( - "context" - "crypto/sha256" - "fmt" - "time" - - "github.com/redis/go-redis/v9" -) - -type Cache struct { - client *redis.Client - ttl time.Duration -} - -func New(addr string, ttlSeconds int) (*Cache, error) { - client := redis.NewClient(&redis.Options{ - Addr: addr, - }) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := client.Ping(ctx).Err(); err != nil { - return nil, fmt.Errorf("connecting to Valkey: %w", err) - } - - ttl := time.Duration(ttlSeconds) * time.Second - if ttl == 0 { - ttl = 1 * time.Hour - } - - return &Cache{client: client, ttl: ttl}, nil -} - -func (c *Cache) Get(ctx context.Context, model string, requestBody []byte) ([]byte, error) { - key := c.cacheKey(model, requestBody) - data, err := c.client.Get(ctx, key).Bytes() - if err == redis.Nil { - return nil, nil - } - return data, err -} - -func (c *Cache) Set(ctx context.Context, model string, requestBody, responseBody []byte) error { - key := c.cacheKey(model, requestBody) - return c.client.Set(ctx, key, responseBody, c.ttl).Err() -} - -func (c *Cache) Ping(ctx context.Context) error { - return c.client.Ping(ctx).Err() -} - -func (c *Cache) Close() error { - return c.client.Close() -} - -// CacheStats holds cache statistics from the Valkey/Redis server. -type CacheStats struct { - Hits int64 `json:"hits"` - Misses int64 `json:"misses"` - HitRate float64 `json:"hit_rate"` - MemoryUsed string `json:"memory_used"` - Keys int64 `json:"keys"` - Connected bool `json:"connected"` -} - -// Stats returns cache statistics by querying Valkey/Redis INFO. -func (c *Cache) Stats(ctx context.Context) *CacheStats { - stats := &CacheStats{} - - // Check connectivity - if err := c.client.Ping(ctx).Err(); err != nil { - return stats - } - stats.Connected = true - - // Parse INFO stats for hits/misses - info, err := c.client.Info(ctx, "stats").Result() - if err == nil { - stats.Hits = parseInfoInt(info, "keyspace_hits") - stats.Misses = parseInfoInt(info, "keyspace_misses") - total := stats.Hits + stats.Misses - if total > 0 { - stats.HitRate = float64(stats.Hits) / float64(total) - } - } - - // Parse INFO memory - memInfo, err := c.client.Info(ctx, "memory").Result() - if err == nil { - stats.MemoryUsed = parseInfoString(memInfo, "used_memory_human") - } - - // Parse INFO keyspace - ksInfo, err := c.client.Info(ctx, "keyspace").Result() - if err == nil { - stats.Keys = parseKeyspaceKeys(ksInfo) - } - - return stats -} - -func parseInfoInt(info, key string) int64 { - prefix := key + ":" - for _, line := range splitLines(info) { - if len(line) > len(prefix) && line[:len(prefix)] == prefix { - var v int64 - fmt.Sscanf(line[len(prefix):], "%d", &v) - return v - } - } - return 0 -} - -func parseInfoString(info, key string) string { - prefix := key + ":" - for _, line := range splitLines(info) { - if len(line) > len(prefix) && line[:len(prefix)] == prefix { - val := line[len(prefix):] - // Trim trailing \r - if len(val) > 0 && val[len(val)-1] == '\r' { - val = val[:len(val)-1] - } - return val - } - } - return "" -} - -func parseKeyspaceKeys(info string) int64 { - // Format: db0:keys=123,expires=45,avg_ttl=6789 - for _, line := range splitLines(info) { - if len(line) > 3 && line[:2] == "db" { - prefix := "keys=" - idx := -1 - for i := 0; i <= len(line)-len(prefix); i++ { - if line[i:i+len(prefix)] == prefix { - idx = i + len(prefix) - break - } - } - if idx >= 0 { - end := idx - for end < len(line) && line[end] >= '0' && line[end] <= '9' { - end++ - } - var v int64 - fmt.Sscanf(line[idx:end], "%d", &v) - return v - } - } - } - return 0 -} - -func splitLines(s string) []string { - var lines []string - start := 0 - for i := 0; i < len(s); i++ { - if s[i] == '\n' { - lines = append(lines, s[start:i]) - start = i + 1 - } - } - if start < len(s) { - lines = append(lines, s[start:]) - } - return lines -} - -func (c *Cache) cacheKey(model string, requestBody []byte) string { - h := sha256.New() - h.Write([]byte(model)) - h.Write(requestBody) - return fmt.Sprintf("llm-gw:%x", h.Sum(nil)) -} diff --git a/llm-gateway/internal/cache/cache_test.go b/llm-gateway/internal/cache/cache_test.go deleted file mode 100644 index 9e72e37..0000000 --- a/llm-gateway/internal/cache/cache_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package cache - -import ( - "testing" -) - -func TestCacheKey_Deterministic(t *testing.T) { - c := &Cache{} - - model := "gpt-4" - body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`) - - key1 := c.cacheKey(model, body) - key2 := c.cacheKey(model, body) - - if key1 != key2 { - t.Errorf("cache key not deterministic: %s != %s", key1, key2) - } - - if key1 == "" { - t.Error("cache key is empty") - } -} - -func TestCacheKey_DifferentInputs(t *testing.T) { - c := &Cache{} - - body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`) - - key1 := c.cacheKey("gpt-4", body) - key2 := c.cacheKey("gpt-3.5", body) - - if key1 == key2 { - t.Error("different models should produce different cache keys") - } - - key3 := c.cacheKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"world"}]}`)) - if key1 == key3 { - t.Error("different bodies should produce different cache keys") - } -} - -func TestCacheKey_HasPrefix(t *testing.T) { - c := &Cache{} - key := c.cacheKey("gpt-4", []byte("test")) - - if len(key) < 7 || key[:7] != "llm-gw:" { - t.Errorf("cache key should start with 'llm-gw:', got: %s", key) - } -} - -func TestParseInfoInt(t *testing.T) { - info := "keyspace_hits:42\nkeyspace_misses:10\n" - - hits := parseInfoInt(info, "keyspace_hits") - if hits != 42 { - t.Errorf("expected 42, got %d", hits) - } - - misses := parseInfoInt(info, "keyspace_misses") - if misses != 10 { - t.Errorf("expected 10, got %d", misses) - } - - unknown := parseInfoInt(info, "nonexistent") - if unknown != 0 { - t.Errorf("expected 0 for unknown key, got %d", unknown) - } -} - -func TestParseInfoString(t *testing.T) { - info := "used_memory_human:1.5M\r\nother:value\r\n" - - mem := parseInfoString(info, "used_memory_human") - if mem != "1.5M" { - t.Errorf("expected '1.5M', got '%s'", mem) - } - - unknown := parseInfoString(info, "nonexistent") - if unknown != "" { - t.Errorf("expected empty for unknown key, got '%s'", unknown) - } -} - -func TestParseKeyspaceKeys(t *testing.T) { - info := "# Keyspace\ndb0:keys=123,expires=45,avg_ttl=6789\n" - - keys := parseKeyspaceKeys(info) - if keys != 123 { - t.Errorf("expected 123, got %d", keys) - } - - empty := parseKeyspaceKeys("# Keyspace\n") - if empty != 0 { - t.Errorf("expected 0 for empty keyspace, got %d", empty) - } -} - -func TestSplitLines(t *testing.T) { - lines := splitLines("a\nb\nc") - if len(lines) != 3 { - t.Errorf("expected 3 lines, got %d", len(lines)) - } - if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" { - t.Errorf("unexpected lines: %v", lines) - } - - single := splitLines("hello") - if len(single) != 1 || single[0] != "hello" { - t.Errorf("single line: %v", single) - } -} diff --git a/llm-gateway/internal/config/config.go b/llm-gateway/internal/config/config.go deleted file mode 100644 index ffa934b..0000000 --- a/llm-gateway/internal/config/config.go +++ /dev/null @@ -1,312 +0,0 @@ -package config - -import ( - "crypto/rand" - "encoding/hex" - "fmt" - "log" - "os" - "time" - - "gopkg.in/yaml.v3" -) - -type Config struct { - Server ServerConfig `yaml:"server"` - Database DatabaseConfig `yaml:"database"` - Cache CacheConfig `yaml:"cache"` - Pricing PricingLookupConfig `yaml:"pricing_lookup"` - CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"` - Retry RetryConfig `yaml:"retry"` - Debug DebugConfig `yaml:"debug"` - CORS CORSConfig `yaml:"cors"` - Dedup DedupConfig `yaml:"dedup"` - Webhooks []WebhookConfig `yaml:"webhooks"` - Providers []ProviderConfig `yaml:"providers"` - Models []ModelConfig `yaml:"models"` - Tokens []TokenConfig `yaml:"tokens"` -} - -type DedupConfig struct { - Enabled bool `yaml:"enabled"` - Window time.Duration `yaml:"window"` // max time to wait for dedup result -} - -type WebhookConfig struct { - URL string `yaml:"url"` - Events []string `yaml:"events"` // event types to send - Secret string `yaml:"secret"` // optional HMAC secret -} - -type PricingLookupConfig struct { - URL string `yaml:"url"` - RefreshInterval time.Duration `yaml:"refresh_interval"` -} - -type DefaultAdminConfig struct { - Username string `yaml:"username"` - Password string `yaml:"password"` -} - -type TokenConfig struct { - Name string `yaml:"name"` - Key string `yaml:"key"` - RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited - DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited - MonthlyBudgetUSD float64 `yaml:"monthly_budget_usd"` // 0 = unlimited - MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited -} - -type ServerConfig struct { - Listen string `yaml:"listen"` - RequestTimeout time.Duration `yaml:"request_timeout"` - StreamingTimeout time.Duration `yaml:"streaming_timeout"` - MaxRequestBodyMB int `yaml:"max_request_body_mb"` - SessionSecret string `yaml:"session_secret"` - DefaultAdmin DefaultAdminConfig `yaml:"default_admin"` -} - -type CircuitBreakerConfig struct { - Enabled bool `yaml:"enabled"` - ErrorThreshold float64 `yaml:"error_threshold"` - MinRequests int `yaml:"min_requests"` - CooldownDuration time.Duration `yaml:"cooldown_duration"` -} - -type RetryConfig struct { - InitialBackoff time.Duration `yaml:"initial_backoff"` - MaxBackoff time.Duration `yaml:"max_backoff"` - Multiplier float64 `yaml:"multiplier"` -} - -type DebugConfig struct { - Enabled bool `yaml:"enabled"` - MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies) - RetentionDays int `yaml:"retention_days"` - DataDir string `yaml:"data_dir"` -} - -type CORSConfig struct { - Enabled bool `yaml:"enabled"` - AllowedOrigins []string `yaml:"allowed_origins"` - AllowedMethods []string `yaml:"allowed_methods"` - AllowedHeaders []string `yaml:"allowed_headers"` - MaxAge int `yaml:"max_age"` -} - -type DatabaseConfig struct { - Path string `yaml:"path"` - RetentionDays int `yaml:"retention_days"` -} - -type CacheConfig struct { - Enabled bool `yaml:"enabled"` - Address string `yaml:"address"` - TTL int `yaml:"ttl"` // seconds -} - -type ProviderConfig struct { - Name string `yaml:"name"` - BaseURL string `yaml:"base_url"` - APIKey string `yaml:"api_key"` - Priority int `yaml:"priority"` - Timeout time.Duration `yaml:"timeout"` -} - -type ModelConfig struct { - Name string `yaml:"name"` - Aliases []string `yaml:"aliases"` - Routes []RouteConfig `yaml:"routes"` - LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost - RequestTimeout time.Duration `yaml:"request_timeout"` // per-model override; 0 = use server default - StreamingTimeout time.Duration `yaml:"streaming_timeout"` // per-model override; 0 = use server default -} - -type RouteConfig struct { - Provider string `yaml:"provider"` - Model string `yaml:"model"` - Pricing PricingConfig `yaml:"pricing"` -} - -type PricingConfig struct { - Input float64 `yaml:"input"` // cost per 1M tokens - Output float64 `yaml:"output"` // cost per 1M tokens -} - -func Load(path string) (*Config, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("reading config: %w", err) - } - - // Expand environment variables - expanded := os.ExpandEnv(string(data)) - - var cfg Config - if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil { - return nil, fmt.Errorf("parsing config: %w", err) - } - - if err := cfg.Validate(); err != nil { - return nil, fmt.Errorf("validating config: %w", err) - } - - return &cfg, nil -} - -// Validate checks the config for correctness and applies defaults. -func (c *Config) Validate() error { - if c.Server.Listen == "" { - c.Server.Listen = "0.0.0.0:3000" - } - if c.Server.RequestTimeout == 0 { - c.Server.RequestTimeout = 300 * time.Second - } - if c.Server.MaxRequestBodyMB == 0 { - c.Server.MaxRequestBodyMB = 10 - } - if c.Server.SessionSecret == "" { - b := make([]byte, 32) - rand.Read(b) - c.Server.SessionSecret = hex.EncodeToString(b) - log.Println("WARNING: no session_secret configured, generated random one (sessions won't survive restart)") - } - if c.Database.Path == "" { - c.Database.Path = "gateway.db" - } - if c.Database.RetentionDays == 0 { - c.Database.RetentionDays = 90 - } - if c.Pricing.RefreshInterval == 0 { - c.Pricing.RefreshInterval = 6 * time.Hour - } - - // Server defaults - if c.Server.StreamingTimeout == 0 { - c.Server.StreamingTimeout = 5 * time.Minute - } - - // Circuit breaker defaults - if c.CircuitBreaker.ErrorThreshold == 0 { - c.CircuitBreaker.ErrorThreshold = 0.5 - } - if c.CircuitBreaker.MinRequests == 0 { - c.CircuitBreaker.MinRequests = 5 - } - if c.CircuitBreaker.CooldownDuration == 0 { - c.CircuitBreaker.CooldownDuration = 30 * time.Second - } - - // Retry defaults - if c.Retry.InitialBackoff == 0 { - c.Retry.InitialBackoff = 100 * time.Millisecond - } - if c.Retry.MaxBackoff == 0 { - c.Retry.MaxBackoff = 5 * time.Second - } - if c.Retry.Multiplier == 0 { - c.Retry.Multiplier = 2.0 - } - - // Debug defaults - if c.Debug.RetentionDays == 0 { - c.Debug.RetentionDays = 90 - } - - // CORS defaults - if c.CORS.MaxAge == 0 { - c.CORS.MaxAge = 300 - } - - // Dedup defaults - if c.Dedup.Window == 0 { - c.Dedup.Window = 30 * time.Second - } - - if len(c.Providers) == 0 { - return fmt.Errorf("at least one provider is required") - } - providerNames := make(map[string]bool) - for i, p := range c.Providers { - if p.Name == "" || p.BaseURL == "" || p.APIKey == "" { - return fmt.Errorf("provider %d: name, base_url, and api_key are required", i) - } - if providerNames[p.Name] { - return fmt.Errorf("duplicate provider name: %s", p.Name) - } - providerNames[p.Name] = true - if c.Providers[i].Timeout == 0 { - c.Providers[i].Timeout = 120 * time.Second - } - if c.Providers[i].Priority == 0 { - c.Providers[i].Priority = 1 - } - } - - if len(c.Models) == 0 { - return fmt.Errorf("at least one model is required") - } - modelNames := make(map[string]bool) - for i, m := range c.Models { - if m.Name == "" { - return fmt.Errorf("model %d: name is required", i) - } - if modelNames[m.Name] { - return fmt.Errorf("duplicate model name: %s", m.Name) - } - modelNames[m.Name] = true - for _, alias := range m.Aliases { - if modelNames[alias] { - return fmt.Errorf("model alias %s conflicts with existing model or alias", alias) - } - modelNames[alias] = true - } - if len(m.Routes) == 0 { - return fmt.Errorf("model %s: at least one route is required", m.Name) - } - for j, r := range m.Routes { - if r.Provider == "" || r.Model == "" { - return fmt.Errorf("model %s route %d: provider and model are required", m.Name, j) - } - if !providerNames[r.Provider] { - return fmt.Errorf("model %s route %d: unknown provider %s", m.Name, j, r.Provider) - } - } - } - - // Validate tokens (optional section) - for i, t := range c.Tokens { - if t.Key == "" { - log.Printf("WARNING: token %d (%s) has empty key, skipping", i, t.Name) - continue - } - if t.Name == "" { - c.Tokens[i].Name = fmt.Sprintf("token-%d", i) - } - } - - return nil -} - -// ValidateBytes parses raw YAML and returns a list of validation errors. -func ValidateBytes(data []byte) []string { - expanded := os.ExpandEnv(string(data)) - var cfg Config - if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil { - return []string{"parse error: " + err.Error()} - } - if err := cfg.Validate(); err != nil { - return []string{err.Error()} - } - return nil -} - -// ProviderByName returns the provider config by name. -func (c *Config) ProviderByName(name string) *ProviderConfig { - for i := range c.Providers { - if c.Providers[i].Name == name { - return &c.Providers[i] - } - } - return nil -} diff --git a/llm-gateway/internal/config/config_test.go b/llm-gateway/internal/config/config_test.go deleted file mode 100644 index 3a65f7a..0000000 --- a/llm-gateway/internal/config/config_test.go +++ /dev/null @@ -1,737 +0,0 @@ -package config - -import ( - "fmt" - "os" - "path/filepath" - "strings" - "testing" - "time" -) - -// writeConfigFile creates a temporary YAML config file and returns its path. -func writeConfigFile(t *testing.T, content string) string { - t.Helper() - f, err := os.CreateTemp(t.TempDir(), "config-*.yaml") - if err != nil { - t.Fatalf("creating temp file: %v", err) - } - if _, err := f.WriteString(content); err != nil { - f.Close() - t.Fatalf("writing temp file: %v", err) - } - f.Close() - return f.Name() -} - -// minimalValidConfig returns a minimal valid YAML config string. -func minimalValidConfig() string { - return ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-test-key - -models: - - name: gpt-4 - routes: - - provider: openai - model: gpt-4 -` -} - -func TestLoad_ValidConfig(t *testing.T) { - path := writeConfigFile(t, ` -server: - listen: "127.0.0.1:8080" - request_timeout: 60s - streaming_timeout: 120s - max_request_body_mb: 5 - session_secret: "test-secret-1234567890abcdef1234567890abcdef" - -database: - path: "/tmp/test.db" - retention_days: 30 - -pricing_lookup: - url: "https://pricing.example.com" - refresh_interval: 1h - -circuit_breaker: - enabled: true - error_threshold: 0.3 - min_requests: 10 - cooldown_duration: 60s - -retry: - initial_backoff: 200ms - max_backoff: 10s - multiplier: 3.0 - -debug: - enabled: true - max_body_bytes: 65536 - retention_days: 60 - -cors: - enabled: true - allowed_origins: - - "https://example.com" - allowed_methods: - - GET - - POST - allowed_headers: - - Authorization - max_age: 600 - -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-test-key - priority: 2 - timeout: 60s - - name: anthropic - base_url: https://api.anthropic.com/v1 - api_key: sk-ant-test - priority: 1 - timeout: 30s - -models: - - name: gpt-4 - aliases: - - gpt4 - routes: - - provider: openai - model: gpt-4 - pricing: - input: 30.0 - output: 60.0 - load_balancing: first - - name: claude-3 - routes: - - provider: anthropic - model: claude-3-opus-20240229 - -tokens: - - name: test-token - key: tok-abc123 - rate_limit_rpm: 100 - daily_budget_usd: 10.0 - max_concurrent: 5 -`) - - cfg, err := Load(path) - if err != nil { - t.Fatalf("Load() returned error: %v", err) - } - - // Server - if cfg.Server.Listen != "127.0.0.1:8080" { - t.Errorf("Listen = %q, want %q", cfg.Server.Listen, "127.0.0.1:8080") - } - if cfg.Server.RequestTimeout != 60*time.Second { - t.Errorf("RequestTimeout = %v, want %v", cfg.Server.RequestTimeout, 60*time.Second) - } - if cfg.Server.StreamingTimeout != 120*time.Second { - t.Errorf("StreamingTimeout = %v, want %v", cfg.Server.StreamingTimeout, 120*time.Second) - } - if cfg.Server.MaxRequestBodyMB != 5 { - t.Errorf("MaxRequestBodyMB = %d, want %d", cfg.Server.MaxRequestBodyMB, 5) - } - if cfg.Server.SessionSecret != "test-secret-1234567890abcdef1234567890abcdef" { - t.Errorf("SessionSecret = %q, want %q", cfg.Server.SessionSecret, "test-secret-1234567890abcdef1234567890abcdef") - } - - // Database - if cfg.Database.Path != "/tmp/test.db" { - t.Errorf("Database.Path = %q, want %q", cfg.Database.Path, "/tmp/test.db") - } - if cfg.Database.RetentionDays != 30 { - t.Errorf("Database.RetentionDays = %d, want %d", cfg.Database.RetentionDays, 30) - } - - // Pricing - if cfg.Pricing.URL != "https://pricing.example.com" { - t.Errorf("Pricing.URL = %q, want %q", cfg.Pricing.URL, "https://pricing.example.com") - } - if cfg.Pricing.RefreshInterval != 1*time.Hour { - t.Errorf("Pricing.RefreshInterval = %v, want %v", cfg.Pricing.RefreshInterval, 1*time.Hour) - } - - // Circuit breaker - if !cfg.CircuitBreaker.Enabled { - t.Error("CircuitBreaker.Enabled = false, want true") - } - if cfg.CircuitBreaker.ErrorThreshold != 0.3 { - t.Errorf("CircuitBreaker.ErrorThreshold = %v, want %v", cfg.CircuitBreaker.ErrorThreshold, 0.3) - } - if cfg.CircuitBreaker.MinRequests != 10 { - t.Errorf("CircuitBreaker.MinRequests = %d, want %d", cfg.CircuitBreaker.MinRequests, 10) - } - if cfg.CircuitBreaker.CooldownDuration != 60*time.Second { - t.Errorf("CircuitBreaker.CooldownDuration = %v, want %v", cfg.CircuitBreaker.CooldownDuration, 60*time.Second) - } - - // Retry - if cfg.Retry.InitialBackoff != 200*time.Millisecond { - t.Errorf("Retry.InitialBackoff = %v, want %v", cfg.Retry.InitialBackoff, 200*time.Millisecond) - } - if cfg.Retry.MaxBackoff != 10*time.Second { - t.Errorf("Retry.MaxBackoff = %v, want %v", cfg.Retry.MaxBackoff, 10*time.Second) - } - if cfg.Retry.Multiplier != 3.0 { - t.Errorf("Retry.Multiplier = %v, want %v", cfg.Retry.Multiplier, 3.0) - } - - // Debug - if !cfg.Debug.Enabled { - t.Error("Debug.Enabled = false, want true") - } - if cfg.Debug.MaxBodyBytes != 65536 { - t.Errorf("Debug.MaxBodyBytes = %d, want %d", cfg.Debug.MaxBodyBytes, 65536) - } - if cfg.Debug.RetentionDays != 60 { - t.Errorf("Debug.RetentionDays = %d, want %d", cfg.Debug.RetentionDays, 60) - } - - // CORS - if !cfg.CORS.Enabled { - t.Error("CORS.Enabled = false, want true") - } - if cfg.CORS.MaxAge != 600 { - t.Errorf("CORS.MaxAge = %d, want %d", cfg.CORS.MaxAge, 600) - } - - // Providers - if len(cfg.Providers) != 2 { - t.Fatalf("len(Providers) = %d, want 2", len(cfg.Providers)) - } - if cfg.Providers[0].Name != "openai" { - t.Errorf("Providers[0].Name = %q, want %q", cfg.Providers[0].Name, "openai") - } - if cfg.Providers[0].Timeout != 60*time.Second { - t.Errorf("Providers[0].Timeout = %v, want %v", cfg.Providers[0].Timeout, 60*time.Second) - } - - // Models - if len(cfg.Models) != 2 { - t.Fatalf("len(Models) = %d, want 2", len(cfg.Models)) - } - if cfg.Models[0].LoadBalancing != "first" { - t.Errorf("Models[0].LoadBalancing = %q, want %q", cfg.Models[0].LoadBalancing, "first") - } - if len(cfg.Models[0].Aliases) != 1 || cfg.Models[0].Aliases[0] != "gpt4" { - t.Errorf("Models[0].Aliases = %v, want [gpt4]", cfg.Models[0].Aliases) - } - if cfg.Models[0].Routes[0].Pricing.Input != 30.0 { - t.Errorf("Models[0].Routes[0].Pricing.Input = %v, want 30.0", cfg.Models[0].Routes[0].Pricing.Input) - } - - // Tokens - if len(cfg.Tokens) != 1 { - t.Fatalf("len(Tokens) = %d, want 1", len(cfg.Tokens)) - } - if cfg.Tokens[0].Name != "test-token" { - t.Errorf("Tokens[0].Name = %q, want %q", cfg.Tokens[0].Name, "test-token") - } - if cfg.Tokens[0].RateLimitRPM != 100 { - t.Errorf("Tokens[0].RateLimitRPM = %d, want 100", cfg.Tokens[0].RateLimitRPM) - } -} - -func TestValidate_Defaults(t *testing.T) { - path := writeConfigFile(t, minimalValidConfig()) - cfg, err := Load(path) - if err != nil { - t.Fatalf("Load() returned error: %v", err) - } - - tests := []struct { - name string - got any - want any - }{ - // Server defaults - {"Server.Listen", cfg.Server.Listen, "0.0.0.0:3000"}, - {"Server.RequestTimeout", cfg.Server.RequestTimeout, 300 * time.Second}, - {"Server.StreamingTimeout", cfg.Server.StreamingTimeout, 5 * time.Minute}, - {"Server.MaxRequestBodyMB", cfg.Server.MaxRequestBodyMB, 10}, - - // Database defaults - {"Database.Path", cfg.Database.Path, "gateway.db"}, - {"Database.RetentionDays", cfg.Database.RetentionDays, 90}, - - // Pricing defaults - {"Pricing.RefreshInterval", cfg.Pricing.RefreshInterval, 6 * time.Hour}, - - // Circuit breaker defaults - {"CircuitBreaker.ErrorThreshold", cfg.CircuitBreaker.ErrorThreshold, 0.5}, - {"CircuitBreaker.MinRequests", cfg.CircuitBreaker.MinRequests, 5}, - {"CircuitBreaker.CooldownDuration", cfg.CircuitBreaker.CooldownDuration, 30 * time.Second}, - - // Retry defaults - {"Retry.InitialBackoff", cfg.Retry.InitialBackoff, 100 * time.Millisecond}, - {"Retry.MaxBackoff", cfg.Retry.MaxBackoff, 5 * time.Second}, - {"Retry.Multiplier", cfg.Retry.Multiplier, 2.0}, - - // Debug defaults - {"Debug.MaxBodyBytes", cfg.Debug.MaxBodyBytes, 0}, - {"Debug.RetentionDays", cfg.Debug.RetentionDays, 90}, - - // CORS defaults - {"CORS.MaxAge", cfg.CORS.MaxAge, 300}, - - // Provider defaults - {"Providers[0].Timeout", cfg.Providers[0].Timeout, 120 * time.Second}, - {"Providers[0].Priority", cfg.Providers[0].Priority, 1}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Compare using formatted strings to handle different numeric types - gotStr := formatValue(tt.got) - wantStr := formatValue(tt.want) - if gotStr != wantStr { - t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want) - } - }) - } - - // SessionSecret should be auto-generated (non-empty, 64 hex chars) - if cfg.Server.SessionSecret == "" { - t.Error("SessionSecret should be auto-generated when empty") - } - if len(cfg.Server.SessionSecret) != 64 { - t.Errorf("SessionSecret length = %d, want 64 hex chars", len(cfg.Server.SessionSecret)) - } -} - -func formatValue(v any) string { - switch val := v.(type) { - case time.Duration: - return val.String() - case float64: - return fmt.Sprintf("%g", val) - case int: - return fmt.Sprintf("%d", val) - case string: - return val - default: - return fmt.Sprintf("%v", val) - } -} - -func TestLoad_FileNotFound(t *testing.T) { - _, err := Load(filepath.Join(t.TempDir(), "nonexistent.yaml")) - if err == nil { - t.Fatal("Load() should return error for nonexistent file") - } -} - -func TestLoad_InvalidYAML(t *testing.T) { - path := writeConfigFile(t, `{{{invalid yaml`) - _, err := Load(path) - if err == nil { - t.Fatal("Load() should return error for invalid YAML") - } -} - -func TestValidate_DuplicateProviderNames(t *testing.T) { - path := writeConfigFile(t, ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key1 - - name: openai - base_url: https://api.openai.com/v2 - api_key: sk-key2 - -models: - - name: gpt-4 - routes: - - provider: openai - model: gpt-4 -`) - - _, err := Load(path) - if err == nil { - t.Fatal("Load() should return error for duplicate provider names") - } - wantSubstr := "duplicate provider name: openai" - if !strings.Contains(err.Error(), wantSubstr) { - t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr) - } -} - -func TestValidate_DuplicateModelNames(t *testing.T) { - path := writeConfigFile(t, ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key1 - -models: - - name: gpt-4 - routes: - - provider: openai - model: gpt-4 - - name: gpt-4 - routes: - - provider: openai - model: gpt-4-turbo -`) - - _, err := Load(path) - if err == nil { - t.Fatal("Load() should return error for duplicate model names") - } - wantSubstr := "duplicate model name: gpt-4" - if !strings.Contains(err.Error(), wantSubstr) { - t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr) - } -} - -func TestValidate_AliasConflicts(t *testing.T) { - tests := []struct { - name string - config string - wantErr string - }{ - { - name: "alias conflicts with model name", - config: ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key1 - -models: - - name: gpt-4 - routes: - - provider: openai - model: gpt-4 - - name: claude-3 - aliases: - - gpt-4 - routes: - - provider: openai - model: claude-3 -`, - wantErr: "model alias gpt-4 conflicts with existing model or alias", - }, - { - name: "alias conflicts with another alias", - config: ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key1 - -models: - - name: gpt-4 - aliases: - - fast-model - routes: - - provider: openai - model: gpt-4 - - name: claude-3 - aliases: - - fast-model - routes: - - provider: openai - model: claude-3 -`, - wantErr: "model alias fast-model conflicts with existing model or alias", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - path := writeConfigFile(t, tt.config) - _, err := Load(path) - if err == nil { - t.Fatal("Load() should return error for alias conflicts") - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr) - } - }) - } -} - -func TestValidate_MissingRequiredFields(t *testing.T) { - tests := []struct { - name string - config string - wantErr string - }{ - { - name: "no providers", - config: `models: [{name: test, routes: [{provider: x, model: y}]}]`, - wantErr: "at least one provider is required", - }, - { - name: "no models", - config: ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key -`, - wantErr: "at least one model is required", - }, - { - name: "provider missing name", - config: ` -providers: - - base_url: https://api.openai.com/v1 - api_key: sk-key - -models: - - name: gpt-4 - routes: - - provider: openai - model: gpt-4 -`, - wantErr: "provider 0: name, base_url, and api_key are required", - }, - { - name: "provider missing base_url", - config: ` -providers: - - name: openai - api_key: sk-key - -models: - - name: gpt-4 - routes: - - provider: openai - model: gpt-4 -`, - wantErr: "provider 0: name, base_url, and api_key are required", - }, - { - name: "provider missing api_key", - config: ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - -models: - - name: gpt-4 - routes: - - provider: openai - model: gpt-4 -`, - wantErr: "provider 0: name, base_url, and api_key are required", - }, - { - name: "model missing name", - config: ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key - -models: - - routes: - - provider: openai - model: gpt-4 -`, - wantErr: "model 0: name is required", - }, - { - name: "model missing routes", - config: ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key - -models: - - name: gpt-4 -`, - wantErr: "model gpt-4: at least one route is required", - }, - { - name: "route missing provider", - config: ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key - -models: - - name: gpt-4 - routes: - - model: gpt-4 -`, - wantErr: "model gpt-4 route 0: provider and model are required", - }, - { - name: "route missing model", - config: ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key - -models: - - name: gpt-4 - routes: - - provider: openai -`, - wantErr: "model gpt-4 route 0: provider and model are required", - }, - { - name: "route references unknown provider", - config: ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-key - -models: - - name: gpt-4 - routes: - - provider: anthropic - model: gpt-4 -`, - wantErr: "model gpt-4 route 0: unknown provider anthropic", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - path := writeConfigFile(t, tt.config) - _, err := Load(path) - if err == nil { - t.Fatalf("Load() should return error, want %q", tt.wantErr) - } - if !strings.Contains(err.Error(), tt.wantErr) { - t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr) - } - }) - } -} - -func TestProviderByName(t *testing.T) { - path := writeConfigFile(t, ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: sk-openai - - name: anthropic - base_url: https://api.anthropic.com/v1 - api_key: sk-anthropic - -models: - - name: gpt-4 - routes: - - provider: openai - model: gpt-4 -`) - - cfg, err := Load(path) - if err != nil { - t.Fatalf("Load() returned error: %v", err) - } - - tests := []struct { - name string - lookup string - wantNil bool - wantName string - }{ - {"existing provider openai", "openai", false, "openai"}, - {"existing provider anthropic", "anthropic", false, "anthropic"}, - {"nonexistent provider", "google", true, ""}, - {"empty name", "", true, ""}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := cfg.ProviderByName(tt.lookup) - if tt.wantNil { - if p != nil { - t.Errorf("ProviderByName(%q) = %v, want nil", tt.lookup, p) - } - } else { - if p == nil { - t.Fatalf("ProviderByName(%q) = nil, want provider", tt.lookup) - } - if p.Name != tt.wantName { - t.Errorf("ProviderByName(%q).Name = %q, want %q", tt.lookup, p.Name, tt.wantName) - } - } - }) - } - - // Verify returned pointer refers to the actual config entry - p := cfg.ProviderByName("openai") - if p.APIKey != "sk-openai" { - t.Errorf("ProviderByName(openai).APIKey = %q, want %q", p.APIKey, "sk-openai") - } -} - -func TestLoad_EnvironmentVariableExpansion(t *testing.T) { - t.Setenv("TEST_API_KEY", "sk-from-env") - t.Setenv("TEST_BASE_URL", "https://env.example.com/v1") - t.Setenv("TEST_PROVIDER_NAME", "env-provider") - - path := writeConfigFile(t, ` -providers: - - name: $TEST_PROVIDER_NAME - base_url: ${TEST_BASE_URL} - api_key: ${TEST_API_KEY} - -models: - - name: test-model - routes: - - provider: env-provider - model: gpt-4 -`) - - cfg, err := Load(path) - if err != nil { - t.Fatalf("Load() returned error: %v", err) - } - - if cfg.Providers[0].Name != "env-provider" { - t.Errorf("Provider.Name = %q, want %q", cfg.Providers[0].Name, "env-provider") - } - if cfg.Providers[0].BaseURL != "https://env.example.com/v1" { - t.Errorf("Provider.BaseURL = %q, want %q", cfg.Providers[0].BaseURL, "https://env.example.com/v1") - } - if cfg.Providers[0].APIKey != "sk-from-env" { - t.Errorf("Provider.APIKey = %q, want %q", cfg.Providers[0].APIKey, "sk-from-env") - } -} - -func TestLoad_UnsetEnvVarExpandsToEmpty(t *testing.T) { - // Ensure the variable is not set - t.Setenv("TEST_UNSET_VAR", "") - os.Unsetenv("TEST_UNSET_VAR") - - path := writeConfigFile(t, ` -providers: - - name: openai - base_url: https://api.openai.com/v1 - api_key: ${TEST_UNSET_VAR} - -models: - - name: gpt-4 - routes: - - provider: openai - model: gpt-4 -`) - - _, err := Load(path) - if err == nil { - t.Fatal("Load() should return error when env var expands to empty required field") - } - // api_key will be empty, so validation should catch it - if !strings.Contains(err.Error(), "api_key are required") { - t.Errorf("error = %q, want to contain api_key validation message", err.Error()) - } -} diff --git a/llm-gateway/internal/config/watcher.go b/llm-gateway/internal/config/watcher.go deleted file mode 100644 index f63a478..0000000 --- a/llm-gateway/internal/config/watcher.go +++ /dev/null @@ -1,27 +0,0 @@ -package config - -import ( - "log" - "os" - "os/signal" - "syscall" -) - -// WatchReload listens for SIGHUP and calls the callback with the new config. -func WatchReload(configPath string, callback func(*Config)) { - sighup := make(chan os.Signal, 1) - signal.Notify(sighup, syscall.SIGHUP) - - go func() { - for range sighup { - log.Println("SIGHUP received, reloading config...") - newCfg, err := Load(configPath) - if err != nil { - log.Printf("ERROR: config reload failed: %v", err) - continue - } - callback(newCfg) - log.Println("Config reloaded successfully") - } - }() -} diff --git a/llm-gateway/internal/dashboard/api.go b/llm-gateway/internal/dashboard/api.go deleted file mode 100644 index fe75a2d..0000000 --- a/llm-gateway/internal/dashboard/api.go +++ /dev/null @@ -1,775 +0,0 @@ -package dashboard - -import ( - "encoding/json" - "net/http" - "os" - "sort" - "strconv" - "time" - - "github.com/go-chi/chi/v5" - - "llm-gateway/internal/auth" - "llm-gateway/internal/cache" - "llm-gateway/internal/config" - "llm-gateway/internal/provider" - "llm-gateway/internal/storage" -) - -// Exported types for template rendering and JSON API. - -type Period struct { - Requests int `json:"requests"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CostUSD float64 `json:"cost_usd"` - Errors int `json:"errors"` - CachedHits int `json:"cached_hits"` -} - -type SummaryResult struct { - Today *Period `json:"today"` - Week *Period `json:"week"` - Month *Period `json:"month"` -} - -type ModelStats struct { - Model string `json:"model"` - Requests int `json:"requests"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CostUSD float64 `json:"cost_usd"` - AvgLatencyMS float64 `json:"avg_latency_ms"` -} - -type ProviderStats struct { - Provider string `json:"provider"` - Requests int `json:"requests"` - Successes int `json:"successes"` - Errors int `json:"errors"` - AvgLatencyMS float64 `json:"avg_latency_ms"` - CostUSD float64 `json:"cost_usd"` -} - -type TokenUsageStats struct { - TokenName string `json:"token_name"` - Requests int `json:"requests"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CostUSD float64 `json:"cost_usd"` -} - -// RequestLogEntry represents a single request log row. -type RequestLogEntry struct { - RequestID string `json:"request_id"` - Timestamp int64 `json:"timestamp"` - TokenName string `json:"token_name"` - Model string `json:"model"` - Provider string `json:"provider"` - ProviderModel string `json:"provider_model"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CostUSD float64 `json:"cost_usd"` - LatencyMS int64 `json:"latency_ms"` - Status string `json:"status"` - ErrorMessage string `json:"error_message"` - Streaming bool `json:"streaming"` - Cached bool `json:"cached"` -} - -// LogsResult holds paginated logs. -type LogsResult struct { - Logs []RequestLogEntry `json:"logs"` - Page int `json:"page"` - TotalPages int `json:"total_pages"` - Total int `json:"total"` -} - -// LatencyResult holds latency percentiles. -type LatencyResult struct { - P50 float64 `json:"p50"` - P95 float64 `json:"p95"` - P99 float64 `json:"p99"` - Avg float64 `json:"avg"` - Min float64 `json:"min"` - Max float64 `json:"max"` -} - -// CostBreakdownEntry holds cost data grouped by day and dimension. -type CostBreakdownEntry struct { - Day string `json:"day"` - GroupBy string `json:"group_by"` - CostUSD float64 `json:"cost_usd"` - Requests int `json:"requests"` -} - -type StatsAPI struct { - db *storage.DB - authStore *auth.Store - healthTracker *provider.HealthTracker - cache *cache.Cache - auditLogger *storage.AuditLogger - debugLogger *storage.DebugLogger - configPath string -} - -func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI { - return &StatsAPI{db: db, authStore: authStore} -} - -// SetHealthTracker sets the provider health tracker. -func (s *StatsAPI) SetHealthTracker(ht *provider.HealthTracker) { - s.healthTracker = ht -} - -// SetCache sets the cache for stats. -func (s *StatsAPI) SetCache(c *cache.Cache) { - s.cache = c -} - -// SetAuditLogger sets the audit logger. -func (s *StatsAPI) SetAuditLogger(al *storage.AuditLogger) { - s.auditLogger = al -} - -// SetDebugLogger sets the debug logger. -func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) { - s.debugLogger = dl -} - -// SetConfigPath sets the config file path for validation. -func (s *StatsAPI) SetConfigPath(path string) { - s.configPath = path -} - -// TokenNamesForUser returns the token names that belong to the user. -// Admins get nil (no filter), non-admins get their token names. -func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string { - if user == nil || user.IsAdmin { - return nil - } - tokens, err := s.authStore.ListAPITokens(user.ID) - if err != nil { - return []string{"__none__"} - } - names := make([]string, len(tokens)) - for i, t := range tokens { - names[i] = t.Name - } - if len(names) == 0 { - return []string{"__none__"} - } - return names -} - -// tokenNamesForUser returns token names from request context (for HTTP handlers). -func (s *StatsAPI) tokenNamesForUser(r *http.Request) []string { - user := auth.UserFromContext(r.Context()) - return s.TokenNamesForUser(user) -} - -func buildTokenFilter(tokenNames []string) (string, []any) { - if tokenNames == nil { - return "", nil - } - placeholders := "" - args := make([]any, len(tokenNames)) - for i, n := range tokenNames { - if i > 0 { - placeholders += "," - } - placeholders += "?" - args[i] = n - } - return " AND token_name IN (" + placeholders + ")", args -} - -// Data-fetching methods (used by both JSON handlers and template handlers). - -func (s *StatsAPI) GetSummary(tokenNames []string) *SummaryResult { - now := time.Now() - todayStart := now.Truncate(24 * time.Hour).Unix() - weekStart := now.AddDate(0, 0, -7).Unix() - monthStart := now.AddDate(0, -1, 0).Unix() - - tokenFilter, filterArgs := buildTokenFilter(tokenNames) - - result := &SummaryResult{ - Today: &Period{}, - Week: &Period{}, - Month: &Period{}, - } - - periods := map[string]struct { - since int64 - period *Period - }{ - "today": {todayStart, result.Today}, - "week": {weekStart, result.Week}, - "month": {monthStart, result.Month}, - } - - for _, p := range periods { - args := append([]any{p.since}, filterArgs...) - row := s.db.QueryRow(`SELECT - COUNT(*), - COALESCE(SUM(input_tokens), 0), - COALESCE(SUM(output_tokens), 0), - COALESCE(SUM(cost_usd), 0), - COALESCE(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END), 0), - COALESCE(SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END), 0) - FROM request_logs WHERE timestamp >= ?`+tokenFilter, args...) - row.Scan(&p.period.Requests, &p.period.InputTokens, &p.period.OutputTokens, &p.period.CostUSD, &p.period.Errors, &p.period.CachedHits) - } - - return result -} - -func (s *StatsAPI) GetModels(tokenNames []string) []ModelStats { - since := time.Now().AddDate(0, 0, -30).Unix() - tokenFilter, filterArgs := buildTokenFilter(tokenNames) - - args := append([]any{since}, filterArgs...) - rows, err := s.db.Query(`SELECT - model, - COUNT(*) as requests, - COALESCE(SUM(input_tokens), 0) as input_tokens, - COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cost_usd), 0) as cost, - COALESCE(AVG(latency_ms), 0) as avg_latency - FROM request_logs WHERE timestamp >= ?`+tokenFilter+` - GROUP BY model ORDER BY requests DESC`, args...) - if err != nil { - return nil - } - defer rows.Close() - - var results []ModelStats - for rows.Next() { - var m ModelStats - rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS) - results = append(results, m) - } - return results -} - -func (s *StatsAPI) GetProviders(tokenNames []string) []ProviderStats { - since := time.Now().AddDate(0, 0, -30).Unix() - tokenFilter, filterArgs := buildTokenFilter(tokenNames) - - args := append([]any{since}, filterArgs...) - rows, err := s.db.Query(`SELECT - provider, - COUNT(*) as requests, - COALESCE(SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END), 0) as successes, - COALESCE(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END), 0) as errors, - COALESCE(AVG(latency_ms), 0) as avg_latency, - COALESCE(SUM(cost_usd), 0) as cost - FROM request_logs WHERE timestamp >= ?`+tokenFilter+` - GROUP BY provider ORDER BY requests DESC`, args...) - if err != nil { - return nil - } - defer rows.Close() - - var results []ProviderStats - for rows.Next() { - var p ProviderStats - rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD) - results = append(results, p) - } - return results -} - -func (s *StatsAPI) GetTokenUsage(tokenNames []string) []TokenUsageStats { - since := time.Now().AddDate(0, 0, -30).Unix() - tokenFilter, filterArgs := buildTokenFilter(tokenNames) - - args := append([]any{since}, filterArgs...) - rows, err := s.db.Query(`SELECT - token_name, - COUNT(*) as requests, - COALESCE(SUM(input_tokens), 0) as input_tokens, - COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cost_usd), 0) as cost - FROM request_logs WHERE timestamp >= ?`+tokenFilter+` - GROUP BY token_name ORDER BY requests DESC`, args...) - if err != nil { - return nil - } - defer rows.Close() - - var results []TokenUsageStats - for rows.Next() { - var t TokenUsageStats - rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD) - results = append(results, t) - } - return results -} - -// GetLogs returns paginated request logs with filters. -func (s *StatsAPI) GetLogs(tokenNames []string, page int, model, token, status string) *LogsResult { - if page < 1 { - page = 1 - } - limit := 50 - offset := (page - 1) * limit - - tokenFilter, filterArgs := buildTokenFilter(tokenNames) - - where := "WHERE 1=1" + tokenFilter - args := make([]any, 0) - args = append(args, filterArgs...) - - if model != "" { - where += " AND model = ?" - args = append(args, model) - } - if token != "" { - where += " AND token_name = ?" - args = append(args, token) - } - if status != "" { - where += " AND status = ?" - args = append(args, status) - } - - // Get total count - var total int - countArgs := make([]any, len(args)) - copy(countArgs, args) - s.db.QueryRow("SELECT COUNT(*) FROM request_logs "+where, countArgs...).Scan(&total) - - totalPages := (total + limit - 1) / limit - if totalPages < 1 { - totalPages = 1 - } - - // Get page - query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model, - input_tokens, output_tokens, cost_usd, latency_ms, status, - COALESCE(error_message, ''), streaming, cached - FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?` - args = append(args, limit, offset) - - rows, err := s.db.Query(query, args...) - if err != nil { - return &LogsResult{Logs: []RequestLogEntry{}, Page: page, TotalPages: totalPages, Total: total} - } - defer rows.Close() - - var logs []RequestLogEntry - for rows.Next() { - var l RequestLogEntry - var streaming, cached int - rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel, - &l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status, - &l.ErrorMessage, &streaming, &cached) - l.Streaming = streaming == 1 - l.Cached = cached == 1 - logs = append(logs, l) - } - if logs == nil { - logs = []RequestLogEntry{} - } - - return &LogsResult{ - Logs: logs, - Page: page, - TotalPages: totalPages, - Total: total, - } -} - -// GetDistinctModels returns distinct model names from logs. -func (s *StatsAPI) GetDistinctModels() []string { - rows, err := s.db.Query("SELECT DISTINCT model FROM request_logs ORDER BY model") - if err != nil { - return nil - } - defer rows.Close() - var models []string - for rows.Next() { - var m string - rows.Scan(&m) - models = append(models, m) - } - return models -} - -// GetDistinctTokens returns distinct token names from logs. -func (s *StatsAPI) GetDistinctTokens() []string { - rows, err := s.db.Query("SELECT DISTINCT token_name FROM request_logs ORDER BY token_name") - if err != nil { - return nil - } - defer rows.Close() - var tokens []string - for rows.Next() { - var t string - rows.Scan(&t) - tokens = append(tokens, t) - } - return tokens -} - -// GetLatency computes latency percentiles from request_logs. -func (s *StatsAPI) GetLatency(tokenNames []string, period, model, providerName string) *LatencyResult { - var since int64 - switch period { - case "7d": - since = time.Now().AddDate(0, 0, -7).Unix() - case "30d": - since = time.Now().AddDate(0, -1, 0).Unix() - default: - since = time.Now().Add(-24 * time.Hour).Unix() - } - - tokenFilter, filterArgs := buildTokenFilter(tokenNames) - - where := "WHERE timestamp >= ? AND status = 'success'" + tokenFilter - args := []any{since} - args = append(args, filterArgs...) - - if model != "" { - where += " AND model = ?" - args = append(args, model) - } - if providerName != "" { - where += " AND provider = ?" - args = append(args, providerName) - } - - rows, err := s.db.Query("SELECT latency_ms FROM request_logs "+where+" ORDER BY latency_ms", args...) - if err != nil { - return &LatencyResult{} - } - defer rows.Close() - - var latencies []float64 - for rows.Next() { - var l float64 - rows.Scan(&l) - latencies = append(latencies, l) - } - - if len(latencies) == 0 { - return &LatencyResult{} - } - - sort.Float64s(latencies) - n := len(latencies) - var sum float64 - for _, l := range latencies { - sum += l - } - - return &LatencyResult{ - P50: latencies[n*50/100], - P95: latencies[n*95/100], - P99: latencies[min(n*99/100, n-1)], - Avg: sum / float64(n), - Min: latencies[0], - Max: latencies[n-1], - } -} - -// GetCostBreakdown returns cost data grouped by day and dimension. -func (s *StatsAPI) GetCostBreakdown(tokenNames []string, period, groupBy string) []CostBreakdownEntry { - var since int64 - switch period { - case "30d": - since = time.Now().AddDate(0, -1, 0).Unix() - case "7d": - since = time.Now().AddDate(0, 0, -7).Unix() - default: - since = time.Now().Add(-24 * time.Hour).Unix() - } - - tokenFilter, filterArgs := buildTokenFilter(tokenNames) - - groupCol := "model" - if groupBy == "token" { - groupCol = "token_name" - } else if groupBy == "provider" { - groupCol = "provider" - } - - args := []any{since} - args = append(args, filterArgs...) - - query := `SELECT date(timestamp, 'unixepoch') as day, ` + groupCol + `, - COALESCE(SUM(cost_usd), 0), COUNT(*) - FROM request_logs WHERE timestamp >= ?` + tokenFilter + ` - GROUP BY day, ` + groupCol + ` ORDER BY day, ` + groupCol - - rows, err := s.db.Query(query, args...) - if err != nil { - return nil - } - defer rows.Close() - - var results []CostBreakdownEntry - for rows.Next() { - var e CostBreakdownEntry - rows.Scan(&e.Day, &e.GroupBy, &e.CostUSD, &e.Requests) - results = append(results, e) - } - return results -} - -// JSON HTTP handlers (thin wrappers). - -func (s *StatsAPI) Summary(w http.ResponseWriter, r *http.Request) { - tokenNames := s.tokenNamesForUser(r) - result := s.GetSummary(tokenNames) - writeJSON(w, result) -} - -func (s *StatsAPI) Models(w http.ResponseWriter, r *http.Request) { - tokenNames := s.tokenNamesForUser(r) - results := s.GetModels(tokenNames) - writeJSON(w, results) -} - -func (s *StatsAPI) Providers(w http.ResponseWriter, r *http.Request) { - tokenNames := s.tokenNamesForUser(r) - results := s.GetProviders(tokenNames) - writeJSON(w, results) -} - -func (s *StatsAPI) Tokens(w http.ResponseWriter, r *http.Request) { - tokenNames := s.tokenNamesForUser(r) - results := s.GetTokenUsage(tokenNames) - writeJSON(w, results) -} - -func (s *StatsAPI) Timeseries(w http.ResponseWriter, r *http.Request) { - period := r.URL.Query().Get("period") - var since int64 - var groupFmt string - switch period { - case "7d": - since = time.Now().AddDate(0, 0, -7).Unix() - groupFmt = "%Y-%m-%d" - case "30d": - since = time.Now().AddDate(0, -1, 0).Unix() - groupFmt = "%Y-%m-%d" - default: - since = time.Now().Add(-24 * time.Hour).Unix() - groupFmt = "%Y-%m-%d %H:00" - } - - tokenNames := s.tokenNamesForUser(r) - tokenFilter, filterArgs := buildTokenFilter(tokenNames) - - args := append([]any{since}, filterArgs...) - rows, err := s.db.Query(`SELECT - strftime('`+groupFmt+`', timestamp, 'unixepoch') as bucket, - COUNT(*) as requests, - COALESCE(SUM(cost_usd), 0) as cost, - COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens - FROM request_logs WHERE timestamp >= ?`+tokenFilter+` - GROUP BY bucket ORDER BY bucket`, args...) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer rows.Close() - - type point struct { - Bucket string `json:"bucket"` - Requests int `json:"requests"` - CostUSD float64 `json:"cost_usd"` - TotalTokens int `json:"total_tokens"` - } - - var results []point - for rows.Next() { - var p point - rows.Scan(&p.Bucket, &p.Requests, &p.CostUSD, &p.TotalTokens) - results = append(results, p) - } - writeJSON(w, results) -} - -// Logs serves the paginated logs API. -func (s *StatsAPI) Logs(w http.ResponseWriter, r *http.Request) { - tokenNames := s.tokenNamesForUser(r) - page, _ := strconv.Atoi(r.URL.Query().Get("page")) - model := r.URL.Query().Get("model") - token := r.URL.Query().Get("token") - status := r.URL.Query().Get("status") - result := s.GetLogs(tokenNames, page, model, token, status) - writeJSON(w, result) -} - -// Latency serves latency percentiles API. -func (s *StatsAPI) Latency(w http.ResponseWriter, r *http.Request) { - tokenNames := s.tokenNamesForUser(r) - period := r.URL.Query().Get("period") - model := r.URL.Query().Get("model") - providerName := r.URL.Query().Get("provider") - result := s.GetLatency(tokenNames, period, model, providerName) - writeJSON(w, result) -} - -// CostBreakdown serves cost breakdown API. -func (s *StatsAPI) CostBreakdown(w http.ResponseWriter, r *http.Request) { - tokenNames := s.tokenNamesForUser(r) - period := r.URL.Query().Get("period") - groupBy := r.URL.Query().Get("group_by") - if groupBy == "" { - groupBy = "model" - } - result := s.GetCostBreakdown(tokenNames, period, groupBy) - writeJSON(w, result) -} - -// ProviderHealthHandler serves provider health status API. -func (s *StatsAPI) ProviderHealthHandler(w http.ResponseWriter, r *http.Request) { - if s.healthTracker == nil { - writeJSON(w, []provider.ProviderHealth{}) - return - } - writeJSON(w, s.healthTracker.Status()) -} - -// CacheStats serves cache statistics API. -func (s *StatsAPI) CacheStats(w http.ResponseWriter, r *http.Request) { - if s.cache == nil { - writeJSON(w, map[string]any{"enabled": false}) - return - } - stats := s.cache.Stats(r.Context()) - writeJSON(w, stats) -} - -// AuditLogs serves the audit log API (admin-only). -func (s *StatsAPI) AuditLogs(w http.ResponseWriter, r *http.Request) { - if s.auditLogger == nil { - writeJSON(w, map[string]any{"entries": []any{}, "total": 0}) - return - } - page, _ := strconv.Atoi(r.URL.Query().Get("page")) - action := r.URL.Query().Get("action") - since := time.Now().AddDate(0, 0, -30).Unix() - if sinceStr := r.URL.Query().Get("since"); sinceStr != "" { - if s, err := strconv.ParseInt(sinceStr, 10, 64); err == nil { - since = s - } - } - result := s.auditLogger.Query(since, action, page, 50) - writeJSON(w, result) -} - -// DebugToggle enables/disables debug logging at runtime. -func (s *StatsAPI) DebugToggle(w http.ResponseWriter, r *http.Request) { - if s.debugLogger == nil { - writeJSON(w, map[string]any{"error": "debug logger not configured"}) - return - } - var req struct { - Enabled bool `json:"enabled"` - } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - w.WriteHeader(http.StatusBadRequest) - writeJSON(w, map[string]string{"error": "invalid JSON"}) - return - } - s.debugLogger.SetEnabled(req.Enabled) - writeJSON(w, map[string]any{"enabled": s.debugLogger.IsEnabled()}) -} - -// DebugStatus returns whether debug logging is enabled. -func (s *StatsAPI) DebugStatus(w http.ResponseWriter, r *http.Request) { - enabled := false - if s.debugLogger != nil { - enabled = s.debugLogger.IsEnabled() - } - writeJSON(w, map[string]any{"enabled": enabled}) -} - -// DebugLogs serves paginated debug log entries. -func (s *StatsAPI) DebugLogs(w http.ResponseWriter, r *http.Request) { - if s.debugLogger == nil { - writeJSON(w, map[string]any{"entries": []any{}, "total": 0}) - return - } - page, _ := strconv.Atoi(r.URL.Query().Get("page")) - result := s.debugLogger.Query(page, 50) - writeJSON(w, result) -} - -// DebugLogByRequestID serves a single debug log entry by request ID. -func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) { - if s.debugLogger == nil { - w.WriteHeader(http.StatusNotFound) - writeJSON(w, map[string]string{"error": "debug logger not configured"}) - return - } - requestID := chi.URLParam(r, "requestID") - entry := s.debugLogger.GetByRequestID(requestID) - if entry == nil { - w.WriteHeader(http.StatusNotFound) - writeJSON(w, map[string]string{"error": "not found"}) - return - } - writeJSON(w, entry) -} - -// ValidateConfig validates the config file at the stored path. -// Returns HTML for HTMX requests, JSON otherwise. -func (s *StatsAPI) ValidateConfig(w http.ResponseWriter, r *http.Request) { - if s.configPath == "" { - if r.Header.Get("HX-Request") == "true" { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Write([]byte(`
Config path not set
`)) - } else { - w.WriteHeader(http.StatusInternalServerError) - writeJSON(w, map[string]any{"valid": false, "errors": []string{"config path not set"}}) - } - return - } - data, err := os.ReadFile(s.configPath) - if err != nil { - msg := "failed to read config: " + err.Error() - if r.Header.Get("HX-Request") == "true" { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.Write([]byte(`
` + msg + `
`)) - } else { - w.WriteHeader(http.StatusInternalServerError) - writeJSON(w, map[string]any{"valid": false, "errors": []string{msg}}) - } - return - } - errs := config.ValidateBytes(data) - - if r.Header.Get("HX-Request") == "true" { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - if len(errs) > 0 { - html := `
Configuration errors:
" - w.Write([]byte(html)) - } else { - w.Write([]byte(`
Configuration is valid.
`)) - } - return - } - - if len(errs) > 0 { - writeJSON(w, map[string]any{"valid": false, "errors": errs}) - return - } - writeJSON(w, map[string]any{"valid": true, "errors": []string{}}) -} - -func writeJSON(w http.ResponseWriter, v any) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(v) -} diff --git a/llm-gateway/internal/dashboard/export.go b/llm-gateway/internal/dashboard/export.go deleted file mode 100644 index 2ac0875..0000000 --- a/llm-gateway/internal/dashboard/export.go +++ /dev/null @@ -1,297 +0,0 @@ -package dashboard - -import ( - "encoding/csv" - "encoding/json" - "fmt" - "net/http" - "strconv" - "time" - - "llm-gateway/internal/auth" - "llm-gateway/internal/storage" -) - -type ExportHandler struct { - db *storage.DB - authStore *auth.Store -} - -func NewExportHandler(db *storage.DB, authStore *auth.Store) *ExportHandler { - return &ExportHandler{db: db, authStore: authStore} -} - -// ExportLogs exports request logs as CSV or JSON. -func (e *ExportHandler) ExportLogs(w http.ResponseWriter, r *http.Request) { - format := r.URL.Query().Get("format") - if format == "" { - format = "json" - } - - // Build query - where := "WHERE 1=1" - var args []any - - if from := r.URL.Query().Get("from"); from != "" { - if ts, err := strconv.ParseInt(from, 10, 64); err == nil { - where += " AND timestamp >= ?" - args = append(args, ts) - } - } - if to := r.URL.Query().Get("to"); to != "" { - if ts, err := strconv.ParseInt(to, 10, 64); err == nil { - where += " AND timestamp <= ?" - args = append(args, ts) - } - } - if model := r.URL.Query().Get("model"); model != "" { - where += " AND model = ?" - args = append(args, model) - } - if token := r.URL.Query().Get("token"); token != "" { - where += " AND token_name = ?" - args = append(args, token) - } - if status := r.URL.Query().Get("status"); status != "" { - where += " AND status = ?" - args = append(args, status) - } - - // Token filtering for non-admins - user := auth.UserFromContext(r.Context()) - if user != nil && !user.IsAdmin { - tokens, err := e.authStore.ListAPITokens(user.ID) - if err != nil || len(tokens) == 0 { - where += " AND 1=0" - } else { - where += " AND token_name IN (" - for i, t := range tokens { - if i > 0 { - where += "," - } - where += "?" - args = append(args, t.Name) - } - where += ")" - } - } - - query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model, - input_tokens, output_tokens, cost_usd, latency_ms, status, - COALESCE(error_message, ''), streaming, cached - FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT 100000` - - rows, err := e.db.Query(query, args...) - if err != nil { - http.Error(w, "query failed", http.StatusInternalServerError) - return - } - defer rows.Close() - - type logRow struct { - RequestID string `json:"request_id"` - Timestamp int64 `json:"timestamp"` - TokenName string `json:"token_name"` - Model string `json:"model"` - Provider string `json:"provider"` - ProviderModel string `json:"provider_model"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CostUSD float64 `json:"cost_usd"` - LatencyMS int64 `json:"latency_ms"` - Status string `json:"status"` - ErrorMessage string `json:"error_message"` - Streaming bool `json:"streaming"` - Cached bool `json:"cached"` - } - - var results []logRow - for rows.Next() { - var l logRow - var streaming, cached int - rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel, - &l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status, - &l.ErrorMessage, &streaming, &cached) - l.Streaming = streaming == 1 - l.Cached = cached == 1 - results = append(results, l) - } - - now := time.Now().Format("20060102-150405") - - switch format { - case "csv": - w.Header().Set("Content-Type", "text/csv") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.csv", now)) - writer := csv.NewWriter(w) - writer.Write([]string{"request_id", "timestamp", "token_name", "model", "provider", "provider_model", - "input_tokens", "output_tokens", "cost_usd", "latency_ms", "status", "error_message", "streaming", "cached"}) - for _, l := range results { - writer.Write([]string{ - l.RequestID, - strconv.FormatInt(l.Timestamp, 10), - l.TokenName, l.Model, l.Provider, l.ProviderModel, - strconv.Itoa(l.InputTokens), strconv.Itoa(l.OutputTokens), - fmt.Sprintf("%.8f", l.CostUSD), - strconv.FormatInt(l.LatencyMS, 10), - l.Status, l.ErrorMessage, - strconv.FormatBool(l.Streaming), strconv.FormatBool(l.Cached), - }) - } - writer.Flush() - default: - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.json", now)) - json.NewEncoder(w).Encode(results) - } -} - -// ExportStats exports aggregated stats as CSV or JSON. -func (e *ExportHandler) ExportStats(w http.ResponseWriter, r *http.Request) { - format := r.URL.Query().Get("format") - if format == "" { - format = "json" - } - statsType := r.URL.Query().Get("type") - if statsType == "" { - statsType = "summary" - } - - now := time.Now().Format("20060102-150405") - since := time.Now().AddDate(0, -1, 0).Unix() - - switch statsType { - case "models": - rows, err := e.db.Query(`SELECT model, COUNT(*) as requests, - COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), - COALESCE(SUM(cost_usd), 0), COALESCE(AVG(latency_ms), 0) - FROM request_logs WHERE timestamp >= ? GROUP BY model ORDER BY requests DESC`, since) - if err != nil { - http.Error(w, "query failed", http.StatusInternalServerError) - return - } - defer rows.Close() - - type modelRow struct { - Model string `json:"model"` - Requests int `json:"requests"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CostUSD float64 `json:"cost_usd"` - AvgLatencyMS float64 `json:"avg_latency_ms"` - } - var results []modelRow - for rows.Next() { - var m modelRow - rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS) - results = append(results, m) - } - - if format == "csv" { - w.Header().Set("Content-Type", "text/csv") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.csv", now)) - writer := csv.NewWriter(w) - writer.Write([]string{"model", "requests", "input_tokens", "output_tokens", "cost_usd", "avg_latency_ms"}) - for _, m := range results { - writer.Write([]string{m.Model, strconv.Itoa(m.Requests), strconv.Itoa(m.InputTokens), - strconv.Itoa(m.OutputTokens), fmt.Sprintf("%.8f", m.CostUSD), fmt.Sprintf("%.2f", m.AvgLatencyMS)}) - } - writer.Flush() - } else { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.json", now)) - json.NewEncoder(w).Encode(results) - } - - case "providers": - rows, err := e.db.Query(`SELECT provider, COUNT(*) as requests, - COALESCE(SUM(CASE WHEN status='success' THEN 1 ELSE 0 END), 0), - COALESCE(SUM(CASE WHEN status='error' THEN 1 ELSE 0 END), 0), - COALESCE(AVG(latency_ms), 0), COALESCE(SUM(cost_usd), 0) - FROM request_logs WHERE timestamp >= ? GROUP BY provider ORDER BY requests DESC`, since) - if err != nil { - http.Error(w, "query failed", http.StatusInternalServerError) - return - } - defer rows.Close() - - type providerRow struct { - Provider string `json:"provider"` - Requests int `json:"requests"` - Successes int `json:"successes"` - Errors int `json:"errors"` - AvgLatencyMS float64 `json:"avg_latency_ms"` - CostUSD float64 `json:"cost_usd"` - } - var results []providerRow - for rows.Next() { - var p providerRow - rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD) - results = append(results, p) - } - - if format == "csv" { - w.Header().Set("Content-Type", "text/csv") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.csv", now)) - writer := csv.NewWriter(w) - writer.Write([]string{"provider", "requests", "successes", "errors", "avg_latency_ms", "cost_usd"}) - for _, p := range results { - writer.Write([]string{p.Provider, strconv.Itoa(p.Requests), strconv.Itoa(p.Successes), - strconv.Itoa(p.Errors), fmt.Sprintf("%.2f", p.AvgLatencyMS), fmt.Sprintf("%.8f", p.CostUSD)}) - } - writer.Flush() - } else { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.json", now)) - json.NewEncoder(w).Encode(results) - } - - case "tokens": - rows, err := e.db.Query(`SELECT token_name, COUNT(*) as requests, - COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), - COALESCE(SUM(cost_usd), 0) - FROM request_logs WHERE timestamp >= ? GROUP BY token_name ORDER BY requests DESC`, since) - if err != nil { - http.Error(w, "query failed", http.StatusInternalServerError) - return - } - defer rows.Close() - - type tokenRow struct { - TokenName string `json:"token_name"` - Requests int `json:"requests"` - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CostUSD float64 `json:"cost_usd"` - } - var results []tokenRow - for rows.Next() { - var t tokenRow - rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD) - results = append(results, t) - } - - if format == "csv" { - w.Header().Set("Content-Type", "text/csv") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.csv", now)) - writer := csv.NewWriter(w) - writer.Write([]string{"token_name", "requests", "input_tokens", "output_tokens", "cost_usd"}) - for _, t := range results { - writer.Write([]string{t.TokenName, strconv.Itoa(t.Requests), strconv.Itoa(t.InputTokens), - strconv.Itoa(t.OutputTokens), fmt.Sprintf("%.8f", t.CostUSD)}) - } - writer.Flush() - } else { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.json", now)) - json.NewEncoder(w).Encode(results) - } - - default: // summary - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-summary-%s.json", now)) - statsAPI := NewStatsAPI(e.db, e.authStore) - result := statsAPI.GetSummary(nil) - json.NewEncoder(w).Encode(result) - } -} diff --git a/llm-gateway/internal/dashboard/handler.go b/llm-gateway/internal/dashboard/handler.go deleted file mode 100644 index e672637..0000000 --- a/llm-gateway/internal/dashboard/handler.go +++ /dev/null @@ -1,426 +0,0 @@ -package dashboard - -import ( - "embed" - "fmt" - "html/template" - "net/http" - "strconv" - "time" - - "llm-gateway/internal/auth" - "llm-gateway/internal/cache" - "llm-gateway/internal/provider" - "llm-gateway/internal/storage" -) - -//go:embed templates/*.html templates/partials/*.html -var templateFiles embed.FS - -var templateFuncs = template.FuncMap{ - "formatTime": func(ts int64) string { - if ts == 0 { - return "never" - } - return time.Unix(ts, 0).Format("2006-01-02") - }, - "formatTimeDetail": func(ts int64) string { - if ts == 0 { - return "never" - } - return time.Unix(ts, 0).Format("2006-01-02 15:04:05") - }, - "addInt": func(a, b int) int { - return a + b - }, - "subInt": func(a, b int) int { - return a - b - }, - "formatCost": func(v float64) string { - if v == 0 { - return "$0.00" - } - if v < 0.01 { - return fmt.Sprintf("$%.6f", v) - } - return fmt.Sprintf("$%.4f", v) - }, - "formatPrice": func(v float64) string { - if v == 0 { - return "-" - } - return fmt.Sprintf("$%.2f", v) - }, - "formatPct": func(v float64) string { - return fmt.Sprintf("%.1f%%", v*100) - }, - "budgetPct": func(spend, budget float64) float64 { - if budget <= 0 { - return 0 - } - return spend / budget * 100 - }, - "budgetColor": func(pct float64) string { - if pct >= 80 { - return "#f87171" - } - if pct >= 50 { - return "#fbbf24" - } - return "#4ade80" - }, - "seq": func(start, end int) []int { - var s []int - for i := start; i <= end; i++ { - s = append(s, i) - } - return s - }, - "paginationStart": func(page, totalPages int) int { - start := page - 2 - if start < 1 { - start = 1 - } - if totalPages-start < 4 && totalPages > 4 { - start = totalPages - 4 - } - return start - }, - "paginationEnd": func(page, totalPages int) int { - start := page - 2 - if start < 1 { - start = 1 - } - end := start + 4 - if end > totalPages { - end = totalPages - } - return end - }, -} - -// PageData is the common data passed to all templates. -type PageData struct { - ActivePage string - User *auth.User - // Dashboard data - Summary *SummaryResult - Models []ModelStats - Providers []ProviderStats - TokenStats []TokenUsageStats - ProviderHealth []provider.ProviderHealth - Latency *LatencyResult - CacheEnabled bool - CacheInfo *cache.CacheStats - // Tokens page data - Tokens []auth.APIToken - TokenSpend map[string]float64 - // Users page data - Users []auth.User - // Logs page data - LogsResult *LogsResult - LogModels []string - LogTokens []string - FilterModel string - FilterToken string - FilterStatus string - // Models routing page data - ModelRoutes []provider.ModelRouteInfo - // Audit page data - AuditResult *storage.AuditQueryResult - AuditFilterActions []string - FilterAction string - // Debug page data - DebugResult *storage.DebugLogQueryResult - DebugEnabled bool -} - -// Dashboard serves the HTMX-based dashboard pages. -type Dashboard struct { - templates *template.Template - authStore *auth.Store - statsAPI *StatsAPI - registry *provider.Registry - cache *cache.Cache - auditLogger *storage.AuditLogger - debugLogger *storage.DebugLogger -} - -// NewDashboard creates a new Dashboard handler. -func NewDashboard(authStore *auth.Store, statsAPI *StatsAPI) *Dashboard { - tmpl := template.Must( - template.New("").Funcs(templateFuncs).ParseFS(templateFiles, - "templates/*.html", - "templates/partials/*.html", - ), - ) - - return &Dashboard{ - templates: tmpl, - authStore: authStore, - statsAPI: statsAPI, - } -} - -// SetRegistry sets the provider registry for model routing display. -func (d *Dashboard) SetRegistry(r *provider.Registry) { - d.registry = r -} - -// SetCache sets the cache reference for cache stats display. -func (d *Dashboard) SetCache(c *cache.Cache) { - d.cache = c -} - -// SetAuditLogger sets the audit logger for the audit page. -func (d *Dashboard) SetAuditLogger(al *storage.AuditLogger) { - d.auditLogger = al -} - -// SetDebugLogger sets the debug logger for the debug page. -func (d *Dashboard) SetDebugLogger(dl *storage.DebugLogger) { - d.debugLogger = dl -} - -// LoginPage serves the login page. -func (d *Dashboard) LoginPage(w http.ResponseWriter, r *http.Request) { - if !d.authStore.HasAnyUser() { - http.Redirect(w, r, "/setup", http.StatusFound) - return - } - if user := d.getSessionUser(r); user != nil { - http.Redirect(w, r, "/dashboard", http.StatusFound) - return - } - w.Header().Set("Content-Type", "text/html; charset=utf-8") - d.templates.ExecuteTemplate(w, "login", nil) -} - -// SetupPage serves the initial setup page. -func (d *Dashboard) SetupPage(w http.ResponseWriter, r *http.Request) { - if d.authStore.HasAnyUser() { - http.Redirect(w, r, "/login", http.StatusFound) - return - } - w.Header().Set("Content-Type", "text/html; charset=utf-8") - d.templates.ExecuteTemplate(w, "setup", nil) -} - -// DashboardPage serves the main dashboard view. -func (d *Dashboard) DashboardPage(w http.ResponseWriter, r *http.Request) { - user := auth.UserFromContext(r.Context()) - tokenNames := d.statsAPI.TokenNamesForUser(user) - - data := PageData{ - ActivePage: "dashboard", - User: user, - Summary: d.statsAPI.GetSummary(tokenNames), - Models: d.statsAPI.GetModels(tokenNames), - Providers: d.statsAPI.GetProviders(tokenNames), - TokenStats: d.statsAPI.GetTokenUsage(tokenNames), - Latency: d.statsAPI.GetLatency(tokenNames, "24h", "", ""), - } - - // Provider health - if d.statsAPI.healthTracker != nil { - data.ProviderHealth = d.statsAPI.healthTracker.Status() - } - - // Cache stats - if d.cache != nil { - data.CacheEnabled = true - data.CacheInfo = d.cache.Stats(r.Context()) - } - - d.renderDashboardPage(w, r, "partials/dashboard.html", data) -} - -// LogsPage serves the request logs view. -func (d *Dashboard) LogsPage(w http.ResponseWriter, r *http.Request) { - user := auth.UserFromContext(r.Context()) - tokenNames := d.statsAPI.TokenNamesForUser(user) - - page, _ := strconv.Atoi(r.URL.Query().Get("page")) - if page < 1 { - page = 1 - } - model := r.URL.Query().Get("model") - token := r.URL.Query().Get("token") - status := r.URL.Query().Get("status") - - data := PageData{ - ActivePage: "logs", - User: user, - LogsResult: d.statsAPI.GetLogs(tokenNames, page, model, token, status), - LogModels: d.statsAPI.GetDistinctModels(), - LogTokens: d.statsAPI.GetDistinctTokens(), - FilterModel: model, - FilterToken: token, - FilterStatus: status, - } - - d.renderDashboardPage(w, r, "partials/logs.html", data) -} - -// ModelsPage serves the model routing table view. -func (d *Dashboard) ModelsPage(w http.ResponseWriter, r *http.Request) { - user := auth.UserFromContext(r.Context()) - - data := PageData{ - ActivePage: "models", - User: user, - } - - if d.registry != nil { - data.ModelRoutes = d.registry.AllRoutes() - } - - if d.statsAPI.healthTracker != nil { - data.ProviderHealth = d.statsAPI.healthTracker.Status() - } - - d.renderDashboardPage(w, r, "partials/models-page.html", data) -} - -// TokensPage serves the tokens management view. -func (d *Dashboard) TokensPage(w http.ResponseWriter, r *http.Request) { - user := auth.UserFromContext(r.Context()) - - var userID int64 - if !user.IsAdmin { - userID = user.ID - } - - tokens, _ := d.authStore.ListAPITokens(userID) - if tokens == nil { - tokens = []auth.APIToken{} - } - - // Get today's spend for budget display - spend, _ := d.statsAPI.db.TodaySpendAll() - if spend == nil { - spend = make(map[string]float64) - } - - d.renderDashboardPage(w, r, "partials/tokens.html", PageData{ - ActivePage: "tokens", - User: user, - Tokens: tokens, - TokenSpend: spend, - }) -} - -// UsersPage serves the user management view (admin only). -func (d *Dashboard) UsersPage(w http.ResponseWriter, r *http.Request) { - user := auth.UserFromContext(r.Context()) - users, _ := d.authStore.ListUsers() - - d.renderDashboardPage(w, r, "partials/users.html", PageData{ - ActivePage: "users", - User: user, - Users: users, - }) -} - -// AuditPage serves the audit log view (admin only). -func (d *Dashboard) AuditPage(w http.ResponseWriter, r *http.Request) { - user := auth.UserFromContext(r.Context()) - - page, _ := strconv.Atoi(r.URL.Query().Get("page")) - if page < 1 { - page = 1 - } - action := r.URL.Query().Get("action") - since := time.Now().AddDate(0, 0, -30).Unix() - - var auditResult *storage.AuditQueryResult - if d.auditLogger != nil { - auditResult = d.auditLogger.Query(since, action, page, 50) - } else { - auditResult = &storage.AuditQueryResult{Entries: []storage.AuditEntry{}, Page: 1, TotalPages: 1} - } - - // Common audit action types for the filter dropdown - actions := []string{"login", "logout", "create_user", "delete_user", "create_token", "delete_token", "change_password", "setup_totp", "disable_totp"} - - d.renderDashboardPage(w, r, "partials/audit.html", PageData{ - ActivePage: "audit", - User: user, - AuditResult: auditResult, - AuditFilterActions: actions, - FilterAction: action, - }) -} - -// DebugPage serves the debug logging view (admin only). -func (d *Dashboard) DebugPage(w http.ResponseWriter, r *http.Request) { - user := auth.UserFromContext(r.Context()) - - page, _ := strconv.Atoi(r.URL.Query().Get("page")) - if page < 1 { - page = 1 - } - - var debugResult *storage.DebugLogQueryResult - debugEnabled := false - if d.debugLogger != nil { - debugResult = d.debugLogger.QueryFull(page, 50) - debugEnabled = d.debugLogger.IsEnabled() - } else { - debugResult = &storage.DebugLogQueryResult{Entries: []storage.DebugLogEntry{}, Page: 1, TotalPages: 1} - } - - d.renderDashboardPage(w, r, "partials/debug.html", PageData{ - ActivePage: "debug", - User: user, - DebugResult: debugResult, - DebugEnabled: debugEnabled, - }) -} - -// SettingsPage serves the settings view. -func (d *Dashboard) SettingsPage(w http.ResponseWriter, r *http.Request) { - user := auth.UserFromContext(r.Context()) - user, _ = d.authStore.GetUserByID(user.ID) - - d.renderDashboardPage(w, r, "partials/settings.html", PageData{ - ActivePage: "settings", - User: user, - }) -} - -// renderDashboardPage renders either the full layout or just the content partial. -func (d *Dashboard) renderDashboardPage(w http.ResponseWriter, r *http.Request, partialFile string, data PageData) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - - if r.Header.Get("HX-Request") == "true" { - tmpl := template.Must( - template.New("").Funcs(templateFuncs).ParseFS(templateFiles, "templates/"+partialFile), - ) - tmpl.ExecuteTemplate(w, "content", data) - } else { - tmpl := template.Must( - template.New("").Funcs(templateFuncs).ParseFS(templateFiles, - "templates/layout.html", - "templates/"+partialFile, - ), - ) - tmpl.ExecuteTemplate(w, "layout", data) - } -} - -func (d *Dashboard) getSessionUser(r *http.Request) *auth.User { - cookie, err := r.Cookie("llmgw_session") - if err != nil || cookie.Value == "" { - return nil - } - sess, err := d.authStore.GetSession(cookie.Value) - if err != nil { - return nil - } - user, err := d.authStore.GetUserByID(sess.UserID) - if err != nil { - return nil - } - return user -} diff --git a/llm-gateway/internal/dashboard/sse.go b/llm-gateway/internal/dashboard/sse.go deleted file mode 100644 index aafd88d..0000000 --- a/llm-gateway/internal/dashboard/sse.go +++ /dev/null @@ -1,73 +0,0 @@ -package dashboard - -import ( - "fmt" - "net/http" - "sync" -) - -// SSEBroker manages Server-Sent Events connections. -type SSEBroker struct { - mu sync.RWMutex - clients map[chan struct{}]struct{} -} - -// NewSSEBroker creates a new SSE broker. -func NewSSEBroker() *SSEBroker { - return &SSEBroker{ - clients: make(map[chan struct{}]struct{}), - } -} - -// Notify sends a refresh signal to all connected SSE clients. -func (b *SSEBroker) Notify() { - b.mu.RLock() - defer b.mu.RUnlock() - for ch := range b.clients { - select { - case ch <- struct{}{}: - default: - // Client not ready, skip - } - } -} - -// ServeHTTP handles SSE connections. -func (b *SSEBroker) ServeHTTP(w http.ResponseWriter, r *http.Request) { - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "streaming not supported", http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - - ch := make(chan struct{}, 1) - b.mu.Lock() - b.clients[ch] = struct{}{} - b.mu.Unlock() - - defer func() { - b.mu.Lock() - delete(b.clients, ch) - b.mu.Unlock() - }() - - // Send initial connection event - fmt.Fprintf(w, "event: connected\ndata: ok\n\n") - flusher.Flush() - - ctx := r.Context() - for { - select { - case <-ctx.Done(): - return - case <-ch: - fmt.Fprintf(w, "event: refresh\ndata: updated\n\n") - flusher.Flush() - } - } -} diff --git a/llm-gateway/internal/dashboard/templates/layout.html b/llm-gateway/internal/dashboard/templates/layout.html deleted file mode 100644 index 1b35132..0000000 --- a/llm-gateway/internal/dashboard/templates/layout.html +++ /dev/null @@ -1,366 +0,0 @@ -{{define "layout"}} - - - - - -LLM Gateway - - - - - - - - - - - -
- -
- {{template "content" .}} -
-
- - - -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/login.html b/llm-gateway/internal/dashboard/templates/login.html deleted file mode 100644 index 500cc79..0000000 --- a/llm-gateway/internal/dashboard/templates/login.html +++ /dev/null @@ -1,44 +0,0 @@ -{{define "login"}} - - - - - -Login - LLM Gateway - - - - - -
-

LLM Gateway

-
-
-
-
- - -
-
- - -
- -
-
-
- - -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/audit.html b/llm-gateway/internal/dashboard/templates/partials/audit.html deleted file mode 100644 index 7697964..0000000 --- a/llm-gateway/internal/dashboard/templates/partials/audit.html +++ /dev/null @@ -1,83 +0,0 @@ -{{define "content"}} - - -
- - -
- -
- - - - - - - - - - - - - {{range .AuditResult.Entries}} - - - - - - - - - {{end}} - {{if not .AuditResult.Entries}} - - {{end}} - -
TimeUserActionTargetDetailsIP
{{formatTimeDetail .Timestamp}}{{.Username}}{{.Action}}{{if .TargetType}}{{.TargetType}}{{if .TargetID}}/{{.TargetID}}{{end}}{{else}}-{{end}}{{if .Details}}{{.Details}}{{else}}-{{end}}{{if .IPAddress}}{{.IPAddress}}{{else}}-{{end}}
No audit log entries
- - {{if gt .AuditResult.TotalPages 1}} - - {{end}} -
- - -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/dashboard.html b/llm-gateway/internal/dashboard/templates/partials/dashboard.html deleted file mode 100644 index 715313b..0000000 --- a/llm-gateway/internal/dashboard/templates/partials/dashboard.html +++ /dev/null @@ -1,230 +0,0 @@ -{{define "content"}} - - -
- {{with .Summary.Today}} -
Requests Today
{{.Requests}}
-
Cost Today
{{formatCost .CostUSD}}
-
Tokens Today
{{addInt .InputTokens .OutputTokens}}
{{.InputTokens}} in / {{.OutputTokens}} out
-
Errors Today
{{.Errors}}
-
Cache Hits
{{.CachedHits}}
- {{end}} - {{with .Summary.Week}} -
Cost (7d)
{{formatCost .CostUSD}}
- {{end}} -
- -{{if .ProviderHealth}} -
-

Provider Health

-
- {{range .ProviderHealth}} -
- {{.Provider}} - {{.Status}} - {{if eq .CircuitState "open"}}circuit open{{end}} - {{if eq .CircuitState "half-open"}}half-open{{end}} - {{printf "%.0f" .AvgLatency}}ms avg | {{formatPct .ErrorRate}} errors -
- {{end}} -
-
-{{end}} - -{{if .Latency}}{{if gt .Latency.Max 0.0}} -
-
P50 Latency
{{printf "%.0f" .Latency.P50}}ms
-
P95 Latency
{{printf "%.0f" .Latency.P95}}ms
-
P99 Latency
{{printf "%.0f" .Latency.P99}}ms
-
Avg Latency
{{printf "%.0f" .Latency.Avg}}ms
-
-{{end}}{{end}} - -{{if .CacheEnabled}}{{if .CacheInfo}}{{if .CacheInfo.Connected}} -
-
Cache Hit Rate
{{formatPct .CacheInfo.HitRate}}
{{.CacheInfo.Hits}} hits / {{.CacheInfo.Misses}} misses
-
Cache Memory
{{.CacheInfo.MemoryUsed}}
-
Cached Keys
{{.CacheInfo.Keys}}
-
-{{end}}{{end}}{{end}} - -
- - - -
-
-

Requests & Cost

- -
- -
-

Cost Breakdown

-
- - - -
- -
- -{{if .Models}} -
-

ModelsCSVJSON

- - - - {{range .Models}} - - - - - - - - {{end}} - -
ModelRequestsTokens (in/out)CostAvg Latency
{{.Model}}{{.Requests}}{{.InputTokens}} / {{.OutputTokens}}{{formatCost .CostUSD}}{{printf "%.0f" .AvgLatencyMS}}ms
-
-{{end}} - -{{if .Providers}} -
-

ProvidersCSVJSON

- - - - {{range .Providers}} - - - - - - - - - {{end}} - -
ProviderRequestsSuccessErrorsAvg LatencyCost
{{.Provider}}{{.Requests}}{{.Successes}}{{.Errors}}{{printf "%.0f" .AvgLatencyMS}}ms{{formatCost .CostUSD}}
-
-{{end}} - -{{if .TokenStats}} -
-

API Token UsageCSVJSON

- - - - {{range .TokenStats}} - - - - - - - {{end}} - -
TokenRequestsTokens (in/out)Cost
{{.TokenName}}{{.Requests}}{{.InputTokens}} / {{.OutputTokens}}{{formatCost .CostUSD}}
-
-{{end}} - - -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/debug.html b/llm-gateway/internal/dashboard/templates/partials/debug.html deleted file mode 100644 index 4b55fd5..0000000 --- a/llm-gateway/internal/dashboard/templates/partials/debug.html +++ /dev/null @@ -1,94 +0,0 @@ -{{define "content"}} - - -
- Debug Mode - - {{if .DebugEnabled}}Enabled — requests are being logged{{else}}Disabled{{end}} -
- -
- - - - - - - - - - - - - - {{range $i, $entry := .DebugResult.Entries}} - - - - - - - - - - - - - {{end}} - {{if not .DebugResult.Entries}} - - {{end}} - -
TimeRequest IDTokenModelProviderStatus
-
-
Request Headers:
-
{{if $entry.RequestHeaders}}{{$entry.RequestHeaders}}{{else}}(none){{end}}
-
Request Body:
-
{{if $entry.RequestBody}}{{$entry.RequestBody}}{{else}}(none){{end}}
-
Response Body:
-
{{if $entry.ResponseBody}}{{$entry.ResponseBody}}{{else}}(none){{end}}
-
-
No debug log entries
- - {{if gt .DebugResult.TotalPages 1}} - - {{end}} -
- - -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/logs.html b/llm-gateway/internal/dashboard/templates/partials/logs.html deleted file mode 100644 index bf8c640..0000000 --- a/llm-gateway/internal/dashboard/templates/partials/logs.html +++ /dev/null @@ -1,133 +0,0 @@ -{{define "content"}} - - -
- - - - - - - -
- -
- - - - - - - - - - - - - - - {{range $i, $log := .LogsResult.Logs}} - - - - - - - - - - - {{if $log.ErrorMessage}} - - - - {{end}} - {{end}} - {{if not .LogsResult.Logs}} - - {{end}} - -
TimeTokenModelProviderStatusLatencyTokensCost
{{formatTimeDetail $log.Timestamp}}{{$log.TokenName}}{{$log.Model}}{{$log.Provider}} - {{if eq $log.Status "success"}}success - {{else if eq $log.Status "error"}}error - {{else if eq $log.Status "cached"}}cached - {{else}}{{$log.Status}}{{end}} - {{if $log.Streaming}} stream{{end}} - {{$log.LatencyMS}}ms{{$log.InputTokens}} / {{$log.OutputTokens}}{{formatCost $log.CostUSD}}
-
{{$log.ErrorMessage}}
-
No logs found
- - {{if gt .LogsResult.TotalPages 1}} - - {{end}} -
- - -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/models-page.html b/llm-gateway/internal/dashboard/templates/partials/models-page.html deleted file mode 100644 index 6d4a60b..0000000 --- a/llm-gateway/internal/dashboard/templates/partials/models-page.html +++ /dev/null @@ -1,53 +0,0 @@ -{{define "content"}} - - -{{if .ModelRoutes}} -{{range .ModelRoutes}} -
-

{{.Name}}{{if .Aliases}} aliases: {{range $i, $a := .Aliases}}{{if $i}}, {{end}}{{$a}}{{end}}{{end}}

- - - - - - - - - - - - - {{$health := $.ProviderHealth}} - {{range .Routes}} - - - - - - - - - {{end}} - -
ProviderProvider ModelPriorityInput Price (per 1M)Output Price (per 1M)Health
{{.ProviderName}}{{.ProviderModel}}{{.Priority}}{{formatPrice .InputPrice}}{{formatPrice .OutputPrice}} - {{$pname := .ProviderName}} - {{range $health}} - {{if eq .Provider $pname}} - {{if eq .Status "healthy"}}healthy - {{else if eq .Status "degraded"}}degraded - {{else}}down - {{end}} - {{end}} - {{end}} - {{if not $health}}-{{end}} -
-
-{{end}} -{{else}} -
- No models configured -
-{{end}} -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/settings.html b/llm-gateway/internal/dashboard/templates/partials/settings.html deleted file mode 100644 index f63c0a5..0000000 --- a/llm-gateway/internal/dashboard/templates/partials/settings.html +++ /dev/null @@ -1,141 +0,0 @@ -{{define "content"}} - - -
-

Profile

-
-
-
- -
- - -
-
-
-
-
- -
- - -
-
-
-
- -
-

Change Password

-
-
-
- - -
-
- - -
-
- - -
- -
-
- -
-

Two-Factor Authentication

-
- {{if .User.TOTPEnabled}} -

Two-factor authentication is enabled.

- - {{else}} -

Two-factor authentication is not enabled.

- - {{end}} -
- -
- -{{if .User.IsAdmin}} -
-

Config Validation

-

Validate the current gateway configuration file for errors.

- -
-
-{{end}} - - - -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/tokens.html b/llm-gateway/internal/dashboard/templates/partials/tokens.html deleted file mode 100644 index 974f5e4..0000000 --- a/llm-gateway/internal/dashboard/templates/partials/tokens.html +++ /dev/null @@ -1,142 +0,0 @@ -{{define "content"}} - - - - -
-

Static Tokens (from config, managed via environment variables)

- - - - {{range .Tokens}}{{if lt .ID 0}} - - - - - - - - - {{end}}{{end}} - -
NamePrefixRate LimitBudgetToday's Spend
{{.Name}}{{.KeyPrefix}}...{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}} - {{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}} - {{if gt .MonthlyBudgetUSD 0.0}}
${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}} -
- {{$spend := index $.TokenSpend .Name}} - {{if gt .DailyBudgetUSD 0.0}} - {{$pct := budgetPct $spend .DailyBudgetUSD}} -
-
-
${{printf "%.4f" $spend}} / ${{printf "%.2f" .DailyBudgetUSD}} ({{printf "%.1f" $pct}}%)
-
- {{else}} - {{if gt $spend 0.0}}{{formatCost $spend}}{{else}}-{{end}} - {{end}} -
config
-
- -
-

Dynamic Tokens (created via dashboard)

- - - - {{range .Tokens}}{{if gt .ID 0}} - - - - - - - - - - - {{end}}{{end}} - -
NamePrefixRate LimitBudgetToday's SpendCreatedLast Used
{{.Name}}{{.KeyPrefix}}...{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}} - {{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}} - {{if gt .MonthlyBudgetUSD 0.0}}
${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}} -
- {{$spend := index $.TokenSpend .Name}} - {{if gt .DailyBudgetUSD 0.0}} - {{$pct := budgetPct $spend .DailyBudgetUSD}} -
-
-
${{printf "%.4f" $spend}} / ${{printf "%.2f" .DailyBudgetUSD}} ({{printf "%.1f" $pct}}%)
-
- {{else}} - {{if gt $spend 0.0}}{{formatCost $spend}}{{else}}-{{end}} - {{end}} -
{{formatTime .CreatedAt}}{{if gt .LastUsedAt 0}}{{formatTime .LastUsedAt}}{{else}}never{{end}}
-
- - - - - -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/users.html b/llm-gateway/internal/dashboard/templates/partials/users.html deleted file mode 100644 index 2402b49..0000000 --- a/llm-gateway/internal/dashboard/templates/partials/users.html +++ /dev/null @@ -1,70 +0,0 @@ -{{define "content"}} - - -
- - - - {{range .Users}} - - - - - - - - - {{end}} - -
IDUsernameRole2FACreated
{{.ID}}{{.Username}}{{if .IsAdmin}}Admin{{else}}User{{end}}{{if .TOTPEnabled}}Enabled{{else}}Off{{end}}{{formatTime .CreatedAt}}{{if ne .ID $.User.ID}}{{end}}
-
- - - - - -{{end}} diff --git a/llm-gateway/internal/dashboard/templates/setup.html b/llm-gateway/internal/dashboard/templates/setup.html deleted file mode 100644 index 11c359f..0000000 --- a/llm-gateway/internal/dashboard/templates/setup.html +++ /dev/null @@ -1,59 +0,0 @@ -{{define "setup"}} - - - - - -Setup - LLM Gateway - - - - - -
-

LLM Gateway Setup

-

Create the first admin account

-
-
-
- - -
-
- - -
-
- - -
- -
-
- - - -{{end}} diff --git a/llm-gateway/internal/metrics/prometheus.go b/llm-gateway/internal/metrics/prometheus.go deleted file mode 100644 index 620f51e..0000000 --- a/llm-gateway/internal/metrics/prometheus.go +++ /dev/null @@ -1,73 +0,0 @@ -package metrics - -import ( - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promauto" -) - -type Metrics struct { - requestsTotal *prometheus.CounterVec - requestDuration *prometheus.HistogramVec - tokensTotal *prometheus.CounterVec - costTotal *prometheus.CounterVec - cacheHits prometheus.Counter - cacheMisses prometheus.Counter -} - -func New() *Metrics { - return &Metrics{ - requestsTotal: promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "llm_gateway_requests_total", - Help: "Total number of LLM requests", - }, []string{"model", "provider", "token_name", "status"}), - - requestDuration: promauto.NewHistogramVec(prometheus.HistogramOpts{ - Name: "llm_gateway_request_duration_ms", - Help: "Request duration in milliseconds", - Buckets: []float64{100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000}, - }, []string{"model", "provider"}), - - tokensTotal: promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "llm_gateway_tokens_total", - Help: "Total tokens processed", - }, []string{"model", "provider", "type"}), - - costTotal: promauto.NewCounterVec(prometheus.CounterOpts{ - Name: "llm_gateway_cost_usd_total", - Help: "Total cost in USD", - }, []string{"model", "provider", "token_name"}), - - cacheHits: promauto.NewCounter(prometheus.CounterOpts{ - Name: "llm_gateway_cache_hits_total", - Help: "Total number of cache hits", - }), - - cacheMisses: promauto.NewCounter(prometheus.CounterOpts{ - Name: "llm_gateway_cache_misses_total", - Help: "Total number of cache misses", - }), - } -} - -func (m *Metrics) RecordRequest(model, providerName, tokenName, status string, latencyMS int64, inputTokens, outputTokens int, cost float64) { - m.requestsTotal.WithLabelValues(model, providerName, tokenName, status).Inc() - m.requestDuration.WithLabelValues(model, providerName).Observe(float64(latencyMS)) - - if inputTokens > 0 { - m.tokensTotal.WithLabelValues(model, providerName, "input").Add(float64(inputTokens)) - } - if outputTokens > 0 { - m.tokensTotal.WithLabelValues(model, providerName, "output").Add(float64(outputTokens)) - } - if cost > 0 { - m.costTotal.WithLabelValues(model, providerName, tokenName).Add(cost) - } -} - -func (m *Metrics) RecordCacheHit() { - m.cacheHits.Inc() -} - -func (m *Metrics) RecordCacheMiss() { - m.cacheMisses.Inc() -} diff --git a/llm-gateway/internal/pricing/pricing.go b/llm-gateway/internal/pricing/pricing.go deleted file mode 100644 index 69b01d8..0000000 --- a/llm-gateway/internal/pricing/pricing.go +++ /dev/null @@ -1,191 +0,0 @@ -package pricing - -import ( - "encoding/json" - "fmt" - "io" - "log" - "net/http" - "sync" - "time" -) - -const defaultPricesURL = "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json" - -// Provider represents a provider entry in genai_prices.json. -type Provider struct { - ID string `json:"id"` - Models []Model `json:"models"` -} - -// Model represents a model entry with pricing. -type Model struct { - ID string `json:"id"` - Prices json.RawMessage `json:"prices"` -} - -// Lookup provides pricing data fetched from genai-prices. -type Lookup struct { - mu sync.RWMutex - prices map[string][2]float64 - url string - stopCh chan struct{} -} - -// NewLookup creates a Lookup that fetches pricing data immediately and refreshes every interval. -// If url is empty, uses the default genai-prices URL. -// Returns a usable Lookup even if the initial fetch fails (prices will be empty until next refresh). -func NewLookup(url string, interval time.Duration) *Lookup { - if url == "" { - url = defaultPricesURL - } - l := &Lookup{ - prices: make(map[string][2]float64), - url: url, - stopCh: make(chan struct{}), - } - - // Initial fetch - l.refresh() - - // Background refresh - go func() { - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - l.refresh() - case <-l.stopCh: - return - } - } - }() - - return l -} - -// Close stops the background refresh goroutine. -func (l *Lookup) Close() { - close(l.stopCh) -} - -// Get returns (inputPer1M, outputPer1M) for a provider:model pair. -// Returns (0, 0) if not found. -func (l *Lookup) Get(provider, model string) (float64, float64) { - if l == nil { - return 0, 0 - } - l.mu.RLock() - defer l.mu.RUnlock() - key := fmt.Sprintf("%s:%s", provider, model) - if p, ok := l.prices[key]; ok { - return p[0], p[1] - } - return 0, 0 -} - -// FillMissing fills in zero-value pricing from the lookup data. -// Returns the number of prices filled. -func (l *Lookup) FillMissing(provider, model string, input, output *float64) bool { - if l == nil || (*input > 0 && *output > 0) { - return false - } - i, o := l.Get(provider, model) - if i == 0 && o == 0 { - return false - } - if *input == 0 { - *input = i - } - if *output == 0 { - *output = o - } - return true -} - -func (l *Lookup) refresh() { - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Get(l.url) - if err != nil { - log.Printf("WARNING: failed to fetch pricing data: %v", err) - return - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - log.Printf("WARNING: pricing data fetch returned %d", resp.StatusCode) - return - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Printf("WARNING: failed to read pricing data: %v", err) - return - } - - var providers []Provider - if err := json.Unmarshal(body, &providers); err != nil { - log.Printf("WARNING: failed to parse pricing data: %v", err) - return - } - - prices := make(map[string][2]float64) - for _, p := range providers { - for _, m := range p.Models { - input, output := parsePrices(m.Prices) - if input > 0 || output > 0 { - key := fmt.Sprintf("%s:%s", p.ID, m.ID) - prices[key] = [2]float64{input, output} - } - } - } - - l.mu.Lock() - l.prices = prices - l.mu.Unlock() - - log.Printf("Loaded pricing data: %d model prices from genai-prices", len(prices)) -} - -// parsePrices handles the different shapes of the "prices" field: -// - object: {"input_mtok": 0.5, "output_mtok": 1.0} -// - array: [{"prices": {"input_mtok": 0.5, ...}}, ...] (time-of-day; use first entry) -func parsePrices(raw json.RawMessage) (input, output float64) { - if len(raw) == 0 { - return 0, 0 - } - - // Try as object first (most common) - var obj map[string]any - if json.Unmarshal(raw, &obj) == nil { - return extractPrice(obj, "input_mtok"), extractPrice(obj, "output_mtok") - } - - // Try as array (time-of-day pricing) — use first entry - var arr []struct { - Prices map[string]any `json:"prices"` - } - if json.Unmarshal(raw, &arr) == nil && len(arr) > 0 { - return extractPrice(arr[0].Prices, "input_mtok"), extractPrice(arr[0].Prices, "output_mtok") - } - - return 0, 0 -} - -// extractPrice handles both simple float and tiered pricing (uses base price). -func extractPrice(prices map[string]any, key string) float64 { - v, ok := prices[key] - if !ok { - return 0 - } - switch val := v.(type) { - case float64: - return val - case map[string]any: - if base, ok := val["base"].(float64); ok { - return base - } - } - return 0 -} diff --git a/llm-gateway/internal/provider/balancer.go b/llm-gateway/internal/provider/balancer.go deleted file mode 100644 index 602d885..0000000 --- a/llm-gateway/internal/provider/balancer.go +++ /dev/null @@ -1,144 +0,0 @@ -package provider - -import ( - "math/rand" - "sort" - "sync/atomic" -) - -// LoadBalancer reorders routes for load distribution. -type LoadBalancer interface { - Reorder(routes []Route) []Route -} - -// NewLoadBalancer creates a load balancer by strategy name. -func NewLoadBalancer(strategy string) LoadBalancer { - switch strategy { - case "round-robin": - return &RoundRobinBalancer{} - case "random": - return &RandomBalancer{} - case "least-cost": - return &LeastCostBalancer{} - default: - return &FirstBalancer{} - } -} - -// FirstBalancer is a no-op that preserves original order. -type FirstBalancer struct{} - -func (b *FirstBalancer) Reorder(routes []Route) []Route { - return routes -} - -// RoundRobinBalancer rotates routes within same-priority groups. -type RoundRobinBalancer struct { - counter atomic.Uint64 -} - -func (b *RoundRobinBalancer) Reorder(routes []Route) []Route { - if len(routes) <= 1 { - return routes - } - - result := make([]Route, len(routes)) - copy(result, routes) - - // Group by priority and rotate within each group - groups := groupByPriority(result) - idx := 0 - count := b.counter.Add(1) - for _, group := range groups { - if len(group) > 1 { - offset := int(count) % len(group) - for j := 0; j < len(group); j++ { - result[idx] = group[(j+offset)%len(group)] - idx++ - } - } else { - result[idx] = group[0] - idx++ - } - } - - return result -} - -// RandomBalancer shuffles routes within same-priority groups. -type RandomBalancer struct{} - -func (b *RandomBalancer) Reorder(routes []Route) []Route { - if len(routes) <= 1 { - return routes - } - - result := make([]Route, len(routes)) - copy(result, routes) - - groups := groupByPriority(result) - idx := 0 - for _, group := range groups { - rand.Shuffle(len(group), func(i, j int) { - group[i], group[j] = group[j], group[i] - }) - for _, r := range group { - result[idx] = r - idx++ - } - } - - return result -} - -// LeastCostBalancer sorts by price within same-priority groups. -type LeastCostBalancer struct{} - -func (b *LeastCostBalancer) Reorder(routes []Route) []Route { - if len(routes) <= 1 { - return routes - } - - result := make([]Route, len(routes)) - copy(result, routes) - - groups := groupByPriority(result) - idx := 0 - for _, group := range groups { - sort.Slice(group, func(i, j int) bool { - costI := group[i].InputPrice + group[i].OutputPrice - costJ := group[j].InputPrice + group[j].OutputPrice - return costI < costJ - }) - for _, r := range group { - result[idx] = r - idx++ - } - } - - return result -} - -// groupByPriority splits routes into groups of same priority, preserving order. -func groupByPriority(routes []Route) [][]Route { - if len(routes) == 0 { - return nil - } - - var groups [][]Route - currentPriority := routes[0].Priority - currentGroup := []Route{routes[0]} - - for i := 1; i < len(routes); i++ { - if routes[i].Priority == currentPriority { - currentGroup = append(currentGroup, routes[i]) - } else { - groups = append(groups, currentGroup) - currentPriority = routes[i].Priority - currentGroup = []Route{routes[i]} - } - } - groups = append(groups, currentGroup) - - return groups -} diff --git a/llm-gateway/internal/provider/balancer_test.go b/llm-gateway/internal/provider/balancer_test.go deleted file mode 100644 index cc5378e..0000000 --- a/llm-gateway/internal/provider/balancer_test.go +++ /dev/null @@ -1,294 +0,0 @@ -package provider - -import ( - "fmt" - "testing" -) - -type routeSpec struct { - name string - priority int - input float64 - output float64 -} - -func makeRoutes(specs ...routeSpec) []Route { - routes := make([]Route, len(specs)) - for i, s := range specs { - routes[i] = Route{ - Provider: &mockProvider{name: s.name}, - ProviderModel: s.name + "-model", - Priority: s.priority, - InputPrice: s.input, - OutputPrice: s.output, - } - } - return routes -} - -func routeNames(routes []Route) []string { - names := make([]string, len(routes)) - for i, r := range routes { - names[i] = r.Provider.Name() - } - return names -} - -func TestFirstBalancer_PreservesOrder(t *testing.T) { - routes := makeRoutes( - routeSpec{"a", 1, 1.0, 1.0}, - routeSpec{"b", 1, 2.0, 2.0}, - routeSpec{"c", 1, 3.0, 3.0}, - ) - - b := &FirstBalancer{} - result := b.Reorder(routes) - - names := routeNames(result) - if names[0] != "a" || names[1] != "b" || names[2] != "c" { - t.Fatalf("expected [a b c], got %v", names) - } -} - -func TestRoundRobinBalancer_RotatesWithinPriorityGroup(t *testing.T) { - routes := makeRoutes( - routeSpec{"a", 1, 1.0, 1.0}, - routeSpec{"b", 1, 1.0, 1.0}, - routeSpec{"c", 1, 1.0, 1.0}, - ) - - b := &RoundRobinBalancer{} - - // Collect the first element from multiple calls - seen := make(map[string]bool) - for i := 0; i < 6; i++ { - result := b.Reorder(routes) - seen[result[0].Provider.Name()] = true - } - - // All routes should have appeared as first at some point - for _, name := range []string{"a", "b", "c"} { - if !seen[name] { - t.Errorf("expected %q to appear as first element in rotation", name) - } - } -} - -func TestRoundRobinBalancer_PreservesPriorityOrder(t *testing.T) { - routes := makeRoutes( - routeSpec{"a", 1, 1.0, 1.0}, - routeSpec{"b", 1, 1.0, 1.0}, - routeSpec{"c", 2, 1.0, 1.0}, - ) - - b := &RoundRobinBalancer{} - - // Priority 2 route should always be last - for i := 0; i < 5; i++ { - result := b.Reorder(routes) - if result[2].Provider.Name() != "c" { - t.Fatalf("expected priority-2 route 'c' at the end, got %q", result[2].Provider.Name()) - } - } -} - -func TestRandomBalancer_AllRoutesPresent(t *testing.T) { - routes := makeRoutes( - routeSpec{"a", 1, 1.0, 1.0}, - routeSpec{"b", 1, 1.0, 1.0}, - routeSpec{"c", 1, 1.0, 1.0}, - ) - - b := &RandomBalancer{} - - for i := 0; i < 10; i++ { - result := b.Reorder(routes) - if len(result) != 3 { - t.Fatalf("expected 3 routes, got %d", len(result)) - } - - names := make(map[string]bool) - for _, r := range result { - names[r.Provider.Name()] = true - } - for _, want := range []string{"a", "b", "c"} { - if !names[want] { - t.Errorf("missing route %q in result", want) - } - } - } -} - -func TestRandomBalancer_PreservesPriorityOrder(t *testing.T) { - routes := makeRoutes( - routeSpec{"a", 1, 1.0, 1.0}, - routeSpec{"b", 1, 1.0, 1.0}, - routeSpec{"c", 2, 1.0, 1.0}, - ) - - b := &RandomBalancer{} - - for i := 0; i < 10; i++ { - result := b.Reorder(routes) - if result[2].Provider.Name() != "c" { - t.Fatalf("expected priority-2 route 'c' last, got %q", result[2].Provider.Name()) - } - } -} - -func TestLeastCostBalancer_SortsByCost(t *testing.T) { - routes := makeRoutes( - routeSpec{"expensive", 1, 10.0, 10.0}, - routeSpec{"cheap", 1, 1.0, 1.0}, - routeSpec{"medium", 1, 5.0, 5.0}, - ) - - b := &LeastCostBalancer{} - result := b.Reorder(routes) - - names := routeNames(result) - expected := []string{"cheap", "medium", "expensive"} - for i, want := range expected { - if names[i] != want { - t.Errorf("position %d: got %q, want %q", i, names[i], want) - } - } -} - -func TestLeastCostBalancer_PreservesPriorityOrder(t *testing.T) { - routes := makeRoutes( - routeSpec{"expensive-p1", 1, 10.0, 10.0}, - routeSpec{"cheap-p1", 1, 1.0, 1.0}, - routeSpec{"cheap-p2", 2, 0.5, 0.5}, - ) - - b := &LeastCostBalancer{} - result := b.Reorder(routes) - - names := routeNames(result) - // Within priority 1, cheap should come first; priority 2 always last - if names[0] != "cheap-p1" { - t.Errorf("expected cheap-p1 first, got %q", names[0]) - } - if names[1] != "expensive-p1" { - t.Errorf("expected expensive-p1 second, got %q", names[1]) - } - if names[2] != "cheap-p2" { - t.Errorf("expected cheap-p2 last, got %q", names[2]) - } -} - -func TestGroupByPriority(t *testing.T) { - tests := []struct { - name string - priorities []int - wantGroups [][]int - }{ - { - name: "empty", - priorities: nil, - wantGroups: nil, - }, - { - name: "single", - priorities: []int{1}, - wantGroups: [][]int{{1}}, - }, - { - name: "all same", - priorities: []int{1, 1, 1}, - wantGroups: [][]int{{1, 1, 1}}, - }, - { - name: "two groups", - priorities: []int{1, 1, 2, 2}, - wantGroups: [][]int{{1, 1}, {2, 2}}, - }, - { - name: "three groups", - priorities: []int{1, 2, 2, 3}, - wantGroups: [][]int{{1}, {2, 2}, {3}}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var routes []Route - for _, p := range tt.priorities { - routes = append(routes, Route{Priority: p}) - } - - groups := groupByPriority(routes) - - if tt.wantGroups == nil { - if groups != nil { - t.Fatalf("expected nil groups, got %v", groups) - } - return - } - - if len(groups) != len(tt.wantGroups) { - t.Fatalf("expected %d groups, got %d", len(tt.wantGroups), len(groups)) - } - - for i, wg := range tt.wantGroups { - if len(groups[i]) != len(wg) { - t.Errorf("group %d: expected %d routes, got %d", i, len(wg), len(groups[i])) - continue - } - for j, wp := range wg { - if groups[i][j].Priority != wp { - t.Errorf("group %d, route %d: expected priority %d, got %d", i, j, wp, groups[i][j].Priority) - } - } - } - }) - } -} - -func TestBalancer_SingleRoute(t *testing.T) { - routes := makeRoutes(routeSpec{"only", 1, 1.0, 1.0}) - - balancers := []struct { - name string - balancer LoadBalancer - }{ - {"first", &FirstBalancer{}}, - {"round-robin", &RoundRobinBalancer{}}, - {"random", &RandomBalancer{}}, - {"least-cost", &LeastCostBalancer{}}, - } - - for _, bb := range balancers { - t.Run(bb.name, func(t *testing.T) { - result := bb.balancer.Reorder(routes) - if len(result) != 1 || result[0].Provider.Name() != "only" { - t.Fatalf("expected single route 'only', got %v", routeNames(result)) - } - }) - } -} - -func TestNewLoadBalancer(t *testing.T) { - tests := []struct { - strategy string - wantType string - }{ - {"round-robin", "*provider.RoundRobinBalancer"}, - {"random", "*provider.RandomBalancer"}, - {"least-cost", "*provider.LeastCostBalancer"}, - {"first", "*provider.FirstBalancer"}, - {"unknown", "*provider.FirstBalancer"}, - {"", "*provider.FirstBalancer"}, - } - - for _, tt := range tests { - t.Run(tt.strategy, func(t *testing.T) { - b := NewLoadBalancer(tt.strategy) - got := fmt.Sprintf("%T", b) - if got != tt.wantType { - t.Errorf("NewLoadBalancer(%q) = %s, want %s", tt.strategy, got, tt.wantType) - } - }) - } -} diff --git a/llm-gateway/internal/provider/health.go b/llm-gateway/internal/provider/health.go deleted file mode 100644 index ad3638a..0000000 --- a/llm-gateway/internal/provider/health.go +++ /dev/null @@ -1,264 +0,0 @@ -package provider - -import ( - "sync" - "time" - - "llm-gateway/internal/config" -) - -// CircuitState represents the state of a circuit breaker. -type CircuitState int - -const ( - CircuitClosed CircuitState = iota // normal operation - CircuitOpen // blocking requests - CircuitHalfOpen // testing with probe request -) - -func (s CircuitState) String() string { - switch s { - case CircuitClosed: - return "closed" - case CircuitOpen: - return "open" - case CircuitHalfOpen: - return "half-open" - default: - return "unknown" - } -} - -// ProviderCircuit tracks circuit breaker state for a single provider. -type ProviderCircuit struct { - State CircuitState - OpenedAt time.Time - LastProbe time.Time -} - -// HealthEvent represents a single request outcome for a provider. -type HealthEvent struct { - Timestamp time.Time - LatencyMS int64 - IsError bool - ErrorMsg string -} - -// ProviderHealth is the computed health status for a provider. -type ProviderHealth struct { - Provider string `json:"provider"` - Status string `json:"status"` // healthy, degraded, down - ErrorRate float64 `json:"error_rate"` - AvgLatency float64 `json:"avg_latency_ms"` - Total int `json:"total"` - Errors int `json:"errors"` - CircuitState string `json:"circuit_state"` -} - -// HealthTracker tracks per-provider health using a sliding window. -type HealthTracker struct { - mu sync.RWMutex - windows map[string][]HealthEvent - windowDu time.Duration - circuits map[string]*ProviderCircuit - cbConfig config.CircuitBreakerConfig - OnStateChange func(provider string, from, to CircuitState) -} - -// NewHealthTracker creates a health tracker with the given window duration. -func NewHealthTracker(window time.Duration, cbCfg config.CircuitBreakerConfig) *HealthTracker { - if window == 0 { - window = 5 * time.Minute - } - return &HealthTracker{ - windows: make(map[string][]HealthEvent), - circuits: make(map[string]*ProviderCircuit), - windowDu: window, - cbConfig: cbCfg, - } -} - -// IsAvailable returns true if the provider's circuit breaker allows requests. -func (h *HealthTracker) IsAvailable(provider string) bool { - if !h.cbConfig.Enabled { - return true - } - - h.mu.RLock() - defer h.mu.RUnlock() - - circuit, ok := h.circuits[provider] - if !ok { - return true // no circuit = closed = available - } - - switch circuit.State { - case CircuitOpen: - // Check if cooldown has elapsed -> transition to half-open - if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration { - return true // will transition to half-open on next record - } - return false - case CircuitHalfOpen: - return true // allow probe - default: - return true - } -} - -// Record adds a health event for a provider and evaluates circuit transitions. -func (h *HealthTracker) Record(provider string, latencyMS int64, err error) { - event := HealthEvent{ - Timestamp: time.Now(), - LatencyMS: latencyMS, - IsError: err != nil, - } - if err != nil { - event.ErrorMsg = err.Error() - } - - h.mu.Lock() - defer h.mu.Unlock() - - h.windows[provider] = append(h.windows[provider], event) - h.prune(provider) - - if h.cbConfig.Enabled { - h.evaluateCircuit(provider, err) - } -} - -// evaluateCircuit transitions circuit breaker state. Must be called with lock held. -func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) { - circuit, ok := h.circuits[providerName] - if !ok { - circuit = &ProviderCircuit{State: CircuitClosed} - h.circuits[providerName] = circuit - } - - prevState := circuit.State - - switch circuit.State { - case CircuitClosed: - // Check if error threshold exceeded - errorRate, total := h.errorRateUnlocked(providerName) - if total >= h.cbConfig.MinRequests && errorRate >= h.cbConfig.ErrorThreshold { - circuit.State = CircuitOpen - circuit.OpenedAt = time.Now() - } - case CircuitOpen: - // Check if cooldown elapsed -> half-open - if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration { - circuit.State = CircuitHalfOpen - circuit.LastProbe = time.Now() - // Evaluate the probe result immediately - if lastErr == nil { - circuit.State = CircuitClosed - } else { - circuit.State = CircuitOpen - circuit.OpenedAt = time.Now() - } - } - case CircuitHalfOpen: - if lastErr == nil { - circuit.State = CircuitClosed - } else { - circuit.State = CircuitOpen - circuit.OpenedAt = time.Now() - } - } - - if circuit.State != prevState && h.OnStateChange != nil { - cb := h.OnStateChange - from, to := prevState, circuit.State - // Call outside lock to avoid deadlocks - go cb(providerName, from, to) - } -} - -// errorRateUnlocked computes error rate within window. Must be called with lock held. -func (h *HealthTracker) errorRateUnlocked(provider string) (float64, int) { - cutoff := time.Now().Add(-h.windowDu) - events := h.windows[provider] - var total, errors int - for _, e := range events { - if e.Timestamp.Before(cutoff) { - continue - } - total++ - if e.IsError { - errors++ - } - } - if total == 0 { - return 0, 0 - } - return float64(errors) / float64(total), total -} - -// Status returns computed health for all tracked providers. -func (h *HealthTracker) Status() []ProviderHealth { - h.mu.RLock() - defer h.mu.RUnlock() - - cutoff := time.Now().Add(-h.windowDu) - var results []ProviderHealth - - for provider, events := range h.windows { - var total, errors int - var totalLatency int64 - - for _, e := range events { - if e.Timestamp.Before(cutoff) { - continue - } - total++ - totalLatency += e.LatencyMS - if e.IsError { - errors++ - } - } - - if total == 0 { - continue - } - - errorRate := float64(errors) / float64(total) - status := "healthy" - if errorRate >= 0.5 { - status = "down" - } else if errorRate >= 0.1 { - status = "degraded" - } - - circuitState := "closed" - if circuit, ok := h.circuits[provider]; ok { - circuitState = circuit.State.String() - } - - results = append(results, ProviderHealth{ - Provider: provider, - Status: status, - ErrorRate: errorRate, - AvgLatency: float64(totalLatency) / float64(total), - Total: total, - Errors: errors, - CircuitState: circuitState, - }) - } - - return results -} - -// prune removes events outside the window. Must be called with lock held. -func (h *HealthTracker) prune(provider string) { - cutoff := time.Now().Add(-h.windowDu) - events := h.windows[provider] - i := 0 - for i < len(events) && events[i].Timestamp.Before(cutoff) { - i++ - } - if i > 0 { - h.windows[provider] = events[i:] - } -} diff --git a/llm-gateway/internal/provider/health_test.go b/llm-gateway/internal/provider/health_test.go deleted file mode 100644 index 8dda99e..0000000 --- a/llm-gateway/internal/provider/health_test.go +++ /dev/null @@ -1,345 +0,0 @@ -package provider - -import ( - "errors" - "testing" - "time" - - "llm-gateway/internal/config" -) - -func newTestTracker(window time.Duration, cb config.CircuitBreakerConfig) *HealthTracker { - return NewHealthTracker(window, cb) -} - -func defaultCBConfig() config.CircuitBreakerConfig { - return config.CircuitBreakerConfig{ - Enabled: true, - ErrorThreshold: 0.5, - MinRequests: 3, - CooldownDuration: 100 * time.Millisecond, - } -} - -func TestHealthTracker_Record(t *testing.T) { - ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{}) - - ht.Record("provA", 100, nil) - ht.Record("provA", 200, errors.New("fail")) - ht.Record("provB", 50, nil) - - ht.mu.RLock() - defer ht.mu.RUnlock() - - if len(ht.windows["provA"]) != 2 { - t.Fatalf("expected 2 events for provA, got %d", len(ht.windows["provA"])) - } - if len(ht.windows["provB"]) != 1 { - t.Fatalf("expected 1 event for provB, got %d", len(ht.windows["provB"])) - } - - // Verify event fields - ev := ht.windows["provA"][1] - if !ev.IsError || ev.ErrorMsg != "fail" || ev.LatencyMS != 200 { - t.Fatalf("unexpected event fields: %+v", ev) - } -} - -func TestHealthTracker_Status(t *testing.T) { - tests := []struct { - name string - successCount int - errorCount int - wantStatus string - wantErrorRate float64 - wantTotal int - wantErrors int - }{ - { - name: "healthy - no errors", - successCount: 10, - errorCount: 0, - wantStatus: "healthy", - wantErrorRate: 0.0, - wantTotal: 10, - wantErrors: 0, - }, - { - name: "healthy - below 10% errors", - successCount: 19, - errorCount: 1, - wantStatus: "healthy", - wantErrorRate: 0.05, - wantTotal: 20, - wantErrors: 1, - }, - { - name: "degraded - 20% errors", - successCount: 8, - errorCount: 2, - wantStatus: "degraded", - wantErrorRate: 0.2, - wantTotal: 10, - wantErrors: 2, - }, - { - name: "degraded - exactly 10% errors", - successCount: 9, - errorCount: 1, - wantStatus: "degraded", - wantErrorRate: 0.1, - wantTotal: 10, - wantErrors: 1, - }, - { - name: "down - 50% errors", - successCount: 5, - errorCount: 5, - wantStatus: "down", - wantErrorRate: 0.5, - wantTotal: 10, - wantErrors: 5, - }, - { - name: "down - all errors", - successCount: 0, - errorCount: 5, - wantStatus: "down", - wantErrorRate: 1.0, - wantTotal: 5, - wantErrors: 5, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{}) - - for i := 0; i < tt.successCount; i++ { - ht.Record("prov", 100, nil) - } - for i := 0; i < tt.errorCount; i++ { - ht.Record("prov", 100, errors.New("err")) - } - - statuses := ht.Status() - if len(statuses) != 1 { - t.Fatalf("expected 1 status, got %d", len(statuses)) - } - - s := statuses[0] - if s.Status != tt.wantStatus { - t.Errorf("status = %q, want %q", s.Status, tt.wantStatus) - } - if s.Total != tt.wantTotal { - t.Errorf("total = %d, want %d", s.Total, tt.wantTotal) - } - if s.Errors != tt.wantErrors { - t.Errorf("errors = %d, want %d", s.Errors, tt.wantErrors) - } - // Allow small float tolerance - if diff := s.ErrorRate - tt.wantErrorRate; diff > 0.001 || diff < -0.001 { - t.Errorf("error_rate = %f, want %f", s.ErrorRate, tt.wantErrorRate) - } - }) - } -} - -func TestHealthTracker_CircuitBreaker_ClosedToOpen(t *testing.T) { - cb := defaultCBConfig() - cb.MinRequests = 3 - cb.ErrorThreshold = 0.5 - - ht := newTestTracker(5*time.Minute, cb) - - // Record errors to exceed threshold (3 errors out of 3 = 100% > 50%) - ht.Record("prov", 100, errors.New("err")) - ht.Record("prov", 100, errors.New("err")) - ht.Record("prov", 100, errors.New("err")) - - ht.mu.RLock() - state := ht.circuits["prov"].State - ht.mu.RUnlock() - - if state != CircuitOpen { - t.Fatalf("expected CircuitOpen, got %s", state) - } - - if ht.IsAvailable("prov") { - t.Fatal("expected IsAvailable=false when circuit is open") - } -} - -func TestHealthTracker_CircuitBreaker_OpenToHalfOpenOnCooldown(t *testing.T) { - cb := defaultCBConfig() - cb.CooldownDuration = 50 * time.Millisecond - - ht := newTestTracker(5*time.Minute, cb) - - // Trip the circuit - for i := 0; i < 5; i++ { - ht.Record("prov", 100, errors.New("err")) - } - - if ht.IsAvailable("prov") { - t.Fatal("expected circuit open, IsAvailable should be false") - } - - // Wait for cooldown - time.Sleep(60 * time.Millisecond) - - // After cooldown, IsAvailable should return true (will transition to half-open) - if !ht.IsAvailable("prov") { - t.Fatal("expected IsAvailable=true after cooldown") - } -} - -func TestHealthTracker_CircuitBreaker_HalfOpenToClosedOnSuccess(t *testing.T) { - cb := defaultCBConfig() - cb.CooldownDuration = 10 * time.Millisecond - - ht := newTestTracker(5*time.Minute, cb) - - // Trip the circuit - for i := 0; i < 5; i++ { - ht.Record("prov", 100, errors.New("err")) - } - - // Wait for cooldown so next Record transitions through Open->HalfOpen - time.Sleep(20 * time.Millisecond) - - // A successful record should transition: Open -> HalfOpen -> Closed - ht.Record("prov", 100, nil) - - ht.mu.RLock() - state := ht.circuits["prov"].State - ht.mu.RUnlock() - - if state != CircuitClosed { - t.Fatalf("expected CircuitClosed after success in half-open, got %s", state) - } - - if !ht.IsAvailable("prov") { - t.Fatal("expected IsAvailable=true after circuit closed") - } -} - -func TestHealthTracker_CircuitBreaker_HalfOpenToOpenOnFailure(t *testing.T) { - cb := defaultCBConfig() - cb.CooldownDuration = 10 * time.Millisecond - - ht := newTestTracker(5*time.Minute, cb) - - // Trip the circuit - for i := 0; i < 5; i++ { - ht.Record("prov", 100, errors.New("err")) - } - - // Wait for cooldown - time.Sleep(20 * time.Millisecond) - - // A failed record should transition: Open -> HalfOpen -> Open - ht.Record("prov", 100, errors.New("still failing")) - - ht.mu.RLock() - state := ht.circuits["prov"].State - ht.mu.RUnlock() - - if state != CircuitOpen { - t.Fatalf("expected CircuitOpen after failure in half-open, got %s", state) - } -} - -func TestHealthTracker_IsAvailable_NoCircuitBreaker(t *testing.T) { - ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{Enabled: false}) - - // Even with errors, IsAvailable should return true when CB is disabled - for i := 0; i < 10; i++ { - ht.Record("prov", 100, errors.New("err")) - } - - if !ht.IsAvailable("prov") { - t.Fatal("expected IsAvailable=true when circuit breaker disabled") - } -} - -func TestHealthTracker_IsAvailable_UnknownProvider(t *testing.T) { - ht := newTestTracker(5*time.Minute, defaultCBConfig()) - - if !ht.IsAvailable("unknown") { - t.Fatal("expected IsAvailable=true for unknown provider (no circuit)") - } -} - -func TestHealthTracker_WindowPruning(t *testing.T) { - // Use a tiny window so events expire quickly - ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{}) - - ht.Record("prov", 100, nil) - ht.Record("prov", 200, nil) - - // Wait for events to expire - time.Sleep(60 * time.Millisecond) - - // Record a new event to trigger pruning - ht.Record("prov", 300, nil) - - ht.mu.RLock() - count := len(ht.windows["prov"]) - ht.mu.RUnlock() - - if count != 1 { - t.Fatalf("expected 1 event after pruning, got %d", count) - } -} - -func TestHealthTracker_Status_EmptyAfterPruning(t *testing.T) { - ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{}) - - ht.Record("prov", 100, nil) - - // Wait for events to expire - time.Sleep(60 * time.Millisecond) - - statuses := ht.Status() - if len(statuses) != 0 { - t.Fatalf("expected 0 statuses after window expiry, got %d", len(statuses)) - } -} - -func TestHealthTracker_Status_AvgLatency(t *testing.T) { - ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{}) - - ht.Record("prov", 100, nil) - ht.Record("prov", 200, nil) - ht.Record("prov", 300, nil) - - statuses := ht.Status() - if len(statuses) != 1 { - t.Fatalf("expected 1 status, got %d", len(statuses)) - } - - want := 200.0 - if diff := statuses[0].AvgLatency - want; diff > 0.001 || diff < -0.001 { - t.Errorf("avg_latency = %f, want %f", statuses[0].AvgLatency, want) - } -} - -func TestHealthTracker_Status_CircuitStateReported(t *testing.T) { - cb := defaultCBConfig() - ht := newTestTracker(5*time.Minute, cb) - - // Trip the circuit - for i := 0; i < 5; i++ { - ht.Record("prov", 100, errors.New("err")) - } - - statuses := ht.Status() - if len(statuses) != 1 { - t.Fatalf("expected 1 status, got %d", len(statuses)) - } - - if statuses[0].CircuitState != "open" { - t.Errorf("circuit_state = %q, want %q", statuses[0].CircuitState, "open") - } -} diff --git a/llm-gateway/internal/provider/openai.go b/llm-gateway/internal/provider/openai.go deleted file mode 100644 index 9c096b3..0000000 --- a/llm-gateway/internal/provider/openai.go +++ /dev/null @@ -1,178 +0,0 @@ -package provider - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "time" -) - -// OpenAIProvider is a generic OpenAI-compatible HTTP client. -type OpenAIProvider struct { - name string - baseURL string - apiKey string - client *http.Client -} - -func NewOpenAIProvider(name, baseURL, apiKey string, timeout time.Duration) *OpenAIProvider { - return &OpenAIProvider{ - name: name, - baseURL: baseURL, - apiKey: apiKey, - client: &http.Client{ - Timeout: timeout, - }, - } -} - -func (p *OpenAIProvider) Name() string { return p.name } - -func (p *OpenAIProvider) ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error) { - reqCopy := *req - reqCopy.Model = model - reqCopy.Stream = false - - body, err := json.Marshal(reqCopy) - if err != nil { - return nil, fmt.Errorf("marshaling request: %w", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - p.setHeaders(httpReq) - - resp, err := p.client.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("sending request: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("reading response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, &ProviderError{ - StatusCode: resp.StatusCode, - Body: string(respBody), - Provider: p.name, - } - } - - var chatResp ChatResponse - if err := json.Unmarshal(respBody, &chatResp); err != nil { - return nil, fmt.Errorf("unmarshaling response: %w", err) - } - - return &chatResp, nil -} - -func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error) { - reqCopy := *req - reqCopy.Model = model - reqCopy.Stream = true - - body, err := json.Marshal(reqCopy) - if err != nil { - return nil, fmt.Errorf("marshaling request: %w", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - p.setHeaders(httpReq) - - resp, err := p.client.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("sending request: %w", err) - } - - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - respBody, _ := io.ReadAll(resp.Body) - return nil, &ProviderError{ - StatusCode: resp.StatusCode, - Body: string(respBody), - Provider: p.name, - } - } - - return resp.Body, nil -} - -func (p *OpenAIProvider) Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error) { - reqCopy := *req - reqCopy.Model = model - - body, err := json.Marshal(reqCopy) - if err != nil { - return nil, fmt.Errorf("marshaling request: %w", err) - } - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/embeddings", bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("creating request: %w", err) - } - p.setHeaders(httpReq) - - resp, err := p.client.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("sending request: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("reading response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, &ProviderError{ - StatusCode: resp.StatusCode, - Body: string(respBody), - Provider: p.name, - } - } - - var embResp EmbeddingResponse - if err := json.Unmarshal(respBody, &embResp); err != nil { - return nil, fmt.Errorf("unmarshaling response: %w", err) - } - - return &embResp, nil -} - -func (p *OpenAIProvider) setHeaders(req *http.Request) { - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+p.apiKey) - // Forward request ID if present in context - if reqID := req.Context().Value("requestID"); reqID != nil { - if id, ok := reqID.(string); ok && id != "" { - req.Header.Set("X-Request-ID", id) - } - } -} - -// ProviderError represents a non-200 response from a provider. -type ProviderError struct { - StatusCode int - Body string - Provider string -} - -func (e *ProviderError) Error() string { - return fmt.Sprintf("provider %s returned %d: %s", e.Provider, e.StatusCode, e.Body) -} - -// IsRetryable returns true if the error is a server-side error worth retrying with another provider. -func (e *ProviderError) IsRetryable() bool { - return e.StatusCode >= 500 || e.StatusCode == 429 -} diff --git a/llm-gateway/internal/provider/provider.go b/llm-gateway/internal/provider/provider.go deleted file mode 100644 index 3d552ea..0000000 --- a/llm-gateway/internal/provider/provider.go +++ /dev/null @@ -1,89 +0,0 @@ -package provider - -import ( - "context" - "io" -) - -// ChatRequest is the OpenAI-compatible chat completion request. -type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Temperature *float64 `json:"temperature,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - Stream bool `json:"stream"` - Stop any `json:"stop,omitempty"` - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` - PresencePenalty *float64 `json:"presence_penalty,omitempty"` - N *int `json:"n,omitempty"` - Tools []any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - ResponseFormat any `json:"response_format,omitempty"` - Extra map[string]any `json:"-"` // pass through unknown fields -} - -type Message struct { - Role string `json:"role"` - Content any `json:"content"` // string or []ContentPart - Name string `json:"name,omitempty"` - ToolCalls []any `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` -} - -type ChatResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []Choice `json:"choices"` - Usage *Usage `json:"usage,omitempty"` -} - -type Choice struct { - Index int `json:"index"` - Message Message `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// EmbeddingRequest is the OpenAI-compatible embedding request. -type EmbeddingRequest struct { - Model string `json:"model"` - Input any `json:"input"` // string or []string - EncodingFormat string `json:"encoding_format,omitempty"` -} - -// EmbeddingResponse is the OpenAI-compatible embedding response. -type EmbeddingResponse struct { - Object string `json:"object"` - Data []EmbeddingData `json:"data"` - Model string `json:"model"` - Usage *EmbeddingUsage `json:"usage,omitempty"` -} - -// EmbeddingData holds a single embedding vector. -type EmbeddingData struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` -} - -// EmbeddingUsage reports token usage for embeddings. -type EmbeddingUsage struct { - PromptTokens int `json:"prompt_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// Provider sends requests to an LLM API. -type Provider interface { - Name() string - ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error) - ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error) - Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error) -} diff --git a/llm-gateway/internal/provider/registry.go b/llm-gateway/internal/provider/registry.go deleted file mode 100644 index 2096738..0000000 --- a/llm-gateway/internal/provider/registry.go +++ /dev/null @@ -1,214 +0,0 @@ -package provider - -import ( - "fmt" - "sort" - "sync" - "time" - - "llm-gateway/internal/config" -) - -// ModelTimeouts holds per-model timeout overrides. -type ModelTimeouts struct { - RequestTimeout time.Duration - StreamingTimeout time.Duration -} - -// Route maps a model to a specific provider with pricing. -type Route struct { - Provider Provider - ProviderModel string - Priority int - InputPrice float64 // per 1M tokens - OutputPrice float64 // per 1M tokens -} - -// Registry maps model names to provider routes. -type Registry struct { - mu sync.RWMutex - routes map[string][]Route - balancers map[string]LoadBalancer - aliases map[string]string // alias -> canonical name - order []string // preserves config order (canonical names only) - timeouts map[string]*ModelTimeouts -} - -func NewRegistry(cfg *config.Config) (*Registry, error) { - r := &Registry{} - if err := r.buildFromConfig(cfg); err != nil { - return nil, err - } - return r, nil -} - -func (r *Registry) buildFromConfig(cfg *config.Config) error { - // Build providers - providers := make(map[string]Provider) - for _, pc := range cfg.Providers { - providers[pc.Name] = NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKey, pc.Timeout) - } - - // Build routes - routes := make(map[string][]Route) - balancers := make(map[string]LoadBalancer) - aliases := make(map[string]string) - order := make([]string, 0, len(cfg.Models)) - timeouts := make(map[string]*ModelTimeouts) - - for _, mc := range cfg.Models { - var modelRoutes []Route - for _, rc := range mc.Routes { - p, ok := providers[rc.Provider] - if !ok { - return fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider) - } - pc := cfg.ProviderByName(rc.Provider) - priority := pc.Priority - modelRoutes = append(modelRoutes, Route{ - Provider: p, - ProviderModel: rc.Model, - Priority: priority, - InputPrice: rc.Pricing.Input, - OutputPrice: rc.Pricing.Output, - }) - } - // Sort by priority (lower = higher priority) - sort.Slice(modelRoutes, func(i, j int) bool { - return modelRoutes[i].Priority < modelRoutes[j].Priority - }) - routes[mc.Name] = modelRoutes - order = append(order, mc.Name) - - // Load balancer - balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing) - - // Per-model timeouts - if mc.RequestTimeout > 0 || mc.StreamingTimeout > 0 { - timeouts[mc.Name] = &ModelTimeouts{ - RequestTimeout: mc.RequestTimeout, - StreamingTimeout: mc.StreamingTimeout, - } - } - - // Register aliases - for _, alias := range mc.Aliases { - aliases[alias] = mc.Name - } - } - - r.mu.Lock() - r.routes = routes - r.balancers = balancers - r.aliases = aliases - r.order = order - r.timeouts = timeouts - r.mu.Unlock() - - return nil -} - -// Reload rebuilds routes from new config. Used for hot-reload. -func (r *Registry) Reload(cfg *config.Config) error { - return r.buildFromConfig(cfg) -} - -// Lookup returns the routes for a model name (resolving aliases). -func (r *Registry) Lookup(model string) ([]Route, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - - // Resolve alias - canonical := model - if alias, ok := r.aliases[model]; ok { - canonical = alias - } - - routes, ok := r.routes[canonical] - if !ok { - return nil, false - } - - // Apply load balancer - if balancer, ok := r.balancers[canonical]; ok { - routes = balancer.Reorder(routes) - } - - return routes, true -} - -// ModelNames returns all registered model names in config order (including aliases). -func (r *Registry) ModelNames() []string { - r.mu.RLock() - defer r.mu.RUnlock() - - var names []string - for _, name := range r.order { - names = append(names, name) - } - // Add aliases - for alias := range r.aliases { - names = append(names, alias) - } - return names -} - -// ModelTimeoutsFor returns per-model timeout overrides, resolving aliases. Returns nil if none set. -func (r *Registry) ModelTimeoutsFor(model string) *ModelTimeouts { - r.mu.RLock() - defer r.mu.RUnlock() - - canonical := model - if alias, ok := r.aliases[model]; ok { - canonical = alias - } - return r.timeouts[canonical] -} - -// RouteInfo exposes route details for dashboard display. -type RouteInfo struct { - ProviderName string `json:"provider_name"` - ProviderModel string `json:"provider_model"` - Priority int `json:"priority"` - InputPrice float64 `json:"input_price"` - OutputPrice float64 `json:"output_price"` -} - -// ModelRouteInfo exposes a model and its routes for dashboard display. -type ModelRouteInfo struct { - Name string `json:"name"` - Aliases []string `json:"aliases,omitempty"` - Routes []RouteInfo `json:"routes"` -} - -// AllRoutes returns all models and their routes in config order. -func (r *Registry) AllRoutes() []ModelRouteInfo { - r.mu.RLock() - defer r.mu.RUnlock() - - // Build reverse alias map - modelAliases := make(map[string][]string) - for alias, canonical := range r.aliases { - modelAliases[canonical] = append(modelAliases[canonical], alias) - } - - results := make([]ModelRouteInfo, 0, len(r.order)) - for _, name := range r.order { - routes := r.routes[name] - info := ModelRouteInfo{ - Name: name, - Aliases: modelAliases[name], - } - for _, rt := range routes { - info.Routes = append(info.Routes, RouteInfo{ - ProviderName: rt.Provider.Name(), - ProviderModel: rt.ProviderModel, - Priority: rt.Priority, - InputPrice: rt.InputPrice, - OutputPrice: rt.OutputPrice, - }) - } - results = append(results, info) - } - return results -} diff --git a/llm-gateway/internal/provider/registry_test.go b/llm-gateway/internal/provider/registry_test.go deleted file mode 100644 index 6c2a5cc..0000000 --- a/llm-gateway/internal/provider/registry_test.go +++ /dev/null @@ -1,286 +0,0 @@ -package provider - -import ( - "context" - "io" - "testing" - - "llm-gateway/internal/config" -) - -// mockProvider implements the Provider interface for testing. -type mockProvider struct { - name string -} - -func (m *mockProvider) Name() string { return m.name } - -func (m *mockProvider) ChatCompletion(_ context.Context, _ string, _ *ChatRequest) (*ChatResponse, error) { - return nil, nil -} - -func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *ChatRequest) (io.ReadCloser, error) { - return nil, nil -} - -func (m *mockProvider) Embedding(_ context.Context, _ string, _ *EmbeddingRequest) (*EmbeddingResponse, error) { - return nil, nil -} - -// newTestRegistry builds a Registry directly without going through config parsing. -func newTestRegistry(models []testModel) *Registry { - r := &Registry{ - routes: make(map[string][]Route), - balancers: make(map[string]LoadBalancer), - aliases: make(map[string]string), - } - - for _, m := range models { - r.routes[m.name] = m.routes - r.balancers[m.name] = &FirstBalancer{} - r.order = append(r.order, m.name) - for _, alias := range m.aliases { - r.aliases[alias] = m.name - } - } - - return r -} - -type testModel struct { - name string - aliases []string - routes []Route -} - -func TestRegistry_Lookup_Canonical(t *testing.T) { - reg := newTestRegistry([]testModel{ - { - name: "gpt-4", - routes: []Route{ - {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, - }, - }, - }) - - routes, ok := reg.Lookup("gpt-4") - if !ok { - t.Fatal("expected Lookup to find gpt-4") - } - if len(routes) != 1 { - t.Fatalf("expected 1 route, got %d", len(routes)) - } - if routes[0].Provider.Name() != "openai" { - t.Errorf("expected provider 'openai', got %q", routes[0].Provider.Name()) - } -} - -func TestRegistry_Lookup_Alias(t *testing.T) { - reg := newTestRegistry([]testModel{ - { - name: "gpt-4", - aliases: []string{"gpt4", "gpt-4-latest"}, - routes: []Route{ - {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, - }, - }, - }) - - tests := []struct { - name string - model string - found bool - }{ - {"canonical", "gpt-4", true}, - {"alias1", "gpt4", true}, - {"alias2", "gpt-4-latest", true}, - {"unknown", "gpt-5", false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - routes, ok := reg.Lookup(tt.model) - if ok != tt.found { - t.Fatalf("Lookup(%q) found=%v, want %v", tt.model, ok, tt.found) - } - if tt.found && len(routes) != 1 { - t.Fatalf("expected 1 route, got %d", len(routes)) - } - }) - } -} - -func TestRegistry_ModelNames_IncludesAliases(t *testing.T) { - reg := newTestRegistry([]testModel{ - { - name: "gpt-4", - aliases: []string{"gpt4"}, - routes: []Route{ - {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, - }, - }, - { - name: "claude-3", - routes: []Route{ - {Provider: &mockProvider{name: "anthropic"}, ProviderModel: "claude-3", Priority: 1}, - }, - }, - }) - - names := reg.ModelNames() - - want := map[string]bool{"gpt-4": true, "gpt4": true, "claude-3": true} - got := make(map[string]bool) - for _, n := range names { - got[n] = true - } - - for name := range want { - if !got[name] { - t.Errorf("expected %q in ModelNames, not found", name) - } - } - - if len(names) != len(want) { - t.Errorf("expected %d names, got %d: %v", len(want), len(names), names) - } -} - -func TestRegistry_AllRoutes_ShowsAliases(t *testing.T) { - reg := newTestRegistry([]testModel{ - { - name: "gpt-4", - aliases: []string{"gpt4", "gpt-4-latest"}, - routes: []Route{ - {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, - {Provider: &mockProvider{name: "azure"}, ProviderModel: "gpt-4", Priority: 2}, - }, - }, - }) - - allRoutes := reg.AllRoutes() - if len(allRoutes) != 1 { - t.Fatalf("expected 1 model, got %d", len(allRoutes)) - } - - m := allRoutes[0] - if m.Name != "gpt-4" { - t.Errorf("expected name 'gpt-4', got %q", m.Name) - } - - aliasSet := make(map[string]bool) - for _, a := range m.Aliases { - aliasSet[a] = true - } - if !aliasSet["gpt4"] || !aliasSet["gpt-4-latest"] { - t.Errorf("expected aliases [gpt4, gpt-4-latest], got %v", m.Aliases) - } - - if len(m.Routes) != 2 { - t.Fatalf("expected 2 routes, got %d", len(m.Routes)) - } - if m.Routes[0].ProviderName != "openai" { - t.Errorf("expected first route provider 'openai', got %q", m.Routes[0].ProviderName) - } - if m.Routes[1].ProviderName != "azure" { - t.Errorf("expected second route provider 'azure', got %q", m.Routes[1].ProviderName) - } -} - -func TestRegistry_AllRoutes_ConfigOrder(t *testing.T) { - reg := newTestRegistry([]testModel{ - { - name: "model-b", - routes: []Route{ - {Provider: &mockProvider{name: "prov"}, ProviderModel: "b", Priority: 1}, - }, - }, - { - name: "model-a", - routes: []Route{ - {Provider: &mockProvider{name: "prov"}, ProviderModel: "a", Priority: 1}, - }, - }, - }) - - allRoutes := reg.AllRoutes() - if len(allRoutes) != 2 { - t.Fatalf("expected 2 models, got %d", len(allRoutes)) - } - if allRoutes[0].Name != "model-b" { - t.Errorf("expected first model 'model-b', got %q", allRoutes[0].Name) - } - if allRoutes[1].Name != "model-a" { - t.Errorf("expected second model 'model-a', got %q", allRoutes[1].Name) - } -} - -func TestRegistry_PrioritySorting(t *testing.T) { - reg := newTestRegistry([]testModel{ - { - name: "multi-provider", - routes: []Route{ - {Provider: &mockProvider{name: "low-priority"}, ProviderModel: "m", Priority: 3}, - {Provider: &mockProvider{name: "high-priority"}, ProviderModel: "m", Priority: 1}, - {Provider: &mockProvider{name: "mid-priority"}, ProviderModel: "m", Priority: 2}, - }, - }, - }) - - // Note: routes are stored as given (sorting happens during buildFromConfig). - // For this test we verify AllRoutes returns them in stored order. - allRoutes := reg.AllRoutes() - if len(allRoutes) != 1 { - t.Fatalf("expected 1 model, got %d", len(allRoutes)) - } - - routes := allRoutes[0].Routes - if len(routes) != 3 { - t.Fatalf("expected 3 routes, got %d", len(routes)) - } - - // Verify the priorities are present - priorities := make(map[int]bool) - for _, r := range routes { - priorities[r.Priority] = true - } - for _, p := range []int{1, 2, 3} { - if !priorities[p] { - t.Errorf("expected priority %d in routes", p) - } - } -} - -func TestRegistry_NewRegistry_UnknownProvider(t *testing.T) { - cfg := &config.Config{ - Models: []config.ModelConfig{ - { - Name: "test-model", - Routes: []config.RouteConfig{ - {Provider: "nonexistent", Model: "m"}, - }, - }, - }, - } - - _, err := NewRegistry(cfg) - if err == nil { - t.Fatal("expected error for unknown provider, got nil") - } -} - -func TestRegistry_Lookup_NotFound(t *testing.T) { - reg := newTestRegistry([]testModel{ - { - name: "gpt-4", - routes: []Route{ - {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1}, - }, - }, - }) - - _, ok := reg.Lookup("nonexistent") - if ok { - t.Fatal("expected Lookup to return false for nonexistent model") - } -} diff --git a/llm-gateway/internal/proxy/auth.go b/llm-gateway/internal/proxy/auth.go deleted file mode 100644 index ca0944f..0000000 --- a/llm-gateway/internal/proxy/auth.go +++ /dev/null @@ -1,43 +0,0 @@ -package proxy - -import ( - "net/http" - "strings" - - "llm-gateway/internal/auth" -) - -type AuthMiddleware struct { - authStore *auth.Store -} - -func NewAuthMiddleware(authStore *auth.Store) *AuthMiddleware { - return &AuthMiddleware{authStore: authStore} -} - -// Authenticate validates the bearer token against the DB and sets token info in context. -func (a *AuthMiddleware) Authenticate(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - hdr := r.Header.Get("Authorization") - if !strings.HasPrefix(hdr, "Bearer ") { - writeError(w, http.StatusUnauthorized, "missing or invalid Authorization header") - return - } - key := strings.TrimPrefix(hdr, "Bearer ") - - token, err := a.authStore.LookupAPIToken(key) - if err != nil { - writeError(w, http.StatusUnauthorized, "invalid API key") - return - } - - // Update last used asynchronously (skip for static tokens) - if token.ID > 0 { - go a.authStore.UpdateAPITokenLastUsed(token.ID) - } - - ctx := withTokenName(r.Context(), token.Name) - ctx = withAPIToken(ctx, token) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} diff --git a/llm-gateway/internal/proxy/concurrency.go b/llm-gateway/internal/proxy/concurrency.go deleted file mode 100644 index 4f28262..0000000 --- a/llm-gateway/internal/proxy/concurrency.go +++ /dev/null @@ -1,51 +0,0 @@ -package proxy - -import ( - "net/http" - "sync" - "sync/atomic" -) - -// ConcurrencyLimiter enforces per-token concurrent request limits. -type ConcurrencyLimiter struct { - mu sync.Mutex - counters map[string]*atomic.Int64 -} - -func NewConcurrencyLimiter() *ConcurrencyLimiter { - return &ConcurrencyLimiter{ - counters: make(map[string]*atomic.Int64), - } -} - -func (cl *ConcurrencyLimiter) getCounter(tokenName string) *atomic.Int64 { - cl.mu.Lock() - defer cl.mu.Unlock() - c, ok := cl.counters[tokenName] - if !ok { - c = &atomic.Int64{} - cl.counters[tokenName] = c - } - return c -} - -func (cl *ConcurrencyLimiter) Check(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - apiToken := getAPIToken(r.Context()) - if apiToken == nil || apiToken.MaxConcurrent <= 0 { - next.ServeHTTP(w, r) - return - } - - counter := cl.getCounter(apiToken.Name) - current := counter.Add(1) - defer counter.Add(-1) - - if current > int64(apiToken.MaxConcurrent) { - writeError(w, http.StatusTooManyRequests, "concurrent request limit exceeded") - return - } - - next.ServeHTTP(w, r) - }) -} diff --git a/llm-gateway/internal/proxy/concurrency_test.go b/llm-gateway/internal/proxy/concurrency_test.go deleted file mode 100644 index fa6ccbf..0000000 --- a/llm-gateway/internal/proxy/concurrency_test.go +++ /dev/null @@ -1,317 +0,0 @@ -package proxy - -import ( - "net/http" - "net/http/httptest" - "sync" - "sync/atomic" - "testing" - "time" - - "llm-gateway/internal/auth" -) - -func TestConcurrencyLimiter_AllowsWithinLimit(t *testing.T) { - tests := []struct { - name string - maxConcurrent int - numRequests int - wantAllowed int - }{ - { - name: "single request within limit", - maxConcurrent: 5, - numRequests: 1, - wantAllowed: 1, - }, - { - name: "all requests within limit", - maxConcurrent: 5, - numRequests: 5, - wantAllowed: 5, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cl := NewConcurrencyLimiter() - - token := &auth.APIToken{ - Name: "conc-token", - MaxConcurrent: tt.maxConcurrent, - } - - var allowed atomic.Int64 - var wg sync.WaitGroup - // Use a channel to hold all goroutines inside the handler simultaneously. - gate := make(chan struct{}) - - handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - allowed.Add(1) - <-gate // Block until released. - w.WriteHeader(http.StatusOK) - })) - - for i := 0; i < tt.numRequests; i++ { - wg.Add(1) - go func() { - defer wg.Done() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx := withAPIToken(req.Context(), token) - req = req.WithContext(ctx) - handler.ServeHTTP(rec, req) - }() - } - - // Wait for goroutines to enter the handler. - time.Sleep(50 * time.Millisecond) - close(gate) - wg.Wait() - - if int(allowed.Load()) != tt.wantAllowed { - t.Errorf("allowed = %d, want %d", allowed.Load(), tt.wantAllowed) - } - }) - } -} - -func TestConcurrencyLimiter_DeniesOverLimit(t *testing.T) { - tests := []struct { - name string - maxConcurrent int - numRequests int - wantDenied int - }{ - { - name: "one over limit", - maxConcurrent: 2, - numRequests: 3, - wantDenied: 1, - }, - { - name: "many over limit", - maxConcurrent: 1, - numRequests: 5, - wantDenied: 4, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cl := NewConcurrencyLimiter() - - token := &auth.APIToken{ - Name: "conc-token", - MaxConcurrent: tt.maxConcurrent, - } - - var denied atomic.Int64 - var wg sync.WaitGroup - gate := make(chan struct{}) - - handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - <-gate - w.WriteHeader(http.StatusOK) - })) - - results := make([]int, tt.numRequests) - for i := 0; i < tt.numRequests; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx := withAPIToken(req.Context(), token) - req = req.WithContext(ctx) - handler.ServeHTTP(rec, req) - results[idx] = rec.Code - if rec.Code == http.StatusTooManyRequests { - denied.Add(1) - } - }(i) - } - - // Wait for goroutines to reach the handler or be rejected. - time.Sleep(50 * time.Millisecond) - close(gate) - wg.Wait() - - if int(denied.Load()) != tt.wantDenied { - t.Errorf("denied = %d, want %d", denied.Load(), tt.wantDenied) - } - }) - } -} - -func TestConcurrencyLimiter_CounterDecrementsAfterCompletion(t *testing.T) { - cl := NewConcurrencyLimiter() - - token := &auth.APIToken{ - Name: "decrement-token", - MaxConcurrent: 1, - } - - handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - // First request should succeed and complete, decrementing the counter. - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx := withAPIToken(req.Context(), token) - req = req.WithContext(ctx) - handler.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("first request: status = %d, want %d", rec.Code, http.StatusOK) - } - - // Counter should have decremented. A second request should also succeed. - rec2 := httptest.NewRecorder() - req2 := httptest.NewRequest(http.MethodGet, "/", nil) - ctx2 := withAPIToken(req2.Context(), token) - req2 = req2.WithContext(ctx2) - handler.ServeHTTP(rec2, req2) - - if rec2.Code != http.StatusOK { - t.Errorf("second request after first completed: status = %d, want %d", rec2.Code, http.StatusOK) - } - - // Verify the internal counter is back to 0. - counter := cl.getCounter(token.Name) - val := counter.Load() - if val != 0 { - t.Errorf("counter = %d, want 0 after all requests completed", val) - } -} - -func TestConcurrencyLimiter_ZeroMaxConcurrentMeansUnlimited(t *testing.T) { - tests := []struct { - name string - maxConcurrent int - numRequests int - }{ - { - name: "zero allows unlimited concurrent requests", - maxConcurrent: 0, - numRequests: 50, - }, - { - name: "negative allows unlimited concurrent requests", - maxConcurrent: -1, - numRequests: 50, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cl := NewConcurrencyLimiter() - - token := &auth.APIToken{ - Name: "unlimited-token", - MaxConcurrent: tt.maxConcurrent, - } - - var allowed atomic.Int64 - var wg sync.WaitGroup - gate := make(chan struct{}) - - handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - allowed.Add(1) - <-gate - w.WriteHeader(http.StatusOK) - })) - - for i := 0; i < tt.numRequests; i++ { - wg.Add(1) - go func() { - defer wg.Done() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx := withAPIToken(req.Context(), token) - req = req.WithContext(ctx) - handler.ServeHTTP(rec, req) - }() - } - - // Give goroutines time to enter the handler. - time.Sleep(100 * time.Millisecond) - close(gate) - wg.Wait() - - if int(allowed.Load()) != tt.numRequests { - t.Errorf("allowed = %d, want %d (zero/negative maxConcurrent should be unlimited)", allowed.Load(), tt.numRequests) - } - }) - } -} - -func TestConcurrencyLimiter_NoToken(t *testing.T) { - cl := NewConcurrencyLimiter() - - handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - // No API token in context. - handler.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Errorf("status = %d, want %d (should pass through without token)", rec.Code, http.StatusOK) - } -} - -func TestConcurrencyLimiter_PerTokenIsolation(t *testing.T) { - cl := NewConcurrencyLimiter() - - tokenA := &auth.APIToken{ - Name: "token-a", - MaxConcurrent: 1, - } - tokenB := &auth.APIToken{ - Name: "token-b", - MaxConcurrent: 1, - } - - gateA := make(chan struct{}) - var wg sync.WaitGroup - - handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tok := getAPIToken(r.Context()) - if tok.Name == "token-a" { - <-gateA // Block token A's request. - } - w.WriteHeader(http.StatusOK) - })) - - // Start a request for token A that blocks. - wg.Add(1) - go func() { - defer wg.Done() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx := withAPIToken(req.Context(), tokenA) - req = req.WithContext(ctx) - handler.ServeHTTP(rec, req) - }() - - // Give token A's goroutine time to enter handler. - time.Sleep(50 * time.Millisecond) - - // Token B should not be affected by token A's in-flight request. - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx := withAPIToken(req.Context(), tokenB) - req = req.WithContext(ctx) - handler.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Errorf("token-b status = %d, want %d (should not be affected by token-a)", rec.Code, http.StatusOK) - } - - close(gateA) - wg.Wait() -} diff --git a/llm-gateway/internal/proxy/dedup.go b/llm-gateway/internal/proxy/dedup.go deleted file mode 100644 index d55b6c5..0000000 --- a/llm-gateway/internal/proxy/dedup.go +++ /dev/null @@ -1,107 +0,0 @@ -package proxy - -import ( - "crypto/sha256" - "encoding/hex" - "sync" - "time" -) - -// inflight represents an in-progress deduplicated request. -type inflight struct { - done chan struct{} - result []byte - statusCode int - createdAt time.Time -} - -// Deduplicator coalesces identical concurrent non-streaming requests. -type Deduplicator struct { - mu sync.Mutex - flights map[string]*inflight - window time.Duration - done chan struct{} -} - -// NewDeduplicator creates a new request deduplicator. -func NewDeduplicator(window time.Duration) *Deduplicator { - if window == 0 { - window = 30 * time.Second - } - d := &Deduplicator{ - flights: make(map[string]*inflight), - window: window, - done: make(chan struct{}), - } - go d.cleanup() - return d -} - -// DedupKey computes a dedup key from model name and request body. -func DedupKey(model string, body []byte) string { - h := sha256.New() - h.Write([]byte(model)) - h.Write([]byte{0}) - h.Write(body) - return hex.EncodeToString(h.Sum(nil)) -} - -// TryJoin attempts to join an in-flight request. Returns the inflight entry and -// whether this caller is the leader (true) or a follower (false). -func (d *Deduplicator) TryJoin(key string) (*inflight, bool) { - d.mu.Lock() - defer d.mu.Unlock() - - if f, ok := d.flights[key]; ok { - return f, false // follower - } - - f := &inflight{ - done: make(chan struct{}), - createdAt: time.Now(), - } - d.flights[key] = f - return f, true // leader -} - -// Complete signals completion of a deduplicated request. -func (d *Deduplicator) Complete(key string, result []byte, statusCode int) { - d.mu.Lock() - f, ok := d.flights[key] - delete(d.flights, key) - d.mu.Unlock() - - if ok { - f.result = result - f.statusCode = statusCode - close(f.done) - } -} - -// Close stops the background cleanup goroutine. -func (d *Deduplicator) Close() { - close(d.done) -} - -// cleanup periodically removes stale in-flight entries. -func (d *Deduplicator) cleanup() { - ticker := time.NewTicker(d.window) - defer ticker.Stop() - - for { - select { - case <-d.done: - return - case <-ticker.C: - d.mu.Lock() - now := time.Now() - for key, f := range d.flights { - if now.Sub(f.createdAt) > d.window*2 { - delete(d.flights, key) - close(f.done) // unblock any waiting followers - } - } - d.mu.Unlock() - } - } -} diff --git a/llm-gateway/internal/proxy/dedup_test.go b/llm-gateway/internal/proxy/dedup_test.go deleted file mode 100644 index 900a53d..0000000 --- a/llm-gateway/internal/proxy/dedup_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package proxy - -import ( - "sync" - "testing" - "time" -) - -func TestDedupKey(t *testing.T) { - k1 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`)) - k2 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`)) - k3 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hello"}]}`)) - - if k1 != k2 { - t.Error("identical requests should produce the same key") - } - if k1 == k3 { - t.Error("different requests should produce different keys") - } -} - -func TestDeduplicator_LeaderFollower(t *testing.T) { - d := NewDeduplicator(5 * time.Second) - defer d.Close() - - key := DedupKey("gpt-4", []byte(`test`)) - - // First call is leader - f1, isLeader := d.TryJoin(key) - if !isLeader { - t.Fatal("first caller should be leader") - } - - // Second call with same key is follower - f2, isLeader := d.TryJoin(key) - if isLeader { - t.Fatal("second caller should be follower") - } - if f1 != f2 { - t.Fatal("follower should get same inflight entry") - } - - // Complete the request - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - <-f2.done - if string(f2.result) != "response" { - t.Error("follower should receive leader's result") - } - if f2.statusCode != 200 { - t.Error("follower should receive leader's status code") - } - }() - - d.Complete(key, []byte("response"), 200) - wg.Wait() -} - -func TestDeduplicator_DifferentKeys(t *testing.T) { - d := NewDeduplicator(5 * time.Second) - defer d.Close() - - _, isLeader1 := d.TryJoin("key1") - _, isLeader2 := d.TryJoin("key2") - - if !isLeader1 || !isLeader2 { - t.Error("different keys should both be leaders") - } - - d.Complete("key1", []byte("r1"), 200) - d.Complete("key2", []byte("r2"), 200) -} diff --git a/llm-gateway/internal/proxy/handler.go b/llm-gateway/internal/proxy/handler.go deleted file mode 100644 index 1b24cc2..0000000 --- a/llm-gateway/internal/proxy/handler.go +++ /dev/null @@ -1,514 +0,0 @@ -package proxy - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log" - "net/http" - "sort" - "strings" - "time" - - "github.com/go-chi/chi/v5/middleware" - - "llm-gateway/internal/auth" - "llm-gateway/internal/cache" - "llm-gateway/internal/config" - "llm-gateway/internal/metrics" - "llm-gateway/internal/provider" - "llm-gateway/internal/storage" -) - -type contextKey string - -const tokenNameKey contextKey = "token_name" -const apiTokenKey contextKey = "api_token" - -func withTokenName(ctx context.Context, name string) context.Context { - return context.WithValue(ctx, tokenNameKey, name) -} - -func getTokenName(ctx context.Context) string { - name, _ := ctx.Value(tokenNameKey).(string) - return name -} - -func withAPIToken(ctx context.Context, token *auth.APIToken) context.Context { - return context.WithValue(ctx, apiTokenKey, token) -} - -func getAPIToken(ctx context.Context) *auth.APIToken { - t, _ := ctx.Value(apiTokenKey).(*auth.APIToken) - return t -} - -type Handler struct { - registry *provider.Registry - logger *storage.AsyncLogger - cache *cache.Cache - metrics *metrics.Metrics - cfg *config.Config - healthTracker *provider.HealthTracker - debugLogger *storage.DebugLogger - dedup *Deduplicator -} - -func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler { - return &Handler{ - registry: registry, - logger: logger, - cache: c, - metrics: m, - cfg: cfg, - healthTracker: ht, - } -} - -func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) { - h.debugLogger = dl -} - -func (h *Handler) SetDeduplicator(d *Deduplicator) { - h.dedup = d -} - -func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20)) - if err != nil { - writeError(w, http.StatusBadRequest, "failed to read request body") - return - } - - var req provider.ChatRequest - if err := json.Unmarshal(body, &req); err != nil { - writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error()) - return - } - - if req.Model == "" { - writeError(w, http.StatusBadRequest, "model is required") - return - } - - routes, ok := h.registry.Lookup(req.Model) - if !ok { - writeError(w, http.StatusNotFound, "model not found: "+req.Model) - return - } - - // Filter healthy routes (circuit breaker) - routes = h.filterHealthyRoutes(routes) - - tokenName := getTokenName(r.Context()) - requestID := middleware.GetReqID(r.Context()) - - // Check cache for non-streaming requests - if !req.Stream && h.cache != nil { - if cached, err := h.cache.Get(r.Context(), req.Model, body); err == nil && cached != nil { - h.logRequest(requestID, tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true) - if h.metrics != nil { - h.metrics.RecordCacheHit() - } - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Cache", "HIT") - w.Header().Set("X-Request-ID", requestID) - w.Write(cached) - return - } - if h.metrics != nil { - h.metrics.RecordCacheMiss() - } - } - - // Apply per-model timeout for non-streaming requests - modelTimeouts := h.registry.ModelTimeoutsFor(req.Model) - - if req.Stream { - h.handleStream(w, r, &req, routes, tokenName, requestID, modelTimeouts) - return - } - - // Request deduplication for non-streaming requests - if h.dedup != nil { - dedupKey := DedupKey(req.Model, body) - flight, isLeader := h.dedup.TryJoin(dedupKey) - if !isLeader { - // Wait for the leader to complete - select { - case <-flight.done: - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Request-ID", requestID) - w.Header().Set("X-Dedup", "HIT") - w.WriteHeader(flight.statusCode) - w.Write(flight.result) - return - case <-r.Context().Done(): - writeError(w, http.StatusGatewayTimeout, "request cancelled while waiting for dedup") - return - } - } - // Leader: proceed normally, but capture response for followers - defer func() { - // If we haven't completed yet (e.g., panic), clean up - }() - h.handleNonStreamDedup(w, r, &req, routes, tokenName, body, requestID, modelTimeouts, dedupKey) - return - } - - if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 { - ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout) - defer cancel() - r = r.WithContext(ctx) - } - - h.handleNonStream(w, r, &req, routes, tokenName, body, requestID) -} - -func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string) { - var lastErr error - - for i, route := range routes { - // Retry backoff between attempts (not before first attempt) - if i > 0 { - backoff := backoffDuration(i, h.cfg.Retry) - select { - case <-time.After(backoff): - case <-r.Context().Done(): - writeError(w, http.StatusGatewayTimeout, "request cancelled") - return - } - } - - start := time.Now() - resp, err := route.Provider.ChatCompletion(r.Context(), route.ProviderModel, req) - latency := time.Since(start).Milliseconds() - - if err != nil { - var pe *provider.ProviderError - if errors.As(err, &pe) && !pe.IsRetryable() { - h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) - h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) - if h.healthTracker != nil { - h.healthTracker.Record(route.Provider.Name(), latency, err) - } - w.Header().Set("X-Request-ID", requestID) - writeErrorRaw(w, pe.StatusCode, pe.Body) - return - } - lastErr = err - log.Printf("Provider %s failed for %s: %v", route.Provider.Name(), req.Model, err) - h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) - h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) - if h.healthTracker != nil { - h.healthTracker.Record(route.Provider.Name(), latency, err) - } - continue - } - - if h.healthTracker != nil { - h.healthTracker.Record(route.Provider.Name(), latency, nil) - } - - inputTokens, outputTokens := 0, 0 - if resp.Usage != nil { - inputTokens = resp.Usage.PromptTokens - outputTokens = resp.Usage.CompletionTokens - } - cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice) - - h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost) - h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false) - - resp.Model = req.Model - - respBytes, err := json.Marshal(resp) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to marshal response") - return - } - - if h.cache != nil { - h.cache.Set(r.Context(), req.Model, rawBody, respBytes) - } - - // Debug logging - if h.debugLogger != nil && h.debugLogger.IsEnabled() { - reqBody := string(rawBody) - respBody := string(respBytes) - if h.cfg.Debug.MaxBodyBytes > 0 { - if len(reqBody) > h.cfg.Debug.MaxBodyBytes { - reqBody = reqBody[:h.cfg.Debug.MaxBodyBytes] - } - if len(respBody) > h.cfg.Debug.MaxBodyBytes { - respBody = respBody[:h.cfg.Debug.MaxBodyBytes] - } - } - h.debugLogger.Log(storage.DebugLogEntry{ - RequestID: requestID, - TokenName: tokenName, - Model: req.Model, - Provider: route.Provider.Name(), - RequestBody: reqBody, - ResponseBody: respBody, - RequestHeaders: formatHeaders(r.Header), - ResponseStatus: http.StatusOK, - }) - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Cache", "MISS") - w.Header().Set("X-Request-ID", requestID) - w.Write(respBytes) - return - } - - if lastErr != nil { - w.Header().Set("X-Request-ID", requestID) - writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error()) - } else { - w.Header().Set("X-Request-ID", requestID) - writeError(w, http.StatusBadGateway, "all providers failed") - } -} - -// handleNonStreamDedup wraps handleNonStream to capture the response for dedup followers. -func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20)) - if err != nil { - writeError(w, http.StatusBadRequest, "failed to read request body") - return - } - - var req provider.EmbeddingRequest - if err := json.Unmarshal(body, &req); err != nil { - writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error()) - return - } - - if req.Model == "" { - writeError(w, http.StatusBadRequest, "model is required") - return - } - - routes, ok := h.registry.Lookup(req.Model) - if !ok { - writeError(w, http.StatusNotFound, "model not found: "+req.Model) - return - } - - routes = h.filterHealthyRoutes(routes) - tokenName := getTokenName(r.Context()) - requestID := middleware.GetReqID(r.Context()) - - var lastErr error - for i, route := range routes { - if i > 0 { - backoff := backoffDuration(i, h.cfg.Retry) - select { - case <-time.After(backoff): - case <-r.Context().Done(): - writeError(w, http.StatusGatewayTimeout, "request cancelled") - return - } - } - - start := time.Now() - resp, err := route.Provider.Embedding(r.Context(), route.ProviderModel, &req) - latency := time.Since(start).Milliseconds() - - if err != nil { - var pe *provider.ProviderError - if errors.As(err, &pe) && !pe.IsRetryable() { - h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) - h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error()) - if h.healthTracker != nil { - h.healthTracker.Record(route.Provider.Name(), latency, err) - } - w.Header().Set("X-Request-ID", requestID) - writeErrorRaw(w, pe.StatusCode, pe.Body) - return - } - lastErr = err - log.Printf("Provider %s embedding failed for %s: %v", route.Provider.Name(), req.Model, err) - h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error()) - if h.healthTracker != nil { - h.healthTracker.Record(route.Provider.Name(), latency, err) - } - continue - } - - if h.healthTracker != nil { - h.healthTracker.Record(route.Provider.Name(), latency, nil) - } - - promptTokens := 0 - if resp.Usage != nil { - promptTokens = resp.Usage.PromptTokens - } - cost := float64(promptTokens) / 1_000_000.0 * route.InputPrice - - h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, promptTokens, 0, cost) - h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, promptTokens, cost, latency, "success", "") - - resp.Model = req.Model - - respBytes, err := json.Marshal(resp) - if err != nil { - writeError(w, http.StatusInternalServerError, "failed to marshal response") - return - } - - w.Header().Set("Content-Type", "application/json") - w.Header().Set("X-Request-ID", requestID) - w.Write(respBytes) - return - } - - w.Header().Set("X-Request-ID", requestID) - if lastErr != nil { - writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error()) - } else { - writeError(w, http.StatusBadGateway, "all providers failed") - } -} - -func (h *Handler) logEmbeddingRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens int, cost float64, latencyMS int64, status, errMsg string) { - h.logger.Log(storage.RequestLog{ - RequestID: requestID, - Timestamp: time.Now().Unix(), - TokenName: tokenName, - Model: model, - Provider: providerName, - ProviderModel: providerModel, - InputTokens: inputTokens, - CostUSD: cost, - LatencyMS: latencyMS, - Status: status, - ErrorMessage: errMsg, - RequestType: "embedding", - }) -} - -func (h *Handler) handleNonStreamDedup(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string, modelTimeouts *provider.ModelTimeouts, dedupKey string) { - if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 { - ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout) - defer cancel() - r = r.WithContext(ctx) - } - - rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK} - h.handleNonStream(rec, r, req, routes, tokenName, rawBody, requestID) - h.dedup.Complete(dedupKey, rec.body, rec.statusCode) -} - -// responseRecorder captures the response for dedup. -type responseRecorder struct { - http.ResponseWriter - statusCode int - body []byte -} - -func (r *responseRecorder) WriteHeader(code int) { - r.statusCode = code - r.ResponseWriter.WriteHeader(code) -} - -func (r *responseRecorder) Write(b []byte) (int, error) { - r.body = append(r.body, b...) - return r.ResponseWriter.Write(b) -} - -// filterHealthyRoutes removes providers with open circuit breakers. -// If all are filtered out, returns original routes as fallback. -func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route { - if h.healthTracker == nil { - return routes - } - var healthy []provider.Route - for _, r := range routes { - if h.healthTracker.IsAvailable(r.Provider.Name()) { - healthy = append(healthy, r) - } - } - if len(healthy) == 0 { - return routes // all-down fallback - } - return healthy -} - -// backoffDuration computes exponential backoff for the given attempt. -func backoffDuration(attempt int, cfg config.RetryConfig) time.Duration { - d := cfg.InitialBackoff - for i := 1; i < attempt; i++ { - d = time.Duration(float64(d) * cfg.Multiplier) - if d > cfg.MaxBackoff { - d = cfg.MaxBackoff - break - } - } - return d -} - -func (h *Handler) logRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) { - h.logger.Log(storage.RequestLog{ - RequestID: requestID, - Timestamp: time.Now().Unix(), - TokenName: tokenName, - Model: model, - Provider: providerName, - ProviderModel: providerModel, - InputTokens: inputTokens, - OutputTokens: outputTokens, - CostUSD: cost, - LatencyMS: latencyMS, - Status: status, - ErrorMessage: errMsg, - Streaming: streaming, - Cached: cached, - }) -} - -func computeCost(inputTokens, outputTokens int, inputPrice, outputPrice float64) float64 { - return (float64(inputTokens) / 1_000_000.0 * inputPrice) + (float64(outputTokens) / 1_000_000.0 * outputPrice) -} - -func writeError(w http.ResponseWriter, code int, msg string) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "message": msg, - "type": "error", - "code": code, - }, - }) -} - -func writeErrorRaw(w http.ResponseWriter, code int, body string) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(code) - w.Write([]byte(body)) -} - -// formatHeaders serializes HTTP headers to a readable string, sorted by key. -// Sensitive headers (Authorization) are redacted. -func formatHeaders(h http.Header) string { - keys := make([]string, 0, len(h)) - for k := range h { - keys = append(keys, k) - } - sort.Strings(keys) - - var b strings.Builder - for _, k := range keys { - val := strings.Join(h[k], ", ") - if strings.EqualFold(k, "Authorization") { - val = "[REDACTED]" - } - fmt.Fprintf(&b, "%s: %s\n", k, val) - } - return b.String() -} diff --git a/llm-gateway/internal/proxy/models.go b/llm-gateway/internal/proxy/models.go deleted file mode 100644 index 184fbcc..0000000 --- a/llm-gateway/internal/proxy/models.go +++ /dev/null @@ -1,75 +0,0 @@ -package proxy - -import ( - "encoding/json" - "net/http" - "time" - - "llm-gateway/internal/config" - "llm-gateway/internal/provider" -) - -type ModelsHandler struct { - registry *provider.Registry - healthTracker *provider.HealthTracker - cfg *config.Config -} - -func NewModelsHandler(registry *provider.Registry, healthTracker *provider.HealthTracker, cfg *config.Config) *ModelsHandler { - return &ModelsHandler{ - registry: registry, - healthTracker: healthTracker, - cfg: cfg, - } -} - -func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) { - allRoutes := h.registry.AllRoutes() - models := make([]map[string]any, 0, len(allRoutes)) - - for _, m := range allRoutes { - providers := make([]map[string]any, 0, len(m.Routes)) - for _, rt := range m.Routes { - healthy := true - if h.healthTracker != nil { - healthy = h.healthTracker.IsAvailable(rt.ProviderName) - } - providers = append(providers, map[string]any{ - "name": rt.ProviderName, - "model": rt.ProviderModel, - "input_price": rt.InputPrice, - "output_price": rt.OutputPrice, - "priority": rt.Priority, - "healthy": healthy, - }) - } - - // Find load balancing strategy from config - loadBalancing := "first" - for _, mc := range h.cfg.Models { - if mc.Name == m.Name { - if mc.LoadBalancing != "" { - loadBalancing = mc.LoadBalancing - } - break - } - } - - models = append(models, map[string]any{ - "id": m.Name, - "object": "model", - "created": time.Now().Unix(), - "owned_by": "llm-gateway", - "providers": providers, - "provider_count": len(providers), - "load_balancing": loadBalancing, - "aliases": m.Aliases, - }) - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "object": "list", - "data": models, - }) -} diff --git a/llm-gateway/internal/proxy/ratelimit.go b/llm-gateway/internal/proxy/ratelimit.go deleted file mode 100644 index bc06174..0000000 --- a/llm-gateway/internal/proxy/ratelimit.go +++ /dev/null @@ -1,169 +0,0 @@ -package proxy - -import ( - "fmt" - "math" - "net/http" - "sync" - "time" - - "llm-gateway/internal/storage" - "llm-gateway/internal/webhook" -) - -type RateLimiter struct { - db *storage.DB - mu sync.Mutex - buckets map[string]*tokenBucket - notifier *webhook.Notifier - budgetNotified sync.Map // tracks which token+budget combos have been notified -} - -type tokenBucket struct { - tokens float64 - maxTokens float64 - refillRate float64 // tokens per second - lastRefill time.Time -} - -func NewRateLimiter(db *storage.DB) *RateLimiter { - return &RateLimiter{ - db: db, - buckets: make(map[string]*tokenBucket), - } -} - -// SetNotifier sets the webhook notifier for budget threshold alerts. -func (rl *RateLimiter) SetNotifier(n *webhook.Notifier) { - rl.notifier = n -} - -func (rl *RateLimiter) Check(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - apiToken := getAPIToken(r.Context()) - if apiToken == nil { - next.ServeHTTP(w, r) - return - } - - tokenName := apiToken.Name - - // Check rate limit - if apiToken.RateLimitRPM > 0 { - allowed, remaining, resetAt := rl.allow(tokenName, apiToken.RateLimitRPM) - - // Set rate limit headers on all responses - w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", apiToken.RateLimitRPM)) - w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining)) - w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetAt)) - - if !allowed { - retryAfter := resetAt - time.Now().Unix() - if retryAfter < 1 { - retryAfter = 1 - } - w.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter)) - writeError(w, http.StatusTooManyRequests, "rate limit exceeded") - return - } - } - - // Check daily budget - if apiToken.DailyBudgetUSD > 0 { - spent, err := rl.db.TodaySpend(tokenName) - if err == nil { - if spent >= apiToken.DailyBudgetUSD { - writeError(w, http.StatusTooManyRequests, "daily budget exceeded") - return - } - rl.checkBudgetThreshold(tokenName, "daily", spent, apiToken.DailyBudgetUSD) - } - } - - // Check monthly budget - if apiToken.MonthlyBudgetUSD > 0 { - spent, err := rl.db.MonthSpend(tokenName) - if err == nil { - if spent >= apiToken.MonthlyBudgetUSD { - writeError(w, http.StatusTooManyRequests, "monthly budget exceeded") - return - } - rl.checkBudgetThreshold(tokenName, "monthly", spent, apiToken.MonthlyBudgetUSD) - } - } - - next.ServeHTTP(w, r) - }) -} - -// checkBudgetThreshold fires a webhook notification when spend reaches 80% of budget. -func (rl *RateLimiter) checkBudgetThreshold(tokenName, budgetType string, spent, budget float64) { - if rl.notifier == nil || budget <= 0 { - return - } - if spent/budget < 0.8 { - return - } - key := tokenName + ":" + budgetType - if _, loaded := rl.budgetNotified.LoadOrStore(key, true); loaded { - return // already notified - } - rl.notifier.Notify(webhook.Event{ - Type: webhook.EventBudgetThreshold, - Data: map[string]any{ - "token": tokenName, - "budget_type": budgetType, - "spent": spent, - "budget": budget, - "percent": spent / budget * 100, - }, - }) -} - -func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) { - rl.mu.Lock() - defer rl.mu.Unlock() - - bucket, ok := rl.buckets[tokenName] - if !ok { - bucket = &tokenBucket{ - tokens: float64(rateLimitRPM), - maxTokens: float64(rateLimitRPM), - refillRate: float64(rateLimitRPM) / 60.0, - lastRefill: time.Now(), - } - rl.buckets[tokenName] = bucket - } - - now := time.Now() - elapsed := now.Sub(bucket.lastRefill).Seconds() - bucket.tokens += elapsed * bucket.refillRate - if bucket.tokens > bucket.maxTokens { - bucket.tokens = bucket.maxTokens - } - bucket.lastRefill = now - - remaining := int(math.Floor(bucket.tokens)) - if remaining < 0 { - remaining = 0 - } - - // Compute reset time: when bucket would be full again - deficit := bucket.maxTokens - bucket.tokens - var resetAt int64 - if deficit > 0 && bucket.refillRate > 0 { - resetAt = now.Add(time.Duration(deficit/bucket.refillRate) * time.Second).Unix() - } else { - resetAt = now.Unix() - } - - if bucket.tokens < 1 { - return false, 0, resetAt - } - bucket.tokens-- - remaining = int(math.Floor(bucket.tokens)) - if remaining < 0 { - remaining = 0 - } - return true, remaining, resetAt -} diff --git a/llm-gateway/internal/proxy/ratelimit_test.go b/llm-gateway/internal/proxy/ratelimit_test.go deleted file mode 100644 index 20cb311..0000000 --- a/llm-gateway/internal/proxy/ratelimit_test.go +++ /dev/null @@ -1,374 +0,0 @@ -package proxy - -import ( - "context" - "database/sql" - "net/http" - "net/http/httptest" - "strconv" - "testing" - "time" - - _ "modernc.org/sqlite" - - "llm-gateway/internal/auth" - "llm-gateway/internal/storage" -) - -// newTestDB creates an in-memory SQLite database wrapped in storage.DB. -// It creates the request_logs table needed by TodaySpend. -func newTestDB(t *testing.T) *storage.DB { - t.Helper() - sqlDB, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("opening in-memory sqlite: %v", err) - } - t.Cleanup(func() { sqlDB.Close() }) - - // Create the minimal table needed for TodaySpend queries. - _, err = sqlDB.Exec(`CREATE TABLE IF NOT EXISTS request_logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - token_name TEXT, - cost_usd REAL, - timestamp INTEGER - )`) - if err != nil { - t.Fatalf("creating request_logs table: %v", err) - } - return &storage.DB{DB: sqlDB} -} - -// okHandler is a simple handler that writes 200 OK. -var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) -}) - -func TestRateLimiter_Allow(t *testing.T) { - tests := []struct { - name string - rateLimitRPM int - numRequests int - wantAllowed int - wantDenied int - }{ - { - name: "allows requests within limit", - rateLimitRPM: 10, - numRequests: 5, - wantAllowed: 5, - wantDenied: 0, - }, - { - name: "denies requests over limit", - rateLimitRPM: 3, - numRequests: 6, - wantAllowed: 3, - wantDenied: 3, - }, - { - name: "allows exactly up to limit", - rateLimitRPM: 5, - numRequests: 5, - wantAllowed: 5, - wantDenied: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db := newTestDB(t) - rl := NewRateLimiter(db) - - allowed := 0 - denied := 0 - for i := 0; i < tt.numRequests; i++ { - ok, _, _ := rl.allow("test-token", tt.rateLimitRPM) - if ok { - allowed++ - } else { - denied++ - } - } - - if allowed != tt.wantAllowed { - t.Errorf("allowed = %d, want %d", allowed, tt.wantAllowed) - } - if denied != tt.wantDenied { - t.Errorf("denied = %d, want %d", denied, tt.wantDenied) - } - }) - } -} - -func TestRateLimiter_TokenRefillsOverTime(t *testing.T) { - db := newTestDB(t) - rl := NewRateLimiter(db) - - rpm := 60 // 1 token per second refill rate - - // Exhaust all tokens. - for i := 0; i < rpm; i++ { - ok, _, _ := rl.allow("refill-token", rpm) - if !ok { - t.Fatalf("request %d should have been allowed", i) - } - } - - // Next request should be denied. - ok, _, _ := rl.allow("refill-token", rpm) - if ok { - t.Fatal("request should have been denied after exhausting tokens") - } - - // Manually advance the bucket's lastRefill to simulate time passing. - rl.mu.Lock() - bucket := rl.buckets["refill-token"] - bucket.lastRefill = bucket.lastRefill.Add(-2 * time.Second) - rl.mu.Unlock() - - // After 2 seconds at 1 token/sec, we should have ~2 tokens refilled. - ok, remaining, _ := rl.allow("refill-token", rpm) - if !ok { - t.Fatal("request should have been allowed after token refill") - } - // We consumed 1 of the ~2 refilled tokens, so remaining should be >= 0. - if remaining < 0 { - t.Errorf("remaining = %d, want >= 0", remaining) - } -} - -func TestRateLimiter_AllowReturnValues(t *testing.T) { - tests := []struct { - name string - rateLimitRPM int - numRequests int - wantLastAllowed bool - wantLastRemaining int - }{ - { - name: "remaining decrements correctly", - rateLimitRPM: 5, - numRequests: 1, - wantLastAllowed: true, - wantLastRemaining: 4, - }, - { - name: "remaining is zero at limit", - rateLimitRPM: 3, - numRequests: 3, - wantLastAllowed: true, - wantLastRemaining: 0, - }, - { - name: "denied returns zero remaining", - rateLimitRPM: 2, - numRequests: 3, - wantLastAllowed: false, - wantLastRemaining: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db := newTestDB(t) - rl := NewRateLimiter(db) - - var allowed bool - var remaining int - for i := 0; i < tt.numRequests; i++ { - allowed, remaining, _ = rl.allow("test-token", tt.rateLimitRPM) - } - - if allowed != tt.wantLastAllowed { - t.Errorf("allowed = %v, want %v", allowed, tt.wantLastAllowed) - } - if remaining != tt.wantLastRemaining { - t.Errorf("remaining = %d, want %d", remaining, tt.wantLastRemaining) - } - }) - } -} - -func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) { - tests := []struct { - name string - rateLimitRPM int - numRequests int - wantStatusCode int - wantLimitHeader string - wantRetryAfter bool - }{ - { - name: "sets rate limit headers on allowed request", - rateLimitRPM: 10, - numRequests: 1, - wantStatusCode: http.StatusOK, - wantLimitHeader: "10", - wantRetryAfter: false, - }, - { - name: "sets Retry-After header on 429", - rateLimitRPM: 2, - numRequests: 3, - wantStatusCode: http.StatusTooManyRequests, - wantLimitHeader: "2", - wantRetryAfter: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db := newTestDB(t) - rl := NewRateLimiter(db) - - token := &auth.APIToken{ - Name: "header-test-token", - RateLimitRPM: tt.rateLimitRPM, - } - - handler := rl.Check(okHandler) - - var rec *httptest.ResponseRecorder - for i := 0; i < tt.numRequests; i++ { - rec = httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx := withAPIToken(req.Context(), token) - req = req.WithContext(ctx) - handler.ServeHTTP(rec, req) - } - - // Check the last response. - if rec.Code != tt.wantStatusCode { - t.Errorf("status code = %d, want %d", rec.Code, tt.wantStatusCode) - } - - // X-RateLimit-Limit header. - limitHeader := rec.Header().Get("X-RateLimit-Limit") - if limitHeader != tt.wantLimitHeader { - t.Errorf("X-RateLimit-Limit = %q, want %q", limitHeader, tt.wantLimitHeader) - } - - // X-RateLimit-Remaining header must be present and numeric. - remainingHeader := rec.Header().Get("X-RateLimit-Remaining") - if remainingHeader == "" { - t.Error("X-RateLimit-Remaining header is missing") - } else if _, err := strconv.Atoi(remainingHeader); err != nil { - t.Errorf("X-RateLimit-Remaining = %q, not a valid integer", remainingHeader) - } - - // X-RateLimit-Reset header must be present and numeric. - resetHeader := rec.Header().Get("X-RateLimit-Reset") - if resetHeader == "" { - t.Error("X-RateLimit-Reset header is missing") - } else if _, err := strconv.ParseInt(resetHeader, 10, 64); err != nil { - t.Errorf("X-RateLimit-Reset = %q, not a valid integer", resetHeader) - } - - // Retry-After header. - retryAfter := rec.Header().Get("Retry-After") - if tt.wantRetryAfter && retryAfter == "" { - t.Error("Retry-After header is missing on 429 response") - } - if !tt.wantRetryAfter && retryAfter != "" { - t.Errorf("Retry-After header should not be present, got %q", retryAfter) - } - }) - } -} - -func TestRateLimiter_CheckMiddleware_NoToken(t *testing.T) { - db := newTestDB(t) - rl := NewRateLimiter(db) - - handler := rl.Check(okHandler) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - // No API token in context. - handler.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Errorf("status code = %d, want %d (should pass through without token)", rec.Code, http.StatusOK) - } -} - -func TestRateLimiter_CheckMiddleware_ZeroRPM(t *testing.T) { - db := newTestDB(t) - rl := NewRateLimiter(db) - - token := &auth.APIToken{ - Name: "unlimited-token", - RateLimitRPM: 0, // zero means unlimited - } - - handler := rl.Check(okHandler) - - for i := 0; i < 100; i++ { - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx := withAPIToken(req.Context(), token) - req = req.WithContext(ctx) - handler.ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("request %d: status code = %d, want %d (zero RPM should be unlimited)", i, rec.Code, http.StatusOK) - } - } -} - -func TestRateLimiter_PerTokenIsolation(t *testing.T) { - db := newTestDB(t) - rl := NewRateLimiter(db) - - rpm := 2 - - // Exhaust token A. - for i := 0; i < rpm; i++ { - rl.allow("token-a", rpm) - } - ok, _, _ := rl.allow("token-a", rpm) - if ok { - t.Fatal("token-a should be rate limited") - } - - // Token B should still have its own bucket. - ok, _, _ = rl.allow("token-b", rpm) - if !ok { - t.Fatal("token-b should not be affected by token-a's rate limit") - } -} - -func TestRateLimiter_ResetAtIsFuture(t *testing.T) { - db := newTestDB(t) - rl := NewRateLimiter(db) - - // Consume one token so there's a deficit. - _, _, resetAt := rl.allow("reset-token", 10) - now := time.Now().Unix() - - if resetAt < now { - t.Errorf("resetAt = %d, want >= %d (should be now or in the future)", resetAt, now) - } -} - -func TestRateLimiter_CheckMiddleware_ContextCancelled(t *testing.T) { - db := newTestDB(t) - rl := NewRateLimiter(db) - - token := &auth.APIToken{ - Name: "ctx-token", - RateLimitRPM: 10, - } - - handler := rl.Check(okHandler) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx, cancel := context.WithCancel(req.Context()) - ctx = withAPIToken(ctx, token) - cancel() // Cancel immediately. - req = req.WithContext(ctx) - - // Should still process (rate limiter does not check context cancellation). - handler.ServeHTTP(rec, req) - // The handler itself may or may not respect cancelled context; - // the key point is no panic occurs. -} diff --git a/llm-gateway/internal/proxy/stream.go b/llm-gateway/internal/proxy/stream.go deleted file mode 100644 index 3b5181d..0000000 --- a/llm-gateway/internal/proxy/stream.go +++ /dev/null @@ -1,195 +0,0 @@ -package proxy - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "log" - "net/http" - "strings" - "time" - - "llm-gateway/internal/provider" - "llm-gateway/internal/storage" -) - -func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string, modelTimeouts *provider.ModelTimeouts) { - flusher, ok := w.(http.Flusher) - if !ok { - writeError(w, http.StatusInternalServerError, "streaming not supported") - return - } - - var lastErr error - - for i, route := range routes { - // Retry backoff between attempts - if i > 0 { - backoff := backoffDuration(i, h.cfg.Retry) - select { - case <-time.After(backoff): - case <-r.Context().Done(): - writeError(w, http.StatusGatewayTimeout, "request cancelled") - return - } - } - - start := time.Now() - body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req) - - if err != nil { - var pe *provider.ProviderError - if errors.As(err, &pe) && !pe.IsRetryable() { - latency := time.Since(start).Milliseconds() - h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) - h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false) - if h.healthTracker != nil { - h.healthTracker.Record(route.Provider.Name(), latency, err) - } - w.Header().Set("X-Request-ID", requestID) - writeErrorRaw(w, pe.StatusCode, pe.Body) - return - } - lastErr = err - latency := time.Since(start).Milliseconds() - log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err) - h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false) - if h.healthTracker != nil { - h.healthTracker.Record(route.Provider.Name(), latency, err) - } - continue - } - - // Apply streaming timeout (per-model override takes precedence) - streamingTimeout := h.cfg.Server.StreamingTimeout - if modelTimeouts != nil && modelTimeouts.StreamingTimeout > 0 { - streamingTimeout = modelTimeouts.StreamingTimeout - } - var streamCtx context.Context - var streamCancel context.CancelFunc - if streamingTimeout > 0 { - streamCtx, streamCancel = context.WithTimeout(r.Context(), streamingTimeout) - } else { - streamCtx, streamCancel = context.WithCancel(r.Context()) - } - - // Stream the response - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - w.Header().Set("X-Request-ID", requestID) - w.WriteHeader(http.StatusOK) - - inputTokens, outputTokens := 0, 0 - scanner := bufio.NewScanner(body) - scanner.Buffer(make([]byte, 64*1024), 256*1024) - - // Capture streamed lines for debug logging - debugEnabled := h.debugLogger != nil && h.debugLogger.IsEnabled() - var debugLines []string - - scanDone := make(chan struct{}) - go func() { - defer close(scanDone) - for scanner.Scan() { - select { - case <-streamCtx.Done(): - return - default: - } - - line := scanner.Text() - - if strings.HasPrefix(line, "data: ") { - data := strings.TrimPrefix(line, "data: ") - if data != "[DONE]" { - var chunk streamChunk - if json.Unmarshal([]byte(data), &chunk) == nil { - if chunk.Usage != nil { - inputTokens = chunk.Usage.PromptTokens - outputTokens = chunk.Usage.CompletionTokens - } - if chunk.Model != "" { - chunk.Model = req.Model - if rewritten, err := json.Marshal(chunk); err == nil { - line = "data: " + string(rewritten) - } - } - } - } - } - - if debugEnabled { - debugLines = append(debugLines, line) - } - - w.Write([]byte(line + "\n")) - flusher.Flush() - } - }() - - select { - case <-scanDone: - // Normal completion - case <-streamCtx.Done(): - log.Printf("Stream timeout for %s via %s", req.Model, route.Provider.Name()) - } - - body.Close() - streamCancel() - - latency := time.Since(start).Milliseconds() - cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice) - h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost) - h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false) - if h.healthTracker != nil { - h.healthTracker.Record(route.Provider.Name(), latency, nil) - } - - // Debug logging for streaming requests - if debugEnabled && len(debugLines) > 0 { - respBody := strings.Join(debugLines, "\n") - reqBody, _ := json.Marshal(req) - reqBodyStr := string(reqBody) - if h.cfg.Debug.MaxBodyBytes > 0 { - if len(reqBodyStr) > h.cfg.Debug.MaxBodyBytes { - reqBodyStr = reqBodyStr[:h.cfg.Debug.MaxBodyBytes] - } - if len(respBody) > h.cfg.Debug.MaxBodyBytes { - respBody = respBody[:h.cfg.Debug.MaxBodyBytes] - } - } - h.debugLogger.Log(storage.DebugLogEntry{ - RequestID: requestID, - TokenName: tokenName, - Model: req.Model, - Provider: route.Provider.Name(), - RequestBody: reqBodyStr, - ResponseBody: respBody, - RequestHeaders: formatHeaders(r.Header), - ResponseStatus: http.StatusOK, - }) - } - - return - } - - // All providers failed - w.Header().Set("X-Request-ID", requestID) - if lastErr != nil { - writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error()) - } else { - writeError(w, http.StatusBadGateway, "all providers failed") - } -} - -type streamChunk struct { - ID string `json:"id,omitempty"` - Object string `json:"object,omitempty"` - Created int64 `json:"created,omitempty"` - Model string `json:"model,omitempty"` - Choices []any `json:"choices,omitempty"` - Usage *provider.Usage `json:"usage,omitempty"` -} diff --git a/llm-gateway/internal/storage/audit.go b/llm-gateway/internal/storage/audit.go deleted file mode 100644 index 2b41a10..0000000 --- a/llm-gateway/internal/storage/audit.go +++ /dev/null @@ -1,105 +0,0 @@ -package storage - -import ( - "log" - "time" -) - -type AuditEntry struct { - ID int64 `json:"id"` - Timestamp int64 `json:"timestamp"` - UserID int64 `json:"user_id"` - Username string `json:"username"` - Action string `json:"action"` - TargetType string `json:"target_type"` - TargetID string `json:"target_id"` - Details string `json:"details"` - IPAddress string `json:"ip_address"` - RequestID string `json:"request_id"` -} - -type AuditLogger struct { - db *DB - OnWrite func() -} - -func NewAuditLogger(db *DB) *AuditLogger { - return &AuditLogger{db: db} -} - -func (a *AuditLogger) Log(entry AuditEntry) { - if entry.Timestamp == 0 { - entry.Timestamp = time.Now().Unix() - } - _, err := a.db.Exec(`INSERT INTO audit_log - (timestamp, user_id, username, action, target_type, target_id, details, ip_address, request_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, - entry.Timestamp, entry.UserID, entry.Username, entry.Action, - entry.TargetType, entry.TargetID, entry.Details, entry.IPAddress, entry.RequestID, - ) - if err != nil { - log.Printf("ERROR: audit log: %v", err) - } else if a.OnWrite != nil { - a.OnWrite() - } -} - -type AuditQueryResult struct { - Entries []AuditEntry `json:"entries"` - Page int `json:"page"` - TotalPages int `json:"total_pages"` - Total int `json:"total"` -} - -func (a *AuditLogger) Query(since int64, action string, page, limit int) *AuditQueryResult { - if page < 1 { - page = 1 - } - if limit <= 0 { - limit = 50 - } - offset := (page - 1) * limit - - where := "WHERE timestamp >= ?" - args := []any{since} - - if action != "" { - where += " AND action = ?" - args = append(args, action) - } - - var total int - countArgs := make([]any, len(args)) - copy(countArgs, args) - a.db.QueryRow("SELECT COUNT(*) FROM audit_log "+where, countArgs...).Scan(&total) - - totalPages := (total + limit - 1) / limit - if totalPages < 1 { - totalPages = 1 - } - - query := `SELECT id, timestamp, COALESCE(user_id, 0), username, action, - COALESCE(target_type, ''), COALESCE(target_id, ''), COALESCE(details, ''), - COALESCE(ip_address, ''), COALESCE(request_id, '') - FROM audit_log ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?` - args = append(args, limit, offset) - - rows, err := a.db.Query(query, args...) - if err != nil { - return &AuditQueryResult{Entries: []AuditEntry{}, Page: page, TotalPages: totalPages, Total: total} - } - defer rows.Close() - - var entries []AuditEntry - for rows.Next() { - var e AuditEntry - rows.Scan(&e.ID, &e.Timestamp, &e.UserID, &e.Username, &e.Action, - &e.TargetType, &e.TargetID, &e.Details, &e.IPAddress, &e.RequestID) - entries = append(entries, e) - } - if entries == nil { - entries = []AuditEntry{} - } - - return &AuditQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total} -} diff --git a/llm-gateway/internal/storage/db.go b/llm-gateway/internal/storage/db.go deleted file mode 100644 index f7fc480..0000000 --- a/llm-gateway/internal/storage/db.go +++ /dev/null @@ -1,142 +0,0 @@ -package storage - -import ( - "database/sql" - "fmt" - "log" - "path/filepath" - "time" - - "github.com/golang-migrate/migrate/v4" - "github.com/golang-migrate/migrate/v4/database/sqlite" - "github.com/golang-migrate/migrate/v4/source/iofs" - _ "modernc.org/sqlite" - - "llm-gateway/internal/storage/migrations" -) - -type DB struct { - *sql.DB -} - -func Open(path string) (*DB, error) { - dir := filepath.Dir(path) - if dir != "." && dir != "" { - // Ensure directory exists — caller should create it if needed - } - - db, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000&_cache_size=-20000") - if err != nil { - return nil, fmt.Errorf("opening database: %w", err) - } - - // Performance pragmas - for _, pragma := range []string{ - "PRAGMA foreign_keys = ON", - "PRAGMA temp_store = MEMORY", - "PRAGMA mmap_size = 268435456", - } { - if _, err := db.Exec(pragma); err != nil { - return nil, fmt.Errorf("setting pragma %s: %w", pragma, err) - } - } - - db.SetMaxOpenConns(1) // SQLite is single-writer - db.SetMaxIdleConns(1) - - if err := runMigrations(db); err != nil { - return nil, fmt.Errorf("running migrations: %w", err) - } - - return &DB{db}, nil -} - -func runMigrations(db *sql.DB) error { - sourceDriver, err := iofs.New(migrations.FS, ".") - if err != nil { - return fmt.Errorf("creating migration source: %w", err) - } - - dbDriver, err := sqlite.WithInstance(db, &sqlite.Config{}) - if err != nil { - return fmt.Errorf("creating migration db driver: %w", err) - } - - m, err := migrate.NewWithInstance("iofs", sourceDriver, "sqlite", dbDriver) - if err != nil { - return fmt.Errorf("creating migrator: %w", err) - } - - if err := m.Up(); err != nil && err != migrate.ErrNoChange { - return fmt.Errorf("applying migrations: %w", err) - } - - return nil -} - -// CleanupOldRecords deletes records older than retentionDays. -func (db *DB) CleanupOldRecords(retentionDays int) error { - cutoff := time.Now().AddDate(0, 0, -retentionDays).Unix() - result, err := db.Exec("DELETE FROM request_logs WHERE timestamp < ?", cutoff) - if err != nil { - return err - } - affected, _ := result.RowsAffected() - if affected > 0 { - log.Printf("Cleaned up %d old request log records", affected) - } - return nil -} - -// TodaySpend returns the total cost in USD for a given token today. -func (db *DB) TodaySpend(tokenName string) (float64, error) { - startOfDay := time.Now().Truncate(24 * time.Hour).Unix() - var total sql.NullFloat64 - err := db.QueryRow( - "SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?", - tokenName, startOfDay, - ).Scan(&total) - if err != nil { - return 0, err - } - return total.Float64, nil -} - -// MonthSpend returns the total cost in USD for a given token this month. -func (db *DB) MonthSpend(tokenName string) (float64, error) { - now := time.Now() - startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()).Unix() - var total sql.NullFloat64 - err := db.QueryRow( - "SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?", - tokenName, startOfMonth, - ).Scan(&total) - if err != nil { - return 0, err - } - return total.Float64, nil -} - -// TodaySpendAll returns today's spend for all tokens as a map. -func (db *DB) TodaySpendAll() (map[string]float64, error) { - startOfDay := time.Now().Truncate(24 * time.Hour).Unix() - rows, err := db.Query( - "SELECT token_name, SUM(cost_usd) FROM request_logs WHERE timestamp >= ? GROUP BY token_name", - startOfDay, - ) - if err != nil { - return nil, err - } - defer rows.Close() - - result := make(map[string]float64) - for rows.Next() { - var name string - var total float64 - if err := rows.Scan(&name, &total); err != nil { - continue - } - result[name] = total - } - return result, nil -} diff --git a/llm-gateway/internal/storage/debuglog.go b/llm-gateway/internal/storage/debuglog.go deleted file mode 100644 index 7fe5a83..0000000 --- a/llm-gateway/internal/storage/debuglog.go +++ /dev/null @@ -1,253 +0,0 @@ -package storage - -import ( - "encoding/json" - "fmt" - "log" - "os" - "path/filepath" - "sort" - "strings" - "sync/atomic" - "time" -) - -type DebugLogEntry struct { - ID int64 `json:"id"` - RequestID string `json:"request_id"` - Timestamp int64 `json:"timestamp"` - TokenName string `json:"token_name"` - Model string `json:"model"` - Provider string `json:"provider"` - RequestBody string `json:"request_body"` - ResponseBody string `json:"response_body"` - RequestHeaders string `json:"request_headers"` - ResponseStatus int `json:"response_status"` - FilePath string `json:"-"` -} - -// debugFile is the JSON structure written to disk. -type debugFile struct { - RequestHeaders string `json:"request_headers"` - RequestBody string `json:"request_body"` - ResponseBody string `json:"response_body"` -} - -type DebugLogger struct { - db *DB - enabled atomic.Bool - dataDir string - OnWrite func() -} - -func NewDebugLogger(db *DB, enabled bool, dataDir string) *DebugLogger { - dl := &DebugLogger{db: db, dataDir: dataDir} - dl.enabled.Store(enabled) - return dl -} - -func (d *DebugLogger) SetEnabled(v bool) { - d.enabled.Store(v) -} - -func (d *DebugLogger) IsEnabled() bool { - return d.enabled.Load() -} - -// debugLogDir returns the base directory for debug log files. -func (d *DebugLogger) debugLogDir() string { - return filepath.Join(d.dataDir, "debug-logs") -} - -// debugFilePath builds the file path for a debug log entry. -func (d *DebugLogger) debugFilePath(requestID string, ts time.Time) string { - date := ts.Format("2006-01-02") - return filepath.Join(d.debugLogDir(), date, requestID+".json") -} - -func (d *DebugLogger) Log(entry DebugLogEntry) { - if !d.IsEnabled() { - return - } - if entry.Timestamp == 0 { - entry.Timestamp = time.Now().Unix() - } - - ts := time.Unix(entry.Timestamp, 0) - fp := d.debugFilePath(entry.RequestID, ts) - - // Write body file - if err := os.MkdirAll(filepath.Dir(fp), 0755); err != nil { - log.Printf("ERROR: debug log mkdir: %v", err) - return - } - - df := debugFile{ - RequestHeaders: entry.RequestHeaders, - RequestBody: entry.RequestBody, - ResponseBody: entry.ResponseBody, - } - data, err := json.Marshal(df) - if err != nil { - log.Printf("ERROR: debug log marshal: %v", err) - return - } - if err := os.WriteFile(fp, data, 0644); err != nil { - log.Printf("ERROR: debug log write: %v", err) - return - } - - // Insert metadata into DB (no bodies) - _, err = d.db.Exec(`INSERT INTO debug_log - (request_id, timestamp, token_name, model, provider, response_status, file_path) - VALUES (?, ?, ?, ?, ?, ?, ?)`, - entry.RequestID, entry.Timestamp, entry.TokenName, entry.Model, - entry.Provider, entry.ResponseStatus, fp, - ) - if err != nil { - log.Printf("ERROR: debug log db insert: %v", err) - } else if d.OnWrite != nil { - d.OnWrite() - } -} - -type DebugLogQueryResult struct { - Entries []DebugLogEntry `json:"entries"` - Page int `json:"page"` - TotalPages int `json:"total_pages"` - Total int `json:"total"` -} - -// Query returns paginated debug log metadata (no bodies — fast). -func (d *DebugLogger) Query(page, limit int) *DebugLogQueryResult { - if page < 1 { - page = 1 - } - if limit <= 0 { - limit = 50 - } - offset := (page - 1) * limit - - var total int - d.db.QueryRow("SELECT COUNT(*) FROM debug_log").Scan(&total) - - totalPages := (total + limit - 1) / limit - if totalPages < 1 { - totalPages = 1 - } - - rows, err := d.db.Query(`SELECT id, request_id, timestamp, COALESCE(token_name, ''), - COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '') - FROM debug_log ORDER BY timestamp DESC LIMIT ? OFFSET ?`, limit, offset) - if err != nil { - return &DebugLogQueryResult{Entries: []DebugLogEntry{}, Page: page, TotalPages: totalPages, Total: total} - } - defer rows.Close() - - var entries []DebugLogEntry - for rows.Next() { - var e DebugLogEntry - rows.Scan(&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName, - &e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath) - entries = append(entries, e) - } - if entries == nil { - entries = []DebugLogEntry{} - } - - return &DebugLogQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total} -} - -// QueryFull returns paginated debug log entries including request/response bodies read from files. -func (d *DebugLogger) QueryFull(page, limit int) *DebugLogQueryResult { - result := d.Query(page, limit) - for i := range result.Entries { - d.populateFromFile(&result.Entries[i]) - } - return result -} - -// GetByRequestID returns a single debug log entry with bodies read from file. -func (d *DebugLogger) GetByRequestID(requestID string) *DebugLogEntry { - var e DebugLogEntry - err := d.db.QueryRow(`SELECT id, request_id, timestamp, COALESCE(token_name, ''), - COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '') - FROM debug_log WHERE request_id = ?`, requestID).Scan( - &e.ID, &e.RequestID, &e.Timestamp, &e.TokenName, - &e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath) - if err != nil { - return nil - } - d.populateFromFile(&e) - return &e -} - -// populateFromFile reads body data from the debug file on disk. -// Falls back to DB columns for pre-migration entries that have no file_path. -func (d *DebugLogger) populateFromFile(e *DebugLogEntry) { - if e.FilePath == "" { - // Legacy entry: try reading bodies from DB columns - d.db.QueryRow(`SELECT COALESCE(request_body, ''), COALESCE(response_body, ''), COALESCE(request_headers, '') - FROM debug_log WHERE id = ?`, e.ID).Scan(&e.RequestBody, &e.ResponseBody, &e.RequestHeaders) - return - } - data, err := os.ReadFile(e.FilePath) - if err != nil { - log.Printf("WARN: debug log read file %s: %v", e.FilePath, err) - return - } - var df debugFile - if err := json.Unmarshal(data, &df); err != nil { - log.Printf("WARN: debug log parse file %s: %v", e.FilePath, err) - return - } - e.RequestHeaders = df.RequestHeaders - e.RequestBody = df.RequestBody - e.ResponseBody = df.ResponseBody -} - -// Cleanup removes debug log entries and files older than retentionDays. -func (d *DebugLogger) Cleanup(retentionDays int) error { - cutoff := time.Now().AddDate(0, 0, -retentionDays) - cutoffUnix := cutoff.Unix() - - // Delete old DB rows - result, err := d.db.Exec("DELETE FROM debug_log WHERE timestamp < ?", cutoffUnix) - if err != nil { - return fmt.Errorf("delete old debug rows: %w", err) - } - affected, _ := result.RowsAffected() - if affected > 0 { - log.Printf("Cleaned up %d old debug log entries", affected) - } - - // Remove old date directories - baseDir := d.debugLogDir() - dirs, err := os.ReadDir(baseDir) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return fmt.Errorf("read debug log dir: %w", err) - } - - cutoffDate := cutoff.Format("2006-01-02") - sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() }) - - for _, dir := range dirs { - if !dir.IsDir() { - continue - } - // Date directories are named YYYY-MM-DD; string comparison works - if strings.Compare(dir.Name(), cutoffDate) < 0 { - dirPath := filepath.Join(baseDir, dir.Name()) - if err := os.RemoveAll(dirPath); err != nil { - log.Printf("WARN: failed to remove debug log dir %s: %v", dirPath, err) - } else { - log.Printf("Removed old debug log directory: %s", dir.Name()) - } - } - } - - return nil -} diff --git a/llm-gateway/internal/storage/logger.go b/llm-gateway/internal/storage/logger.go deleted file mode 100644 index ce54675..0000000 --- a/llm-gateway/internal/storage/logger.go +++ /dev/null @@ -1,138 +0,0 @@ -package storage - -import ( - "log" - "time" -) - -type RequestLog struct { - RequestID string - Timestamp int64 - TokenName string - Model string - Provider string - ProviderModel string - InputTokens int - OutputTokens int - CostUSD float64 - LatencyMS int64 - Status string // success, error, cached - ErrorMessage string - Streaming bool - Cached bool - RequestType string // "chat" or "embedding" -} - -type AsyncLogger struct { - db *DB - ch chan RequestLog - done chan struct{} - OnFlush func() // called after successful flush, if set -} - -func NewAsyncLogger(db *DB, bufferSize int) *AsyncLogger { - if bufferSize == 0 { - bufferSize = 1000 - } - l := &AsyncLogger{ - db: db, - ch: make(chan RequestLog, bufferSize), - done: make(chan struct{}), - } - go l.run() - return l -} - -func (l *AsyncLogger) Log(r RequestLog) { - select { - case l.ch <- r: - default: - log.Println("WARNING: request log buffer full, dropping entry") - } -} - -func (l *AsyncLogger) Close() { - close(l.ch) - <-l.done -} - -func (l *AsyncLogger) run() { - defer close(l.done) - - batch := make([]RequestLog, 0, 100) - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case r, ok := <-l.ch: - if !ok { - // Channel closed, flush remaining - if len(batch) > 0 { - l.flush(batch) - } - return - } - batch = append(batch, r) - if len(batch) >= 100 { - l.flush(batch) - batch = batch[:0] - } - case <-ticker.C: - if len(batch) > 0 { - l.flush(batch) - batch = batch[:0] - } - } - } -} - -func (l *AsyncLogger) flush(batch []RequestLog) { - tx, err := l.db.Begin() - if err != nil { - log.Printf("ERROR: starting log transaction: %v", err) - return - } - - stmt, err := tx.Prepare(`INSERT INTO request_logs - (request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached, request_type) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) - if err != nil { - log.Printf("ERROR: preparing log statement: %v", err) - tx.Rollback() - return - } - defer stmt.Close() - - for _, r := range batch { - streaming := 0 - if r.Streaming { - streaming = 1 - } - cached := 0 - if r.Cached { - cached = 1 - } - reqType := r.RequestType - if reqType == "" { - reqType = "chat" - } - _, err := stmt.Exec( - r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel, - r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS, - r.Status, r.ErrorMessage, streaming, cached, reqType, - ) - if err != nil { - log.Printf("ERROR: inserting log: %v", err) - } - } - - if err := tx.Commit(); err != nil { - log.Printf("ERROR: committing log batch: %v", err) - return - } - - if l.OnFlush != nil { - l.OnFlush() - } -} diff --git a/llm-gateway/internal/storage/migrations/001_init.down.sql b/llm-gateway/internal/storage/migrations/001_init.down.sql deleted file mode 100644 index bd1fad9..0000000 --- a/llm-gateway/internal/storage/migrations/001_init.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE IF EXISTS request_logs; diff --git a/llm-gateway/internal/storage/migrations/001_init.up.sql b/llm-gateway/internal/storage/migrations/001_init.up.sql deleted file mode 100644 index 610d375..0000000 --- a/llm-gateway/internal/storage/migrations/001_init.up.sql +++ /dev/null @@ -1,20 +0,0 @@ -CREATE TABLE IF NOT EXISTS request_logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp INTEGER NOT NULL, - token_name TEXT NOT NULL, - model TEXT NOT NULL, - provider TEXT NOT NULL, - provider_model TEXT NOT NULL, - input_tokens INTEGER DEFAULT 0, - output_tokens INTEGER DEFAULT 0, - cost_usd REAL DEFAULT 0, - latency_ms INTEGER DEFAULT 0, - status TEXT NOT NULL, - error_message TEXT DEFAULT '', - streaming INTEGER DEFAULT 0, - cached INTEGER DEFAULT 0 -); - -CREATE INDEX IF NOT EXISTS idx_timestamp ON request_logs(timestamp); -CREATE INDEX IF NOT EXISTS idx_token ON request_logs(token_name); -CREATE INDEX IF NOT EXISTS idx_model ON request_logs(model); diff --git a/llm-gateway/internal/storage/migrations/002_users.down.sql b/llm-gateway/internal/storage/migrations/002_users.down.sql deleted file mode 100644 index 12f43a1..0000000 --- a/llm-gateway/internal/storage/migrations/002_users.down.sql +++ /dev/null @@ -1,3 +0,0 @@ -DROP TABLE IF EXISTS api_tokens; -DROP TABLE IF EXISTS sessions; -DROP TABLE IF EXISTS users; diff --git a/llm-gateway/internal/storage/migrations/002_users.up.sql b/llm-gateway/internal/storage/migrations/002_users.up.sql deleted file mode 100644 index fb8eb11..0000000 --- a/llm-gateway/internal/storage/migrations/002_users.up.sql +++ /dev/null @@ -1,33 +0,0 @@ -CREATE TABLE users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT NOT NULL UNIQUE, - password_hash TEXT NOT NULL, - is_admin INTEGER NOT NULL DEFAULT 0, - totp_secret TEXT DEFAULT '', - totp_enabled INTEGER NOT NULL DEFAULT 0, - created_at INTEGER NOT NULL, - updated_at INTEGER NOT NULL -); - -CREATE TABLE sessions ( - id TEXT PRIMARY KEY, - user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - created_at INTEGER NOT NULL, - expires_at INTEGER NOT NULL -); -CREATE INDEX idx_sessions_user ON sessions(user_id); -CREATE INDEX idx_sessions_expires ON sessions(expires_at); - -CREATE TABLE api_tokens ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - key_hash TEXT NOT NULL, - key_prefix TEXT NOT NULL, - user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, - rate_limit_rpm INTEGER DEFAULT 60, - daily_budget_usd REAL DEFAULT 0, - created_at INTEGER NOT NULL, - last_used_at INTEGER DEFAULT 0 -); -CREATE UNIQUE INDEX idx_api_tokens_hash ON api_tokens(key_hash); -CREATE INDEX idx_api_tokens_user ON api_tokens(user_id); diff --git a/llm-gateway/internal/storage/migrations/003_user_email.down.sql b/llm-gateway/internal/storage/migrations/003_user_email.down.sql deleted file mode 100644 index e00308d..0000000 --- a/llm-gateway/internal/storage/migrations/003_user_email.down.sql +++ /dev/null @@ -1,5 +0,0 @@ --- SQLite doesn't support DROP COLUMN before 3.35.0, so we recreate -CREATE TABLE users_backup AS SELECT id, username, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users; -DROP TABLE users; -ALTER TABLE users_backup RENAME TO users; -CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username); diff --git a/llm-gateway/internal/storage/migrations/003_user_email.up.sql b/llm-gateway/internal/storage/migrations/003_user_email.up.sql deleted file mode 100644 index 9468211..0000000 --- a/llm-gateway/internal/storage/migrations/003_user_email.up.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE users ADD COLUMN email TEXT DEFAULT ''; diff --git a/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql b/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql deleted file mode 100644 index e11bb40..0000000 --- a/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql +++ /dev/null @@ -1,4 +0,0 @@ --- SQLite doesn't support DROP COLUMN in older versions, so we recreate the table -CREATE TABLE api_tokens_backup AS SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens; -DROP TABLE api_tokens; -ALTER TABLE api_tokens_backup RENAME TO api_tokens; diff --git a/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql b/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql deleted file mode 100644 index ccf0549..0000000 --- a/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE api_tokens ADD COLUMN max_concurrent INTEGER DEFAULT 0; diff --git a/llm-gateway/internal/storage/migrations/005_request_id.down.sql b/llm-gateway/internal/storage/migrations/005_request_id.down.sql deleted file mode 100644 index 819b90b..0000000 --- a/llm-gateway/internal/storage/migrations/005_request_id.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP INDEX IF EXISTS idx_request_logs_request_id; diff --git a/llm-gateway/internal/storage/migrations/005_request_id.up.sql b/llm-gateway/internal/storage/migrations/005_request_id.up.sql deleted file mode 100644 index ff54384..0000000 --- a/llm-gateway/internal/storage/migrations/005_request_id.up.sql +++ /dev/null @@ -1,2 +0,0 @@ -ALTER TABLE request_logs ADD COLUMN request_id TEXT DEFAULT ''; -CREATE INDEX idx_request_logs_request_id ON request_logs(request_id); diff --git a/llm-gateway/internal/storage/migrations/006_audit_log.down.sql b/llm-gateway/internal/storage/migrations/006_audit_log.down.sql deleted file mode 100644 index b750c3b..0000000 --- a/llm-gateway/internal/storage/migrations/006_audit_log.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE IF EXISTS audit_log; diff --git a/llm-gateway/internal/storage/migrations/006_audit_log.up.sql b/llm-gateway/internal/storage/migrations/006_audit_log.up.sql deleted file mode 100644 index c2fc48a..0000000 --- a/llm-gateway/internal/storage/migrations/006_audit_log.up.sql +++ /dev/null @@ -1,14 +0,0 @@ -CREATE TABLE audit_log ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp INTEGER NOT NULL, - user_id INTEGER, - username TEXT NOT NULL DEFAULT '', - action TEXT NOT NULL, - target_type TEXT DEFAULT '', - target_id TEXT DEFAULT '', - details TEXT DEFAULT '', - ip_address TEXT DEFAULT '', - request_id TEXT DEFAULT '' -); -CREATE INDEX idx_audit_timestamp ON audit_log(timestamp); -CREATE INDEX idx_audit_action ON audit_log(action); diff --git a/llm-gateway/internal/storage/migrations/007_debug_log.down.sql b/llm-gateway/internal/storage/migrations/007_debug_log.down.sql deleted file mode 100644 index 41353f5..0000000 --- a/llm-gateway/internal/storage/migrations/007_debug_log.down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE IF EXISTS debug_log; diff --git a/llm-gateway/internal/storage/migrations/007_debug_log.up.sql b/llm-gateway/internal/storage/migrations/007_debug_log.up.sql deleted file mode 100644 index 9a8441a..0000000 --- a/llm-gateway/internal/storage/migrations/007_debug_log.up.sql +++ /dev/null @@ -1,14 +0,0 @@ -CREATE TABLE debug_log ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - request_id TEXT NOT NULL, - timestamp INTEGER NOT NULL, - token_name TEXT DEFAULT '', - model TEXT DEFAULT '', - provider TEXT DEFAULT '', - request_body TEXT DEFAULT '', - response_body TEXT DEFAULT '', - request_headers TEXT DEFAULT '', - response_status INTEGER DEFAULT 0 -); -CREATE INDEX idx_debug_request_id ON debug_log(request_id); -CREATE INDEX idx_debug_timestamp ON debug_log(timestamp); diff --git a/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql b/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql deleted file mode 100644 index 032a37d..0000000 --- a/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql +++ /dev/null @@ -1 +0,0 @@ --- no-op: file_path column is harmless to keep diff --git a/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql b/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql deleted file mode 100644 index 7a5bf8b..0000000 --- a/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE debug_log ADD COLUMN file_path TEXT DEFAULT ''; diff --git a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql b/llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql deleted file mode 100644 index f288a97..0000000 --- a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE api_tokens DROP COLUMN monthly_budget_usd; diff --git a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql b/llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql deleted file mode 100644 index 7d9dfbc..0000000 --- a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE api_tokens ADD COLUMN monthly_budget_usd REAL NOT NULL DEFAULT 0; diff --git a/llm-gateway/internal/storage/migrations/010_request_type.down.sql b/llm-gateway/internal/storage/migrations/010_request_type.down.sql deleted file mode 100644 index 52e41cb..0000000 --- a/llm-gateway/internal/storage/migrations/010_request_type.down.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE request_logs DROP COLUMN request_type; diff --git a/llm-gateway/internal/storage/migrations/010_request_type.up.sql b/llm-gateway/internal/storage/migrations/010_request_type.up.sql deleted file mode 100644 index 20bca14..0000000 --- a/llm-gateway/internal/storage/migrations/010_request_type.up.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE request_logs ADD COLUMN request_type TEXT NOT NULL DEFAULT 'chat'; diff --git a/llm-gateway/internal/storage/migrations/embed.go b/llm-gateway/internal/storage/migrations/embed.go deleted file mode 100644 index 91cca1c..0000000 --- a/llm-gateway/internal/storage/migrations/embed.go +++ /dev/null @@ -1,6 +0,0 @@ -package migrations - -import "embed" - -//go:embed *.sql -var FS embed.FS diff --git a/llm-gateway/internal/webhook/webhook.go b/llm-gateway/internal/webhook/webhook.go deleted file mode 100644 index 098a0e2..0000000 --- a/llm-gateway/internal/webhook/webhook.go +++ /dev/null @@ -1,123 +0,0 @@ -package webhook - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "log" - "net/http" - "time" - - "llm-gateway/internal/config" -) - -// Event types. -const ( - EventCircuitBreakerOpen = "circuit_breaker.open" - EventCircuitBreakerClosed = "circuit_breaker.closed" - EventBudgetThreshold = "budget.threshold" -) - -// Event represents a webhook notification payload. -type Event struct { - Type string `json:"type"` - Timestamp time.Time `json:"timestamp"` - Data map[string]any `json:"data"` -} - -// Notifier sends webhook notifications. -type Notifier struct { - webhooks []config.WebhookConfig - ch chan Event - done chan struct{} - client *http.Client -} - -// NewNotifier creates a webhook notifier from config. -func NewNotifier(webhooks []config.WebhookConfig) *Notifier { - n := &Notifier{ - webhooks: webhooks, - ch: make(chan Event, 100), - done: make(chan struct{}), - client: &http.Client{Timeout: 10 * time.Second}, - } - go n.run() - return n -} - -// Notify queues an event for delivery (non-blocking). -func (n *Notifier) Notify(evt Event) { - if evt.Timestamp.IsZero() { - evt.Timestamp = time.Now() - } - select { - case n.ch <- evt: - default: - log.Printf("WARNING: webhook channel full, dropping event %s", evt.Type) - } -} - -// Close drains pending events and shuts down. -func (n *Notifier) Close() { - close(n.ch) - <-n.done -} - -func (n *Notifier) run() { - defer close(n.done) - for evt := range n.ch { - for _, wh := range n.webhooks { - if !n.shouldSend(wh, evt.Type) { - continue - } - n.send(wh, evt) - } - } -} - -func (n *Notifier) shouldSend(wh config.WebhookConfig, eventType string) bool { - if len(wh.Events) == 0 { - return true // no filter = send all - } - for _, e := range wh.Events { - if e == eventType { - return true - } - } - return false -} - -func (n *Notifier) send(wh config.WebhookConfig, evt Event) { - body, err := json.Marshal(evt) - if err != nil { - log.Printf("ERROR: webhook marshal: %v", err) - return - } - - req, err := http.NewRequest(http.MethodPost, wh.URL, bytes.NewReader(body)) - if err != nil { - log.Printf("ERROR: webhook request: %v", err) - return - } - req.Header.Set("Content-Type", "application/json") - - if wh.Secret != "" { - mac := hmac.New(sha256.New, []byte(wh.Secret)) - mac.Write(body) - sig := hex.EncodeToString(mac.Sum(nil)) - req.Header.Set("X-Webhook-Signature", "sha256="+sig) - } - - resp, err := n.client.Do(req) - if err != nil { - log.Printf("WARNING: webhook delivery to %s failed: %v", wh.URL, err) - return - } - resp.Body.Close() - - if resp.StatusCode >= 400 { - log.Printf("WARNING: webhook %s returned %d", wh.URL, resp.StatusCode) - } -}