diff --git a/llm-gateway.yaml b/llm-gateway.yaml
deleted file mode 100644
index 4025d81..0000000
--- a/llm-gateway.yaml
+++ /dev/null
@@ -1,372 +0,0 @@
-server:
- listen: "0.0.0.0:3000"
- request_timeout: 300s
- max_request_body_mb: 10
- session_secret: "${SESSION_SECRET}"
- default_admin:
- username: "${ADMIN_USERNAME}"
- password: "${ADMIN_PASSWORD}"
-
-tokens:
- - name: "open-webui"
- key: "${OPENWEBUI_API_KEY}"
- rate_limit_rpm: 0 # unlimited
- daily_budget_usd: 0
- - name: "opencode"
- key: "${OPENCODE_API_KEY}"
- rate_limit_rpm: 0 # unlimited
- daily_budget_usd: 0
-
-pricing_lookup:
- # url: "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json" # default
- refresh_interval: 6h
-
-database:
- path: "/data/gateway.db"
- retention_days: 90
-
-debug:
- enabled: true
- retention_days: 90
- # data_dir: "/data" # defaults to directory of database.path
- # max_body_bytes: 0 # 0 = unlimited (save full bodies)
-
-cache:
- enabled: true
- address: "valkey:6379"
- ttl: 3600
-
-providers:
- - name: deepinfra
- base_url: "https://api.deepinfra.com/v1/openai"
- api_key: "${DEEPINFRA_API_KEY}"
- priority: 1
- timeout: 120s
- - name: siliconflow
- base_url: "https://api.siliconflow.com/v1"
- api_key: "${SILICONFLOW_API_KEY}"
- priority: 2
- timeout: 120s
- - name: openrouter
- base_url: "https://openrouter.ai/api/v1"
- api_key: "${OPENROUTER_API_KEY}"
- priority: 3
- timeout: 120s
- - name: groq
- base_url: "https://api.groq.com/openai/v1"
- api_key: "${GROQ_API_KEY}"
- priority: 1
- timeout: 120s
- - name: cerebras
- base_url: "https://api.cerebras.ai/v1"
- api_key: "${CEREBRAS_API_KEY}"
- priority: 1
- timeout: 120s
-
-models:
- # ═══ TIER 1: Free (OpenRouter free models, $0) ═══
- # NOTE: Commented out — free models are heavily rate-limited upstream.
- # Uncomment if you want best-effort free access.
- # - name: "llama-3.3-70b-free"
- # routes:
- # - provider: openrouter
- # model: "meta-llama/llama-3.3-70b-instruct:free"
- # - name: "deepseek-r1-free"
- # routes:
- # - provider: openrouter
- # model: "deepseek/deepseek-r1-0528:free"
- # - name: "gpt-oss-free"
- # routes:
- # - provider: openrouter
- # model: "openai/gpt-oss-120b:free"
- # - name: "gpt-oss-20b-free"
- # routes:
- # - provider: openrouter
- # model: "openai/gpt-oss-20b:free"
- # - name: "qwen3-coder-free"
- # routes:
- # - provider: openrouter
- # model: "qwen/qwen3-coder:free"
- # - name: "qwen3-235b-free"
- # routes:
- # - provider: openrouter
- # model: "qwen/qwen3-235b-a22b-thinking-2507"
- # - name: "glm-4.5-air-free"
- # routes:
- # - provider: openrouter
- # model: "z-ai/glm-4.5-air:free"
- # - name: "nemotron-nano-free"
- # routes:
- # - provider: openrouter
- # model: "nvidia/nemotron-nano-9b-v2:free"
- # - name: "trinity-large-free"
- # routes:
- # - provider: openrouter
- # model: "arcee-ai/trinity-large-preview:free"
- # - name: "mistral-small-free"
- # routes:
- # - provider: openrouter
- # model: "mistralai/mistral-small-3.1-24b-instruct:free"
- # - name: "gemma-3-27b-free"
- # routes:
- # - provider: openrouter
- # model: "google/gemma-3-27b-it:free"
- # - name: "step-3.5-flash-free"
- # routes:
- # - provider: openrouter
- # model: "stepfun/step-3.5-flash:free"
-
- # ═══ TIER 2: Low cost (Groq, Cerebras — free tier with rate limits) ═══
- - name: "llama-3.1-8b"
- routes:
- - provider: groq
- model: "llama-3.1-8b-instant"
- pricing: { input: 0.05, output: 0.08 }
- - provider: cerebras
- model: "llama3.1-8b"
- pricing: { input: 0.10, output: 0.10 }
- - provider: deepinfra
- model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
- pricing: { input: 0.03, output: 0.05 }
-
- - name: "llama-3.3-70b"
- routes:
- - provider: deepinfra
- model: "meta-llama/Llama-3.3-70B-Instruct-Turbo"
- pricing: { input: 0.23, output: 0.40 }
- - provider: groq
- model: "llama-3.3-70b-versatile"
- pricing: { input: 0.59, output: 0.79 }
- - provider: cerebras
- model: "llama-3.3-70b"
- pricing: { input: 0.85, output: 1.20 }
-
- - name: "gpt-oss"
- routes:
- - provider: groq
- model: "openai/gpt-oss-120b"
- pricing: { input: 0.15, output: 0.60 }
- - provider: cerebras
- model: "gpt-oss-120b"
- pricing: { input: 0.35, output: 0.75 }
- - provider: deepinfra
- model: "openai/gpt-oss-120b"
- pricing: { input: 0.05, output: 0.24 }
-
- - name: "gpt-oss-20b"
- routes:
- - provider: groq
- model: "openai/gpt-oss-20b"
- pricing: { input: 0.075, output: 0.30 }
- - provider: deepinfra
- model: "openai/gpt-oss-20b"
- pricing: { input: 0.04, output: 0.16 }
-
- - name: "llama-4-scout"
- routes:
- - provider: groq
- model: "meta-llama/llama-4-scout-17b-16e-instruct"
- pricing: { input: 0.11, output: 0.34 }
-
- - name: "llama-4-maverick"
- routes:
- - provider: groq
- model: "meta-llama/llama-4-maverick-17b-128e-instruct"
- pricing: { input: 0.20, output: 0.60 }
-
- - name: "qwen3-32b"
- routes:
- - provider: groq
- model: "qwen/qwen3-32b"
- pricing: { input: 0.29, output: 0.59 }
- - provider: cerebras
- model: "qwen-3-32b"
-
- # ═══ TIER 3: DeepSeek V3.2 (cheapest flagship) ═══
- - name: "deepseek-v3.2"
- routes:
- - provider: deepinfra
- model: "deepseek-ai/DeepSeek-V3.2"
- pricing: { input: 0.26, output: 0.38 }
- - provider: siliconflow
- model: "deepseek-ai/DeepSeek-V3.2"
- pricing: { input: 0.27, output: 0.42 }
- - provider: openrouter
- model: "deepseek/deepseek-chat-v3-0324"
- pricing: { input: 0.30, output: 0.88 }
-
- # ═══ TIER 4: Ultra-cheap DeepInfra ═══
- - name: "nemotron-super"
- routes:
- - provider: deepinfra
- model: "nvidia/Llama-3.3-Nemotron-Super-49B-v1.5"
- pricing: { input: 0.10, output: 0.40 }
-
- - name: "nemotron-nano"
- routes:
- - provider: deepinfra
- model: "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
- pricing: { input: 0.04, output: 0.16 }
-
- # ═══ TIER 5: DeepSeek R1 & reasoning ═══
- - name: "deepseek-r1"
- routes:
- - provider: deepinfra
- model: "deepseek-ai/DeepSeek-R1-0528"
- - provider: openrouter
- model: "deepseek/deepseek-r1"
-
- - name: "deepseek-r1-distill-llama-70b"
- routes:
- - provider: deepinfra
- model: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
-
- - name: "devstral-small"
- routes:
- - provider: openrouter
- model: "mistralai/devstral-small"
-
- - name: "devstral-medium"
- routes:
- - provider: openrouter
- model: "mistralai/devstral-medium"
-
- # ═══ TIER 6: GLM ═══
- - name: "glm-4.6"
- routes:
- - provider: deepinfra
- model: "zai-org/GLM-4.6"
- pricing: { input: 0.60, output: 1.90 }
-
- - name: "glm-4.7"
- routes:
- - provider: deepinfra
- model: "zai-org/GLM-4.7"
- pricing: { input: 0.40, output: 1.75 }
- - provider: cerebras
- model: "zai-glm-4.7"
- pricing: { input: 2.25, output: 2.75 }
- - provider: siliconflow
- model: "THUDM/GLM-4-32B-0414"
-
- - name: "glm-5"
- routes:
- - provider: deepinfra
- model: "zai-org/GLM-5"
- pricing: { input: 0.80, output: 2.56 }
-
- # ═══ TIER 7: Kimi ═══
- - name: "kimi-k2"
- routes:
- - provider: groq
- model: "moonshotai/kimi-k2-instruct-0905"
- pricing: { input: 1.00, output: 3.00 }
- - provider: deepinfra
- model: "moonshotai/Kimi-K2-Instruct-0905"
- pricing: { input: 0.50, output: 2.00 }
- - provider: siliconflow
- model: "moonshotai/Kimi-K2-Instruct-0905"
- pricing: { input: 0.58, output: 2.29 }
-
- - name: "kimi-k2.5"
- routes:
- - provider: deepinfra
- model: "moonshotai/Kimi-K2.5"
- pricing: { input: 0.45, output: 2.25 }
- - provider: openrouter
- model: "moonshotai/kimi-k2.5"
-
- # ═══ TIER 8: SiliconFlow (Qwen) ═══
- - name: "qwen3-coder"
- routes:
- - provider: siliconflow
- model: "Qwen/Qwen3-Coder-480B-A35B-Instruct"
- pricing: { input: 1.14, output: 2.28 }
-
- - name: "qwen3-coder-30b"
- routes:
- - provider: siliconflow
- model: "Qwen/Qwen3-Coder-30B-A3B-Instruct"
-
- # ═══ TIER 9: OpenRouter premium (paid) ═══
- - name: "minimax-m2.5"
- routes:
- - provider: openrouter
- model: "minimax/minimax-m2.5"
-
- - name: "gpt-4.1-mini"
- routes:
- - provider: openrouter
- model: "openai/gpt-4.1-mini"
-
- - name: "gpt-4.1"
- routes:
- - provider: openrouter
- model: "openai/gpt-4.1"
-
- - name: "gemini-3-flash-preview"
- routes:
- - provider: openrouter
- model: "google/gemini-3-flash-preview"
-
- - name: "gemini-2.5-pro"
- routes:
- - provider: openrouter
- model: "google/gemini-2.5-pro-preview"
-
- # ═══ TIER 10: Vision / Multimodal ═══
- - name: "gemma-3-4b"
- routes:
- - provider: openrouter
- model: "google/gemma-3-4b-it"
- pricing: { input: 0.017, output: 0.068 }
- - provider: deepinfra
- model: "google/gemma-3-4b-it"
- pricing: { input: 0.04, output: 0.08 }
-
- - name: "gemma-3-12b"
- routes:
- - provider: openrouter
- model: "google/gemma-3-12b-it"
- pricing: { input: 0.03, output: 0.10 }
- - provider: deepinfra
- model: "google/gemma-3-12b-it"
- pricing: { input: 0.04, output: 0.13 }
-
- - name: "gemma-3-27b"
- routes:
- - provider: openrouter
- model: "google/gemma-3-27b-it"
- pricing: { input: 0.04, output: 0.15 }
- - provider: deepinfra
- model: "google/gemma-3-27b-it"
- pricing: { input: 0.08, output: 0.16 }
-
- - name: "qwen3-vl-8b"
- routes:
- - provider: openrouter
- model: "qwen/qwen3-vl-8b-instruct"
- pricing: { input: 0.08, output: 0.50 }
- - provider: deepinfra
- model: "Qwen/Qwen3-VL-8B-Instruct"
- pricing: { input: 0.18, output: 0.69 }
-
- - name: "qwen3-vl-32b"
- routes:
- - provider: openrouter
- model: "qwen/qwen3-vl-32b-instruct"
- pricing: { input: 0.104, output: 0.416 }
-
- - name: "qwen2.5-vl-32b"
- routes:
- - provider: openrouter
- model: "qwen/qwen2.5-vl-32b-instruct"
- pricing: { input: 0.05, output: 0.22 }
- - provider: deepinfra
- model: "Qwen/Qwen2.5-VL-32B-Instruct"
- pricing: { input: 0.20, output: 0.60 }
-
- - name: "claude-sonnet"
- routes:
- - provider: openrouter
- model: "anthropic/claude-sonnet-4"
diff --git a/llm-gateway/.env.example b/llm-gateway/.env.example
deleted file mode 100644
index a1acce2..0000000
--- a/llm-gateway/.env.example
+++ /dev/null
@@ -1,19 +0,0 @@
-# LLM Gateway Environment Variables
-
-# Session secret (required for persistent sessions)
-SESSION_SECRET=change-me-to-a-random-string
-
-# Default admin (created on first run if no users exist)
-ADMIN_USERNAME=admin
-ADMIN_PASSWORD=change-me-min-8-chars
-
-# Static API tokens (seeded on startup)
-OPENWEBUI_API_KEY=sk-your-openwebui-key
-PERSONAL_API_KEY=sk-your-personal-key
-
-# Provider API keys
-DEEPINFRA_API_KEY=
-SILICONFLOW_API_KEY=
-OPENROUTER_API_KEY=
-GROQ_API_KEY=
-CEREBRAS_API_KEY=
diff --git a/llm-gateway/.gitignore b/llm-gateway/.gitignore
deleted file mode 100644
index 759cb20..0000000
--- a/llm-gateway/.gitignore
+++ /dev/null
@@ -1,18 +0,0 @@
-# Binaries
-gateway
-llm-gateway
-
-# Database
-*.db
-*.db-journal
-*.db-wal
-*.db-shm
-
-# Debug log files
-debug-logs/
-
-# Local config
-configs/config.local.yaml
-
-# Environment
-.env
diff --git a/llm-gateway/Dockerfile b/llm-gateway/Dockerfile
deleted file mode 100644
index 9fdb103..0000000
--- a/llm-gateway/Dockerfile
+++ /dev/null
@@ -1,15 +0,0 @@
-FROM golang:1.24-alpine AS builder
-WORKDIR /src
-COPY go.mod go.sum ./
-RUN go mod download
-COPY . .
-RUN CGO_ENABLED=0 go build -ldflags="-s -w" -o /llm-gateway ./cmd/gateway
-
-FROM alpine:3.19
-RUN apk add --no-cache ca-certificates tzdata
-COPY --from=builder /llm-gateway /usr/local/bin/llm-gateway
-RUN mkdir -p /data
-VOLUME /data
-EXPOSE 3000
-ENTRYPOINT ["llm-gateway"]
-CMD ["-config", "/etc/llm-gateway/config.yaml"]
diff --git a/llm-gateway/Makefile b/llm-gateway/Makefile
deleted file mode 100644
index f042edf..0000000
--- a/llm-gateway/Makefile
+++ /dev/null
@@ -1,16 +0,0 @@
-.PHONY: build run clean docker
-
-BINARY=llm-gateway
-VERSION=$(shell git describe --tags --always --dirty 2>/dev/null || echo dev)
-
-build:
- go build -ldflags="-s -w -X main.version=$(VERSION)" -o $(BINARY) ./cmd/gateway
-
-run: build
- ./$(BINARY) -config configs/config.yaml
-
-clean:
- rm -f $(BINARY)
-
-docker:
- docker build -t llm-gateway:latest .
diff --git a/llm-gateway/cmd/gateway/main.go b/llm-gateway/cmd/gateway/main.go
deleted file mode 100644
index c8bbd06..0000000
--- a/llm-gateway/cmd/gateway/main.go
+++ /dev/null
@@ -1,434 +0,0 @@
-package main
-
-import (
- "context"
- "flag"
- "log"
- "net/http"
- "os"
- "os/signal"
- "path/filepath"
- "syscall"
- "time"
-
- "github.com/go-chi/chi/v5"
- "github.com/go-chi/chi/v5/middleware"
- gocors "github.com/go-chi/cors"
- "github.com/prometheus/client_golang/prometheus/promhttp"
-
- "llm-gateway/internal/auth"
- "llm-gateway/internal/cache"
- "llm-gateway/internal/config"
- "llm-gateway/internal/dashboard"
- "llm-gateway/internal/metrics"
- "llm-gateway/internal/pricing"
- "llm-gateway/internal/provider"
- "llm-gateway/internal/proxy"
- "llm-gateway/internal/storage"
- "llm-gateway/internal/webhook"
-)
-
-var version = "dev"
-
-func main() {
- configPath := flag.String("config", "configs/config.yaml", "path to config file")
- flag.Parse()
-
- log.Printf("llm-gateway %s starting", version)
-
- cfg, err := config.Load(*configPath)
- if err != nil {
- log.Fatalf("Failed to load config: %v", err)
- }
-
- // Pricing lookup (fetches from URL, refreshes periodically)
- pricingLookup := pricing.NewLookup(cfg.Pricing.URL, cfg.Pricing.RefreshInterval)
- defer pricingLookup.Close()
-
- // Auto-fill missing pricing from fetched data
- for i, m := range cfg.Models {
- for j, r := range m.Routes {
- if r.Pricing.Input == 0 && r.Pricing.Output == 0 {
- if pricingLookup.FillMissing(r.Provider, r.Model, &cfg.Models[i].Routes[j].Pricing.Input, &cfg.Models[i].Routes[j].Pricing.Output) {
- log.Printf("Auto-filled pricing for %s via %s: $%.2f/$%.2f per 1M tokens",
- m.Name, r.Provider, cfg.Models[i].Routes[j].Pricing.Input, cfg.Models[i].Routes[j].Pricing.Output)
- }
- }
- }
- }
-
- // Database
- db, err := storage.Open(cfg.Database.Path)
- if err != nil {
- log.Fatalf("Failed to open database: %v", err)
- }
- defer db.Close()
-
- if err := db.CleanupOldRecords(cfg.Database.RetentionDays); err != nil {
- log.Printf("WARNING: retention cleanup failed: %v", err)
- }
-
- asyncLogger := storage.NewAsyncLogger(db, 1000)
- defer asyncLogger.Close()
-
- // SSE broker for real-time dashboard updates
- sseBroker := dashboard.NewSSEBroker()
- asyncLogger.OnFlush = sseBroker.Notify
-
- // Cache (optional)
- var c *cache.Cache
- if cfg.Cache.Enabled {
- c, err = cache.New(cfg.Cache.Address, cfg.Cache.TTL)
- if err != nil {
- log.Printf("WARNING: cache disabled: %v", err)
- } else {
- log.Printf("Cache connected to %s", cfg.Cache.Address)
- }
- }
-
- // Provider registry
- registry, err := provider.NewRegistry(cfg)
- if err != nil {
- log.Fatalf("Failed to build provider registry: %v", err)
- }
- log.Printf("Registered %d models", len(cfg.Models))
-
- // Provider health tracker
- healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker)
-
- // Webhook notifier
- var notifier *webhook.Notifier
- if len(cfg.Webhooks) > 0 {
- notifier = webhook.NewNotifier(cfg.Webhooks)
- defer notifier.Close()
- log.Printf("Webhooks configured: %d endpoints", len(cfg.Webhooks))
-
- // Wire health tracker state changes to webhook
- healthTracker.OnStateChange = func(providerName string, from, to provider.CircuitState) {
- eventType := webhook.EventCircuitBreakerOpen
- if to == provider.CircuitClosed {
- eventType = webhook.EventCircuitBreakerClosed
- }
- notifier.Notify(webhook.Event{
- Type: eventType,
- Data: map[string]any{
- "provider": providerName,
- "from": from.String(),
- "to": to.String(),
- },
- })
- }
- }
-
- // Auth store (static tokens checked in-memory, not seeded to DB)
- var staticTokens []auth.StaticToken
- for _, t := range cfg.Tokens {
- if t.Key != "" {
- staticTokens = append(staticTokens, auth.StaticToken{
- Name: t.Name,
- Key: t.Key,
- RateLimitRPM: t.RateLimitRPM,
- DailyBudgetUSD: t.DailyBudgetUSD,
- MonthlyBudgetUSD: t.MonthlyBudgetUSD,
- MaxConcurrent: t.MaxConcurrent,
- })
- log.Printf("Loaded static token: %s", t.Name)
- }
- }
- authStore := auth.NewStore(db.DB, staticTokens)
- authMiddleware := auth.NewMiddleware(authStore)
- authHandlers := auth.NewHandlers(authStore, cfg.Server.SessionSecret)
-
- // Audit logger
- auditLogger := storage.NewAuditLogger(db)
- auditLogger.OnWrite = sseBroker.Notify
- authHandlers.SetAuditLogger(auditLogger)
-
- // Debug logger
- debugDataDir := cfg.Debug.DataDir
- if debugDataDir == "" {
- debugDataDir = filepath.Dir(cfg.Database.Path)
- }
- debugLogger := storage.NewDebugLogger(db, cfg.Debug.Enabled, debugDataDir)
- debugLogger.OnWrite = sseBroker.Notify
-
- // Seed default admin
- seedDefaultAdmin(cfg, authStore)
-
- // Metrics
- m := metrics.New()
-
- // Handlers
- proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker)
- proxyHandler.SetDebugLogger(debugLogger)
-
- // Request deduplication
- if cfg.Dedup.Enabled {
- dedup := proxy.NewDeduplicator(cfg.Dedup.Window)
- defer dedup.Close()
- proxyHandler.SetDeduplicator(dedup)
- log.Printf("Request deduplication enabled (window: %v)", cfg.Dedup.Window)
- }
-
- modelsHandler := proxy.NewModelsHandler(registry, healthTracker, cfg)
- proxyAuth := proxy.NewAuthMiddleware(authStore)
- rateLimiter := proxy.NewRateLimiter(db)
- if notifier != nil {
- rateLimiter.SetNotifier(notifier)
- }
- concurrencyLimiter := proxy.NewConcurrencyLimiter()
- statsAPI := dashboard.NewStatsAPI(db, authStore)
- statsAPI.SetHealthTracker(healthTracker)
- statsAPI.SetAuditLogger(auditLogger)
- statsAPI.SetDebugLogger(debugLogger)
- statsAPI.SetConfigPath(*configPath)
- if c != nil {
- statsAPI.SetCache(c)
- }
- dash := dashboard.NewDashboard(authStore, statsAPI)
- dash.SetRegistry(registry)
- dash.SetAuditLogger(auditLogger)
- dash.SetDebugLogger(debugLogger)
- if c != nil {
- dash.SetCache(c)
- }
-
- // Export handler
- exportHandler := dashboard.NewExportHandler(db, authStore)
-
- // Router
- r := chi.NewRouter()
-
- // CORS (before other middleware)
- if cfg.CORS.Enabled {
- r.Use(gocors.Handler(gocors.Options{
- AllowedOrigins: cfg.CORS.AllowedOrigins,
- AllowedMethods: cfg.CORS.AllowedMethods,
- AllowedHeaders: cfg.CORS.AllowedHeaders,
- MaxAge: cfg.CORS.MaxAge,
- AllowCredentials: true,
- }))
- }
-
- r.Use(middleware.RealIP)
- r.Use(middleware.Recoverer)
- r.Use(middleware.RequestID)
-
- // Health & metrics (public)
- r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
- if err := db.Ping(); err != nil {
- http.Error(w, "database unhealthy", http.StatusServiceUnavailable)
- return
- }
- if c != nil {
- if err := c.Ping(r.Context()); err != nil {
- http.Error(w, "cache unhealthy", http.StatusServiceUnavailable)
- return
- }
- }
- w.WriteHeader(http.StatusOK)
- w.Write([]byte("OK"))
- })
- r.Handle("/metrics", promhttp.Handler())
-
- // OpenAI-compatible API (API token auth via Bearer header)
- r.Group(func(r chi.Router) {
- r.Use(proxyAuth.Authenticate)
- r.Use(rateLimiter.Check)
- r.Use(concurrencyLimiter.Check)
- r.Post("/v1/chat/completions", proxyHandler.ChatCompletions)
- r.Post("/v1/embeddings", proxyHandler.Embeddings)
- r.Get("/v1/models", modelsHandler.ListModels)
- })
-
- // Auth pages (public)
- r.Get("/login", dash.LoginPage)
- r.Get("/setup", dash.SetupPage)
-
- // Auth API endpoints (public)
- r.Post("/api/auth/login", authHandlers.Login)
- r.Post("/api/auth/setup", authHandlers.Setup)
- r.Post("/api/auth/login/totp", authHandlers.LoginTOTP)
-
- // Favicon (prevent 401 noise in browser console)
- r.Get("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusNoContent)
- })
-
- // Root redirect
- r.Get("/", func(w http.ResponseWriter, r *http.Request) {
- http.Redirect(w, r, "/dashboard", http.StatusFound)
- })
-
- // Authenticated pages and API
- r.Group(func(r chi.Router) {
- r.Use(authMiddleware.RequireAuth)
-
- // Dashboard pages (HTMX)
- r.Get("/dashboard", dash.DashboardPage)
- r.Get("/logs", dash.LogsPage)
- r.Get("/models", dash.ModelsPage)
- r.Get("/tokens", dash.TokensPage)
- r.Get("/settings", dash.SettingsPage)
-
- // Admin-only pages
- r.Group(func(r chi.Router) {
- r.Use(authMiddleware.RequireAdmin)
- r.Get("/users", dash.UsersPage)
- r.Get("/audit", dash.AuditPage)
- r.Get("/debug", dash.DebugPage)
- })
-
- // Auth API
- r.Post("/api/auth/logout", authHandlers.Logout)
- r.Get("/api/auth/me", authHandlers.Me)
- r.Put("/api/auth/me/password", authHandlers.ChangePassword)
- r.Put("/api/auth/me/username", authHandlers.ChangeUsername)
- r.Put("/api/auth/me/email", authHandlers.ChangeEmail)
- r.Post("/api/auth/totp/setup", authHandlers.TOTPSetup)
- r.Post("/api/auth/totp/verify", authHandlers.TOTPVerify)
- r.Delete("/api/auth/totp", authHandlers.TOTPDisable)
-
- // API token management
- r.Get("/api/tokens", authHandlers.ListTokens)
- r.Post("/api/tokens", authHandlers.CreateToken)
- r.Delete("/api/tokens/{id}", authHandlers.DeleteToken)
-
- // SSE events
- r.Get("/api/events", sseBroker.ServeHTTP)
-
- // Dashboard stats
- r.Get("/api/stats/summary", statsAPI.Summary)
- r.Get("/api/stats/models", statsAPI.Models)
- r.Get("/api/stats/providers", statsAPI.Providers)
- r.Get("/api/stats/tokens", statsAPI.Tokens)
- r.Get("/api/stats/timeseries", statsAPI.Timeseries)
- r.Get("/api/stats/logs", statsAPI.Logs)
- r.Get("/api/stats/latency", statsAPI.Latency)
- r.Get("/api/stats/cost-breakdown", statsAPI.CostBreakdown)
- r.Get("/api/stats/provider-health", statsAPI.ProviderHealthHandler)
- r.Get("/api/stats/cache", statsAPI.CacheStats)
-
- // Data export
- r.Get("/api/export/logs", exportHandler.ExportLogs)
- r.Get("/api/export/stats", exportHandler.ExportStats)
-
- // Admin-only: user management, audit, debug, config validation
- r.Group(func(r chi.Router) {
- r.Use(authMiddleware.RequireAdmin)
- r.Get("/api/auth/users", authHandlers.ListUsers)
- r.Post("/api/auth/users", authHandlers.CreateUser)
- r.Delete("/api/auth/users/{id}", authHandlers.DeleteUser)
-
- // Audit log
- r.Get("/api/stats/audit", statsAPI.AuditLogs)
-
- // Config validation
- r.Get("/api/config/validate", statsAPI.ValidateConfig)
-
- // Debug logging
- r.Post("/api/debug/toggle", statsAPI.DebugToggle)
- r.Get("/api/debug/status", statsAPI.DebugStatus)
- r.Get("/api/debug/logs", statsAPI.DebugLogs)
- r.Get("/api/debug/logs/{requestID}", statsAPI.DebugLogByRequestID)
- })
- })
-
- // Periodic session cleanup and debug log cleanup
- go func() {
- ticker := time.NewTicker(1 * time.Hour)
- defer ticker.Stop()
- for range ticker.C {
- if err := authStore.CleanExpiredSessions(); err != nil {
- log.Printf("WARNING: session cleanup failed: %v", err)
- }
- if err := debugLogger.Cleanup(cfg.Debug.RetentionDays); err != nil {
- log.Printf("WARNING: debug log cleanup failed: %v", err)
- }
- }
- }()
-
- // Server
- srv := &http.Server{
- Addr: cfg.Server.Listen,
- Handler: r,
- ReadTimeout: 30 * time.Second,
- WriteTimeout: cfg.Server.RequestTimeout + 10*time.Second,
- IdleTimeout: 120 * time.Second,
- }
-
- // Config hot-reload via SIGHUP
- config.WatchReload(*configPath, func(newCfg *config.Config) {
- // Reload registry (models, providers, routes)
- if err := registry.Reload(newCfg); err != nil {
- log.Printf("ERROR: registry reload failed: %v", err)
- return
- }
- log.Printf("Reloaded %d models", len(newCfg.Models))
-
- // Reload pricing
- for i, m := range newCfg.Models {
- for j, rt := range m.Routes {
- if rt.Pricing.Input == 0 && rt.Pricing.Output == 0 {
- pricingLookup.FillMissing(rt.Provider, rt.Model,
- &newCfg.Models[i].Routes[j].Pricing.Input,
- &newCfg.Models[i].Routes[j].Pricing.Output)
- }
- }
- }
-
- // Reload static tokens
- var newStaticTokens []auth.StaticToken
- for _, t := range newCfg.Tokens {
- if t.Key != "" {
- newStaticTokens = append(newStaticTokens, auth.StaticToken{
- Name: t.Name,
- Key: t.Key,
- RateLimitRPM: t.RateLimitRPM,
- DailyBudgetUSD: t.DailyBudgetUSD,
- MonthlyBudgetUSD: t.MonthlyBudgetUSD,
- MaxConcurrent: t.MaxConcurrent,
- })
- }
- }
- authStore.SetStaticTokens(newStaticTokens)
-
- // Update config pointer for retry/debug/etc
- cfg = newCfg
- })
-
- // Graceful shutdown
- done := make(chan os.Signal, 1)
- signal.Notify(done, os.Interrupt, syscall.SIGTERM)
-
- go func() {
- log.Printf("Listening on %s", cfg.Server.Listen)
- if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- log.Fatalf("Server failed: %v", err)
- }
- }()
-
- <-done
- log.Println("Shutting down...")
-
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- srv.Shutdown(ctx)
-
- log.Println("Stopped")
-}
-
-// seedDefaultAdmin creates the default admin user if no users exist.
-func seedDefaultAdmin(cfg *config.Config, authStore *auth.Store) {
- if !authStore.HasAnyUser() {
- da := cfg.Server.DefaultAdmin
- if da.Username != "" && da.Password != "" {
- user, err := authStore.CreateUser(da.Username, da.Password, true)
- if err != nil {
- log.Printf("WARNING: failed to create default admin: %v", err)
- } else {
- log.Printf("Created default admin user: %s (id=%d)", user.Username, user.ID)
- }
- }
- }
-}
diff --git a/llm-gateway/configs/config.yaml b/llm-gateway/configs/config.yaml
deleted file mode 100644
index 91f1913..0000000
--- a/llm-gateway/configs/config.yaml
+++ /dev/null
@@ -1,140 +0,0 @@
-server:
- listen: "0.0.0.0:3000"
- request_timeout: 300s
- max_request_body_mb: 10
- session_secret: "${SESSION_SECRET}"
- default_admin:
- username: "${ADMIN_USERNAME}"
- password: "${ADMIN_PASSWORD}"
-
-tokens:
- - name: "open-webui"
- key: "${OPENWEBUI_API_KEY}"
- rate_limit_rpm: 0 # unlimited
- daily_budget_usd: 5.0
- - name: "rayandrew"
- key: "${PERSONAL_API_KEY}"
- rate_limit_rpm: 0 # unlimited
- daily_budget_usd: 10.0
-
-pricing_lookup:
- # url: "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json" # default
- refresh_interval: 6h
-
-database:
- path: "/data/gateway.db"
- retention_days: 90
-
-cache:
- enabled: true
- address: "valkey:6379"
- ttl: 3600
-
-providers:
- - name: deepinfra
- base_url: "https://api.deepinfra.com/v1/openai"
- api_key: "${DEEPINFRA_API_KEY}"
- priority: 1
- timeout: 120s
- - name: siliconflow
- base_url: "https://api.siliconflow.com/v1"
- api_key: "${SILICONFLOW_API_KEY}"
- priority: 2
- timeout: 120s
- - name: openrouter
- base_url: "https://openrouter.ai/api/v1"
- api_key: "${OPENROUTER_API_KEY}"
- priority: 3
- timeout: 120s
- - name: groq
- base_url: "https://api.groq.com/openai/v1"
- api_key: "${GROQ_API_KEY}"
- priority: 1
- timeout: 120s
- - name: cerebras
- base_url: "https://api.cerebras.ai/v1"
- api_key: "${CEREBRAS_API_KEY}"
- priority: 1
- timeout: 120s
-
-models:
- - name: "deepseek-v3.2"
- routes:
- - provider: deepinfra
- model: "deepseek-ai/DeepSeek-V3.2"
- pricing: { input: 0.26, output: 0.38 }
- - provider: siliconflow
- model: "deepseek-ai/DeepSeek-V3.2"
- pricing: { input: 0.27, output: 0.42 }
- - provider: openrouter
- model: "deepseek/deepseek-chat-v3-0324"
- pricing: { input: 0.30, output: 0.88 }
-
- - name: "llama-3.3-70b"
- routes:
- - provider: groq
- model: "llama-3.3-70b-versatile"
- pricing: { input: 0, output: 0 }
- - provider: deepinfra
- model: "meta-llama/Llama-3.3-70B-Instruct"
- pricing: { input: 0.23, output: 0.40 }
-
- - name: "llama-3.1-8b"
- routes:
- - provider: groq
- model: "llama-3.1-8b-instant"
- pricing: { input: 0, output: 0 }
- - provider: cerebras
- model: "llama-3.1-8b"
- pricing: { input: 0, output: 0 }
- - provider: deepinfra
- model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
- pricing: { input: 0.03, output: 0.05 }
-
- - name: "qwen-2.5-72b"
- routes:
- - provider: groq
- model: "qwen-2.5-72b"
- pricing: { input: 0, output: 0 }
- - provider: deepinfra
- model: "Qwen/Qwen2.5-72B-Instruct"
- pricing: { input: 0.23, output: 0.40 }
-
- - name: "qwen-2.5-coder-32b"
- routes:
- - provider: groq
- model: "qwen-2.5-coder-32b"
- pricing: { input: 0, output: 0 }
- - provider: deepinfra
- model: "Qwen/Qwen2.5-Coder-32B-Instruct"
- pricing: { input: 0.07, output: 0.16 }
-
- - name: "gemma-2-9b"
- routes:
- - provider: groq
- model: "gemma2-9b-it"
- pricing: { input: 0, output: 0 }
-
- - name: "deepseek-r1"
- routes:
- - provider: deepinfra
- model: "deepseek-ai/DeepSeek-R1"
- pricing: { input: 0.40, output: 1.60 }
- - provider: openrouter
- model: "deepseek/deepseek-r1"
- pricing: { input: 0.55, output: 2.19 }
-
- - name: "deepseek-r1-distill-llama-70b"
- routes:
- - provider: groq
- model: "deepseek-r1-distill-llama-70b"
- pricing: { input: 0, output: 0 }
- - provider: deepinfra
- model: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
- pricing: { input: 0.23, output: 0.69 }
-
- - name: "deepseek-r1-distill-qwen-32b"
- routes:
- - provider: deepinfra
- model: "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
- pricing: { input: 0.07, output: 0.16 }
diff --git a/llm-gateway/go.mod b/llm-gateway/go.mod
deleted file mode 100644
index 08c7f3b..0000000
--- a/llm-gateway/go.mod
+++ /dev/null
@@ -1,39 +0,0 @@
-module llm-gateway
-
-go 1.24.0
-
-require (
- github.com/go-chi/chi/v5 v5.2.5
- github.com/go-chi/cors v1.2.2
- github.com/golang-migrate/migrate/v4 v4.19.1
- github.com/pquerna/otp v1.5.0
- github.com/prometheus/client_golang v1.23.2
- github.com/redis/go-redis/v9 v9.17.3
- golang.org/x/crypto v0.48.0
- gopkg.in/yaml.v3 v3.0.1
- modernc.org/sqlite v1.45.0
-)
-
-require (
- github.com/beorn7/perks v1.0.1 // indirect
- github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
- github.com/cespare/xxhash/v2 v2.3.0 // indirect
- github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
- github.com/dustin/go-humanize v1.0.1 // indirect
- github.com/google/uuid v1.6.0 // indirect
- github.com/kr/text v0.2.0 // indirect
- github.com/mattn/go-isatty v0.0.20 // indirect
- github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
- github.com/ncruces/go-strftime v1.0.0 // indirect
- github.com/prometheus/client_model v0.6.2 // indirect
- github.com/prometheus/common v0.66.1 // indirect
- github.com/prometheus/procfs v0.16.1 // indirect
- github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
- go.yaml.in/yaml/v2 v2.4.2 // indirect
- golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
- golang.org/x/sys v0.41.0 // indirect
- google.golang.org/protobuf v1.36.8 // indirect
- modernc.org/libc v1.67.6 // indirect
- modernc.org/mathutil v1.7.1 // indirect
- modernc.org/memory v1.11.0 // indirect
-)
diff --git a/llm-gateway/go.sum b/llm-gateway/go.sum
deleted file mode 100644
index 853ad13..0000000
--- a/llm-gateway/go.sum
+++ /dev/null
@@ -1,123 +0,0 @@
-github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
-github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
-github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
-github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
-github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
-github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
-github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
-github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
-github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
-github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
-github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
-github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
-github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
-github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
-github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
-github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
-github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
-github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
-github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
-github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
-github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
-github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
-github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
-github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
-github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
-github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
-github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
-github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
-github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
-github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
-github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
-github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
-github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
-github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
-github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
-github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
-github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
-github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
-github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
-github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
-github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
-github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
-github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
-github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
-github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
-github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
-github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
-github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
-github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
-github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
-github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
-github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
-github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
-github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
-github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
-github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4=
-github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
-github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
-github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
-github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
-github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
-github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
-github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
-go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
-go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
-go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
-go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
-golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
-golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
-golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
-golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
-golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
-golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
-golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
-golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
-golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
-golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
-golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
-golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
-google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
-google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
-gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
-gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
-gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
-gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
-modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
-modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
-modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
-modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
-modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
-modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
-modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
-modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
-modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
-modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
-modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
-modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
-modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
-modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
-modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
-modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
-modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
-modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
-modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
-modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
-modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
-modernc.org/sqlite v1.45.0 h1:r51cSGzKpbptxnby+EIIz5fop4VuE4qFoVEjNvWoObs=
-modernc.org/sqlite v1.45.0/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
-modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
-modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
-modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
-modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
diff --git a/llm-gateway/internal/auth/handlers.go b/llm-gateway/internal/auth/handlers.go
deleted file mode 100644
index 7585587..0000000
--- a/llm-gateway/internal/auth/handlers.go
+++ /dev/null
@@ -1,1067 +0,0 @@
-package auth
-
-import (
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "fmt"
- "html/template"
- "net/http"
- "strconv"
- "strings"
- "sync"
- "time"
-
- "github.com/go-chi/chi/v5"
-
- "llm-gateway/internal/storage"
-)
-
-type Handlers struct {
- store *Store
- sessionSecret string
- loginLimiter *loginRateLimiter
- auditLogger *storage.AuditLogger
-}
-
-func NewHandlers(store *Store, sessionSecret string) *Handlers {
- return &Handlers{
- store: store,
- sessionSecret: sessionSecret,
- loginLimiter: newLoginRateLimiter(),
- }
-}
-
-func (h *Handlers) SetAuditLogger(al *storage.AuditLogger) {
- h.auditLogger = al
-}
-
-func (h *Handlers) audit(r *http.Request, action, targetType, targetID, details string) {
- if h.auditLogger == nil {
- return
- }
- user := UserFromContext(r.Context())
- var userID int64
- var username string
- if user != nil {
- userID = user.ID
- username = user.Username
- }
- ip := r.RemoteAddr
- if fwd := r.Header.Get("X-Real-IP"); fwd != "" {
- ip = fwd
- }
- h.auditLogger.Log(storage.AuditEntry{
- UserID: userID,
- Username: username,
- Action: action,
- TargetType: targetType,
- TargetID: targetID,
- Details: details,
- IPAddress: ip,
- })
-}
-
-// Login brute-force protection
-type loginRateLimiter struct {
- mu sync.Mutex
- attempts map[string][]time.Time
-}
-
-func newLoginRateLimiter() *loginRateLimiter {
- return &loginRateLimiter{attempts: make(map[string][]time.Time)}
-}
-
-func (l *loginRateLimiter) allow(ip string) bool {
- l.mu.Lock()
- defer l.mu.Unlock()
-
- now := time.Now()
- cutoff := now.Add(-1 * time.Minute)
-
- // Clean old entries
- recent := l.attempts[ip][:0]
- for _, t := range l.attempts[ip] {
- if t.After(cutoff) {
- recent = append(recent, t)
- }
- }
- l.attempts[ip] = recent
-
- if len(recent) >= 5 {
- return false
- }
- l.attempts[ip] = append(l.attempts[ip], now)
- return true
-}
-
-func (h *Handlers) Status(w http.ResponseWriter, r *http.Request) {
- initialized := h.store.HasAnyUser()
-
- resp := map[string]any{
- "initialized": initialized,
- "logged_in": false,
- }
-
- cookie, err := r.Cookie(sessionCookieName)
- if err == nil && cookie.Value != "" {
- sess, err := h.store.GetSession(cookie.Value)
- if err == nil {
- user, err := h.store.GetUserByID(sess.UserID)
- if err == nil {
- resp["logged_in"] = true
- resp["user"] = map[string]any{
- "id": user.ID,
- "username": user.Username,
- "is_admin": user.IsAdmin,
- "totp_enabled": user.TOTPEnabled,
- }
- }
- }
- }
-
- writeJSON(w, resp)
-}
-
-func (h *Handlers) Setup(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
-
- if h.store.HasAnyUser() {
- if htmx {
- writeHTMXError(w, "already initialized")
- return
- }
- writeError(w, http.StatusBadRequest, "already initialized")
- return
- }
-
- var req struct {
- Username string `json:"username"`
- Password string `json:"password"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- if htmx {
- writeHTMXError(w, "invalid request")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid JSON")
- return
- }
- if req.Username == "" || req.Password == "" {
- if htmx {
- writeHTMXError(w, "username and password required")
- return
- }
- writeError(w, http.StatusBadRequest, "username and password required")
- return
- }
- if len(req.Password) < 8 {
- if htmx {
- writeHTMXError(w, "password must be at least 8 characters")
- return
- }
- writeError(w, http.StatusBadRequest, "password must be at least 8 characters")
- return
- }
-
- user, err := h.store.CreateUser(req.Username, req.Password, true)
- if err != nil {
- msg := "failed to create user: " + err.Error()
- if htmx {
- writeHTMXError(w, msg)
- return
- }
- writeError(w, http.StatusInternalServerError, msg)
- return
- }
-
- // Auto-login
- sessionID, err := h.store.CreateSession(user.ID, 7*24*time.Hour)
- if err != nil {
- if htmx {
- writeHTMXError(w, "failed to create session")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to create session")
- return
- }
-
- h.audit(r, "auth.setup", "user", fmt.Sprintf("%d", user.ID), "initial setup")
-
- h.setSessionCookie(w, sessionID)
- if htmx {
- writeHTMXRedirect(w, "/dashboard")
- return
- }
- writeJSON(w, map[string]any{
- "user": map[string]any{
- "id": user.ID,
- "username": user.Username,
- "is_admin": user.IsAdmin,
- },
- })
-}
-
-func (h *Handlers) Login(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
-
- ip := r.RemoteAddr
- if fwd := r.Header.Get("X-Real-IP"); fwd != "" {
- ip = fwd
- }
- if !h.loginLimiter.allow(ip) {
- if htmx {
- writeHTMXError(w, "too many login attempts, try again in a minute")
- return
- }
- writeError(w, http.StatusTooManyRequests, "too many login attempts, try again in a minute")
- return
- }
-
- var req struct {
- Username string `json:"username"`
- Password string `json:"password"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- if htmx {
- writeHTMXError(w, "invalid request")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid JSON")
- return
- }
-
- user, err := h.store.GetUserByUsername(req.Username)
- if err != nil {
- if htmx {
- writeHTMXError(w, "invalid credentials")
- return
- }
- writeError(w, http.StatusUnauthorized, "invalid credentials")
- return
- }
-
- if !h.store.CheckPassword(user, req.Password) {
- if htmx {
- writeHTMXError(w, "invalid credentials")
- return
- }
- writeError(w, http.StatusUnauthorized, "invalid credentials")
- return
- }
-
- if user.TOTPEnabled {
- // Set pending cookie for TOTP step
- pending := h.signPendingToken(user.ID)
- http.SetCookie(w, &http.Cookie{
- Name: "llmgw_pending",
- Value: pending,
- Path: "/",
- HttpOnly: true,
- SameSite: http.SameSiteLaxMode,
- MaxAge: 300, // 5 minutes
- })
- if htmx {
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
- fmt.Fprint(w, `
`)
- return
- }
- writeJSON(w, map[string]any{"require_totp": true})
- return
- }
-
- sessionID, err := h.store.CreateSession(user.ID, 7*24*time.Hour)
- if err != nil {
- if htmx {
- writeHTMXError(w, "failed to create session")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to create session")
- return
- }
-
- h.audit(r, "auth.login", "user", fmt.Sprintf("%d", user.ID), user.Username)
-
- h.setSessionCookie(w, sessionID)
- if htmx {
- writeHTMXRedirect(w, "/dashboard")
- return
- }
- writeJSON(w, map[string]any{
- "require_totp": false,
- "user": map[string]any{
- "id": user.ID,
- "username": user.Username,
- "is_admin": user.IsAdmin,
- },
- })
-}
-
-func (h *Handlers) LoginTOTP(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
-
- cookie, err := r.Cookie("llmgw_pending")
- if err != nil || cookie.Value == "" {
- if htmx {
- writeHTMXError(w, "no pending login")
- return
- }
- writeError(w, http.StatusBadRequest, "no pending login")
- return
- }
-
- userID, err := h.verifyPendingToken(cookie.Value)
- if err != nil {
- if htmx {
- writeHTMXError(w, "invalid or expired pending login")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid or expired pending login")
- return
- }
-
- var req struct {
- Code string `json:"code"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- if htmx {
- writeHTMXError(w, "invalid request")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid JSON")
- return
- }
-
- user, err := h.store.GetUserByID(userID)
- if err != nil {
- if htmx {
- writeHTMXError(w, "user not found")
- return
- }
- writeError(w, http.StatusBadRequest, "user not found")
- return
- }
-
- if !ValidateTOTPCode(user.TOTPSecret, req.Code) {
- if htmx {
- writeHTMXError(w, "invalid TOTP code")
- return
- }
- writeError(w, http.StatusUnauthorized, "invalid TOTP code")
- return
- }
-
- // Clear pending cookie
- http.SetCookie(w, &http.Cookie{
- Name: "llmgw_pending",
- Value: "",
- Path: "/",
- HttpOnly: true,
- MaxAge: -1,
- })
-
- sessionID, err := h.store.CreateSession(user.ID, 7*24*time.Hour)
- if err != nil {
- if htmx {
- writeHTMXError(w, "failed to create session")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to create session")
- return
- }
-
- h.setSessionCookie(w, sessionID)
- if htmx {
- writeHTMXRedirect(w, "/dashboard")
- return
- }
- writeJSON(w, map[string]any{
- "user": map[string]any{
- "id": user.ID,
- "username": user.Username,
- "is_admin": user.IsAdmin,
- },
- })
-}
-
-func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) {
- h.audit(r, "auth.logout", "", "", "")
-
- cookie, err := r.Cookie(sessionCookieName)
- if err == nil {
- h.store.DeleteSession(cookie.Value)
- }
-
- http.SetCookie(w, &http.Cookie{
- Name: sessionCookieName,
- Value: "",
- Path: "/",
- HttpOnly: true,
- MaxAge: -1,
- })
-
- if isHTMX(r) {
- writeHTMXRedirect(w, "/login")
- return
- }
- writeJSON(w, map[string]string{"status": "ok"})
-}
-
-func (h *Handlers) Me(w http.ResponseWriter, r *http.Request) {
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
- writeJSON(w, map[string]any{
- "id": user.ID,
- "username": user.Username,
- "is_admin": user.IsAdmin,
- "totp_enabled": user.TOTPEnabled,
- })
-}
-
-func (h *Handlers) TOTPSetup(w http.ResponseWriter, r *http.Request) {
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
-
- key, err := GenerateTOTPKey(user.Username)
- if err != nil {
- writeError(w, http.StatusInternalServerError, "failed to generate TOTP key")
- return
- }
-
- if err := h.store.SetTOTPSecret(user.ID, key.Secret()); err != nil {
- writeError(w, http.StatusInternalServerError, "failed to save TOTP secret")
- return
- }
-
- writeJSON(w, map[string]string{
- "secret": key.Secret(),
- "uri": key.URL(),
- })
-}
-
-func (h *Handlers) TOTPVerify(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
-
- var req struct {
- Code string `json:"code"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- if htmx {
- writeHTMXError(w, "invalid request")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid JSON")
- return
- }
-
- // Re-fetch user to get latest TOTP secret
- user, err := h.store.GetUserByID(user.ID)
- if err != nil {
- if htmx {
- writeHTMXError(w, "failed to fetch user")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to fetch user")
- return
- }
-
- if user.TOTPSecret == "" {
- if htmx {
- writeHTMXError(w, "TOTP not set up yet")
- return
- }
- writeError(w, http.StatusBadRequest, "TOTP not set up, call /api/auth/totp/setup first")
- return
- }
-
- if !ValidateTOTPCode(user.TOTPSecret, req.Code) {
- if htmx {
- writeHTMXError(w, "invalid TOTP code")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid TOTP code")
- return
- }
-
- if err := h.store.EnableTOTP(user.ID); err != nil {
- if htmx {
- writeHTMXError(w, "failed to enable TOTP")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to enable TOTP")
- return
- }
-
- h.audit(r, "totp.enable", "user", fmt.Sprintf("%d", user.ID), "")
- if htmx {
- // Trigger settings page reload to show updated TOTP status
- w.Header().Set("HX-Trigger", "settingsRefresh")
- writeHTMXSuccess(w, "Two-factor authentication enabled")
- return
- }
- writeJSON(w, map[string]string{"status": "totp_enabled"})
-}
-
-func (h *Handlers) TOTPDisable(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
-
- if err := h.store.DisableTOTP(user.ID); err != nil {
- if htmx {
- writeHTMXError(w, "failed to disable TOTP")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to disable TOTP")
- return
- }
-
- h.audit(r, "totp.disable", "user", fmt.Sprintf("%d", user.ID), "")
- if htmx {
- w.Header().Set("HX-Trigger", "settingsRefresh")
- writeHTMXSuccess(w, "Two-factor authentication disabled")
- return
- }
- writeJSON(w, map[string]string{"status": "totp_disabled"})
-}
-
-// User management (admin only)
-
-func (h *Handlers) ListUsers(w http.ResponseWriter, r *http.Request) {
- users, err := h.store.ListUsers()
- if err != nil {
- writeError(w, http.StatusInternalServerError, "failed to list users")
- return
- }
-
- // Strip sensitive fields
- type safeUser struct {
- ID int64 `json:"id"`
- Username string `json:"username"`
- IsAdmin bool `json:"is_admin"`
- TOTPEnabled bool `json:"totp_enabled"`
- CreatedAt int64 `json:"created_at"`
- }
- result := make([]safeUser, len(users))
- for i, u := range users {
- result[i] = safeUser{
- ID: u.ID,
- Username: u.Username,
- IsAdmin: u.IsAdmin,
- TOTPEnabled: u.TOTPEnabled,
- CreatedAt: u.CreatedAt,
- }
- }
- writeJSON(w, result)
-}
-
-func (h *Handlers) CreateUser(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
- var req struct {
- Username string `json:"username"`
- Password string `json:"password"`
- IsAdmin bool `json:"is_admin"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- if htmx {
- writeHTMXError(w, "invalid request")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid JSON")
- return
- }
- if req.Username == "" || req.Password == "" {
- if htmx {
- writeHTMXError(w, "username and password required")
- return
- }
- writeError(w, http.StatusBadRequest, "username and password required")
- return
- }
- if len(req.Password) < 8 {
- if htmx {
- writeHTMXError(w, "password must be at least 8 characters")
- return
- }
- writeError(w, http.StatusBadRequest, "password must be at least 8 characters")
- return
- }
-
- user, err := h.store.CreateUser(req.Username, req.Password, req.IsAdmin)
- if err != nil {
- msg := "failed to create user: " + err.Error()
- if htmx {
- writeHTMXError(w, msg)
- return
- }
- writeError(w, http.StatusInternalServerError, msg)
- return
- }
-
- h.audit(r, "user.create", "user", fmt.Sprintf("%d", user.ID), user.Username)
- if htmx {
- writeHTMXRefresh(w)
- return
- }
- writeJSON(w, map[string]any{
- "id": user.ID,
- "username": user.Username,
- "is_admin": user.IsAdmin,
- })
-}
-
-func (h *Handlers) DeleteUser(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
- idStr := chi.URLParam(r, "id")
- id, err := strconv.ParseInt(idStr, 10, 64)
- if err != nil {
- if htmx {
- writeHTMXError(w, "invalid user ID")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid user ID")
- return
- }
-
- // Prevent deleting yourself
- user := UserFromContext(r.Context())
- if user != nil && user.ID == id {
- if htmx {
- writeHTMXError(w, "cannot delete yourself")
- return
- }
- writeError(w, http.StatusBadRequest, "cannot delete yourself")
- return
- }
-
- if err := h.store.DeleteUser(id); err != nil {
- if htmx {
- writeHTMXError(w, err.Error())
- return
- }
- writeError(w, http.StatusBadRequest, err.Error())
- return
- }
-
- h.audit(r, "user.delete", "user", idStr, "")
- if htmx {
- writeHTMXRefresh(w)
- return
- }
- writeJSON(w, map[string]string{"status": "deleted"})
-}
-
-// API Token management
-
-func (h *Handlers) ListTokens(w http.ResponseWriter, r *http.Request) {
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
-
- var userID int64
- if !user.IsAdmin {
- userID = user.ID
- }
- // userID=0 means list all (admin)
-
- tokens, err := h.store.ListAPITokens(userID)
- if err != nil {
- writeError(w, http.StatusInternalServerError, "failed to list tokens")
- return
- }
- if tokens == nil {
- tokens = []APIToken{}
- }
- writeJSON(w, tokens)
-}
-
-func (h *Handlers) CreateToken(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
-
- var req struct {
- Name string `json:"name"`
- RateLimitRPM int `json:"rate_limit_rpm"`
- DailyBudgetUSD float64 `json:"daily_budget_usd"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- if htmx {
- writeHTMXError(w, "invalid request")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid JSON")
- return
- }
- if req.Name == "" {
- if htmx {
- writeHTMXError(w, "name is required")
- return
- }
- writeError(w, http.StatusBadRequest, "name is required")
- return
- }
- // RateLimitRPM: 0 = unlimited, negative treated as 0
- if req.RateLimitRPM < 0 {
- req.RateLimitRPM = 0
- }
-
- plainKey, token, err := h.store.CreateAPIToken(user.ID, req.Name, req.RateLimitRPM, req.DailyBudgetUSD)
- if err != nil {
- msg := "failed to create token: " + err.Error()
- if htmx {
- writeHTMXError(w, msg)
- return
- }
- writeError(w, http.StatusInternalServerError, msg)
- return
- }
-
- h.audit(r, "token.create", "token", fmt.Sprintf("%d", token.ID), req.Name)
- if htmx {
- // Return HTML with the key display and trigger page refresh
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
- w.Header().Set("HX-Trigger", "tokenCreated")
- escaped := template.HTMLEscapeString(plainKey)
- fmt.Fprintf(w, `Token created! Copy the key below — it won't be shown again.
-%s
`, escaped)
- return
- }
- writeJSON(w, map[string]any{
- "key": plainKey,
- "token": token,
- })
-}
-
-func (h *Handlers) DeleteToken(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
-
- idStr := chi.URLParam(r, "id")
- id, err := strconv.ParseInt(idStr, 10, 64)
- if err != nil {
- if htmx {
- writeHTMXError(w, "invalid token ID")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid token ID")
- return
- }
-
- // Non-admin can only delete own tokens
- if !user.IsAdmin {
- token, err := h.store.GetAPIToken(id)
- if err != nil {
- if htmx {
- writeHTMXError(w, "token not found")
- return
- }
- writeError(w, http.StatusNotFound, "token not found")
- return
- }
- if token.UserID != user.ID {
- if htmx {
- writeHTMXError(w, "not your token")
- return
- }
- writeError(w, http.StatusForbidden, "not your token")
- return
- }
- }
-
- if err := h.store.DeleteAPIToken(id); err != nil {
- if htmx {
- writeHTMXError(w, "failed to delete token")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to delete token")
- return
- }
-
- h.audit(r, "token.delete", "token", idStr, "")
- if htmx {
- writeHTMXRefresh(w)
- return
- }
- writeJSON(w, map[string]string{"status": "deleted"})
-}
-
-// Self-service endpoints
-
-func (h *Handlers) ChangePassword(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
-
- var req struct {
- CurrentPassword string `json:"current_password"`
- NewPassword string `json:"new_password"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- if htmx {
- writeHTMXError(w, "invalid request")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid JSON")
- return
- }
- if req.NewPassword == "" || len(req.NewPassword) < 8 {
- if htmx {
- writeHTMXError(w, "new password must be at least 8 characters")
- return
- }
- writeError(w, http.StatusBadRequest, "new password must be at least 8 characters")
- return
- }
-
- // Re-fetch user to get password hash
- user, err := h.store.GetUserByID(user.ID)
- if err != nil {
- if htmx {
- writeHTMXError(w, "failed to fetch user")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to fetch user")
- return
- }
-
- if !h.store.CheckPassword(user, req.CurrentPassword) {
- if htmx {
- writeHTMXError(w, "current password is incorrect")
- return
- }
- writeError(w, http.StatusUnauthorized, "current password is incorrect")
- return
- }
-
- if err := h.store.UpdatePassword(user.ID, req.NewPassword); err != nil {
- if htmx {
- writeHTMXError(w, "failed to update password")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to update password")
- return
- }
-
- h.audit(r, "password.change", "user", fmt.Sprintf("%d", user.ID), "")
- if htmx {
- writeHTMXSuccess(w, "Password updated")
- return
- }
- writeJSON(w, map[string]string{"status": "password_updated"})
-}
-
-func (h *Handlers) ChangeUsername(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
-
- var req struct {
- NewUsername string `json:"new_username"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- if htmx {
- writeHTMXError(w, "invalid request")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid JSON")
- return
- }
- if req.NewUsername == "" {
- if htmx {
- writeHTMXError(w, "username is required")
- return
- }
- writeError(w, http.StatusBadRequest, "username is required")
- return
- }
-
- // Check uniqueness
- existing, err := h.store.GetUserByUsername(req.NewUsername)
- if err == nil && existing.ID != user.ID {
- if htmx {
- writeHTMXError(w, "username already taken")
- return
- }
- writeError(w, http.StatusConflict, "username already taken")
- return
- }
-
- if err := h.store.UpdateUsername(user.ID, req.NewUsername); err != nil {
- if htmx {
- writeHTMXError(w, "failed to update username")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to update username")
- return
- }
-
- h.audit(r, "username.change", "user", fmt.Sprintf("%d", user.ID), req.NewUsername)
- if htmx {
- writeHTMXSuccess(w, "Username updated")
- return
- }
- writeJSON(w, map[string]string{"status": "username_updated"})
-}
-
-func (h *Handlers) ChangeEmail(w http.ResponseWriter, r *http.Request) {
- htmx := isHTMX(r)
- user := UserFromContext(r.Context())
- if user == nil {
- writeError(w, http.StatusUnauthorized, "not authenticated")
- return
- }
-
- var req struct {
- Email string `json:"email"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- if htmx {
- writeHTMXError(w, "invalid request")
- return
- }
- writeError(w, http.StatusBadRequest, "invalid JSON")
- return
- }
-
- if err := h.store.UpdateEmail(user.ID, req.Email); err != nil {
- if htmx {
- writeHTMXError(w, "failed to update email")
- return
- }
- writeError(w, http.StatusInternalServerError, "failed to update email")
- return
- }
-
- if htmx {
- writeHTMXSuccess(w, "Email updated")
- return
- }
- writeJSON(w, map[string]string{"status": "email_updated"})
-}
-
-// Helpers
-
-func (h *Handlers) setSessionCookie(w http.ResponseWriter, sessionID string) {
- http.SetCookie(w, &http.Cookie{
- Name: sessionCookieName,
- Value: sessionID,
- Path: "/",
- HttpOnly: true,
- SameSite: http.SameSiteLaxMode,
- MaxAge: sessionTTLDays * 24 * 60 * 60,
- })
-}
-
-func (h *Handlers) signPendingToken(userID int64) string {
- data := fmt.Sprintf("%d:%d", userID, time.Now().Unix())
- mac := hmac.New(sha256.New, []byte(h.sessionSecret))
- mac.Write([]byte(data))
- sig := hex.EncodeToString(mac.Sum(nil))
- return data + ":" + sig
-}
-
-func (h *Handlers) verifyPendingToken(token string) (int64, error) {
- parts := strings.SplitN(token, ":", 3)
- if len(parts) != 3 {
- return 0, fmt.Errorf("invalid format")
- }
-
- userID, err := strconv.ParseInt(parts[0], 10, 64)
- if err != nil {
- return 0, fmt.Errorf("invalid user ID")
- }
-
- ts, err := strconv.ParseInt(parts[1], 10, 64)
- if err != nil {
- return 0, fmt.Errorf("invalid timestamp")
- }
-
- // Check expiry (5 minutes)
- if time.Now().Unix()-ts > 300 {
- return 0, fmt.Errorf("expired")
- }
-
- // Verify HMAC
- data := parts[0] + ":" + parts[1]
- mac := hmac.New(sha256.New, []byte(h.sessionSecret))
- mac.Write([]byte(data))
- expectedSig := hex.EncodeToString(mac.Sum(nil))
-
- if !hmac.Equal([]byte(parts[2]), []byte(expectedSig)) {
- return 0, fmt.Errorf("invalid signature")
- }
-
- return userID, nil
-}
-
-func isHTMX(r *http.Request) bool {
- return r.Header.Get("HX-Request") == "true"
-}
-
-func writeHTMXError(w http.ResponseWriter, msg string) {
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
- fmt.Fprintf(w, `%s
`, template.HTMLEscapeString(msg))
-}
-
-func writeHTMXSuccess(w http.ResponseWriter, msg string) {
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
- fmt.Fprintf(w, `%s
`, template.HTMLEscapeString(msg))
-}
-
-func writeHTMXRedirect(w http.ResponseWriter, url string) {
- w.Header().Set("HX-Redirect", url)
- w.WriteHeader(http.StatusOK)
-}
-
-func writeHTMXRefresh(w http.ResponseWriter) {
- w.Header().Set("HX-Refresh", "true")
- w.WriteHeader(http.StatusOK)
-}
-
-func writeJSON(w http.ResponseWriter, v any) {
- w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(v)
-}
-
-func writeError(w http.ResponseWriter, code int, msg string) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(code)
- json.NewEncoder(w).Encode(map[string]string{"error": msg})
-}
diff --git a/llm-gateway/internal/auth/middleware.go b/llm-gateway/internal/auth/middleware.go
deleted file mode 100644
index bceadf1..0000000
--- a/llm-gateway/internal/auth/middleware.go
+++ /dev/null
@@ -1,83 +0,0 @@
-package auth
-
-import (
- "context"
- "encoding/json"
- "net/http"
- "strings"
-)
-
-type contextKey string
-
-const userContextKey contextKey = "auth_user"
-
-const (
- sessionCookieName = "llmgw_session"
- sessionTTLDays = 7
-)
-
-type Middleware struct {
- store *Store
-}
-
-func NewMiddleware(store *Store) *Middleware {
- return &Middleware{store: store}
-}
-
-func UserFromContext(ctx context.Context) *User {
- u, _ := ctx.Value(userContextKey).(*User)
- return u
-}
-
-func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- cookie, err := r.Cookie(sessionCookieName)
- if err != nil || cookie.Value == "" {
- m.unauthorized(w, r)
- return
- }
-
- sess, err := m.store.GetSession(cookie.Value)
- if err != nil {
- m.unauthorized(w, r)
- return
- }
-
- user, err := m.store.GetUserByID(sess.UserID)
- if err != nil {
- m.unauthorized(w, r)
- return
- }
-
- ctx := context.WithValue(r.Context(), userContextKey, user)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
-}
-
-func (m *Middleware) RequireAdmin(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- user := UserFromContext(r.Context())
- if user == nil || !user.IsAdmin {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusForbidden)
- json.NewEncoder(w).Encode(map[string]string{"error": "admin access required"})
- return
- }
- next.ServeHTTP(w, r)
- })
-}
-
-func (m *Middleware) unauthorized(w http.ResponseWriter, r *http.Request) {
- if r.Header.Get("HX-Request") == "true" {
- w.Header().Set("HX-Redirect", "/login")
- w.WriteHeader(http.StatusUnauthorized)
- return
- }
- if strings.HasPrefix(r.URL.Path, "/api/") {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusUnauthorized)
- json.NewEncoder(w).Encode(map[string]string{"error": "authentication required"})
- return
- }
- http.Redirect(w, r, "/login", http.StatusFound)
-}
diff --git a/llm-gateway/internal/auth/store.go b/llm-gateway/internal/auth/store.go
deleted file mode 100644
index 4f62ff7..0000000
--- a/llm-gateway/internal/auth/store.go
+++ /dev/null
@@ -1,391 +0,0 @@
-package auth
-
-import (
- "crypto/rand"
- "crypto/sha256"
- "database/sql"
- "encoding/hex"
- "fmt"
- "time"
-
- "golang.org/x/crypto/bcrypt"
-)
-
-type User struct {
- ID int64 `json:"id"`
- Username string `json:"username"`
- Email string `json:"email"`
- PasswordHash string `json:"-"`
- IsAdmin bool `json:"is_admin"`
- TOTPSecret string `json:"-"`
- TOTPEnabled bool `json:"totp_enabled"`
- CreatedAt int64 `json:"created_at"`
- UpdatedAt int64 `json:"updated_at"`
-}
-
-type Session struct {
- ID string
- UserID int64
- CreatedAt int64
- ExpiresAt int64
-}
-
-type APIToken struct {
- ID int64 `json:"id"`
- Name string `json:"name"`
- KeyPrefix string `json:"key_prefix"`
- KeyHash string `json:"-"`
- UserID int64 `json:"user_id"`
- RateLimitRPM int `json:"rate_limit_rpm"`
- DailyBudgetUSD float64 `json:"daily_budget_usd"`
- MonthlyBudgetUSD float64 `json:"monthly_budget_usd"`
- MaxConcurrent int `json:"max_concurrent"`
- CreatedAt int64 `json:"created_at"`
- LastUsedAt int64 `json:"last_used_at"`
-}
-
-// StaticToken represents a token defined in config (checked in-memory, never stored in DB).
-type StaticToken struct {
- Name string
- Key string
- RateLimitRPM int
- DailyBudgetUSD float64
- MonthlyBudgetUSD float64
- MaxConcurrent int
-}
-
-type Store struct {
- db *sql.DB
- staticTokens []StaticToken
-}
-
-func NewStore(db *sql.DB, staticTokens []StaticToken) *Store {
- return &Store{db: db, staticTokens: staticTokens}
-}
-
-// SetStaticTokens updates the static tokens list (used for config hot-reload).
-func (s *Store) SetStaticTokens(tokens []StaticToken) {
- s.staticTokens = tokens
-}
-
-func (s *Store) HasAnyUser() bool {
- var count int
- s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
- return count > 0
-}
-
-func (s *Store) CreateUser(username, password string, isAdmin bool) (*User, error) {
- hash, err := bcrypt.GenerateFromPassword([]byte(password), 12)
- if err != nil {
- return nil, fmt.Errorf("hashing password: %w", err)
- }
-
- now := time.Now().Unix()
- adminInt := 0
- if isAdmin {
- adminInt = 1
- }
-
- result, err := s.db.Exec(
- "INSERT INTO users (username, password_hash, is_admin, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
- username, string(hash), adminInt, now, now,
- )
- if err != nil {
- return nil, fmt.Errorf("creating user: %w", err)
- }
-
- id, _ := result.LastInsertId()
- return &User{
- ID: id,
- Username: username,
- IsAdmin: isAdmin,
- CreatedAt: now,
- UpdatedAt: now,
- }, nil
-}
-
-func (s *Store) GetUserByUsername(username string) (*User, error) {
- return s.scanUser(s.db.QueryRow(
- "SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users WHERE username = ?",
- username,
- ))
-}
-
-func (s *Store) GetUserByID(id int64) (*User, error) {
- return s.scanUser(s.db.QueryRow(
- "SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users WHERE id = ?",
- id,
- ))
-}
-
-func (s *Store) scanUser(row *sql.Row) (*User, error) {
- var u User
- var isAdmin, totpEnabled int
- var totpSecret sql.NullString
- var email sql.NullString
- err := row.Scan(&u.ID, &u.Username, &email, &u.PasswordHash, &isAdmin, &totpSecret, &totpEnabled, &u.CreatedAt, &u.UpdatedAt)
- if err != nil {
- return nil, err
- }
- u.Email = email.String
- u.IsAdmin = isAdmin == 1
- u.TOTPEnabled = totpEnabled == 1
- u.TOTPSecret = totpSecret.String
- return &u, nil
-}
-
-func (s *Store) ListUsers() ([]User, error) {
- rows, err := s.db.Query("SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users ORDER BY id")
- if err != nil {
- return nil, err
- }
- defer rows.Close()
-
- var users []User
- for rows.Next() {
- var u User
- var isAdmin, totpEnabled int
- var totpSecret sql.NullString
- var email sql.NullString
- if err := rows.Scan(&u.ID, &u.Username, &email, &u.PasswordHash, &isAdmin, &totpSecret, &totpEnabled, &u.CreatedAt, &u.UpdatedAt); err != nil {
- return nil, err
- }
- u.Email = email.String
- u.IsAdmin = isAdmin == 1
- u.TOTPEnabled = totpEnabled == 1
- u.TOTPSecret = totpSecret.String
- users = append(users, u)
- }
- return users, nil
-}
-
-func (s *Store) DeleteUser(id int64) error {
- // Prevent deleting the last admin
- var adminCount int
- s.db.QueryRow("SELECT COUNT(*) FROM users WHERE is_admin = 1").Scan(&adminCount)
-
- var isAdmin int
- s.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", id).Scan(&isAdmin)
- if isAdmin == 1 && adminCount <= 1 {
- return fmt.Errorf("cannot delete the last admin user")
- }
-
- _, err := s.db.Exec("DELETE FROM users WHERE id = ?", id)
- return err
-}
-
-func (s *Store) UpdatePassword(userID int64, newPassword string) error {
- hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12)
- if err != nil {
- return fmt.Errorf("hashing password: %w", err)
- }
- _, err = s.db.Exec("UPDATE users SET password_hash = ?, updated_at = ? WHERE id = ?", string(hash), time.Now().Unix(), userID)
- return err
-}
-
-func (s *Store) CheckPassword(user *User, password string) bool {
- return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) == nil
-}
-
-func (s *Store) SetTOTPSecret(userID int64, secret string) error {
- _, err := s.db.Exec("UPDATE users SET totp_secret = ?, updated_at = ? WHERE id = ?", secret, time.Now().Unix(), userID)
- return err
-}
-
-func (s *Store) EnableTOTP(userID int64) error {
- _, err := s.db.Exec("UPDATE users SET totp_enabled = 1, updated_at = ? WHERE id = ?", time.Now().Unix(), userID)
- return err
-}
-
-func (s *Store) DisableTOTP(userID int64) error {
- _, err := s.db.Exec("UPDATE users SET totp_enabled = 0, totp_secret = '', updated_at = ? WHERE id = ?", time.Now().Unix(), userID)
- return err
-}
-
-// Session management
-
-func (s *Store) CreateSession(userID int64, ttl time.Duration) (string, error) {
- b := make([]byte, 32)
- if _, err := rand.Read(b); err != nil {
- return "", fmt.Errorf("generating session ID: %w", err)
- }
- id := hex.EncodeToString(b)
- now := time.Now().Unix()
- expiresAt := time.Now().Add(ttl).Unix()
-
- _, err := s.db.Exec(
- "INSERT INTO sessions (id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)",
- id, userID, now, expiresAt,
- )
- if err != nil {
- return "", fmt.Errorf("creating session: %w", err)
- }
- return id, nil
-}
-
-func (s *Store) GetSession(sessionID string) (*Session, error) {
- var sess Session
- err := s.db.QueryRow(
- "SELECT id, user_id, created_at, expires_at FROM sessions WHERE id = ? AND expires_at > ?",
- sessionID, time.Now().Unix(),
- ).Scan(&sess.ID, &sess.UserID, &sess.CreatedAt, &sess.ExpiresAt)
- if err != nil {
- return nil, err
- }
- return &sess, nil
-}
-
-func (s *Store) DeleteSession(id string) error {
- _, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", id)
- return err
-}
-
-func (s *Store) CleanExpiredSessions() error {
- _, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now().Unix())
- return err
-}
-
-// API Token management
-
-func (s *Store) CreateAPIToken(userID int64, name string, rateLimitRPM int, dailyBudgetUSD float64) (string, *APIToken, error) {
- // Generate sk- prefixed random key
- b := make([]byte, 32)
- if _, err := rand.Read(b); err != nil {
- return "", nil, fmt.Errorf("generating token: %w", err)
- }
- plainKey := "sk-" + hex.EncodeToString(b)
- keyPrefix := plainKey[:11] // "sk-" + first 8 hex chars
-
- hash := sha256.Sum256([]byte(plainKey))
- keyHash := hex.EncodeToString(hash[:])
-
- now := time.Now().Unix()
- result, err := s.db.Exec(
- "INSERT INTO api_tokens (name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
- name, keyHash, keyPrefix, userID, rateLimitRPM, dailyBudgetUSD, now,
- )
- if err != nil {
- return "", nil, fmt.Errorf("creating API token: %w", err)
- }
-
- id, _ := result.LastInsertId()
- token := &APIToken{
- ID: id,
- Name: name,
- KeyPrefix: keyPrefix,
- KeyHash: keyHash,
- UserID: userID,
- RateLimitRPM: rateLimitRPM,
- DailyBudgetUSD: dailyBudgetUSD,
- CreatedAt: now,
- }
- return plainKey, token, nil
-}
-
-func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
- // Check static tokens first (from config, never stored in DB)
- for _, st := range s.staticTokens {
- if st.Key == key {
- prefix := st.Key
- if len(prefix) > 11 {
- prefix = prefix[:11]
- }
- return &APIToken{
- ID: -1, // sentinel: static token
- Name: st.Name,
- KeyPrefix: prefix,
- RateLimitRPM: st.RateLimitRPM,
- DailyBudgetUSD: st.DailyBudgetUSD,
- MonthlyBudgetUSD: st.MonthlyBudgetUSD,
- MaxConcurrent: st.MaxConcurrent,
- }, nil
- }
- }
-
- // Fall back to DB tokens
- hash := sha256.Sum256([]byte(key))
- keyHash := hex.EncodeToString(hash[:])
-
- var t APIToken
- err := s.db.QueryRow(
- "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
- keyHash,
- ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
- if err != nil {
- return nil, err
- }
- return &t, nil
-}
-
-func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
- // Include static tokens (shown for all users, not deletable)
- var tokens []APIToken
- for _, st := range s.staticTokens {
- prefix := st.Key
- if len(prefix) > 11 {
- prefix = prefix[:11]
- }
- tokens = append(tokens, APIToken{
- ID: -1, // sentinel: static token
- Name: st.Name,
- KeyPrefix: prefix,
- RateLimitRPM: st.RateLimitRPM,
- DailyBudgetUSD: st.DailyBudgetUSD,
- MonthlyBudgetUSD: st.MonthlyBudgetUSD,
- MaxConcurrent: st.MaxConcurrent,
- })
- }
-
- // DB tokens
- var rows *sql.Rows
- var err error
- if userID == 0 {
- rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id")
- } else {
- rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
- }
- if err != nil {
- return tokens, nil
- }
- defer rows.Close()
-
- for rows.Next() {
- var t APIToken
- if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil {
- return tokens, nil
- }
- tokens = append(tokens, t)
- }
- return tokens, nil
-}
-
-func (s *Store) DeleteAPIToken(id int64) error {
- _, err := s.db.Exec("DELETE FROM api_tokens WHERE id = ?", id)
- return err
-}
-
-func (s *Store) GetAPIToken(id int64) (*APIToken, error) {
- var t APIToken
- err := s.db.QueryRow(
- "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?",
- id,
- ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
- if err != nil {
- return nil, err
- }
- return &t, nil
-}
-
-func (s *Store) UpdateAPITokenLastUsed(id int64) {
- s.db.Exec("UPDATE api_tokens SET last_used_at = ? WHERE id = ?", time.Now().Unix(), id)
-}
-
-func (s *Store) UpdateUsername(userID int64, newUsername string) error {
- _, err := s.db.Exec("UPDATE users SET username = ?, updated_at = ? WHERE id = ?", newUsername, time.Now().Unix(), userID)
- return err
-}
-
-func (s *Store) UpdateEmail(userID int64, email string) error {
- _, err := s.db.Exec("UPDATE users SET email = ?, updated_at = ? WHERE id = ?", email, time.Now().Unix(), userID)
- return err
-}
diff --git a/llm-gateway/internal/auth/store_test.go b/llm-gateway/internal/auth/store_test.go
deleted file mode 100644
index d566224..0000000
--- a/llm-gateway/internal/auth/store_test.go
+++ /dev/null
@@ -1,301 +0,0 @@
-package auth
-
-import (
- "database/sql"
- "testing"
- "time"
-
- _ "modernc.org/sqlite"
-)
-
-func setupTestDB(t *testing.T) *sql.DB {
- t.Helper()
- db, err := sql.Open("sqlite", ":memory:")
- if err != nil {
- t.Fatalf("opening test db: %v", err)
- }
-
- // Create tables
- _, err = db.Exec(`
- CREATE TABLE users (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- username TEXT UNIQUE NOT NULL,
- email TEXT DEFAULT '',
- password_hash TEXT NOT NULL,
- is_admin INTEGER DEFAULT 0,
- totp_secret TEXT DEFAULT '',
- totp_enabled INTEGER DEFAULT 0,
- created_at INTEGER NOT NULL,
- updated_at INTEGER NOT NULL
- );
- CREATE TABLE sessions (
- id TEXT PRIMARY KEY,
- user_id INTEGER NOT NULL,
- created_at INTEGER NOT NULL,
- expires_at INTEGER NOT NULL
- );
- CREATE TABLE api_tokens (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT NOT NULL,
- key_hash TEXT NOT NULL,
- key_prefix TEXT NOT NULL,
- user_id INTEGER NOT NULL,
- rate_limit_rpm INTEGER DEFAULT 0,
- daily_budget_usd REAL DEFAULT 0,
- monthly_budget_usd REAL DEFAULT 0,
- max_concurrent INTEGER DEFAULT 0,
- created_at INTEGER NOT NULL,
- last_used_at INTEGER DEFAULT 0
- );
- `)
- if err != nil {
- t.Fatalf("creating tables: %v", err)
- }
- return db
-}
-
-func TestCreateUser(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
- store := NewStore(db, nil)
-
- user, err := store.CreateUser("alice", "password123", true)
- if err != nil {
- t.Fatalf("CreateUser: %v", err)
- }
- if user.Username != "alice" {
- t.Errorf("expected username 'alice', got '%s'", user.Username)
- }
- if !user.IsAdmin {
- t.Error("expected admin user")
- }
- if user.ID == 0 {
- t.Error("expected non-zero ID")
- }
-}
-
-func TestGetUserByUsername(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
- store := NewStore(db, nil)
-
- store.CreateUser("bob", "password123", false)
-
- user, err := store.GetUserByUsername("bob")
- if err != nil {
- t.Fatalf("GetUserByUsername: %v", err)
- }
- if user.Username != "bob" {
- t.Errorf("expected 'bob', got '%s'", user.Username)
- }
-
- _, err = store.GetUserByUsername("nonexistent")
- if err == nil {
- t.Error("expected error for nonexistent user")
- }
-}
-
-func TestCheckPassword(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
- store := NewStore(db, nil)
-
- store.CreateUser("charlie", "correctpassword", false)
- user, _ := store.GetUserByUsername("charlie")
-
- if !store.CheckPassword(user, "correctpassword") {
- t.Error("correct password should match")
- }
- if store.CheckPassword(user, "wrongpassword") {
- t.Error("wrong password should not match")
- }
-}
-
-func TestUpdatePassword(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
- store := NewStore(db, nil)
-
- user, _ := store.CreateUser("dave", "oldpass12", false)
-
- if err := store.UpdatePassword(user.ID, "newpass12"); err != nil {
- t.Fatalf("UpdatePassword: %v", err)
- }
-
- user, _ = store.GetUserByUsername("dave")
- if store.CheckPassword(user, "oldpass12") {
- t.Error("old password should not work")
- }
- if !store.CheckPassword(user, "newpass12") {
- t.Error("new password should work")
- }
-}
-
-func TestDeleteUser(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
- store := NewStore(db, nil)
-
- user1, _ := store.CreateUser("admin1", "password1234", true)
- user2, _ := store.CreateUser("user2", "password1234", false)
-
- // Can delete non-admin
- if err := store.DeleteUser(user2.ID); err != nil {
- t.Fatalf("DeleteUser: %v", err)
- }
-
- // Cannot delete last admin
- if err := store.DeleteUser(user1.ID); err == nil {
- t.Error("should not be able to delete last admin")
- }
-}
-
-func TestHasAnyUser(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
- store := NewStore(db, nil)
-
- if store.HasAnyUser() {
- t.Error("should have no users initially")
- }
-
- store.CreateUser("first", "password1234", true)
-
- if !store.HasAnyUser() {
- t.Error("should have users after creation")
- }
-}
-
-func TestSessionCRUD(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
- store := NewStore(db, nil)
-
- user, _ := store.CreateUser("sessuser", "password1234", false)
-
- sessionID, err := store.CreateSession(user.ID, 1*time.Hour)
- if err != nil {
- t.Fatalf("CreateSession: %v", err)
- }
-
- sess, err := store.GetSession(sessionID)
- if err != nil {
- t.Fatalf("GetSession: %v", err)
- }
- if sess.UserID != user.ID {
- t.Errorf("expected user ID %d, got %d", user.ID, sess.UserID)
- }
-
- if err := store.DeleteSession(sessionID); err != nil {
- t.Fatalf("DeleteSession: %v", err)
- }
-
- _, err = store.GetSession(sessionID)
- if err == nil {
- t.Error("session should be deleted")
- }
-}
-
-func TestStaticTokenLookup(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
-
- staticTokens := []StaticToken{
- {Name: "test-token", Key: "sk-test-key-12345678", RateLimitRPM: 60, DailyBudgetUSD: 10.0, MaxConcurrent: 5},
- }
- store := NewStore(db, staticTokens)
-
- token, err := store.LookupAPIToken("sk-test-key-12345678")
- if err != nil {
- t.Fatalf("LookupAPIToken: %v", err)
- }
- if token.Name != "test-token" {
- t.Errorf("expected 'test-token', got '%s'", token.Name)
- }
- if token.ID != -1 {
- t.Errorf("static token should have ID -1, got %d", token.ID)
- }
- if token.RateLimitRPM != 60 {
- t.Errorf("expected RPM 60, got %d", token.RateLimitRPM)
- }
- if token.MaxConcurrent != 5 {
- t.Errorf("expected max_concurrent 5, got %d", token.MaxConcurrent)
- }
-
- // Non-existent token
- _, err = store.LookupAPIToken("nonexistent")
- if err == nil {
- t.Error("should error on nonexistent token")
- }
-}
-
-func TestDBTokenCRUD(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
- store := NewStore(db, nil)
-
- user, _ := store.CreateUser("tokenuser", "password1234", false)
-
- plainKey, token, err := store.CreateAPIToken(user.ID, "my-token", 100, 5.0)
- if err != nil {
- t.Fatalf("CreateAPIToken: %v", err)
- }
- if plainKey == "" {
- t.Error("plain key should not be empty")
- }
- if token.Name != "my-token" {
- t.Errorf("expected 'my-token', got '%s'", token.Name)
- }
-
- // Lookup by key
- found, err := store.LookupAPIToken(plainKey)
- if err != nil {
- t.Fatalf("LookupAPIToken: %v", err)
- }
- if found.Name != "my-token" {
- t.Errorf("expected 'my-token', got '%s'", found.Name)
- }
-
- // List tokens
- tokens, err := store.ListAPITokens(user.ID)
- if err != nil {
- t.Fatalf("ListAPITokens: %v", err)
- }
- if len(tokens) != 1 {
- t.Errorf("expected 1 token, got %d", len(tokens))
- }
-
- // Delete
- if err := store.DeleteAPIToken(token.ID); err != nil {
- t.Fatalf("DeleteAPIToken: %v", err)
- }
-
- _, err = store.LookupAPIToken(plainKey)
- if err == nil {
- t.Error("token should be deleted")
- }
-}
-
-func TestSetStaticTokens(t *testing.T) {
- db := setupTestDB(t)
- defer db.Close()
-
- store := NewStore(db, nil)
-
- _, err := store.LookupAPIToken("key1")
- if err == nil {
- t.Error("should not find token before setting")
- }
-
- store.SetStaticTokens([]StaticToken{
- {Name: "new-token", Key: "key1"},
- })
-
- token, err := store.LookupAPIToken("key1")
- if err != nil {
- t.Fatalf("after SetStaticTokens: %v", err)
- }
- if token.Name != "new-token" {
- t.Errorf("expected 'new-token', got '%s'", token.Name)
- }
-}
diff --git a/llm-gateway/internal/auth/totp.go b/llm-gateway/internal/auth/totp.go
deleted file mode 100644
index b3092ff..0000000
--- a/llm-gateway/internal/auth/totp.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package auth
-
-import (
- "github.com/pquerna/otp"
- "github.com/pquerna/otp/totp"
-)
-
-func GenerateTOTPKey(username string) (*otp.Key, error) {
- return totp.Generate(totp.GenerateOpts{
- Issuer: "LLM Gateway",
- AccountName: username,
- })
-}
-
-func ValidateTOTPCode(secret, code string) bool {
- return totp.Validate(code, secret)
-}
diff --git a/llm-gateway/internal/cache/cache.go b/llm-gateway/internal/cache/cache.go
deleted file mode 100644
index 40852c5..0000000
--- a/llm-gateway/internal/cache/cache.go
+++ /dev/null
@@ -1,178 +0,0 @@
-package cache
-
-import (
- "context"
- "crypto/sha256"
- "fmt"
- "time"
-
- "github.com/redis/go-redis/v9"
-)
-
-type Cache struct {
- client *redis.Client
- ttl time.Duration
-}
-
-func New(addr string, ttlSeconds int) (*Cache, error) {
- client := redis.NewClient(&redis.Options{
- Addr: addr,
- })
-
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
-
- if err := client.Ping(ctx).Err(); err != nil {
- return nil, fmt.Errorf("connecting to Valkey: %w", err)
- }
-
- ttl := time.Duration(ttlSeconds) * time.Second
- if ttl == 0 {
- ttl = 1 * time.Hour
- }
-
- return &Cache{client: client, ttl: ttl}, nil
-}
-
-func (c *Cache) Get(ctx context.Context, model string, requestBody []byte) ([]byte, error) {
- key := c.cacheKey(model, requestBody)
- data, err := c.client.Get(ctx, key).Bytes()
- if err == redis.Nil {
- return nil, nil
- }
- return data, err
-}
-
-func (c *Cache) Set(ctx context.Context, model string, requestBody, responseBody []byte) error {
- key := c.cacheKey(model, requestBody)
- return c.client.Set(ctx, key, responseBody, c.ttl).Err()
-}
-
-func (c *Cache) Ping(ctx context.Context) error {
- return c.client.Ping(ctx).Err()
-}
-
-func (c *Cache) Close() error {
- return c.client.Close()
-}
-
-// CacheStats holds cache statistics from the Valkey/Redis server.
-type CacheStats struct {
- Hits int64 `json:"hits"`
- Misses int64 `json:"misses"`
- HitRate float64 `json:"hit_rate"`
- MemoryUsed string `json:"memory_used"`
- Keys int64 `json:"keys"`
- Connected bool `json:"connected"`
-}
-
-// Stats returns cache statistics by querying Valkey/Redis INFO.
-func (c *Cache) Stats(ctx context.Context) *CacheStats {
- stats := &CacheStats{}
-
- // Check connectivity
- if err := c.client.Ping(ctx).Err(); err != nil {
- return stats
- }
- stats.Connected = true
-
- // Parse INFO stats for hits/misses
- info, err := c.client.Info(ctx, "stats").Result()
- if err == nil {
- stats.Hits = parseInfoInt(info, "keyspace_hits")
- stats.Misses = parseInfoInt(info, "keyspace_misses")
- total := stats.Hits + stats.Misses
- if total > 0 {
- stats.HitRate = float64(stats.Hits) / float64(total)
- }
- }
-
- // Parse INFO memory
- memInfo, err := c.client.Info(ctx, "memory").Result()
- if err == nil {
- stats.MemoryUsed = parseInfoString(memInfo, "used_memory_human")
- }
-
- // Parse INFO keyspace
- ksInfo, err := c.client.Info(ctx, "keyspace").Result()
- if err == nil {
- stats.Keys = parseKeyspaceKeys(ksInfo)
- }
-
- return stats
-}
-
-func parseInfoInt(info, key string) int64 {
- prefix := key + ":"
- for _, line := range splitLines(info) {
- if len(line) > len(prefix) && line[:len(prefix)] == prefix {
- var v int64
- fmt.Sscanf(line[len(prefix):], "%d", &v)
- return v
- }
- }
- return 0
-}
-
-func parseInfoString(info, key string) string {
- prefix := key + ":"
- for _, line := range splitLines(info) {
- if len(line) > len(prefix) && line[:len(prefix)] == prefix {
- val := line[len(prefix):]
- // Trim trailing \r
- if len(val) > 0 && val[len(val)-1] == '\r' {
- val = val[:len(val)-1]
- }
- return val
- }
- }
- return ""
-}
-
-func parseKeyspaceKeys(info string) int64 {
- // Format: db0:keys=123,expires=45,avg_ttl=6789
- for _, line := range splitLines(info) {
- if len(line) > 3 && line[:2] == "db" {
- prefix := "keys="
- idx := -1
- for i := 0; i <= len(line)-len(prefix); i++ {
- if line[i:i+len(prefix)] == prefix {
- idx = i + len(prefix)
- break
- }
- }
- if idx >= 0 {
- end := idx
- for end < len(line) && line[end] >= '0' && line[end] <= '9' {
- end++
- }
- var v int64
- fmt.Sscanf(line[idx:end], "%d", &v)
- return v
- }
- }
- }
- return 0
-}
-
-func splitLines(s string) []string {
- var lines []string
- start := 0
- for i := 0; i < len(s); i++ {
- if s[i] == '\n' {
- lines = append(lines, s[start:i])
- start = i + 1
- }
- }
- if start < len(s) {
- lines = append(lines, s[start:])
- }
- return lines
-}
-
-func (c *Cache) cacheKey(model string, requestBody []byte) string {
- h := sha256.New()
- h.Write([]byte(model))
- h.Write(requestBody)
- return fmt.Sprintf("llm-gw:%x", h.Sum(nil))
-}
diff --git a/llm-gateway/internal/cache/cache_test.go b/llm-gateway/internal/cache/cache_test.go
deleted file mode 100644
index 9e72e37..0000000
--- a/llm-gateway/internal/cache/cache_test.go
+++ /dev/null
@@ -1,112 +0,0 @@
-package cache
-
-import (
- "testing"
-)
-
-func TestCacheKey_Deterministic(t *testing.T) {
- c := &Cache{}
-
- model := "gpt-4"
- body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
-
- key1 := c.cacheKey(model, body)
- key2 := c.cacheKey(model, body)
-
- if key1 != key2 {
- t.Errorf("cache key not deterministic: %s != %s", key1, key2)
- }
-
- if key1 == "" {
- t.Error("cache key is empty")
- }
-}
-
-func TestCacheKey_DifferentInputs(t *testing.T) {
- c := &Cache{}
-
- body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
-
- key1 := c.cacheKey("gpt-4", body)
- key2 := c.cacheKey("gpt-3.5", body)
-
- if key1 == key2 {
- t.Error("different models should produce different cache keys")
- }
-
- key3 := c.cacheKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"world"}]}`))
- if key1 == key3 {
- t.Error("different bodies should produce different cache keys")
- }
-}
-
-func TestCacheKey_HasPrefix(t *testing.T) {
- c := &Cache{}
- key := c.cacheKey("gpt-4", []byte("test"))
-
- if len(key) < 7 || key[:7] != "llm-gw:" {
- t.Errorf("cache key should start with 'llm-gw:', got: %s", key)
- }
-}
-
-func TestParseInfoInt(t *testing.T) {
- info := "keyspace_hits:42\nkeyspace_misses:10\n"
-
- hits := parseInfoInt(info, "keyspace_hits")
- if hits != 42 {
- t.Errorf("expected 42, got %d", hits)
- }
-
- misses := parseInfoInt(info, "keyspace_misses")
- if misses != 10 {
- t.Errorf("expected 10, got %d", misses)
- }
-
- unknown := parseInfoInt(info, "nonexistent")
- if unknown != 0 {
- t.Errorf("expected 0 for unknown key, got %d", unknown)
- }
-}
-
-func TestParseInfoString(t *testing.T) {
- info := "used_memory_human:1.5M\r\nother:value\r\n"
-
- mem := parseInfoString(info, "used_memory_human")
- if mem != "1.5M" {
- t.Errorf("expected '1.5M', got '%s'", mem)
- }
-
- unknown := parseInfoString(info, "nonexistent")
- if unknown != "" {
- t.Errorf("expected empty for unknown key, got '%s'", unknown)
- }
-}
-
-func TestParseKeyspaceKeys(t *testing.T) {
- info := "# Keyspace\ndb0:keys=123,expires=45,avg_ttl=6789\n"
-
- keys := parseKeyspaceKeys(info)
- if keys != 123 {
- t.Errorf("expected 123, got %d", keys)
- }
-
- empty := parseKeyspaceKeys("# Keyspace\n")
- if empty != 0 {
- t.Errorf("expected 0 for empty keyspace, got %d", empty)
- }
-}
-
-func TestSplitLines(t *testing.T) {
- lines := splitLines("a\nb\nc")
- if len(lines) != 3 {
- t.Errorf("expected 3 lines, got %d", len(lines))
- }
- if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
- t.Errorf("unexpected lines: %v", lines)
- }
-
- single := splitLines("hello")
- if len(single) != 1 || single[0] != "hello" {
- t.Errorf("single line: %v", single)
- }
-}
diff --git a/llm-gateway/internal/config/config.go b/llm-gateway/internal/config/config.go
deleted file mode 100644
index ffa934b..0000000
--- a/llm-gateway/internal/config/config.go
+++ /dev/null
@@ -1,312 +0,0 @@
-package config
-
-import (
- "crypto/rand"
- "encoding/hex"
- "fmt"
- "log"
- "os"
- "time"
-
- "gopkg.in/yaml.v3"
-)
-
-type Config struct {
- Server ServerConfig `yaml:"server"`
- Database DatabaseConfig `yaml:"database"`
- Cache CacheConfig `yaml:"cache"`
- Pricing PricingLookupConfig `yaml:"pricing_lookup"`
- CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"`
- Retry RetryConfig `yaml:"retry"`
- Debug DebugConfig `yaml:"debug"`
- CORS CORSConfig `yaml:"cors"`
- Dedup DedupConfig `yaml:"dedup"`
- Webhooks []WebhookConfig `yaml:"webhooks"`
- Providers []ProviderConfig `yaml:"providers"`
- Models []ModelConfig `yaml:"models"`
- Tokens []TokenConfig `yaml:"tokens"`
-}
-
-type DedupConfig struct {
- Enabled bool `yaml:"enabled"`
- Window time.Duration `yaml:"window"` // max time to wait for dedup result
-}
-
-type WebhookConfig struct {
- URL string `yaml:"url"`
- Events []string `yaml:"events"` // event types to send
- Secret string `yaml:"secret"` // optional HMAC secret
-}
-
-type PricingLookupConfig struct {
- URL string `yaml:"url"`
- RefreshInterval time.Duration `yaml:"refresh_interval"`
-}
-
-type DefaultAdminConfig struct {
- Username string `yaml:"username"`
- Password string `yaml:"password"`
-}
-
-type TokenConfig struct {
- Name string `yaml:"name"`
- Key string `yaml:"key"`
- RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited
- DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited
- MonthlyBudgetUSD float64 `yaml:"monthly_budget_usd"` // 0 = unlimited
- MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited
-}
-
-type ServerConfig struct {
- Listen string `yaml:"listen"`
- RequestTimeout time.Duration `yaml:"request_timeout"`
- StreamingTimeout time.Duration `yaml:"streaming_timeout"`
- MaxRequestBodyMB int `yaml:"max_request_body_mb"`
- SessionSecret string `yaml:"session_secret"`
- DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
-}
-
-type CircuitBreakerConfig struct {
- Enabled bool `yaml:"enabled"`
- ErrorThreshold float64 `yaml:"error_threshold"`
- MinRequests int `yaml:"min_requests"`
- CooldownDuration time.Duration `yaml:"cooldown_duration"`
-}
-
-type RetryConfig struct {
- InitialBackoff time.Duration `yaml:"initial_backoff"`
- MaxBackoff time.Duration `yaml:"max_backoff"`
- Multiplier float64 `yaml:"multiplier"`
-}
-
-type DebugConfig struct {
- Enabled bool `yaml:"enabled"`
- MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies)
- RetentionDays int `yaml:"retention_days"`
- DataDir string `yaml:"data_dir"`
-}
-
-type CORSConfig struct {
- Enabled bool `yaml:"enabled"`
- AllowedOrigins []string `yaml:"allowed_origins"`
- AllowedMethods []string `yaml:"allowed_methods"`
- AllowedHeaders []string `yaml:"allowed_headers"`
- MaxAge int `yaml:"max_age"`
-}
-
-type DatabaseConfig struct {
- Path string `yaml:"path"`
- RetentionDays int `yaml:"retention_days"`
-}
-
-type CacheConfig struct {
- Enabled bool `yaml:"enabled"`
- Address string `yaml:"address"`
- TTL int `yaml:"ttl"` // seconds
-}
-
-type ProviderConfig struct {
- Name string `yaml:"name"`
- BaseURL string `yaml:"base_url"`
- APIKey string `yaml:"api_key"`
- Priority int `yaml:"priority"`
- Timeout time.Duration `yaml:"timeout"`
-}
-
-type ModelConfig struct {
- Name string `yaml:"name"`
- Aliases []string `yaml:"aliases"`
- Routes []RouteConfig `yaml:"routes"`
- LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost
- RequestTimeout time.Duration `yaml:"request_timeout"` // per-model override; 0 = use server default
- StreamingTimeout time.Duration `yaml:"streaming_timeout"` // per-model override; 0 = use server default
-}
-
-type RouteConfig struct {
- Provider string `yaml:"provider"`
- Model string `yaml:"model"`
- Pricing PricingConfig `yaml:"pricing"`
-}
-
-type PricingConfig struct {
- Input float64 `yaml:"input"` // cost per 1M tokens
- Output float64 `yaml:"output"` // cost per 1M tokens
-}
-
-func Load(path string) (*Config, error) {
- data, err := os.ReadFile(path)
- if err != nil {
- return nil, fmt.Errorf("reading config: %w", err)
- }
-
- // Expand environment variables
- expanded := os.ExpandEnv(string(data))
-
- var cfg Config
- if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil {
- return nil, fmt.Errorf("parsing config: %w", err)
- }
-
- if err := cfg.Validate(); err != nil {
- return nil, fmt.Errorf("validating config: %w", err)
- }
-
- return &cfg, nil
-}
-
-// Validate checks the config for correctness and applies defaults.
-func (c *Config) Validate() error {
- if c.Server.Listen == "" {
- c.Server.Listen = "0.0.0.0:3000"
- }
- if c.Server.RequestTimeout == 0 {
- c.Server.RequestTimeout = 300 * time.Second
- }
- if c.Server.MaxRequestBodyMB == 0 {
- c.Server.MaxRequestBodyMB = 10
- }
- if c.Server.SessionSecret == "" {
- b := make([]byte, 32)
- rand.Read(b)
- c.Server.SessionSecret = hex.EncodeToString(b)
- log.Println("WARNING: no session_secret configured, generated random one (sessions won't survive restart)")
- }
- if c.Database.Path == "" {
- c.Database.Path = "gateway.db"
- }
- if c.Database.RetentionDays == 0 {
- c.Database.RetentionDays = 90
- }
- if c.Pricing.RefreshInterval == 0 {
- c.Pricing.RefreshInterval = 6 * time.Hour
- }
-
- // Server defaults
- if c.Server.StreamingTimeout == 0 {
- c.Server.StreamingTimeout = 5 * time.Minute
- }
-
- // Circuit breaker defaults
- if c.CircuitBreaker.ErrorThreshold == 0 {
- c.CircuitBreaker.ErrorThreshold = 0.5
- }
- if c.CircuitBreaker.MinRequests == 0 {
- c.CircuitBreaker.MinRequests = 5
- }
- if c.CircuitBreaker.CooldownDuration == 0 {
- c.CircuitBreaker.CooldownDuration = 30 * time.Second
- }
-
- // Retry defaults
- if c.Retry.InitialBackoff == 0 {
- c.Retry.InitialBackoff = 100 * time.Millisecond
- }
- if c.Retry.MaxBackoff == 0 {
- c.Retry.MaxBackoff = 5 * time.Second
- }
- if c.Retry.Multiplier == 0 {
- c.Retry.Multiplier = 2.0
- }
-
- // Debug defaults
- if c.Debug.RetentionDays == 0 {
- c.Debug.RetentionDays = 90
- }
-
- // CORS defaults
- if c.CORS.MaxAge == 0 {
- c.CORS.MaxAge = 300
- }
-
- // Dedup defaults
- if c.Dedup.Window == 0 {
- c.Dedup.Window = 30 * time.Second
- }
-
- if len(c.Providers) == 0 {
- return fmt.Errorf("at least one provider is required")
- }
- providerNames := make(map[string]bool)
- for i, p := range c.Providers {
- if p.Name == "" || p.BaseURL == "" || p.APIKey == "" {
- return fmt.Errorf("provider %d: name, base_url, and api_key are required", i)
- }
- if providerNames[p.Name] {
- return fmt.Errorf("duplicate provider name: %s", p.Name)
- }
- providerNames[p.Name] = true
- if c.Providers[i].Timeout == 0 {
- c.Providers[i].Timeout = 120 * time.Second
- }
- if c.Providers[i].Priority == 0 {
- c.Providers[i].Priority = 1
- }
- }
-
- if len(c.Models) == 0 {
- return fmt.Errorf("at least one model is required")
- }
- modelNames := make(map[string]bool)
- for i, m := range c.Models {
- if m.Name == "" {
- return fmt.Errorf("model %d: name is required", i)
- }
- if modelNames[m.Name] {
- return fmt.Errorf("duplicate model name: %s", m.Name)
- }
- modelNames[m.Name] = true
- for _, alias := range m.Aliases {
- if modelNames[alias] {
- return fmt.Errorf("model alias %s conflicts with existing model or alias", alias)
- }
- modelNames[alias] = true
- }
- if len(m.Routes) == 0 {
- return fmt.Errorf("model %s: at least one route is required", m.Name)
- }
- for j, r := range m.Routes {
- if r.Provider == "" || r.Model == "" {
- return fmt.Errorf("model %s route %d: provider and model are required", m.Name, j)
- }
- if !providerNames[r.Provider] {
- return fmt.Errorf("model %s route %d: unknown provider %s", m.Name, j, r.Provider)
- }
- }
- }
-
- // Validate tokens (optional section)
- for i, t := range c.Tokens {
- if t.Key == "" {
- log.Printf("WARNING: token %d (%s) has empty key, skipping", i, t.Name)
- continue
- }
- if t.Name == "" {
- c.Tokens[i].Name = fmt.Sprintf("token-%d", i)
- }
- }
-
- return nil
-}
-
-// ValidateBytes parses raw YAML and returns a list of validation errors.
-func ValidateBytes(data []byte) []string {
- expanded := os.ExpandEnv(string(data))
- var cfg Config
- if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil {
- return []string{"parse error: " + err.Error()}
- }
- if err := cfg.Validate(); err != nil {
- return []string{err.Error()}
- }
- return nil
-}
-
-// ProviderByName returns the provider config by name.
-func (c *Config) ProviderByName(name string) *ProviderConfig {
- for i := range c.Providers {
- if c.Providers[i].Name == name {
- return &c.Providers[i]
- }
- }
- return nil
-}
diff --git a/llm-gateway/internal/config/config_test.go b/llm-gateway/internal/config/config_test.go
deleted file mode 100644
index 3a65f7a..0000000
--- a/llm-gateway/internal/config/config_test.go
+++ /dev/null
@@ -1,737 +0,0 @@
-package config
-
-import (
- "fmt"
- "os"
- "path/filepath"
- "strings"
- "testing"
- "time"
-)
-
-// writeConfigFile creates a temporary YAML config file and returns its path.
-func writeConfigFile(t *testing.T, content string) string {
- t.Helper()
- f, err := os.CreateTemp(t.TempDir(), "config-*.yaml")
- if err != nil {
- t.Fatalf("creating temp file: %v", err)
- }
- if _, err := f.WriteString(content); err != nil {
- f.Close()
- t.Fatalf("writing temp file: %v", err)
- }
- f.Close()
- return f.Name()
-}
-
-// minimalValidConfig returns a minimal valid YAML config string.
-func minimalValidConfig() string {
- return `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-test-key
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4
-`
-}
-
-func TestLoad_ValidConfig(t *testing.T) {
- path := writeConfigFile(t, `
-server:
- listen: "127.0.0.1:8080"
- request_timeout: 60s
- streaming_timeout: 120s
- max_request_body_mb: 5
- session_secret: "test-secret-1234567890abcdef1234567890abcdef"
-
-database:
- path: "/tmp/test.db"
- retention_days: 30
-
-pricing_lookup:
- url: "https://pricing.example.com"
- refresh_interval: 1h
-
-circuit_breaker:
- enabled: true
- error_threshold: 0.3
- min_requests: 10
- cooldown_duration: 60s
-
-retry:
- initial_backoff: 200ms
- max_backoff: 10s
- multiplier: 3.0
-
-debug:
- enabled: true
- max_body_bytes: 65536
- retention_days: 60
-
-cors:
- enabled: true
- allowed_origins:
- - "https://example.com"
- allowed_methods:
- - GET
- - POST
- allowed_headers:
- - Authorization
- max_age: 600
-
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-test-key
- priority: 2
- timeout: 60s
- - name: anthropic
- base_url: https://api.anthropic.com/v1
- api_key: sk-ant-test
- priority: 1
- timeout: 30s
-
-models:
- - name: gpt-4
- aliases:
- - gpt4
- routes:
- - provider: openai
- model: gpt-4
- pricing:
- input: 30.0
- output: 60.0
- load_balancing: first
- - name: claude-3
- routes:
- - provider: anthropic
- model: claude-3-opus-20240229
-
-tokens:
- - name: test-token
- key: tok-abc123
- rate_limit_rpm: 100
- daily_budget_usd: 10.0
- max_concurrent: 5
-`)
-
- cfg, err := Load(path)
- if err != nil {
- t.Fatalf("Load() returned error: %v", err)
- }
-
- // Server
- if cfg.Server.Listen != "127.0.0.1:8080" {
- t.Errorf("Listen = %q, want %q", cfg.Server.Listen, "127.0.0.1:8080")
- }
- if cfg.Server.RequestTimeout != 60*time.Second {
- t.Errorf("RequestTimeout = %v, want %v", cfg.Server.RequestTimeout, 60*time.Second)
- }
- if cfg.Server.StreamingTimeout != 120*time.Second {
- t.Errorf("StreamingTimeout = %v, want %v", cfg.Server.StreamingTimeout, 120*time.Second)
- }
- if cfg.Server.MaxRequestBodyMB != 5 {
- t.Errorf("MaxRequestBodyMB = %d, want %d", cfg.Server.MaxRequestBodyMB, 5)
- }
- if cfg.Server.SessionSecret != "test-secret-1234567890abcdef1234567890abcdef" {
- t.Errorf("SessionSecret = %q, want %q", cfg.Server.SessionSecret, "test-secret-1234567890abcdef1234567890abcdef")
- }
-
- // Database
- if cfg.Database.Path != "/tmp/test.db" {
- t.Errorf("Database.Path = %q, want %q", cfg.Database.Path, "/tmp/test.db")
- }
- if cfg.Database.RetentionDays != 30 {
- t.Errorf("Database.RetentionDays = %d, want %d", cfg.Database.RetentionDays, 30)
- }
-
- // Pricing
- if cfg.Pricing.URL != "https://pricing.example.com" {
- t.Errorf("Pricing.URL = %q, want %q", cfg.Pricing.URL, "https://pricing.example.com")
- }
- if cfg.Pricing.RefreshInterval != 1*time.Hour {
- t.Errorf("Pricing.RefreshInterval = %v, want %v", cfg.Pricing.RefreshInterval, 1*time.Hour)
- }
-
- // Circuit breaker
- if !cfg.CircuitBreaker.Enabled {
- t.Error("CircuitBreaker.Enabled = false, want true")
- }
- if cfg.CircuitBreaker.ErrorThreshold != 0.3 {
- t.Errorf("CircuitBreaker.ErrorThreshold = %v, want %v", cfg.CircuitBreaker.ErrorThreshold, 0.3)
- }
- if cfg.CircuitBreaker.MinRequests != 10 {
- t.Errorf("CircuitBreaker.MinRequests = %d, want %d", cfg.CircuitBreaker.MinRequests, 10)
- }
- if cfg.CircuitBreaker.CooldownDuration != 60*time.Second {
- t.Errorf("CircuitBreaker.CooldownDuration = %v, want %v", cfg.CircuitBreaker.CooldownDuration, 60*time.Second)
- }
-
- // Retry
- if cfg.Retry.InitialBackoff != 200*time.Millisecond {
- t.Errorf("Retry.InitialBackoff = %v, want %v", cfg.Retry.InitialBackoff, 200*time.Millisecond)
- }
- if cfg.Retry.MaxBackoff != 10*time.Second {
- t.Errorf("Retry.MaxBackoff = %v, want %v", cfg.Retry.MaxBackoff, 10*time.Second)
- }
- if cfg.Retry.Multiplier != 3.0 {
- t.Errorf("Retry.Multiplier = %v, want %v", cfg.Retry.Multiplier, 3.0)
- }
-
- // Debug
- if !cfg.Debug.Enabled {
- t.Error("Debug.Enabled = false, want true")
- }
- if cfg.Debug.MaxBodyBytes != 65536 {
- t.Errorf("Debug.MaxBodyBytes = %d, want %d", cfg.Debug.MaxBodyBytes, 65536)
- }
- if cfg.Debug.RetentionDays != 60 {
- t.Errorf("Debug.RetentionDays = %d, want %d", cfg.Debug.RetentionDays, 60)
- }
-
- // CORS
- if !cfg.CORS.Enabled {
- t.Error("CORS.Enabled = false, want true")
- }
- if cfg.CORS.MaxAge != 600 {
- t.Errorf("CORS.MaxAge = %d, want %d", cfg.CORS.MaxAge, 600)
- }
-
- // Providers
- if len(cfg.Providers) != 2 {
- t.Fatalf("len(Providers) = %d, want 2", len(cfg.Providers))
- }
- if cfg.Providers[0].Name != "openai" {
- t.Errorf("Providers[0].Name = %q, want %q", cfg.Providers[0].Name, "openai")
- }
- if cfg.Providers[0].Timeout != 60*time.Second {
- t.Errorf("Providers[0].Timeout = %v, want %v", cfg.Providers[0].Timeout, 60*time.Second)
- }
-
- // Models
- if len(cfg.Models) != 2 {
- t.Fatalf("len(Models) = %d, want 2", len(cfg.Models))
- }
- if cfg.Models[0].LoadBalancing != "first" {
- t.Errorf("Models[0].LoadBalancing = %q, want %q", cfg.Models[0].LoadBalancing, "first")
- }
- if len(cfg.Models[0].Aliases) != 1 || cfg.Models[0].Aliases[0] != "gpt4" {
- t.Errorf("Models[0].Aliases = %v, want [gpt4]", cfg.Models[0].Aliases)
- }
- if cfg.Models[0].Routes[0].Pricing.Input != 30.0 {
- t.Errorf("Models[0].Routes[0].Pricing.Input = %v, want 30.0", cfg.Models[0].Routes[0].Pricing.Input)
- }
-
- // Tokens
- if len(cfg.Tokens) != 1 {
- t.Fatalf("len(Tokens) = %d, want 1", len(cfg.Tokens))
- }
- if cfg.Tokens[0].Name != "test-token" {
- t.Errorf("Tokens[0].Name = %q, want %q", cfg.Tokens[0].Name, "test-token")
- }
- if cfg.Tokens[0].RateLimitRPM != 100 {
- t.Errorf("Tokens[0].RateLimitRPM = %d, want 100", cfg.Tokens[0].RateLimitRPM)
- }
-}
-
-func TestValidate_Defaults(t *testing.T) {
- path := writeConfigFile(t, minimalValidConfig())
- cfg, err := Load(path)
- if err != nil {
- t.Fatalf("Load() returned error: %v", err)
- }
-
- tests := []struct {
- name string
- got any
- want any
- }{
- // Server defaults
- {"Server.Listen", cfg.Server.Listen, "0.0.0.0:3000"},
- {"Server.RequestTimeout", cfg.Server.RequestTimeout, 300 * time.Second},
- {"Server.StreamingTimeout", cfg.Server.StreamingTimeout, 5 * time.Minute},
- {"Server.MaxRequestBodyMB", cfg.Server.MaxRequestBodyMB, 10},
-
- // Database defaults
- {"Database.Path", cfg.Database.Path, "gateway.db"},
- {"Database.RetentionDays", cfg.Database.RetentionDays, 90},
-
- // Pricing defaults
- {"Pricing.RefreshInterval", cfg.Pricing.RefreshInterval, 6 * time.Hour},
-
- // Circuit breaker defaults
- {"CircuitBreaker.ErrorThreshold", cfg.CircuitBreaker.ErrorThreshold, 0.5},
- {"CircuitBreaker.MinRequests", cfg.CircuitBreaker.MinRequests, 5},
- {"CircuitBreaker.CooldownDuration", cfg.CircuitBreaker.CooldownDuration, 30 * time.Second},
-
- // Retry defaults
- {"Retry.InitialBackoff", cfg.Retry.InitialBackoff, 100 * time.Millisecond},
- {"Retry.MaxBackoff", cfg.Retry.MaxBackoff, 5 * time.Second},
- {"Retry.Multiplier", cfg.Retry.Multiplier, 2.0},
-
- // Debug defaults
- {"Debug.MaxBodyBytes", cfg.Debug.MaxBodyBytes, 0},
- {"Debug.RetentionDays", cfg.Debug.RetentionDays, 90},
-
- // CORS defaults
- {"CORS.MaxAge", cfg.CORS.MaxAge, 300},
-
- // Provider defaults
- {"Providers[0].Timeout", cfg.Providers[0].Timeout, 120 * time.Second},
- {"Providers[0].Priority", cfg.Providers[0].Priority, 1},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Compare using formatted strings to handle different numeric types
- gotStr := formatValue(tt.got)
- wantStr := formatValue(tt.want)
- if gotStr != wantStr {
- t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
- }
- })
- }
-
- // SessionSecret should be auto-generated (non-empty, 64 hex chars)
- if cfg.Server.SessionSecret == "" {
- t.Error("SessionSecret should be auto-generated when empty")
- }
- if len(cfg.Server.SessionSecret) != 64 {
- t.Errorf("SessionSecret length = %d, want 64 hex chars", len(cfg.Server.SessionSecret))
- }
-}
-
-func formatValue(v any) string {
- switch val := v.(type) {
- case time.Duration:
- return val.String()
- case float64:
- return fmt.Sprintf("%g", val)
- case int:
- return fmt.Sprintf("%d", val)
- case string:
- return val
- default:
- return fmt.Sprintf("%v", val)
- }
-}
-
-func TestLoad_FileNotFound(t *testing.T) {
- _, err := Load(filepath.Join(t.TempDir(), "nonexistent.yaml"))
- if err == nil {
- t.Fatal("Load() should return error for nonexistent file")
- }
-}
-
-func TestLoad_InvalidYAML(t *testing.T) {
- path := writeConfigFile(t, `{{{invalid yaml`)
- _, err := Load(path)
- if err == nil {
- t.Fatal("Load() should return error for invalid YAML")
- }
-}
-
-func TestValidate_DuplicateProviderNames(t *testing.T) {
- path := writeConfigFile(t, `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key1
- - name: openai
- base_url: https://api.openai.com/v2
- api_key: sk-key2
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4
-`)
-
- _, err := Load(path)
- if err == nil {
- t.Fatal("Load() should return error for duplicate provider names")
- }
- wantSubstr := "duplicate provider name: openai"
- if !strings.Contains(err.Error(), wantSubstr) {
- t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
- }
-}
-
-func TestValidate_DuplicateModelNames(t *testing.T) {
- path := writeConfigFile(t, `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key1
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4-turbo
-`)
-
- _, err := Load(path)
- if err == nil {
- t.Fatal("Load() should return error for duplicate model names")
- }
- wantSubstr := "duplicate model name: gpt-4"
- if !strings.Contains(err.Error(), wantSubstr) {
- t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
- }
-}
-
-func TestValidate_AliasConflicts(t *testing.T) {
- tests := []struct {
- name string
- config string
- wantErr string
- }{
- {
- name: "alias conflicts with model name",
- config: `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key1
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4
- - name: claude-3
- aliases:
- - gpt-4
- routes:
- - provider: openai
- model: claude-3
-`,
- wantErr: "model alias gpt-4 conflicts with existing model or alias",
- },
- {
- name: "alias conflicts with another alias",
- config: `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key1
-
-models:
- - name: gpt-4
- aliases:
- - fast-model
- routes:
- - provider: openai
- model: gpt-4
- - name: claude-3
- aliases:
- - fast-model
- routes:
- - provider: openai
- model: claude-3
-`,
- wantErr: "model alias fast-model conflicts with existing model or alias",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- path := writeConfigFile(t, tt.config)
- _, err := Load(path)
- if err == nil {
- t.Fatal("Load() should return error for alias conflicts")
- }
- if !strings.Contains(err.Error(), tt.wantErr) {
- t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
- }
- })
- }
-}
-
-func TestValidate_MissingRequiredFields(t *testing.T) {
- tests := []struct {
- name string
- config string
- wantErr string
- }{
- {
- name: "no providers",
- config: `models: [{name: test, routes: [{provider: x, model: y}]}]`,
- wantErr: "at least one provider is required",
- },
- {
- name: "no models",
- config: `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key
-`,
- wantErr: "at least one model is required",
- },
- {
- name: "provider missing name",
- config: `
-providers:
- - base_url: https://api.openai.com/v1
- api_key: sk-key
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4
-`,
- wantErr: "provider 0: name, base_url, and api_key are required",
- },
- {
- name: "provider missing base_url",
- config: `
-providers:
- - name: openai
- api_key: sk-key
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4
-`,
- wantErr: "provider 0: name, base_url, and api_key are required",
- },
- {
- name: "provider missing api_key",
- config: `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4
-`,
- wantErr: "provider 0: name, base_url, and api_key are required",
- },
- {
- name: "model missing name",
- config: `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key
-
-models:
- - routes:
- - provider: openai
- model: gpt-4
-`,
- wantErr: "model 0: name is required",
- },
- {
- name: "model missing routes",
- config: `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key
-
-models:
- - name: gpt-4
-`,
- wantErr: "model gpt-4: at least one route is required",
- },
- {
- name: "route missing provider",
- config: `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key
-
-models:
- - name: gpt-4
- routes:
- - model: gpt-4
-`,
- wantErr: "model gpt-4 route 0: provider and model are required",
- },
- {
- name: "route missing model",
- config: `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
-`,
- wantErr: "model gpt-4 route 0: provider and model are required",
- },
- {
- name: "route references unknown provider",
- config: `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-key
-
-models:
- - name: gpt-4
- routes:
- - provider: anthropic
- model: gpt-4
-`,
- wantErr: "model gpt-4 route 0: unknown provider anthropic",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- path := writeConfigFile(t, tt.config)
- _, err := Load(path)
- if err == nil {
- t.Fatalf("Load() should return error, want %q", tt.wantErr)
- }
- if !strings.Contains(err.Error(), tt.wantErr) {
- t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
- }
- })
- }
-}
-
-func TestProviderByName(t *testing.T) {
- path := writeConfigFile(t, `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: sk-openai
- - name: anthropic
- base_url: https://api.anthropic.com/v1
- api_key: sk-anthropic
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4
-`)
-
- cfg, err := Load(path)
- if err != nil {
- t.Fatalf("Load() returned error: %v", err)
- }
-
- tests := []struct {
- name string
- lookup string
- wantNil bool
- wantName string
- }{
- {"existing provider openai", "openai", false, "openai"},
- {"existing provider anthropic", "anthropic", false, "anthropic"},
- {"nonexistent provider", "google", true, ""},
- {"empty name", "", true, ""},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- p := cfg.ProviderByName(tt.lookup)
- if tt.wantNil {
- if p != nil {
- t.Errorf("ProviderByName(%q) = %v, want nil", tt.lookup, p)
- }
- } else {
- if p == nil {
- t.Fatalf("ProviderByName(%q) = nil, want provider", tt.lookup)
- }
- if p.Name != tt.wantName {
- t.Errorf("ProviderByName(%q).Name = %q, want %q", tt.lookup, p.Name, tt.wantName)
- }
- }
- })
- }
-
- // Verify returned pointer refers to the actual config entry
- p := cfg.ProviderByName("openai")
- if p.APIKey != "sk-openai" {
- t.Errorf("ProviderByName(openai).APIKey = %q, want %q", p.APIKey, "sk-openai")
- }
-}
-
-func TestLoad_EnvironmentVariableExpansion(t *testing.T) {
- t.Setenv("TEST_API_KEY", "sk-from-env")
- t.Setenv("TEST_BASE_URL", "https://env.example.com/v1")
- t.Setenv("TEST_PROVIDER_NAME", "env-provider")
-
- path := writeConfigFile(t, `
-providers:
- - name: $TEST_PROVIDER_NAME
- base_url: ${TEST_BASE_URL}
- api_key: ${TEST_API_KEY}
-
-models:
- - name: test-model
- routes:
- - provider: env-provider
- model: gpt-4
-`)
-
- cfg, err := Load(path)
- if err != nil {
- t.Fatalf("Load() returned error: %v", err)
- }
-
- if cfg.Providers[0].Name != "env-provider" {
- t.Errorf("Provider.Name = %q, want %q", cfg.Providers[0].Name, "env-provider")
- }
- if cfg.Providers[0].BaseURL != "https://env.example.com/v1" {
- t.Errorf("Provider.BaseURL = %q, want %q", cfg.Providers[0].BaseURL, "https://env.example.com/v1")
- }
- if cfg.Providers[0].APIKey != "sk-from-env" {
- t.Errorf("Provider.APIKey = %q, want %q", cfg.Providers[0].APIKey, "sk-from-env")
- }
-}
-
-func TestLoad_UnsetEnvVarExpandsToEmpty(t *testing.T) {
- // Ensure the variable is not set
- t.Setenv("TEST_UNSET_VAR", "")
- os.Unsetenv("TEST_UNSET_VAR")
-
- path := writeConfigFile(t, `
-providers:
- - name: openai
- base_url: https://api.openai.com/v1
- api_key: ${TEST_UNSET_VAR}
-
-models:
- - name: gpt-4
- routes:
- - provider: openai
- model: gpt-4
-`)
-
- _, err := Load(path)
- if err == nil {
- t.Fatal("Load() should return error when env var expands to empty required field")
- }
- // api_key will be empty, so validation should catch it
- if !strings.Contains(err.Error(), "api_key are required") {
- t.Errorf("error = %q, want to contain api_key validation message", err.Error())
- }
-}
diff --git a/llm-gateway/internal/config/watcher.go b/llm-gateway/internal/config/watcher.go
deleted file mode 100644
index f63a478..0000000
--- a/llm-gateway/internal/config/watcher.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package config
-
-import (
- "log"
- "os"
- "os/signal"
- "syscall"
-)
-
-// WatchReload listens for SIGHUP and calls the callback with the new config.
-func WatchReload(configPath string, callback func(*Config)) {
- sighup := make(chan os.Signal, 1)
- signal.Notify(sighup, syscall.SIGHUP)
-
- go func() {
- for range sighup {
- log.Println("SIGHUP received, reloading config...")
- newCfg, err := Load(configPath)
- if err != nil {
- log.Printf("ERROR: config reload failed: %v", err)
- continue
- }
- callback(newCfg)
- log.Println("Config reloaded successfully")
- }
- }()
-}
diff --git a/llm-gateway/internal/dashboard/api.go b/llm-gateway/internal/dashboard/api.go
deleted file mode 100644
index fe75a2d..0000000
--- a/llm-gateway/internal/dashboard/api.go
+++ /dev/null
@@ -1,775 +0,0 @@
-package dashboard
-
-import (
- "encoding/json"
- "net/http"
- "os"
- "sort"
- "strconv"
- "time"
-
- "github.com/go-chi/chi/v5"
-
- "llm-gateway/internal/auth"
- "llm-gateway/internal/cache"
- "llm-gateway/internal/config"
- "llm-gateway/internal/provider"
- "llm-gateway/internal/storage"
-)
-
-// Exported types for template rendering and JSON API.
-
-type Period struct {
- Requests int `json:"requests"`
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- CostUSD float64 `json:"cost_usd"`
- Errors int `json:"errors"`
- CachedHits int `json:"cached_hits"`
-}
-
-type SummaryResult struct {
- Today *Period `json:"today"`
- Week *Period `json:"week"`
- Month *Period `json:"month"`
-}
-
-type ModelStats struct {
- Model string `json:"model"`
- Requests int `json:"requests"`
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- CostUSD float64 `json:"cost_usd"`
- AvgLatencyMS float64 `json:"avg_latency_ms"`
-}
-
-type ProviderStats struct {
- Provider string `json:"provider"`
- Requests int `json:"requests"`
- Successes int `json:"successes"`
- Errors int `json:"errors"`
- AvgLatencyMS float64 `json:"avg_latency_ms"`
- CostUSD float64 `json:"cost_usd"`
-}
-
-type TokenUsageStats struct {
- TokenName string `json:"token_name"`
- Requests int `json:"requests"`
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- CostUSD float64 `json:"cost_usd"`
-}
-
-// RequestLogEntry represents a single request log row.
-type RequestLogEntry struct {
- RequestID string `json:"request_id"`
- Timestamp int64 `json:"timestamp"`
- TokenName string `json:"token_name"`
- Model string `json:"model"`
- Provider string `json:"provider"`
- ProviderModel string `json:"provider_model"`
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- CostUSD float64 `json:"cost_usd"`
- LatencyMS int64 `json:"latency_ms"`
- Status string `json:"status"`
- ErrorMessage string `json:"error_message"`
- Streaming bool `json:"streaming"`
- Cached bool `json:"cached"`
-}
-
-// LogsResult holds paginated logs.
-type LogsResult struct {
- Logs []RequestLogEntry `json:"logs"`
- Page int `json:"page"`
- TotalPages int `json:"total_pages"`
- Total int `json:"total"`
-}
-
-// LatencyResult holds latency percentiles.
-type LatencyResult struct {
- P50 float64 `json:"p50"`
- P95 float64 `json:"p95"`
- P99 float64 `json:"p99"`
- Avg float64 `json:"avg"`
- Min float64 `json:"min"`
- Max float64 `json:"max"`
-}
-
-// CostBreakdownEntry holds cost data grouped by day and dimension.
-type CostBreakdownEntry struct {
- Day string `json:"day"`
- GroupBy string `json:"group_by"`
- CostUSD float64 `json:"cost_usd"`
- Requests int `json:"requests"`
-}
-
-type StatsAPI struct {
- db *storage.DB
- authStore *auth.Store
- healthTracker *provider.HealthTracker
- cache *cache.Cache
- auditLogger *storage.AuditLogger
- debugLogger *storage.DebugLogger
- configPath string
-}
-
-func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI {
- return &StatsAPI{db: db, authStore: authStore}
-}
-
-// SetHealthTracker sets the provider health tracker.
-func (s *StatsAPI) SetHealthTracker(ht *provider.HealthTracker) {
- s.healthTracker = ht
-}
-
-// SetCache sets the cache for stats.
-func (s *StatsAPI) SetCache(c *cache.Cache) {
- s.cache = c
-}
-
-// SetAuditLogger sets the audit logger.
-func (s *StatsAPI) SetAuditLogger(al *storage.AuditLogger) {
- s.auditLogger = al
-}
-
-// SetDebugLogger sets the debug logger.
-func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) {
- s.debugLogger = dl
-}
-
-// SetConfigPath sets the config file path for validation.
-func (s *StatsAPI) SetConfigPath(path string) {
- s.configPath = path
-}
-
-// TokenNamesForUser returns the token names that belong to the user.
-// Admins get nil (no filter), non-admins get their token names.
-func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string {
- if user == nil || user.IsAdmin {
- return nil
- }
- tokens, err := s.authStore.ListAPITokens(user.ID)
- if err != nil {
- return []string{"__none__"}
- }
- names := make([]string, len(tokens))
- for i, t := range tokens {
- names[i] = t.Name
- }
- if len(names) == 0 {
- return []string{"__none__"}
- }
- return names
-}
-
-// tokenNamesForUser returns token names from request context (for HTTP handlers).
-func (s *StatsAPI) tokenNamesForUser(r *http.Request) []string {
- user := auth.UserFromContext(r.Context())
- return s.TokenNamesForUser(user)
-}
-
-func buildTokenFilter(tokenNames []string) (string, []any) {
- if tokenNames == nil {
- return "", nil
- }
- placeholders := ""
- args := make([]any, len(tokenNames))
- for i, n := range tokenNames {
- if i > 0 {
- placeholders += ","
- }
- placeholders += "?"
- args[i] = n
- }
- return " AND token_name IN (" + placeholders + ")", args
-}
-
-// Data-fetching methods (used by both JSON handlers and template handlers).
-
-func (s *StatsAPI) GetSummary(tokenNames []string) *SummaryResult {
- now := time.Now()
- todayStart := now.Truncate(24 * time.Hour).Unix()
- weekStart := now.AddDate(0, 0, -7).Unix()
- monthStart := now.AddDate(0, -1, 0).Unix()
-
- tokenFilter, filterArgs := buildTokenFilter(tokenNames)
-
- result := &SummaryResult{
- Today: &Period{},
- Week: &Period{},
- Month: &Period{},
- }
-
- periods := map[string]struct {
- since int64
- period *Period
- }{
- "today": {todayStart, result.Today},
- "week": {weekStart, result.Week},
- "month": {monthStart, result.Month},
- }
-
- for _, p := range periods {
- args := append([]any{p.since}, filterArgs...)
- row := s.db.QueryRow(`SELECT
- COUNT(*),
- COALESCE(SUM(input_tokens), 0),
- COALESCE(SUM(output_tokens), 0),
- COALESCE(SUM(cost_usd), 0),
- COALESCE(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END), 0),
- COALESCE(SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END), 0)
- FROM request_logs WHERE timestamp >= ?`+tokenFilter, args...)
- row.Scan(&p.period.Requests, &p.period.InputTokens, &p.period.OutputTokens, &p.period.CostUSD, &p.period.Errors, &p.period.CachedHits)
- }
-
- return result
-}
-
-func (s *StatsAPI) GetModels(tokenNames []string) []ModelStats {
- since := time.Now().AddDate(0, 0, -30).Unix()
- tokenFilter, filterArgs := buildTokenFilter(tokenNames)
-
- args := append([]any{since}, filterArgs...)
- rows, err := s.db.Query(`SELECT
- model,
- COUNT(*) as requests,
- COALESCE(SUM(input_tokens), 0) as input_tokens,
- COALESCE(SUM(output_tokens), 0) as output_tokens,
- COALESCE(SUM(cost_usd), 0) as cost,
- COALESCE(AVG(latency_ms), 0) as avg_latency
- FROM request_logs WHERE timestamp >= ?`+tokenFilter+`
- GROUP BY model ORDER BY requests DESC`, args...)
- if err != nil {
- return nil
- }
- defer rows.Close()
-
- var results []ModelStats
- for rows.Next() {
- var m ModelStats
- rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS)
- results = append(results, m)
- }
- return results
-}
-
-func (s *StatsAPI) GetProviders(tokenNames []string) []ProviderStats {
- since := time.Now().AddDate(0, 0, -30).Unix()
- tokenFilter, filterArgs := buildTokenFilter(tokenNames)
-
- args := append([]any{since}, filterArgs...)
- rows, err := s.db.Query(`SELECT
- provider,
- COUNT(*) as requests,
- COALESCE(SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END), 0) as successes,
- COALESCE(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END), 0) as errors,
- COALESCE(AVG(latency_ms), 0) as avg_latency,
- COALESCE(SUM(cost_usd), 0) as cost
- FROM request_logs WHERE timestamp >= ?`+tokenFilter+`
- GROUP BY provider ORDER BY requests DESC`, args...)
- if err != nil {
- return nil
- }
- defer rows.Close()
-
- var results []ProviderStats
- for rows.Next() {
- var p ProviderStats
- rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD)
- results = append(results, p)
- }
- return results
-}
-
-func (s *StatsAPI) GetTokenUsage(tokenNames []string) []TokenUsageStats {
- since := time.Now().AddDate(0, 0, -30).Unix()
- tokenFilter, filterArgs := buildTokenFilter(tokenNames)
-
- args := append([]any{since}, filterArgs...)
- rows, err := s.db.Query(`SELECT
- token_name,
- COUNT(*) as requests,
- COALESCE(SUM(input_tokens), 0) as input_tokens,
- COALESCE(SUM(output_tokens), 0) as output_tokens,
- COALESCE(SUM(cost_usd), 0) as cost
- FROM request_logs WHERE timestamp >= ?`+tokenFilter+`
- GROUP BY token_name ORDER BY requests DESC`, args...)
- if err != nil {
- return nil
- }
- defer rows.Close()
-
- var results []TokenUsageStats
- for rows.Next() {
- var t TokenUsageStats
- rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD)
- results = append(results, t)
- }
- return results
-}
-
-// GetLogs returns paginated request logs with filters.
-func (s *StatsAPI) GetLogs(tokenNames []string, page int, model, token, status string) *LogsResult {
- if page < 1 {
- page = 1
- }
- limit := 50
- offset := (page - 1) * limit
-
- tokenFilter, filterArgs := buildTokenFilter(tokenNames)
-
- where := "WHERE 1=1" + tokenFilter
- args := make([]any, 0)
- args = append(args, filterArgs...)
-
- if model != "" {
- where += " AND model = ?"
- args = append(args, model)
- }
- if token != "" {
- where += " AND token_name = ?"
- args = append(args, token)
- }
- if status != "" {
- where += " AND status = ?"
- args = append(args, status)
- }
-
- // Get total count
- var total int
- countArgs := make([]any, len(args))
- copy(countArgs, args)
- s.db.QueryRow("SELECT COUNT(*) FROM request_logs "+where, countArgs...).Scan(&total)
-
- totalPages := (total + limit - 1) / limit
- if totalPages < 1 {
- totalPages = 1
- }
-
- // Get page
- query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model,
- input_tokens, output_tokens, cost_usd, latency_ms, status,
- COALESCE(error_message, ''), streaming, cached
- FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?`
- args = append(args, limit, offset)
-
- rows, err := s.db.Query(query, args...)
- if err != nil {
- return &LogsResult{Logs: []RequestLogEntry{}, Page: page, TotalPages: totalPages, Total: total}
- }
- defer rows.Close()
-
- var logs []RequestLogEntry
- for rows.Next() {
- var l RequestLogEntry
- var streaming, cached int
- rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
- &l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status,
- &l.ErrorMessage, &streaming, &cached)
- l.Streaming = streaming == 1
- l.Cached = cached == 1
- logs = append(logs, l)
- }
- if logs == nil {
- logs = []RequestLogEntry{}
- }
-
- return &LogsResult{
- Logs: logs,
- Page: page,
- TotalPages: totalPages,
- Total: total,
- }
-}
-
-// GetDistinctModels returns distinct model names from logs.
-func (s *StatsAPI) GetDistinctModels() []string {
- rows, err := s.db.Query("SELECT DISTINCT model FROM request_logs ORDER BY model")
- if err != nil {
- return nil
- }
- defer rows.Close()
- var models []string
- for rows.Next() {
- var m string
- rows.Scan(&m)
- models = append(models, m)
- }
- return models
-}
-
-// GetDistinctTokens returns distinct token names from logs.
-func (s *StatsAPI) GetDistinctTokens() []string {
- rows, err := s.db.Query("SELECT DISTINCT token_name FROM request_logs ORDER BY token_name")
- if err != nil {
- return nil
- }
- defer rows.Close()
- var tokens []string
- for rows.Next() {
- var t string
- rows.Scan(&t)
- tokens = append(tokens, t)
- }
- return tokens
-}
-
-// GetLatency computes latency percentiles from request_logs.
-func (s *StatsAPI) GetLatency(tokenNames []string, period, model, providerName string) *LatencyResult {
- var since int64
- switch period {
- case "7d":
- since = time.Now().AddDate(0, 0, -7).Unix()
- case "30d":
- since = time.Now().AddDate(0, -1, 0).Unix()
- default:
- since = time.Now().Add(-24 * time.Hour).Unix()
- }
-
- tokenFilter, filterArgs := buildTokenFilter(tokenNames)
-
- where := "WHERE timestamp >= ? AND status = 'success'" + tokenFilter
- args := []any{since}
- args = append(args, filterArgs...)
-
- if model != "" {
- where += " AND model = ?"
- args = append(args, model)
- }
- if providerName != "" {
- where += " AND provider = ?"
- args = append(args, providerName)
- }
-
- rows, err := s.db.Query("SELECT latency_ms FROM request_logs "+where+" ORDER BY latency_ms", args...)
- if err != nil {
- return &LatencyResult{}
- }
- defer rows.Close()
-
- var latencies []float64
- for rows.Next() {
- var l float64
- rows.Scan(&l)
- latencies = append(latencies, l)
- }
-
- if len(latencies) == 0 {
- return &LatencyResult{}
- }
-
- sort.Float64s(latencies)
- n := len(latencies)
- var sum float64
- for _, l := range latencies {
- sum += l
- }
-
- return &LatencyResult{
- P50: latencies[n*50/100],
- P95: latencies[n*95/100],
- P99: latencies[min(n*99/100, n-1)],
- Avg: sum / float64(n),
- Min: latencies[0],
- Max: latencies[n-1],
- }
-}
-
-// GetCostBreakdown returns cost data grouped by day and dimension.
-func (s *StatsAPI) GetCostBreakdown(tokenNames []string, period, groupBy string) []CostBreakdownEntry {
- var since int64
- switch period {
- case "30d":
- since = time.Now().AddDate(0, -1, 0).Unix()
- case "7d":
- since = time.Now().AddDate(0, 0, -7).Unix()
- default:
- since = time.Now().Add(-24 * time.Hour).Unix()
- }
-
- tokenFilter, filterArgs := buildTokenFilter(tokenNames)
-
- groupCol := "model"
- if groupBy == "token" {
- groupCol = "token_name"
- } else if groupBy == "provider" {
- groupCol = "provider"
- }
-
- args := []any{since}
- args = append(args, filterArgs...)
-
- query := `SELECT date(timestamp, 'unixepoch') as day, ` + groupCol + `,
- COALESCE(SUM(cost_usd), 0), COUNT(*)
- FROM request_logs WHERE timestamp >= ?` + tokenFilter + `
- GROUP BY day, ` + groupCol + ` ORDER BY day, ` + groupCol
-
- rows, err := s.db.Query(query, args...)
- if err != nil {
- return nil
- }
- defer rows.Close()
-
- var results []CostBreakdownEntry
- for rows.Next() {
- var e CostBreakdownEntry
- rows.Scan(&e.Day, &e.GroupBy, &e.CostUSD, &e.Requests)
- results = append(results, e)
- }
- return results
-}
-
-// JSON HTTP handlers (thin wrappers).
-
-func (s *StatsAPI) Summary(w http.ResponseWriter, r *http.Request) {
- tokenNames := s.tokenNamesForUser(r)
- result := s.GetSummary(tokenNames)
- writeJSON(w, result)
-}
-
-func (s *StatsAPI) Models(w http.ResponseWriter, r *http.Request) {
- tokenNames := s.tokenNamesForUser(r)
- results := s.GetModels(tokenNames)
- writeJSON(w, results)
-}
-
-func (s *StatsAPI) Providers(w http.ResponseWriter, r *http.Request) {
- tokenNames := s.tokenNamesForUser(r)
- results := s.GetProviders(tokenNames)
- writeJSON(w, results)
-}
-
-func (s *StatsAPI) Tokens(w http.ResponseWriter, r *http.Request) {
- tokenNames := s.tokenNamesForUser(r)
- results := s.GetTokenUsage(tokenNames)
- writeJSON(w, results)
-}
-
-func (s *StatsAPI) Timeseries(w http.ResponseWriter, r *http.Request) {
- period := r.URL.Query().Get("period")
- var since int64
- var groupFmt string
- switch period {
- case "7d":
- since = time.Now().AddDate(0, 0, -7).Unix()
- groupFmt = "%Y-%m-%d"
- case "30d":
- since = time.Now().AddDate(0, -1, 0).Unix()
- groupFmt = "%Y-%m-%d"
- default:
- since = time.Now().Add(-24 * time.Hour).Unix()
- groupFmt = "%Y-%m-%d %H:00"
- }
-
- tokenNames := s.tokenNamesForUser(r)
- tokenFilter, filterArgs := buildTokenFilter(tokenNames)
-
- args := append([]any{since}, filterArgs...)
- rows, err := s.db.Query(`SELECT
- strftime('`+groupFmt+`', timestamp, 'unixepoch') as bucket,
- COUNT(*) as requests,
- COALESCE(SUM(cost_usd), 0) as cost,
- COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens
- FROM request_logs WHERE timestamp >= ?`+tokenFilter+`
- GROUP BY bucket ORDER BY bucket`, args...)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
- defer rows.Close()
-
- type point struct {
- Bucket string `json:"bucket"`
- Requests int `json:"requests"`
- CostUSD float64 `json:"cost_usd"`
- TotalTokens int `json:"total_tokens"`
- }
-
- var results []point
- for rows.Next() {
- var p point
- rows.Scan(&p.Bucket, &p.Requests, &p.CostUSD, &p.TotalTokens)
- results = append(results, p)
- }
- writeJSON(w, results)
-}
-
-// Logs serves the paginated logs API.
-func (s *StatsAPI) Logs(w http.ResponseWriter, r *http.Request) {
- tokenNames := s.tokenNamesForUser(r)
- page, _ := strconv.Atoi(r.URL.Query().Get("page"))
- model := r.URL.Query().Get("model")
- token := r.URL.Query().Get("token")
- status := r.URL.Query().Get("status")
- result := s.GetLogs(tokenNames, page, model, token, status)
- writeJSON(w, result)
-}
-
-// Latency serves latency percentiles API.
-func (s *StatsAPI) Latency(w http.ResponseWriter, r *http.Request) {
- tokenNames := s.tokenNamesForUser(r)
- period := r.URL.Query().Get("period")
- model := r.URL.Query().Get("model")
- providerName := r.URL.Query().Get("provider")
- result := s.GetLatency(tokenNames, period, model, providerName)
- writeJSON(w, result)
-}
-
-// CostBreakdown serves cost breakdown API.
-func (s *StatsAPI) CostBreakdown(w http.ResponseWriter, r *http.Request) {
- tokenNames := s.tokenNamesForUser(r)
- period := r.URL.Query().Get("period")
- groupBy := r.URL.Query().Get("group_by")
- if groupBy == "" {
- groupBy = "model"
- }
- result := s.GetCostBreakdown(tokenNames, period, groupBy)
- writeJSON(w, result)
-}
-
-// ProviderHealthHandler serves provider health status API.
-func (s *StatsAPI) ProviderHealthHandler(w http.ResponseWriter, r *http.Request) {
- if s.healthTracker == nil {
- writeJSON(w, []provider.ProviderHealth{})
- return
- }
- writeJSON(w, s.healthTracker.Status())
-}
-
-// CacheStats serves cache statistics API.
-func (s *StatsAPI) CacheStats(w http.ResponseWriter, r *http.Request) {
- if s.cache == nil {
- writeJSON(w, map[string]any{"enabled": false})
- return
- }
- stats := s.cache.Stats(r.Context())
- writeJSON(w, stats)
-}
-
-// AuditLogs serves the audit log API (admin-only).
-func (s *StatsAPI) AuditLogs(w http.ResponseWriter, r *http.Request) {
- if s.auditLogger == nil {
- writeJSON(w, map[string]any{"entries": []any{}, "total": 0})
- return
- }
- page, _ := strconv.Atoi(r.URL.Query().Get("page"))
- action := r.URL.Query().Get("action")
- since := time.Now().AddDate(0, 0, -30).Unix()
- if sinceStr := r.URL.Query().Get("since"); sinceStr != "" {
- if s, err := strconv.ParseInt(sinceStr, 10, 64); err == nil {
- since = s
- }
- }
- result := s.auditLogger.Query(since, action, page, 50)
- writeJSON(w, result)
-}
-
-// DebugToggle enables/disables debug logging at runtime.
-func (s *StatsAPI) DebugToggle(w http.ResponseWriter, r *http.Request) {
- if s.debugLogger == nil {
- writeJSON(w, map[string]any{"error": "debug logger not configured"})
- return
- }
- var req struct {
- Enabled bool `json:"enabled"`
- }
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- w.WriteHeader(http.StatusBadRequest)
- writeJSON(w, map[string]string{"error": "invalid JSON"})
- return
- }
- s.debugLogger.SetEnabled(req.Enabled)
- writeJSON(w, map[string]any{"enabled": s.debugLogger.IsEnabled()})
-}
-
-// DebugStatus returns whether debug logging is enabled.
-func (s *StatsAPI) DebugStatus(w http.ResponseWriter, r *http.Request) {
- enabled := false
- if s.debugLogger != nil {
- enabled = s.debugLogger.IsEnabled()
- }
- writeJSON(w, map[string]any{"enabled": enabled})
-}
-
-// DebugLogs serves paginated debug log entries.
-func (s *StatsAPI) DebugLogs(w http.ResponseWriter, r *http.Request) {
- if s.debugLogger == nil {
- writeJSON(w, map[string]any{"entries": []any{}, "total": 0})
- return
- }
- page, _ := strconv.Atoi(r.URL.Query().Get("page"))
- result := s.debugLogger.Query(page, 50)
- writeJSON(w, result)
-}
-
-// DebugLogByRequestID serves a single debug log entry by request ID.
-func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) {
- if s.debugLogger == nil {
- w.WriteHeader(http.StatusNotFound)
- writeJSON(w, map[string]string{"error": "debug logger not configured"})
- return
- }
- requestID := chi.URLParam(r, "requestID")
- entry := s.debugLogger.GetByRequestID(requestID)
- if entry == nil {
- w.WriteHeader(http.StatusNotFound)
- writeJSON(w, map[string]string{"error": "not found"})
- return
- }
- writeJSON(w, entry)
-}
-
-// ValidateConfig validates the config file at the stored path.
-// Returns HTML for HTMX requests, JSON otherwise.
-func (s *StatsAPI) ValidateConfig(w http.ResponseWriter, r *http.Request) {
- if s.configPath == "" {
- if r.Header.Get("HX-Request") == "true" {
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
- w.Write([]byte(`Config path not set
`))
- } else {
- w.WriteHeader(http.StatusInternalServerError)
- writeJSON(w, map[string]any{"valid": false, "errors": []string{"config path not set"}})
- }
- return
- }
- data, err := os.ReadFile(s.configPath)
- if err != nil {
- msg := "failed to read config: " + err.Error()
- if r.Header.Get("HX-Request") == "true" {
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
- w.Write([]byte(`` + msg + `
`))
- } else {
- w.WriteHeader(http.StatusInternalServerError)
- writeJSON(w, map[string]any{"valid": false, "errors": []string{msg}})
- }
- return
- }
- errs := config.ValidateBytes(data)
-
- if r.Header.Get("HX-Request") == "true" {
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
- if len(errs) > 0 {
- html := `Configuration errors:
`
- for _, e := range errs {
- html += "- " + e + "
"
- }
- html += "
"
- w.Write([]byte(html))
- } else {
- w.Write([]byte(`Configuration is valid.
`))
- }
- return
- }
-
- if len(errs) > 0 {
- writeJSON(w, map[string]any{"valid": false, "errors": errs})
- return
- }
- writeJSON(w, map[string]any{"valid": true, "errors": []string{}})
-}
-
-func writeJSON(w http.ResponseWriter, v any) {
- w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(v)
-}
diff --git a/llm-gateway/internal/dashboard/export.go b/llm-gateway/internal/dashboard/export.go
deleted file mode 100644
index 2ac0875..0000000
--- a/llm-gateway/internal/dashboard/export.go
+++ /dev/null
@@ -1,297 +0,0 @@
-package dashboard
-
-import (
- "encoding/csv"
- "encoding/json"
- "fmt"
- "net/http"
- "strconv"
- "time"
-
- "llm-gateway/internal/auth"
- "llm-gateway/internal/storage"
-)
-
-type ExportHandler struct {
- db *storage.DB
- authStore *auth.Store
-}
-
-func NewExportHandler(db *storage.DB, authStore *auth.Store) *ExportHandler {
- return &ExportHandler{db: db, authStore: authStore}
-}
-
-// ExportLogs exports request logs as CSV or JSON.
-func (e *ExportHandler) ExportLogs(w http.ResponseWriter, r *http.Request) {
- format := r.URL.Query().Get("format")
- if format == "" {
- format = "json"
- }
-
- // Build query
- where := "WHERE 1=1"
- var args []any
-
- if from := r.URL.Query().Get("from"); from != "" {
- if ts, err := strconv.ParseInt(from, 10, 64); err == nil {
- where += " AND timestamp >= ?"
- args = append(args, ts)
- }
- }
- if to := r.URL.Query().Get("to"); to != "" {
- if ts, err := strconv.ParseInt(to, 10, 64); err == nil {
- where += " AND timestamp <= ?"
- args = append(args, ts)
- }
- }
- if model := r.URL.Query().Get("model"); model != "" {
- where += " AND model = ?"
- args = append(args, model)
- }
- if token := r.URL.Query().Get("token"); token != "" {
- where += " AND token_name = ?"
- args = append(args, token)
- }
- if status := r.URL.Query().Get("status"); status != "" {
- where += " AND status = ?"
- args = append(args, status)
- }
-
- // Token filtering for non-admins
- user := auth.UserFromContext(r.Context())
- if user != nil && !user.IsAdmin {
- tokens, err := e.authStore.ListAPITokens(user.ID)
- if err != nil || len(tokens) == 0 {
- where += " AND 1=0"
- } else {
- where += " AND token_name IN ("
- for i, t := range tokens {
- if i > 0 {
- where += ","
- }
- where += "?"
- args = append(args, t.Name)
- }
- where += ")"
- }
- }
-
- query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model,
- input_tokens, output_tokens, cost_usd, latency_ms, status,
- COALESCE(error_message, ''), streaming, cached
- FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT 100000`
-
- rows, err := e.db.Query(query, args...)
- if err != nil {
- http.Error(w, "query failed", http.StatusInternalServerError)
- return
- }
- defer rows.Close()
-
- type logRow struct {
- RequestID string `json:"request_id"`
- Timestamp int64 `json:"timestamp"`
- TokenName string `json:"token_name"`
- Model string `json:"model"`
- Provider string `json:"provider"`
- ProviderModel string `json:"provider_model"`
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- CostUSD float64 `json:"cost_usd"`
- LatencyMS int64 `json:"latency_ms"`
- Status string `json:"status"`
- ErrorMessage string `json:"error_message"`
- Streaming bool `json:"streaming"`
- Cached bool `json:"cached"`
- }
-
- var results []logRow
- for rows.Next() {
- var l logRow
- var streaming, cached int
- rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
- &l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status,
- &l.ErrorMessage, &streaming, &cached)
- l.Streaming = streaming == 1
- l.Cached = cached == 1
- results = append(results, l)
- }
-
- now := time.Now().Format("20060102-150405")
-
- switch format {
- case "csv":
- w.Header().Set("Content-Type", "text/csv")
- w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.csv", now))
- writer := csv.NewWriter(w)
- writer.Write([]string{"request_id", "timestamp", "token_name", "model", "provider", "provider_model",
- "input_tokens", "output_tokens", "cost_usd", "latency_ms", "status", "error_message", "streaming", "cached"})
- for _, l := range results {
- writer.Write([]string{
- l.RequestID,
- strconv.FormatInt(l.Timestamp, 10),
- l.TokenName, l.Model, l.Provider, l.ProviderModel,
- strconv.Itoa(l.InputTokens), strconv.Itoa(l.OutputTokens),
- fmt.Sprintf("%.8f", l.CostUSD),
- strconv.FormatInt(l.LatencyMS, 10),
- l.Status, l.ErrorMessage,
- strconv.FormatBool(l.Streaming), strconv.FormatBool(l.Cached),
- })
- }
- writer.Flush()
- default:
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.json", now))
- json.NewEncoder(w).Encode(results)
- }
-}
-
-// ExportStats exports aggregated stats as CSV or JSON.
-func (e *ExportHandler) ExportStats(w http.ResponseWriter, r *http.Request) {
- format := r.URL.Query().Get("format")
- if format == "" {
- format = "json"
- }
- statsType := r.URL.Query().Get("type")
- if statsType == "" {
- statsType = "summary"
- }
-
- now := time.Now().Format("20060102-150405")
- since := time.Now().AddDate(0, -1, 0).Unix()
-
- switch statsType {
- case "models":
- rows, err := e.db.Query(`SELECT model, COUNT(*) as requests,
- COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
- COALESCE(SUM(cost_usd), 0), COALESCE(AVG(latency_ms), 0)
- FROM request_logs WHERE timestamp >= ? GROUP BY model ORDER BY requests DESC`, since)
- if err != nil {
- http.Error(w, "query failed", http.StatusInternalServerError)
- return
- }
- defer rows.Close()
-
- type modelRow struct {
- Model string `json:"model"`
- Requests int `json:"requests"`
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- CostUSD float64 `json:"cost_usd"`
- AvgLatencyMS float64 `json:"avg_latency_ms"`
- }
- var results []modelRow
- for rows.Next() {
- var m modelRow
- rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS)
- results = append(results, m)
- }
-
- if format == "csv" {
- w.Header().Set("Content-Type", "text/csv")
- w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.csv", now))
- writer := csv.NewWriter(w)
- writer.Write([]string{"model", "requests", "input_tokens", "output_tokens", "cost_usd", "avg_latency_ms"})
- for _, m := range results {
- writer.Write([]string{m.Model, strconv.Itoa(m.Requests), strconv.Itoa(m.InputTokens),
- strconv.Itoa(m.OutputTokens), fmt.Sprintf("%.8f", m.CostUSD), fmt.Sprintf("%.2f", m.AvgLatencyMS)})
- }
- writer.Flush()
- } else {
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.json", now))
- json.NewEncoder(w).Encode(results)
- }
-
- case "providers":
- rows, err := e.db.Query(`SELECT provider, COUNT(*) as requests,
- COALESCE(SUM(CASE WHEN status='success' THEN 1 ELSE 0 END), 0),
- COALESCE(SUM(CASE WHEN status='error' THEN 1 ELSE 0 END), 0),
- COALESCE(AVG(latency_ms), 0), COALESCE(SUM(cost_usd), 0)
- FROM request_logs WHERE timestamp >= ? GROUP BY provider ORDER BY requests DESC`, since)
- if err != nil {
- http.Error(w, "query failed", http.StatusInternalServerError)
- return
- }
- defer rows.Close()
-
- type providerRow struct {
- Provider string `json:"provider"`
- Requests int `json:"requests"`
- Successes int `json:"successes"`
- Errors int `json:"errors"`
- AvgLatencyMS float64 `json:"avg_latency_ms"`
- CostUSD float64 `json:"cost_usd"`
- }
- var results []providerRow
- for rows.Next() {
- var p providerRow
- rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD)
- results = append(results, p)
- }
-
- if format == "csv" {
- w.Header().Set("Content-Type", "text/csv")
- w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.csv", now))
- writer := csv.NewWriter(w)
- writer.Write([]string{"provider", "requests", "successes", "errors", "avg_latency_ms", "cost_usd"})
- for _, p := range results {
- writer.Write([]string{p.Provider, strconv.Itoa(p.Requests), strconv.Itoa(p.Successes),
- strconv.Itoa(p.Errors), fmt.Sprintf("%.2f", p.AvgLatencyMS), fmt.Sprintf("%.8f", p.CostUSD)})
- }
- writer.Flush()
- } else {
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.json", now))
- json.NewEncoder(w).Encode(results)
- }
-
- case "tokens":
- rows, err := e.db.Query(`SELECT token_name, COUNT(*) as requests,
- COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
- COALESCE(SUM(cost_usd), 0)
- FROM request_logs WHERE timestamp >= ? GROUP BY token_name ORDER BY requests DESC`, since)
- if err != nil {
- http.Error(w, "query failed", http.StatusInternalServerError)
- return
- }
- defer rows.Close()
-
- type tokenRow struct {
- TokenName string `json:"token_name"`
- Requests int `json:"requests"`
- InputTokens int `json:"input_tokens"`
- OutputTokens int `json:"output_tokens"`
- CostUSD float64 `json:"cost_usd"`
- }
- var results []tokenRow
- for rows.Next() {
- var t tokenRow
- rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD)
- results = append(results, t)
- }
-
- if format == "csv" {
- w.Header().Set("Content-Type", "text/csv")
- w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.csv", now))
- writer := csv.NewWriter(w)
- writer.Write([]string{"token_name", "requests", "input_tokens", "output_tokens", "cost_usd"})
- for _, t := range results {
- writer.Write([]string{t.TokenName, strconv.Itoa(t.Requests), strconv.Itoa(t.InputTokens),
- strconv.Itoa(t.OutputTokens), fmt.Sprintf("%.8f", t.CostUSD)})
- }
- writer.Flush()
- } else {
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.json", now))
- json.NewEncoder(w).Encode(results)
- }
-
- default: // summary
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-summary-%s.json", now))
- statsAPI := NewStatsAPI(e.db, e.authStore)
- result := statsAPI.GetSummary(nil)
- json.NewEncoder(w).Encode(result)
- }
-}
diff --git a/llm-gateway/internal/dashboard/handler.go b/llm-gateway/internal/dashboard/handler.go
deleted file mode 100644
index e672637..0000000
--- a/llm-gateway/internal/dashboard/handler.go
+++ /dev/null
@@ -1,426 +0,0 @@
-package dashboard
-
-import (
- "embed"
- "fmt"
- "html/template"
- "net/http"
- "strconv"
- "time"
-
- "llm-gateway/internal/auth"
- "llm-gateway/internal/cache"
- "llm-gateway/internal/provider"
- "llm-gateway/internal/storage"
-)
-
-//go:embed templates/*.html templates/partials/*.html
-var templateFiles embed.FS
-
-var templateFuncs = template.FuncMap{
- "formatTime": func(ts int64) string {
- if ts == 0 {
- return "never"
- }
- return time.Unix(ts, 0).Format("2006-01-02")
- },
- "formatTimeDetail": func(ts int64) string {
- if ts == 0 {
- return "never"
- }
- return time.Unix(ts, 0).Format("2006-01-02 15:04:05")
- },
- "addInt": func(a, b int) int {
- return a + b
- },
- "subInt": func(a, b int) int {
- return a - b
- },
- "formatCost": func(v float64) string {
- if v == 0 {
- return "$0.00"
- }
- if v < 0.01 {
- return fmt.Sprintf("$%.6f", v)
- }
- return fmt.Sprintf("$%.4f", v)
- },
- "formatPrice": func(v float64) string {
- if v == 0 {
- return "-"
- }
- return fmt.Sprintf("$%.2f", v)
- },
- "formatPct": func(v float64) string {
- return fmt.Sprintf("%.1f%%", v*100)
- },
- "budgetPct": func(spend, budget float64) float64 {
- if budget <= 0 {
- return 0
- }
- return spend / budget * 100
- },
- "budgetColor": func(pct float64) string {
- if pct >= 80 {
- return "#f87171"
- }
- if pct >= 50 {
- return "#fbbf24"
- }
- return "#4ade80"
- },
- "seq": func(start, end int) []int {
- var s []int
- for i := start; i <= end; i++ {
- s = append(s, i)
- }
- return s
- },
- "paginationStart": func(page, totalPages int) int {
- start := page - 2
- if start < 1 {
- start = 1
- }
- if totalPages-start < 4 && totalPages > 4 {
- start = totalPages - 4
- }
- return start
- },
- "paginationEnd": func(page, totalPages int) int {
- start := page - 2
- if start < 1 {
- start = 1
- }
- end := start + 4
- if end > totalPages {
- end = totalPages
- }
- return end
- },
-}
-
-// PageData is the common data passed to all templates.
-type PageData struct {
- ActivePage string
- User *auth.User
- // Dashboard data
- Summary *SummaryResult
- Models []ModelStats
- Providers []ProviderStats
- TokenStats []TokenUsageStats
- ProviderHealth []provider.ProviderHealth
- Latency *LatencyResult
- CacheEnabled bool
- CacheInfo *cache.CacheStats
- // Tokens page data
- Tokens []auth.APIToken
- TokenSpend map[string]float64
- // Users page data
- Users []auth.User
- // Logs page data
- LogsResult *LogsResult
- LogModels []string
- LogTokens []string
- FilterModel string
- FilterToken string
- FilterStatus string
- // Models routing page data
- ModelRoutes []provider.ModelRouteInfo
- // Audit page data
- AuditResult *storage.AuditQueryResult
- AuditFilterActions []string
- FilterAction string
- // Debug page data
- DebugResult *storage.DebugLogQueryResult
- DebugEnabled bool
-}
-
-// Dashboard serves the HTMX-based dashboard pages.
-type Dashboard struct {
- templates *template.Template
- authStore *auth.Store
- statsAPI *StatsAPI
- registry *provider.Registry
- cache *cache.Cache
- auditLogger *storage.AuditLogger
- debugLogger *storage.DebugLogger
-}
-
-// NewDashboard creates a new Dashboard handler.
-func NewDashboard(authStore *auth.Store, statsAPI *StatsAPI) *Dashboard {
- tmpl := template.Must(
- template.New("").Funcs(templateFuncs).ParseFS(templateFiles,
- "templates/*.html",
- "templates/partials/*.html",
- ),
- )
-
- return &Dashboard{
- templates: tmpl,
- authStore: authStore,
- statsAPI: statsAPI,
- }
-}
-
-// SetRegistry sets the provider registry for model routing display.
-func (d *Dashboard) SetRegistry(r *provider.Registry) {
- d.registry = r
-}
-
-// SetCache sets the cache reference for cache stats display.
-func (d *Dashboard) SetCache(c *cache.Cache) {
- d.cache = c
-}
-
-// SetAuditLogger sets the audit logger for the audit page.
-func (d *Dashboard) SetAuditLogger(al *storage.AuditLogger) {
- d.auditLogger = al
-}
-
-// SetDebugLogger sets the debug logger for the debug page.
-func (d *Dashboard) SetDebugLogger(dl *storage.DebugLogger) {
- d.debugLogger = dl
-}
-
-// LoginPage serves the login page.
-func (d *Dashboard) LoginPage(w http.ResponseWriter, r *http.Request) {
- if !d.authStore.HasAnyUser() {
- http.Redirect(w, r, "/setup", http.StatusFound)
- return
- }
- if user := d.getSessionUser(r); user != nil {
- http.Redirect(w, r, "/dashboard", http.StatusFound)
- return
- }
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
- d.templates.ExecuteTemplate(w, "login", nil)
-}
-
-// SetupPage serves the initial setup page.
-func (d *Dashboard) SetupPage(w http.ResponseWriter, r *http.Request) {
- if d.authStore.HasAnyUser() {
- http.Redirect(w, r, "/login", http.StatusFound)
- return
- }
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
- d.templates.ExecuteTemplate(w, "setup", nil)
-}
-
-// DashboardPage serves the main dashboard view.
-func (d *Dashboard) DashboardPage(w http.ResponseWriter, r *http.Request) {
- user := auth.UserFromContext(r.Context())
- tokenNames := d.statsAPI.TokenNamesForUser(user)
-
- data := PageData{
- ActivePage: "dashboard",
- User: user,
- Summary: d.statsAPI.GetSummary(tokenNames),
- Models: d.statsAPI.GetModels(tokenNames),
- Providers: d.statsAPI.GetProviders(tokenNames),
- TokenStats: d.statsAPI.GetTokenUsage(tokenNames),
- Latency: d.statsAPI.GetLatency(tokenNames, "24h", "", ""),
- }
-
- // Provider health
- if d.statsAPI.healthTracker != nil {
- data.ProviderHealth = d.statsAPI.healthTracker.Status()
- }
-
- // Cache stats
- if d.cache != nil {
- data.CacheEnabled = true
- data.CacheInfo = d.cache.Stats(r.Context())
- }
-
- d.renderDashboardPage(w, r, "partials/dashboard.html", data)
-}
-
-// LogsPage serves the request logs view.
-func (d *Dashboard) LogsPage(w http.ResponseWriter, r *http.Request) {
- user := auth.UserFromContext(r.Context())
- tokenNames := d.statsAPI.TokenNamesForUser(user)
-
- page, _ := strconv.Atoi(r.URL.Query().Get("page"))
- if page < 1 {
- page = 1
- }
- model := r.URL.Query().Get("model")
- token := r.URL.Query().Get("token")
- status := r.URL.Query().Get("status")
-
- data := PageData{
- ActivePage: "logs",
- User: user,
- LogsResult: d.statsAPI.GetLogs(tokenNames, page, model, token, status),
- LogModels: d.statsAPI.GetDistinctModels(),
- LogTokens: d.statsAPI.GetDistinctTokens(),
- FilterModel: model,
- FilterToken: token,
- FilterStatus: status,
- }
-
- d.renderDashboardPage(w, r, "partials/logs.html", data)
-}
-
-// ModelsPage serves the model routing table view.
-func (d *Dashboard) ModelsPage(w http.ResponseWriter, r *http.Request) {
- user := auth.UserFromContext(r.Context())
-
- data := PageData{
- ActivePage: "models",
- User: user,
- }
-
- if d.registry != nil {
- data.ModelRoutes = d.registry.AllRoutes()
- }
-
- if d.statsAPI.healthTracker != nil {
- data.ProviderHealth = d.statsAPI.healthTracker.Status()
- }
-
- d.renderDashboardPage(w, r, "partials/models-page.html", data)
-}
-
-// TokensPage serves the tokens management view.
-func (d *Dashboard) TokensPage(w http.ResponseWriter, r *http.Request) {
- user := auth.UserFromContext(r.Context())
-
- var userID int64
- if !user.IsAdmin {
- userID = user.ID
- }
-
- tokens, _ := d.authStore.ListAPITokens(userID)
- if tokens == nil {
- tokens = []auth.APIToken{}
- }
-
- // Get today's spend for budget display
- spend, _ := d.statsAPI.db.TodaySpendAll()
- if spend == nil {
- spend = make(map[string]float64)
- }
-
- d.renderDashboardPage(w, r, "partials/tokens.html", PageData{
- ActivePage: "tokens",
- User: user,
- Tokens: tokens,
- TokenSpend: spend,
- })
-}
-
-// UsersPage serves the user management view (admin only).
-func (d *Dashboard) UsersPage(w http.ResponseWriter, r *http.Request) {
- user := auth.UserFromContext(r.Context())
- users, _ := d.authStore.ListUsers()
-
- d.renderDashboardPage(w, r, "partials/users.html", PageData{
- ActivePage: "users",
- User: user,
- Users: users,
- })
-}
-
-// AuditPage serves the audit log view (admin only).
-func (d *Dashboard) AuditPage(w http.ResponseWriter, r *http.Request) {
- user := auth.UserFromContext(r.Context())
-
- page, _ := strconv.Atoi(r.URL.Query().Get("page"))
- if page < 1 {
- page = 1
- }
- action := r.URL.Query().Get("action")
- since := time.Now().AddDate(0, 0, -30).Unix()
-
- var auditResult *storage.AuditQueryResult
- if d.auditLogger != nil {
- auditResult = d.auditLogger.Query(since, action, page, 50)
- } else {
- auditResult = &storage.AuditQueryResult{Entries: []storage.AuditEntry{}, Page: 1, TotalPages: 1}
- }
-
- // Common audit action types for the filter dropdown
- actions := []string{"login", "logout", "create_user", "delete_user", "create_token", "delete_token", "change_password", "setup_totp", "disable_totp"}
-
- d.renderDashboardPage(w, r, "partials/audit.html", PageData{
- ActivePage: "audit",
- User: user,
- AuditResult: auditResult,
- AuditFilterActions: actions,
- FilterAction: action,
- })
-}
-
-// DebugPage serves the debug logging view (admin only).
-func (d *Dashboard) DebugPage(w http.ResponseWriter, r *http.Request) {
- user := auth.UserFromContext(r.Context())
-
- page, _ := strconv.Atoi(r.URL.Query().Get("page"))
- if page < 1 {
- page = 1
- }
-
- var debugResult *storage.DebugLogQueryResult
- debugEnabled := false
- if d.debugLogger != nil {
- debugResult = d.debugLogger.QueryFull(page, 50)
- debugEnabled = d.debugLogger.IsEnabled()
- } else {
- debugResult = &storage.DebugLogQueryResult{Entries: []storage.DebugLogEntry{}, Page: 1, TotalPages: 1}
- }
-
- d.renderDashboardPage(w, r, "partials/debug.html", PageData{
- ActivePage: "debug",
- User: user,
- DebugResult: debugResult,
- DebugEnabled: debugEnabled,
- })
-}
-
-// SettingsPage serves the settings view.
-func (d *Dashboard) SettingsPage(w http.ResponseWriter, r *http.Request) {
- user := auth.UserFromContext(r.Context())
- user, _ = d.authStore.GetUserByID(user.ID)
-
- d.renderDashboardPage(w, r, "partials/settings.html", PageData{
- ActivePage: "settings",
- User: user,
- })
-}
-
-// renderDashboardPage renders either the full layout or just the content partial.
-func (d *Dashboard) renderDashboardPage(w http.ResponseWriter, r *http.Request, partialFile string, data PageData) {
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
-
- if r.Header.Get("HX-Request") == "true" {
- tmpl := template.Must(
- template.New("").Funcs(templateFuncs).ParseFS(templateFiles, "templates/"+partialFile),
- )
- tmpl.ExecuteTemplate(w, "content", data)
- } else {
- tmpl := template.Must(
- template.New("").Funcs(templateFuncs).ParseFS(templateFiles,
- "templates/layout.html",
- "templates/"+partialFile,
- ),
- )
- tmpl.ExecuteTemplate(w, "layout", data)
- }
-}
-
-func (d *Dashboard) getSessionUser(r *http.Request) *auth.User {
- cookie, err := r.Cookie("llmgw_session")
- if err != nil || cookie.Value == "" {
- return nil
- }
- sess, err := d.authStore.GetSession(cookie.Value)
- if err != nil {
- return nil
- }
- user, err := d.authStore.GetUserByID(sess.UserID)
- if err != nil {
- return nil
- }
- return user
-}
diff --git a/llm-gateway/internal/dashboard/sse.go b/llm-gateway/internal/dashboard/sse.go
deleted file mode 100644
index aafd88d..0000000
--- a/llm-gateway/internal/dashboard/sse.go
+++ /dev/null
@@ -1,73 +0,0 @@
-package dashboard
-
-import (
- "fmt"
- "net/http"
- "sync"
-)
-
-// SSEBroker manages Server-Sent Events connections.
-type SSEBroker struct {
- mu sync.RWMutex
- clients map[chan struct{}]struct{}
-}
-
-// NewSSEBroker creates a new SSE broker.
-func NewSSEBroker() *SSEBroker {
- return &SSEBroker{
- clients: make(map[chan struct{}]struct{}),
- }
-}
-
-// Notify sends a refresh signal to all connected SSE clients.
-func (b *SSEBroker) Notify() {
- b.mu.RLock()
- defer b.mu.RUnlock()
- for ch := range b.clients {
- select {
- case ch <- struct{}{}:
- default:
- // Client not ready, skip
- }
- }
-}
-
-// ServeHTTP handles SSE connections.
-func (b *SSEBroker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- flusher, ok := w.(http.Flusher)
- if !ok {
- http.Error(w, "streaming not supported", http.StatusInternalServerError)
- return
- }
-
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("X-Accel-Buffering", "no")
-
- ch := make(chan struct{}, 1)
- b.mu.Lock()
- b.clients[ch] = struct{}{}
- b.mu.Unlock()
-
- defer func() {
- b.mu.Lock()
- delete(b.clients, ch)
- b.mu.Unlock()
- }()
-
- // Send initial connection event
- fmt.Fprintf(w, "event: connected\ndata: ok\n\n")
- flusher.Flush()
-
- ctx := r.Context()
- for {
- select {
- case <-ctx.Done():
- return
- case <-ch:
- fmt.Fprintf(w, "event: refresh\ndata: updated\n\n")
- flusher.Flush()
- }
- }
-}
diff --git a/llm-gateway/internal/dashboard/templates/layout.html b/llm-gateway/internal/dashboard/templates/layout.html
deleted file mode 100644
index 1b35132..0000000
--- a/llm-gateway/internal/dashboard/templates/layout.html
+++ /dev/null
@@ -1,366 +0,0 @@
-{{define "layout"}}
-
-
-
-
-
-LLM Gateway
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- {{template "content" .}}
-
-
-
-
-
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/login.html b/llm-gateway/internal/dashboard/templates/login.html
deleted file mode 100644
index 500cc79..0000000
--- a/llm-gateway/internal/dashboard/templates/login.html
+++ /dev/null
@@ -1,44 +0,0 @@
-{{define "login"}}
-
-
-
-
-
-Login - LLM Gateway
-
-
-
-
-
-
-
-
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/partials/audit.html b/llm-gateway/internal/dashboard/templates/partials/audit.html
deleted file mode 100644
index 7697964..0000000
--- a/llm-gateway/internal/dashboard/templates/partials/audit.html
+++ /dev/null
@@ -1,83 +0,0 @@
-{{define "content"}}
-
-
-
-
-
-
-
-
-
-
-
- | Time |
- User |
- Action |
- Target |
- Details |
- IP |
-
-
-
- {{range .AuditResult.Entries}}
-
- | {{formatTimeDetail .Timestamp}} |
- {{.Username}} |
- {{.Action}} |
- {{if .TargetType}}{{.TargetType}}{{if .TargetID}}/{{.TargetID}}{{end}}{{else}}-{{end}} |
- {{if .Details}}{{.Details}}{{else}}-{{end}} |
- {{if .IPAddress}}{{.IPAddress}}{{else}}-{{end}} |
-
- {{end}}
- {{if not .AuditResult.Entries}}
- | No audit log entries |
- {{end}}
-
-
-
- {{if gt .AuditResult.TotalPages 1}}
-
- {{end}}
-
-
-
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/partials/dashboard.html b/llm-gateway/internal/dashboard/templates/partials/dashboard.html
deleted file mode 100644
index 715313b..0000000
--- a/llm-gateway/internal/dashboard/templates/partials/dashboard.html
+++ /dev/null
@@ -1,230 +0,0 @@
-{{define "content"}}
-
-
-
- {{with .Summary.Today}}
-
Requests Today
{{.Requests}}
-
Cost Today
{{formatCost .CostUSD}}
-
Tokens Today
{{addInt .InputTokens .OutputTokens}}
{{.InputTokens}} in / {{.OutputTokens}} out
-
-
Cache Hits
{{.CachedHits}}
- {{end}}
- {{with .Summary.Week}}
-
Cost (7d)
{{formatCost .CostUSD}}
- {{end}}
-
-
-{{if .ProviderHealth}}
-
-
Provider Health
-
- {{range .ProviderHealth}}
-
- {{.Provider}}
- {{.Status}}
- {{if eq .CircuitState "open"}}circuit open{{end}}
- {{if eq .CircuitState "half-open"}}half-open{{end}}
- {{printf "%.0f" .AvgLatency}}ms avg | {{formatPct .ErrorRate}} errors
-
- {{end}}
-
-
-{{end}}
-
-{{if .Latency}}{{if gt .Latency.Max 0.0}}
-
-
P50 Latency
{{printf "%.0f" .Latency.P50}}ms
-
P95 Latency
{{printf "%.0f" .Latency.P95}}ms
-
P99 Latency
{{printf "%.0f" .Latency.P99}}ms
-
Avg Latency
{{printf "%.0f" .Latency.Avg}}ms
-
-{{end}}{{end}}
-
-{{if .CacheEnabled}}{{if .CacheInfo}}{{if .CacheInfo.Connected}}
-
-
Cache Hit Rate
{{formatPct .CacheInfo.HitRate}}
{{.CacheInfo.Hits}} hits / {{.CacheInfo.Misses}} misses
-
Cache Memory
{{.CacheInfo.MemoryUsed}}
-
Cached Keys
{{.CacheInfo.Keys}}
-
-{{end}}{{end}}{{end}}
-
-
-
-
-
-
-
-
Requests & Cost
-
-
-
-
-
Cost Breakdown
-
-
-
-
-
-
-
-
-{{if .Models}}
-
-
-
- | Model | Requests | Tokens (in/out) | Cost | Avg Latency |
-
- {{range .Models}}
-
- | {{.Model}} |
- {{.Requests}} |
- {{.InputTokens}} / {{.OutputTokens}} |
- {{formatCost .CostUSD}} |
- {{printf "%.0f" .AvgLatencyMS}}ms |
-
- {{end}}
-
-
-
-{{end}}
-
-{{if .Providers}}
-
-
-
- | Provider | Requests | Success | Errors | Avg Latency | Cost |
-
- {{range .Providers}}
-
- | {{.Provider}} |
- {{.Requests}} |
- {{.Successes}} |
- {{.Errors}} |
- {{printf "%.0f" .AvgLatencyMS}}ms |
- {{formatCost .CostUSD}} |
-
- {{end}}
-
-
-
-{{end}}
-
-{{if .TokenStats}}
-
-
API Token UsageCSVJSON
-
- | Token | Requests | Tokens (in/out) | Cost |
-
- {{range .TokenStats}}
-
- | {{.TokenName}} |
- {{.Requests}} |
- {{.InputTokens}} / {{.OutputTokens}} |
- {{formatCost .CostUSD}} |
-
- {{end}}
-
-
-
-{{end}}
-
-
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/partials/debug.html b/llm-gateway/internal/dashboard/templates/partials/debug.html
deleted file mode 100644
index 4b55fd5..0000000
--- a/llm-gateway/internal/dashboard/templates/partials/debug.html
+++ /dev/null
@@ -1,94 +0,0 @@
-{{define "content"}}
-
-
-
- Debug Mode
-
- {{if .DebugEnabled}}Enabled — requests are being logged{{else}}Disabled{{end}}
-
-
-
-
-
-
- |
- Time |
- Request ID |
- Token |
- Model |
- Provider |
- Status |
-
-
-
- {{range $i, $entry := .DebugResult.Entries}}
-
- | ▶ |
- {{formatTimeDetail $entry.Timestamp}} |
- {{$entry.RequestID}} |
- {{$entry.TokenName}} |
- {{$entry.Model}} |
- {{$entry.Provider}} |
-
- {{if and (ge $entry.ResponseStatus 200) (lt $entry.ResponseStatus 300)}}{{$entry.ResponseStatus}}
- {{else if ge $entry.ResponseStatus 400}}{{$entry.ResponseStatus}}
- {{else}}{{$entry.ResponseStatus}}{{end}}
- |
-
-
-
-
- Request Headers:
- {{if $entry.RequestHeaders}}{{$entry.RequestHeaders}}{{else}}(none){{end}}
- Request Body:
- {{if $entry.RequestBody}}{{$entry.RequestBody}}{{else}}(none){{end}}
- Response Body:
- {{if $entry.ResponseBody}}{{$entry.ResponseBody}}{{else}}(none){{end}}
-
- |
-
- {{end}}
- {{if not .DebugResult.Entries}}
- | No debug log entries |
- {{end}}
-
-
-
- {{if gt .DebugResult.TotalPages 1}}
-
- {{end}}
-
-
-
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/partials/logs.html b/llm-gateway/internal/dashboard/templates/partials/logs.html
deleted file mode 100644
index bf8c640..0000000
--- a/llm-gateway/internal/dashboard/templates/partials/logs.html
+++ /dev/null
@@ -1,133 +0,0 @@
-{{define "content"}}
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- | Time |
- Token |
- Model |
- Provider |
- Status |
- Latency |
- Tokens |
- Cost |
-
-
-
- {{range $i, $log := .LogsResult.Logs}}
-
- | {{formatTimeDetail $log.Timestamp}} |
- {{$log.TokenName}} |
- {{$log.Model}} |
- {{$log.Provider}} |
-
- {{if eq $log.Status "success"}}success
- {{else if eq $log.Status "error"}}error
- {{else if eq $log.Status "cached"}}cached
- {{else}}{{$log.Status}}{{end}}
- {{if $log.Streaming}} stream{{end}}
- |
- {{$log.LatencyMS}}ms |
- {{$log.InputTokens}} / {{$log.OutputTokens}} |
- {{formatCost $log.CostUSD}} |
-
- {{if $log.ErrorMessage}}
-
- |
- {{$log.ErrorMessage}}
- |
-
- {{end}}
- {{end}}
- {{if not .LogsResult.Logs}}
- | No logs found |
- {{end}}
-
-
-
- {{if gt .LogsResult.TotalPages 1}}
-
- {{end}}
-
-
-
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/partials/models-page.html b/llm-gateway/internal/dashboard/templates/partials/models-page.html
deleted file mode 100644
index 6d4a60b..0000000
--- a/llm-gateway/internal/dashboard/templates/partials/models-page.html
+++ /dev/null
@@ -1,53 +0,0 @@
-{{define "content"}}
-
-
-{{if .ModelRoutes}}
-{{range .ModelRoutes}}
-
-
{{.Name}}{{if .Aliases}} aliases: {{range $i, $a := .Aliases}}{{if $i}}, {{end}}{{$a}}{{end}}{{end}}
-
-
-
- | Provider |
- Provider Model |
- Priority |
- Input Price (per 1M) |
- Output Price (per 1M) |
- Health |
-
-
-
- {{$health := $.ProviderHealth}}
- {{range .Routes}}
-
- | {{.ProviderName}} |
- {{.ProviderModel}} |
- {{.Priority}} |
- {{formatPrice .InputPrice}} |
- {{formatPrice .OutputPrice}} |
-
- {{$pname := .ProviderName}}
- {{range $health}}
- {{if eq .Provider $pname}}
- {{if eq .Status "healthy"}}healthy
- {{else if eq .Status "degraded"}}degraded
- {{else}}down
- {{end}}
- {{end}}
- {{end}}
- {{if not $health}}-{{end}}
- |
-
- {{end}}
-
-
-
-{{end}}
-{{else}}
-
- No models configured
-
-{{end}}
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/partials/settings.html b/llm-gateway/internal/dashboard/templates/partials/settings.html
deleted file mode 100644
index f63c0a5..0000000
--- a/llm-gateway/internal/dashboard/templates/partials/settings.html
+++ /dev/null
@@ -1,141 +0,0 @@
-{{define "content"}}
-
-
-
-
-
-
-
-
Two-Factor Authentication
-
- {{if .User.TOTPEnabled}}
-
Two-factor authentication is enabled.
-
- {{else}}
-
Two-factor authentication is not enabled.
-
- {{end}}
-
-
-
Scan this QR code with your authenticator app, then enter the code below to verify.
-
-
-
-
-
-
-{{if .User.IsAdmin}}
-
-
Config Validation
-
Validate the current gateway configuration file for errors.
-
-
-
-{{end}}
-
-
-
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/partials/tokens.html b/llm-gateway/internal/dashboard/templates/partials/tokens.html
deleted file mode 100644
index 974f5e4..0000000
--- a/llm-gateway/internal/dashboard/templates/partials/tokens.html
+++ /dev/null
@@ -1,142 +0,0 @@
-{{define "content"}}
-
-
-
-
-
-
Static Tokens (from config, managed via environment variables)
-
- | Name | Prefix | Rate Limit | Budget | Today's Spend | |
-
- {{range .Tokens}}{{if lt .ID 0}}
-
- | {{.Name}} |
- {{.KeyPrefix}}... |
- {{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}} |
-
- {{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
- {{if gt .MonthlyBudgetUSD 0.0}} ${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
- |
-
- {{$spend := index $.TokenSpend .Name}}
- {{if gt .DailyBudgetUSD 0.0}}
- {{$pct := budgetPct $spend .DailyBudgetUSD}}
-
-
- ${{printf "%.4f" $spend}} / ${{printf "%.2f" .DailyBudgetUSD}} ({{printf "%.1f" $pct}}%)
-
- {{else}}
- {{if gt $spend 0.0}}{{formatCost $spend}}{{else}}-{{end}}
- {{end}}
- |
- config |
-
- {{end}}{{end}}
-
-
-
-
-
-
Dynamic Tokens (created via dashboard)
-
- | Name | Prefix | Rate Limit | Budget | Today's Spend | Created | Last Used | |
-
- {{range .Tokens}}{{if gt .ID 0}}
-
- | {{.Name}} |
- {{.KeyPrefix}}... |
- {{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}} |
-
- {{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
- {{if gt .MonthlyBudgetUSD 0.0}} ${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
- |
-
- {{$spend := index $.TokenSpend .Name}}
- {{if gt .DailyBudgetUSD 0.0}}
- {{$pct := budgetPct $spend .DailyBudgetUSD}}
-
-
- ${{printf "%.4f" $spend}} / ${{printf "%.2f" .DailyBudgetUSD}} ({{printf "%.1f" $pct}}%)
-
- {{else}}
- {{if gt $spend 0.0}}{{formatCost $spend}}{{else}}-{{end}}
- {{end}}
- |
- {{formatTime .CreatedAt}} |
- {{if gt .LastUsedAt 0}}{{formatTime .LastUsedAt}}{{else}}never{{end}} |
- |
-
- {{end}}{{end}}
-
-
-
-
-
-
-
-
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/partials/users.html b/llm-gateway/internal/dashboard/templates/partials/users.html
deleted file mode 100644
index 2402b49..0000000
--- a/llm-gateway/internal/dashboard/templates/partials/users.html
+++ /dev/null
@@ -1,70 +0,0 @@
-{{define "content"}}
-
-
-
-
- | ID | Username | Role | 2FA | Created | |
-
- {{range .Users}}
-
- | {{.ID}} |
- {{.Username}} |
- {{if .IsAdmin}}Admin{{else}}User{{end}} |
- {{if .TOTPEnabled}}Enabled{{else}}Off{{end}} |
- {{formatTime .CreatedAt}} |
- {{if ne .ID $.User.ID}}{{end}} |
-
- {{end}}
-
-
-
-
-
-
-
-
-{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/setup.html b/llm-gateway/internal/dashboard/templates/setup.html
deleted file mode 100644
index 11c359f..0000000
--- a/llm-gateway/internal/dashboard/templates/setup.html
+++ /dev/null
@@ -1,59 +0,0 @@
-{{define "setup"}}
-
-
-
-
-
-Setup - LLM Gateway
-
-
-
-
-
-
-
LLM Gateway Setup
-
Create the first admin account
-
-
-
-
-
-
-{{end}}
diff --git a/llm-gateway/internal/metrics/prometheus.go b/llm-gateway/internal/metrics/prometheus.go
deleted file mode 100644
index 620f51e..0000000
--- a/llm-gateway/internal/metrics/prometheus.go
+++ /dev/null
@@ -1,73 +0,0 @@
-package metrics
-
-import (
- "github.com/prometheus/client_golang/prometheus"
- "github.com/prometheus/client_golang/prometheus/promauto"
-)
-
-type Metrics struct {
- requestsTotal *prometheus.CounterVec
- requestDuration *prometheus.HistogramVec
- tokensTotal *prometheus.CounterVec
- costTotal *prometheus.CounterVec
- cacheHits prometheus.Counter
- cacheMisses prometheus.Counter
-}
-
-func New() *Metrics {
- return &Metrics{
- requestsTotal: promauto.NewCounterVec(prometheus.CounterOpts{
- Name: "llm_gateway_requests_total",
- Help: "Total number of LLM requests",
- }, []string{"model", "provider", "token_name", "status"}),
-
- requestDuration: promauto.NewHistogramVec(prometheus.HistogramOpts{
- Name: "llm_gateway_request_duration_ms",
- Help: "Request duration in milliseconds",
- Buckets: []float64{100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000},
- }, []string{"model", "provider"}),
-
- tokensTotal: promauto.NewCounterVec(prometheus.CounterOpts{
- Name: "llm_gateway_tokens_total",
- Help: "Total tokens processed",
- }, []string{"model", "provider", "type"}),
-
- costTotal: promauto.NewCounterVec(prometheus.CounterOpts{
- Name: "llm_gateway_cost_usd_total",
- Help: "Total cost in USD",
- }, []string{"model", "provider", "token_name"}),
-
- cacheHits: promauto.NewCounter(prometheus.CounterOpts{
- Name: "llm_gateway_cache_hits_total",
- Help: "Total number of cache hits",
- }),
-
- cacheMisses: promauto.NewCounter(prometheus.CounterOpts{
- Name: "llm_gateway_cache_misses_total",
- Help: "Total number of cache misses",
- }),
- }
-}
-
-func (m *Metrics) RecordRequest(model, providerName, tokenName, status string, latencyMS int64, inputTokens, outputTokens int, cost float64) {
- m.requestsTotal.WithLabelValues(model, providerName, tokenName, status).Inc()
- m.requestDuration.WithLabelValues(model, providerName).Observe(float64(latencyMS))
-
- if inputTokens > 0 {
- m.tokensTotal.WithLabelValues(model, providerName, "input").Add(float64(inputTokens))
- }
- if outputTokens > 0 {
- m.tokensTotal.WithLabelValues(model, providerName, "output").Add(float64(outputTokens))
- }
- if cost > 0 {
- m.costTotal.WithLabelValues(model, providerName, tokenName).Add(cost)
- }
-}
-
-func (m *Metrics) RecordCacheHit() {
- m.cacheHits.Inc()
-}
-
-func (m *Metrics) RecordCacheMiss() {
- m.cacheMisses.Inc()
-}
diff --git a/llm-gateway/internal/pricing/pricing.go b/llm-gateway/internal/pricing/pricing.go
deleted file mode 100644
index 69b01d8..0000000
--- a/llm-gateway/internal/pricing/pricing.go
+++ /dev/null
@@ -1,191 +0,0 @@
-package pricing
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "log"
- "net/http"
- "sync"
- "time"
-)
-
-const defaultPricesURL = "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json"
-
-// Provider represents a provider entry in genai_prices.json.
-type Provider struct {
- ID string `json:"id"`
- Models []Model `json:"models"`
-}
-
-// Model represents a model entry with pricing.
-type Model struct {
- ID string `json:"id"`
- Prices json.RawMessage `json:"prices"`
-}
-
-// Lookup provides pricing data fetched from genai-prices.
-type Lookup struct {
- mu sync.RWMutex
- prices map[string][2]float64
- url string
- stopCh chan struct{}
-}
-
-// NewLookup creates a Lookup that fetches pricing data immediately and refreshes every interval.
-// If url is empty, uses the default genai-prices URL.
-// Returns a usable Lookup even if the initial fetch fails (prices will be empty until next refresh).
-func NewLookup(url string, interval time.Duration) *Lookup {
- if url == "" {
- url = defaultPricesURL
- }
- l := &Lookup{
- prices: make(map[string][2]float64),
- url: url,
- stopCh: make(chan struct{}),
- }
-
- // Initial fetch
- l.refresh()
-
- // Background refresh
- go func() {
- ticker := time.NewTicker(interval)
- defer ticker.Stop()
- for {
- select {
- case <-ticker.C:
- l.refresh()
- case <-l.stopCh:
- return
- }
- }
- }()
-
- return l
-}
-
-// Close stops the background refresh goroutine.
-func (l *Lookup) Close() {
- close(l.stopCh)
-}
-
-// Get returns (inputPer1M, outputPer1M) for a provider:model pair.
-// Returns (0, 0) if not found.
-func (l *Lookup) Get(provider, model string) (float64, float64) {
- if l == nil {
- return 0, 0
- }
- l.mu.RLock()
- defer l.mu.RUnlock()
- key := fmt.Sprintf("%s:%s", provider, model)
- if p, ok := l.prices[key]; ok {
- return p[0], p[1]
- }
- return 0, 0
-}
-
-// FillMissing fills in zero-value pricing from the lookup data.
-// Returns the number of prices filled.
-func (l *Lookup) FillMissing(provider, model string, input, output *float64) bool {
- if l == nil || (*input > 0 && *output > 0) {
- return false
- }
- i, o := l.Get(provider, model)
- if i == 0 && o == 0 {
- return false
- }
- if *input == 0 {
- *input = i
- }
- if *output == 0 {
- *output = o
- }
- return true
-}
-
-func (l *Lookup) refresh() {
- client := &http.Client{Timeout: 30 * time.Second}
- resp, err := client.Get(l.url)
- if err != nil {
- log.Printf("WARNING: failed to fetch pricing data: %v", err)
- return
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- log.Printf("WARNING: pricing data fetch returned %d", resp.StatusCode)
- return
- }
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- log.Printf("WARNING: failed to read pricing data: %v", err)
- return
- }
-
- var providers []Provider
- if err := json.Unmarshal(body, &providers); err != nil {
- log.Printf("WARNING: failed to parse pricing data: %v", err)
- return
- }
-
- prices := make(map[string][2]float64)
- for _, p := range providers {
- for _, m := range p.Models {
- input, output := parsePrices(m.Prices)
- if input > 0 || output > 0 {
- key := fmt.Sprintf("%s:%s", p.ID, m.ID)
- prices[key] = [2]float64{input, output}
- }
- }
- }
-
- l.mu.Lock()
- l.prices = prices
- l.mu.Unlock()
-
- log.Printf("Loaded pricing data: %d model prices from genai-prices", len(prices))
-}
-
-// parsePrices handles the different shapes of the "prices" field:
-// - object: {"input_mtok": 0.5, "output_mtok": 1.0}
-// - array: [{"prices": {"input_mtok": 0.5, ...}}, ...] (time-of-day; use first entry)
-func parsePrices(raw json.RawMessage) (input, output float64) {
- if len(raw) == 0 {
- return 0, 0
- }
-
- // Try as object first (most common)
- var obj map[string]any
- if json.Unmarshal(raw, &obj) == nil {
- return extractPrice(obj, "input_mtok"), extractPrice(obj, "output_mtok")
- }
-
- // Try as array (time-of-day pricing) — use first entry
- var arr []struct {
- Prices map[string]any `json:"prices"`
- }
- if json.Unmarshal(raw, &arr) == nil && len(arr) > 0 {
- return extractPrice(arr[0].Prices, "input_mtok"), extractPrice(arr[0].Prices, "output_mtok")
- }
-
- return 0, 0
-}
-
-// extractPrice handles both simple float and tiered pricing (uses base price).
-func extractPrice(prices map[string]any, key string) float64 {
- v, ok := prices[key]
- if !ok {
- return 0
- }
- switch val := v.(type) {
- case float64:
- return val
- case map[string]any:
- if base, ok := val["base"].(float64); ok {
- return base
- }
- }
- return 0
-}
diff --git a/llm-gateway/internal/provider/balancer.go b/llm-gateway/internal/provider/balancer.go
deleted file mode 100644
index 602d885..0000000
--- a/llm-gateway/internal/provider/balancer.go
+++ /dev/null
@@ -1,144 +0,0 @@
-package provider
-
-import (
- "math/rand"
- "sort"
- "sync/atomic"
-)
-
-// LoadBalancer reorders routes for load distribution.
-type LoadBalancer interface {
- Reorder(routes []Route) []Route
-}
-
-// NewLoadBalancer creates a load balancer by strategy name.
-func NewLoadBalancer(strategy string) LoadBalancer {
- switch strategy {
- case "round-robin":
- return &RoundRobinBalancer{}
- case "random":
- return &RandomBalancer{}
- case "least-cost":
- return &LeastCostBalancer{}
- default:
- return &FirstBalancer{}
- }
-}
-
-// FirstBalancer is a no-op that preserves original order.
-type FirstBalancer struct{}
-
-func (b *FirstBalancer) Reorder(routes []Route) []Route {
- return routes
-}
-
-// RoundRobinBalancer rotates routes within same-priority groups.
-type RoundRobinBalancer struct {
- counter atomic.Uint64
-}
-
-func (b *RoundRobinBalancer) Reorder(routes []Route) []Route {
- if len(routes) <= 1 {
- return routes
- }
-
- result := make([]Route, len(routes))
- copy(result, routes)
-
- // Group by priority and rotate within each group
- groups := groupByPriority(result)
- idx := 0
- count := b.counter.Add(1)
- for _, group := range groups {
- if len(group) > 1 {
- offset := int(count) % len(group)
- for j := 0; j < len(group); j++ {
- result[idx] = group[(j+offset)%len(group)]
- idx++
- }
- } else {
- result[idx] = group[0]
- idx++
- }
- }
-
- return result
-}
-
-// RandomBalancer shuffles routes within same-priority groups.
-type RandomBalancer struct{}
-
-func (b *RandomBalancer) Reorder(routes []Route) []Route {
- if len(routes) <= 1 {
- return routes
- }
-
- result := make([]Route, len(routes))
- copy(result, routes)
-
- groups := groupByPriority(result)
- idx := 0
- for _, group := range groups {
- rand.Shuffle(len(group), func(i, j int) {
- group[i], group[j] = group[j], group[i]
- })
- for _, r := range group {
- result[idx] = r
- idx++
- }
- }
-
- return result
-}
-
-// LeastCostBalancer sorts by price within same-priority groups.
-type LeastCostBalancer struct{}
-
-func (b *LeastCostBalancer) Reorder(routes []Route) []Route {
- if len(routes) <= 1 {
- return routes
- }
-
- result := make([]Route, len(routes))
- copy(result, routes)
-
- groups := groupByPriority(result)
- idx := 0
- for _, group := range groups {
- sort.Slice(group, func(i, j int) bool {
- costI := group[i].InputPrice + group[i].OutputPrice
- costJ := group[j].InputPrice + group[j].OutputPrice
- return costI < costJ
- })
- for _, r := range group {
- result[idx] = r
- idx++
- }
- }
-
- return result
-}
-
-// groupByPriority splits routes into groups of same priority, preserving order.
-func groupByPriority(routes []Route) [][]Route {
- if len(routes) == 0 {
- return nil
- }
-
- var groups [][]Route
- currentPriority := routes[0].Priority
- currentGroup := []Route{routes[0]}
-
- for i := 1; i < len(routes); i++ {
- if routes[i].Priority == currentPriority {
- currentGroup = append(currentGroup, routes[i])
- } else {
- groups = append(groups, currentGroup)
- currentPriority = routes[i].Priority
- currentGroup = []Route{routes[i]}
- }
- }
- groups = append(groups, currentGroup)
-
- return groups
-}
diff --git a/llm-gateway/internal/provider/balancer_test.go b/llm-gateway/internal/provider/balancer_test.go
deleted file mode 100644
index cc5378e..0000000
--- a/llm-gateway/internal/provider/balancer_test.go
+++ /dev/null
@@ -1,294 +0,0 @@
-package provider
-
-import (
- "fmt"
- "testing"
-)
-
-type routeSpec struct {
- name string
- priority int
- input float64
- output float64
-}
-
-func makeRoutes(specs ...routeSpec) []Route {
- routes := make([]Route, len(specs))
- for i, s := range specs {
- routes[i] = Route{
- Provider: &mockProvider{name: s.name},
- ProviderModel: s.name + "-model",
- Priority: s.priority,
- InputPrice: s.input,
- OutputPrice: s.output,
- }
- }
- return routes
-}
-
-func routeNames(routes []Route) []string {
- names := make([]string, len(routes))
- for i, r := range routes {
- names[i] = r.Provider.Name()
- }
- return names
-}
-
-func TestFirstBalancer_PreservesOrder(t *testing.T) {
- routes := makeRoutes(
- routeSpec{"a", 1, 1.0, 1.0},
- routeSpec{"b", 1, 2.0, 2.0},
- routeSpec{"c", 1, 3.0, 3.0},
- )
-
- b := &FirstBalancer{}
- result := b.Reorder(routes)
-
- names := routeNames(result)
- if names[0] != "a" || names[1] != "b" || names[2] != "c" {
- t.Fatalf("expected [a b c], got %v", names)
- }
-}
-
-func TestRoundRobinBalancer_RotatesWithinPriorityGroup(t *testing.T) {
- routes := makeRoutes(
- routeSpec{"a", 1, 1.0, 1.0},
- routeSpec{"b", 1, 1.0, 1.0},
- routeSpec{"c", 1, 1.0, 1.0},
- )
-
- b := &RoundRobinBalancer{}
-
- // Collect the first element from multiple calls
- seen := make(map[string]bool)
- for i := 0; i < 6; i++ {
- result := b.Reorder(routes)
- seen[result[0].Provider.Name()] = true
- }
-
- // All routes should have appeared as first at some point
- for _, name := range []string{"a", "b", "c"} {
- if !seen[name] {
- t.Errorf("expected %q to appear as first element in rotation", name)
- }
- }
-}
-
-func TestRoundRobinBalancer_PreservesPriorityOrder(t *testing.T) {
- routes := makeRoutes(
- routeSpec{"a", 1, 1.0, 1.0},
- routeSpec{"b", 1, 1.0, 1.0},
- routeSpec{"c", 2, 1.0, 1.0},
- )
-
- b := &RoundRobinBalancer{}
-
- // Priority 2 route should always be last
- for i := 0; i < 5; i++ {
- result := b.Reorder(routes)
- if result[2].Provider.Name() != "c" {
- t.Fatalf("expected priority-2 route 'c' at the end, got %q", result[2].Provider.Name())
- }
- }
-}
-
-func TestRandomBalancer_AllRoutesPresent(t *testing.T) {
- routes := makeRoutes(
- routeSpec{"a", 1, 1.0, 1.0},
- routeSpec{"b", 1, 1.0, 1.0},
- routeSpec{"c", 1, 1.0, 1.0},
- )
-
- b := &RandomBalancer{}
-
- for i := 0; i < 10; i++ {
- result := b.Reorder(routes)
- if len(result) != 3 {
- t.Fatalf("expected 3 routes, got %d", len(result))
- }
-
- names := make(map[string]bool)
- for _, r := range result {
- names[r.Provider.Name()] = true
- }
- for _, want := range []string{"a", "b", "c"} {
- if !names[want] {
- t.Errorf("missing route %q in result", want)
- }
- }
- }
-}
-
-func TestRandomBalancer_PreservesPriorityOrder(t *testing.T) {
- routes := makeRoutes(
- routeSpec{"a", 1, 1.0, 1.0},
- routeSpec{"b", 1, 1.0, 1.0},
- routeSpec{"c", 2, 1.0, 1.0},
- )
-
- b := &RandomBalancer{}
-
- for i := 0; i < 10; i++ {
- result := b.Reorder(routes)
- if result[2].Provider.Name() != "c" {
- t.Fatalf("expected priority-2 route 'c' last, got %q", result[2].Provider.Name())
- }
- }
-}
-
-func TestLeastCostBalancer_SortsByCost(t *testing.T) {
- routes := makeRoutes(
- routeSpec{"expensive", 1, 10.0, 10.0},
- routeSpec{"cheap", 1, 1.0, 1.0},
- routeSpec{"medium", 1, 5.0, 5.0},
- )
-
- b := &LeastCostBalancer{}
- result := b.Reorder(routes)
-
- names := routeNames(result)
- expected := []string{"cheap", "medium", "expensive"}
- for i, want := range expected {
- if names[i] != want {
- t.Errorf("position %d: got %q, want %q", i, names[i], want)
- }
- }
-}
-
-func TestLeastCostBalancer_PreservesPriorityOrder(t *testing.T) {
- routes := makeRoutes(
- routeSpec{"expensive-p1", 1, 10.0, 10.0},
- routeSpec{"cheap-p1", 1, 1.0, 1.0},
- routeSpec{"cheap-p2", 2, 0.5, 0.5},
- )
-
- b := &LeastCostBalancer{}
- result := b.Reorder(routes)
-
- names := routeNames(result)
- // Within priority 1, cheap should come first; priority 2 always last
- if names[0] != "cheap-p1" {
- t.Errorf("expected cheap-p1 first, got %q", names[0])
- }
- if names[1] != "expensive-p1" {
- t.Errorf("expected expensive-p1 second, got %q", names[1])
- }
- if names[2] != "cheap-p2" {
- t.Errorf("expected cheap-p2 last, got %q", names[2])
- }
-}
-
-func TestGroupByPriority(t *testing.T) {
- tests := []struct {
- name string
- priorities []int
- wantGroups [][]int
- }{
- {
- name: "empty",
- priorities: nil,
- wantGroups: nil,
- },
- {
- name: "single",
- priorities: []int{1},
- wantGroups: [][]int{{1}},
- },
- {
- name: "all same",
- priorities: []int{1, 1, 1},
- wantGroups: [][]int{{1, 1, 1}},
- },
- {
- name: "two groups",
- priorities: []int{1, 1, 2, 2},
- wantGroups: [][]int{{1, 1}, {2, 2}},
- },
- {
- name: "three groups",
- priorities: []int{1, 2, 2, 3},
- wantGroups: [][]int{{1}, {2, 2}, {3}},
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- var routes []Route
- for _, p := range tt.priorities {
- routes = append(routes, Route{Priority: p})
- }
-
- groups := groupByPriority(routes)
-
- if tt.wantGroups == nil {
- if groups != nil {
- t.Fatalf("expected nil groups, got %v", groups)
- }
- return
- }
-
- if len(groups) != len(tt.wantGroups) {
- t.Fatalf("expected %d groups, got %d", len(tt.wantGroups), len(groups))
- }
-
- for i, wg := range tt.wantGroups {
- if len(groups[i]) != len(wg) {
- t.Errorf("group %d: expected %d routes, got %d", i, len(wg), len(groups[i]))
- continue
- }
- for j, wp := range wg {
- if groups[i][j].Priority != wp {
- t.Errorf("group %d, route %d: expected priority %d, got %d", i, j, wp, groups[i][j].Priority)
- }
- }
- }
- })
- }
-}
-
-func TestBalancer_SingleRoute(t *testing.T) {
- routes := makeRoutes(routeSpec{"only", 1, 1.0, 1.0})
-
- balancers := []struct {
- name string
- balancer LoadBalancer
- }{
- {"first", &FirstBalancer{}},
- {"round-robin", &RoundRobinBalancer{}},
- {"random", &RandomBalancer{}},
- {"least-cost", &LeastCostBalancer{}},
- }
-
- for _, bb := range balancers {
- t.Run(bb.name, func(t *testing.T) {
- result := bb.balancer.Reorder(routes)
- if len(result) != 1 || result[0].Provider.Name() != "only" {
- t.Fatalf("expected single route 'only', got %v", routeNames(result))
- }
- })
- }
-}
-
-func TestNewLoadBalancer(t *testing.T) {
- tests := []struct {
- strategy string
- wantType string
- }{
- {"round-robin", "*provider.RoundRobinBalancer"},
- {"random", "*provider.RandomBalancer"},
- {"least-cost", "*provider.LeastCostBalancer"},
- {"first", "*provider.FirstBalancer"},
- {"unknown", "*provider.FirstBalancer"},
- {"", "*provider.FirstBalancer"},
- }
-
- for _, tt := range tests {
- t.Run(tt.strategy, func(t *testing.T) {
- b := NewLoadBalancer(tt.strategy)
- got := fmt.Sprintf("%T", b)
- if got != tt.wantType {
- t.Errorf("NewLoadBalancer(%q) = %s, want %s", tt.strategy, got, tt.wantType)
- }
- })
- }
-}
diff --git a/llm-gateway/internal/provider/health.go b/llm-gateway/internal/provider/health.go
deleted file mode 100644
index ad3638a..0000000
--- a/llm-gateway/internal/provider/health.go
+++ /dev/null
@@ -1,264 +0,0 @@
-package provider
-
-import (
- "sync"
- "time"
-
- "llm-gateway/internal/config"
-)
-
-// CircuitState represents the state of a circuit breaker.
-type CircuitState int
-
-const (
- CircuitClosed CircuitState = iota // normal operation
- CircuitOpen // blocking requests
- CircuitHalfOpen // testing with probe request
-)
-
-func (s CircuitState) String() string {
- switch s {
- case CircuitClosed:
- return "closed"
- case CircuitOpen:
- return "open"
- case CircuitHalfOpen:
- return "half-open"
- default:
- return "unknown"
- }
-}
-
-// ProviderCircuit tracks circuit breaker state for a single provider.
-type ProviderCircuit struct {
- State CircuitState
- OpenedAt time.Time
- LastProbe time.Time
-}
-
-// HealthEvent represents a single request outcome for a provider.
-type HealthEvent struct {
- Timestamp time.Time
- LatencyMS int64
- IsError bool
- ErrorMsg string
-}
-
-// ProviderHealth is the computed health status for a provider.
-type ProviderHealth struct {
- Provider string `json:"provider"`
- Status string `json:"status"` // healthy, degraded, down
- ErrorRate float64 `json:"error_rate"`
- AvgLatency float64 `json:"avg_latency_ms"`
- Total int `json:"total"`
- Errors int `json:"errors"`
- CircuitState string `json:"circuit_state"`
-}
-
-// HealthTracker tracks per-provider health using a sliding window.
-type HealthTracker struct {
- mu sync.RWMutex
- windows map[string][]HealthEvent
- windowDu time.Duration
- circuits map[string]*ProviderCircuit
- cbConfig config.CircuitBreakerConfig
- OnStateChange func(provider string, from, to CircuitState)
-}
-
-// NewHealthTracker creates a health tracker with the given window duration.
-func NewHealthTracker(window time.Duration, cbCfg config.CircuitBreakerConfig) *HealthTracker {
- if window == 0 {
- window = 5 * time.Minute
- }
- return &HealthTracker{
- windows: make(map[string][]HealthEvent),
- circuits: make(map[string]*ProviderCircuit),
- windowDu: window,
- cbConfig: cbCfg,
- }
-}
-
-// IsAvailable returns true if the provider's circuit breaker allows requests.
-func (h *HealthTracker) IsAvailable(provider string) bool {
- if !h.cbConfig.Enabled {
- return true
- }
-
- h.mu.RLock()
- defer h.mu.RUnlock()
-
- circuit, ok := h.circuits[provider]
- if !ok {
- return true // no circuit = closed = available
- }
-
- switch circuit.State {
- case CircuitOpen:
- // Check if cooldown has elapsed -> transition to half-open
- if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
- return true // will transition to half-open on next record
- }
- return false
- case CircuitHalfOpen:
- return true // allow probe
- default:
- return true
- }
-}
-
-// Record adds a health event for a provider and evaluates circuit transitions.
-func (h *HealthTracker) Record(provider string, latencyMS int64, err error) {
- event := HealthEvent{
- Timestamp: time.Now(),
- LatencyMS: latencyMS,
- IsError: err != nil,
- }
- if err != nil {
- event.ErrorMsg = err.Error()
- }
-
- h.mu.Lock()
- defer h.mu.Unlock()
-
- h.windows[provider] = append(h.windows[provider], event)
- h.prune(provider)
-
- if h.cbConfig.Enabled {
- h.evaluateCircuit(provider, err)
- }
-}
-
-// evaluateCircuit transitions circuit breaker state. Must be called with lock held.
-func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
- circuit, ok := h.circuits[providerName]
- if !ok {
- circuit = &ProviderCircuit{State: CircuitClosed}
- h.circuits[providerName] = circuit
- }
-
- prevState := circuit.State
-
- switch circuit.State {
- case CircuitClosed:
- // Check if error threshold exceeded
- errorRate, total := h.errorRateUnlocked(providerName)
- if total >= h.cbConfig.MinRequests && errorRate >= h.cbConfig.ErrorThreshold {
- circuit.State = CircuitOpen
- circuit.OpenedAt = time.Now()
- }
- case CircuitOpen:
- // Check if cooldown elapsed -> half-open
- if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
- circuit.State = CircuitHalfOpen
- circuit.LastProbe = time.Now()
- // Evaluate the probe result immediately
- if lastErr == nil {
- circuit.State = CircuitClosed
- } else {
- circuit.State = CircuitOpen
- circuit.OpenedAt = time.Now()
- }
- }
- case CircuitHalfOpen:
- if lastErr == nil {
- circuit.State = CircuitClosed
- } else {
- circuit.State = CircuitOpen
- circuit.OpenedAt = time.Now()
- }
- }
-
- if circuit.State != prevState && h.OnStateChange != nil {
- cb := h.OnStateChange
- from, to := prevState, circuit.State
- // Call outside lock to avoid deadlocks
- go cb(providerName, from, to)
- }
-}
-
-// errorRateUnlocked computes error rate within window. Must be called with lock held.
-func (h *HealthTracker) errorRateUnlocked(provider string) (float64, int) {
- cutoff := time.Now().Add(-h.windowDu)
- events := h.windows[provider]
- var total, errors int
- for _, e := range events {
- if e.Timestamp.Before(cutoff) {
- continue
- }
- total++
- if e.IsError {
- errors++
- }
- }
- if total == 0 {
- return 0, 0
- }
- return float64(errors) / float64(total), total
-}
-
-// Status returns computed health for all tracked providers.
-func (h *HealthTracker) Status() []ProviderHealth {
- h.mu.RLock()
- defer h.mu.RUnlock()
-
- cutoff := time.Now().Add(-h.windowDu)
- var results []ProviderHealth
-
- for provider, events := range h.windows {
- var total, errors int
- var totalLatency int64
-
- for _, e := range events {
- if e.Timestamp.Before(cutoff) {
- continue
- }
- total++
- totalLatency += e.LatencyMS
- if e.IsError {
- errors++
- }
- }
-
- if total == 0 {
- continue
- }
-
- errorRate := float64(errors) / float64(total)
- status := "healthy"
- if errorRate >= 0.5 {
- status = "down"
- } else if errorRate >= 0.1 {
- status = "degraded"
- }
-
- circuitState := "closed"
- if circuit, ok := h.circuits[provider]; ok {
- circuitState = circuit.State.String()
- }
-
- results = append(results, ProviderHealth{
- Provider: provider,
- Status: status,
- ErrorRate: errorRate,
- AvgLatency: float64(totalLatency) / float64(total),
- Total: total,
- Errors: errors,
- CircuitState: circuitState,
- })
- }
-
- return results
-}
-
-// prune removes events outside the window. Must be called with lock held.
-func (h *HealthTracker) prune(provider string) {
- cutoff := time.Now().Add(-h.windowDu)
- events := h.windows[provider]
- i := 0
- for i < len(events) && events[i].Timestamp.Before(cutoff) {
- i++
- }
- if i > 0 {
- h.windows[provider] = events[i:]
- }
-}
diff --git a/llm-gateway/internal/provider/health_test.go b/llm-gateway/internal/provider/health_test.go
deleted file mode 100644
index 8dda99e..0000000
--- a/llm-gateway/internal/provider/health_test.go
+++ /dev/null
@@ -1,345 +0,0 @@
-package provider
-
-import (
- "errors"
- "testing"
- "time"
-
- "llm-gateway/internal/config"
-)
-
-func newTestTracker(window time.Duration, cb config.CircuitBreakerConfig) *HealthTracker {
- return NewHealthTracker(window, cb)
-}
-
-func defaultCBConfig() config.CircuitBreakerConfig {
- return config.CircuitBreakerConfig{
- Enabled: true,
- ErrorThreshold: 0.5,
- MinRequests: 3,
- CooldownDuration: 100 * time.Millisecond,
- }
-}
-
-func TestHealthTracker_Record(t *testing.T) {
- ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
-
- ht.Record("provA", 100, nil)
- ht.Record("provA", 200, errors.New("fail"))
- ht.Record("provB", 50, nil)
-
- ht.mu.RLock()
- defer ht.mu.RUnlock()
-
- if len(ht.windows["provA"]) != 2 {
- t.Fatalf("expected 2 events for provA, got %d", len(ht.windows["provA"]))
- }
- if len(ht.windows["provB"]) != 1 {
- t.Fatalf("expected 1 event for provB, got %d", len(ht.windows["provB"]))
- }
-
- // Verify event fields
- ev := ht.windows["provA"][1]
- if !ev.IsError || ev.ErrorMsg != "fail" || ev.LatencyMS != 200 {
- t.Fatalf("unexpected event fields: %+v", ev)
- }
-}
-
-func TestHealthTracker_Status(t *testing.T) {
- tests := []struct {
- name string
- successCount int
- errorCount int
- wantStatus string
- wantErrorRate float64
- wantTotal int
- wantErrors int
- }{
- {
- name: "healthy - no errors",
- successCount: 10,
- errorCount: 0,
- wantStatus: "healthy",
- wantErrorRate: 0.0,
- wantTotal: 10,
- wantErrors: 0,
- },
- {
- name: "healthy - below 10% errors",
- successCount: 19,
- errorCount: 1,
- wantStatus: "healthy",
- wantErrorRate: 0.05,
- wantTotal: 20,
- wantErrors: 1,
- },
- {
- name: "degraded - 20% errors",
- successCount: 8,
- errorCount: 2,
- wantStatus: "degraded",
- wantErrorRate: 0.2,
- wantTotal: 10,
- wantErrors: 2,
- },
- {
- name: "degraded - exactly 10% errors",
- successCount: 9,
- errorCount: 1,
- wantStatus: "degraded",
- wantErrorRate: 0.1,
- wantTotal: 10,
- wantErrors: 1,
- },
- {
- name: "down - 50% errors",
- successCount: 5,
- errorCount: 5,
- wantStatus: "down",
- wantErrorRate: 0.5,
- wantTotal: 10,
- wantErrors: 5,
- },
- {
- name: "down - all errors",
- successCount: 0,
- errorCount: 5,
- wantStatus: "down",
- wantErrorRate: 1.0,
- wantTotal: 5,
- wantErrors: 5,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
-
- for i := 0; i < tt.successCount; i++ {
- ht.Record("prov", 100, nil)
- }
- for i := 0; i < tt.errorCount; i++ {
- ht.Record("prov", 100, errors.New("err"))
- }
-
- statuses := ht.Status()
- if len(statuses) != 1 {
- t.Fatalf("expected 1 status, got %d", len(statuses))
- }
-
- s := statuses[0]
- if s.Status != tt.wantStatus {
- t.Errorf("status = %q, want %q", s.Status, tt.wantStatus)
- }
- if s.Total != tt.wantTotal {
- t.Errorf("total = %d, want %d", s.Total, tt.wantTotal)
- }
- if s.Errors != tt.wantErrors {
- t.Errorf("errors = %d, want %d", s.Errors, tt.wantErrors)
- }
- // Allow small float tolerance
- if diff := s.ErrorRate - tt.wantErrorRate; diff > 0.001 || diff < -0.001 {
- t.Errorf("error_rate = %f, want %f", s.ErrorRate, tt.wantErrorRate)
- }
- })
- }
-}
-
-func TestHealthTracker_CircuitBreaker_ClosedToOpen(t *testing.T) {
- cb := defaultCBConfig()
- cb.MinRequests = 3
- cb.ErrorThreshold = 0.5
-
- ht := newTestTracker(5*time.Minute, cb)
-
- // Record errors to exceed threshold (3 errors out of 3 = 100% > 50%)
- ht.Record("prov", 100, errors.New("err"))
- ht.Record("prov", 100, errors.New("err"))
- ht.Record("prov", 100, errors.New("err"))
-
- ht.mu.RLock()
- state := ht.circuits["prov"].State
- ht.mu.RUnlock()
-
- if state != CircuitOpen {
- t.Fatalf("expected CircuitOpen, got %s", state)
- }
-
- if ht.IsAvailable("prov") {
- t.Fatal("expected IsAvailable=false when circuit is open")
- }
-}
-
-func TestHealthTracker_CircuitBreaker_OpenToHalfOpenOnCooldown(t *testing.T) {
- cb := defaultCBConfig()
- cb.CooldownDuration = 50 * time.Millisecond
-
- ht := newTestTracker(5*time.Minute, cb)
-
- // Trip the circuit
- for i := 0; i < 5; i++ {
- ht.Record("prov", 100, errors.New("err"))
- }
-
- if ht.IsAvailable("prov") {
- t.Fatal("expected circuit open, IsAvailable should be false")
- }
-
- // Wait for cooldown
- time.Sleep(60 * time.Millisecond)
-
- // After cooldown, IsAvailable should return true (will transition to half-open)
- if !ht.IsAvailable("prov") {
- t.Fatal("expected IsAvailable=true after cooldown")
- }
-}
-
-func TestHealthTracker_CircuitBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
- cb := defaultCBConfig()
- cb.CooldownDuration = 10 * time.Millisecond
-
- ht := newTestTracker(5*time.Minute, cb)
-
- // Trip the circuit
- for i := 0; i < 5; i++ {
- ht.Record("prov", 100, errors.New("err"))
- }
-
- // Wait for cooldown so next Record transitions through Open->HalfOpen
- time.Sleep(20 * time.Millisecond)
-
- // A successful record should transition: Open -> HalfOpen -> Closed
- ht.Record("prov", 100, nil)
-
- ht.mu.RLock()
- state := ht.circuits["prov"].State
- ht.mu.RUnlock()
-
- if state != CircuitClosed {
- t.Fatalf("expected CircuitClosed after success in half-open, got %s", state)
- }
-
- if !ht.IsAvailable("prov") {
- t.Fatal("expected IsAvailable=true after circuit closed")
- }
-}
-
-func TestHealthTracker_CircuitBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
- cb := defaultCBConfig()
- cb.CooldownDuration = 10 * time.Millisecond
-
- ht := newTestTracker(5*time.Minute, cb)
-
- // Trip the circuit
- for i := 0; i < 5; i++ {
- ht.Record("prov", 100, errors.New("err"))
- }
-
- // Wait for cooldown
- time.Sleep(20 * time.Millisecond)
-
- // A failed record should transition: Open -> HalfOpen -> Open
- ht.Record("prov", 100, errors.New("still failing"))
-
- ht.mu.RLock()
- state := ht.circuits["prov"].State
- ht.mu.RUnlock()
-
- if state != CircuitOpen {
- t.Fatalf("expected CircuitOpen after failure in half-open, got %s", state)
- }
-}
-
-func TestHealthTracker_IsAvailable_NoCircuitBreaker(t *testing.T) {
- ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{Enabled: false})
-
- // Even with errors, IsAvailable should return true when CB is disabled
- for i := 0; i < 10; i++ {
- ht.Record("prov", 100, errors.New("err"))
- }
-
- if !ht.IsAvailable("prov") {
- t.Fatal("expected IsAvailable=true when circuit breaker disabled")
- }
-}
-
-func TestHealthTracker_IsAvailable_UnknownProvider(t *testing.T) {
- ht := newTestTracker(5*time.Minute, defaultCBConfig())
-
- if !ht.IsAvailable("unknown") {
- t.Fatal("expected IsAvailable=true for unknown provider (no circuit)")
- }
-}
-
-func TestHealthTracker_WindowPruning(t *testing.T) {
- // Use a tiny window so events expire quickly
- ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
-
- ht.Record("prov", 100, nil)
- ht.Record("prov", 200, nil)
-
- // Wait for events to expire
- time.Sleep(60 * time.Millisecond)
-
- // Record a new event to trigger pruning
- ht.Record("prov", 300, nil)
-
- ht.mu.RLock()
- count := len(ht.windows["prov"])
- ht.mu.RUnlock()
-
- if count != 1 {
- t.Fatalf("expected 1 event after pruning, got %d", count)
- }
-}
-
-func TestHealthTracker_Status_EmptyAfterPruning(t *testing.T) {
- ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
-
- ht.Record("prov", 100, nil)
-
- // Wait for events to expire
- time.Sleep(60 * time.Millisecond)
-
- statuses := ht.Status()
- if len(statuses) != 0 {
- t.Fatalf("expected 0 statuses after window expiry, got %d", len(statuses))
- }
-}
-
-func TestHealthTracker_Status_AvgLatency(t *testing.T) {
- ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
-
- ht.Record("prov", 100, nil)
- ht.Record("prov", 200, nil)
- ht.Record("prov", 300, nil)
-
- statuses := ht.Status()
- if len(statuses) != 1 {
- t.Fatalf("expected 1 status, got %d", len(statuses))
- }
-
- want := 200.0
- if diff := statuses[0].AvgLatency - want; diff > 0.001 || diff < -0.001 {
- t.Errorf("avg_latency = %f, want %f", statuses[0].AvgLatency, want)
- }
-}
-
-func TestHealthTracker_Status_CircuitStateReported(t *testing.T) {
- cb := defaultCBConfig()
- ht := newTestTracker(5*time.Minute, cb)
-
- // Trip the circuit
- for i := 0; i < 5; i++ {
- ht.Record("prov", 100, errors.New("err"))
- }
-
- statuses := ht.Status()
- if len(statuses) != 1 {
- t.Fatalf("expected 1 status, got %d", len(statuses))
- }
-
- if statuses[0].CircuitState != "open" {
- t.Errorf("circuit_state = %q, want %q", statuses[0].CircuitState, "open")
- }
-}
diff --git a/llm-gateway/internal/provider/openai.go b/llm-gateway/internal/provider/openai.go
deleted file mode 100644
index 9c096b3..0000000
--- a/llm-gateway/internal/provider/openai.go
+++ /dev/null
@@ -1,178 +0,0 @@
-package provider
-
-import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "time"
-)
-
-// OpenAIProvider is a generic OpenAI-compatible HTTP client.
-type OpenAIProvider struct {
- name string
- baseURL string
- apiKey string
- client *http.Client
-}
-
-func NewOpenAIProvider(name, baseURL, apiKey string, timeout time.Duration) *OpenAIProvider {
- return &OpenAIProvider{
- name: name,
- baseURL: baseURL,
- apiKey: apiKey,
- client: &http.Client{
- Timeout: timeout,
- },
- }
-}
-
-func (p *OpenAIProvider) Name() string { return p.name }
-
-func (p *OpenAIProvider) ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error) {
- reqCopy := *req
- reqCopy.Model = model
- reqCopy.Stream = false
-
- body, err := json.Marshal(reqCopy)
- if err != nil {
- return nil, fmt.Errorf("marshaling request: %w", err)
- }
-
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(body))
- if err != nil {
- return nil, fmt.Errorf("creating request: %w", err)
- }
- p.setHeaders(httpReq)
-
- resp, err := p.client.Do(httpReq)
- if err != nil {
- return nil, fmt.Errorf("sending request: %w", err)
- }
- defer resp.Body.Close()
-
- respBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("reading response: %w", err)
- }
-
- if resp.StatusCode != http.StatusOK {
- return nil, &ProviderError{
- StatusCode: resp.StatusCode,
- Body: string(respBody),
- Provider: p.name,
- }
- }
-
- var chatResp ChatResponse
- if err := json.Unmarshal(respBody, &chatResp); err != nil {
- return nil, fmt.Errorf("unmarshaling response: %w", err)
- }
-
- return &chatResp, nil
-}
-
-func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error) {
- reqCopy := *req
- reqCopy.Model = model
- reqCopy.Stream = true
-
- body, err := json.Marshal(reqCopy)
- if err != nil {
- return nil, fmt.Errorf("marshaling request: %w", err)
- }
-
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(body))
- if err != nil {
- return nil, fmt.Errorf("creating request: %w", err)
- }
- p.setHeaders(httpReq)
-
- resp, err := p.client.Do(httpReq)
- if err != nil {
- return nil, fmt.Errorf("sending request: %w", err)
- }
-
- if resp.StatusCode != http.StatusOK {
- defer resp.Body.Close()
- respBody, _ := io.ReadAll(resp.Body)
- return nil, &ProviderError{
- StatusCode: resp.StatusCode,
- Body: string(respBody),
- Provider: p.name,
- }
- }
-
- return resp.Body, nil
-}
-
-func (p *OpenAIProvider) Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error) {
- reqCopy := *req
- reqCopy.Model = model
-
- body, err := json.Marshal(reqCopy)
- if err != nil {
- return nil, fmt.Errorf("marshaling request: %w", err)
- }
-
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/embeddings", bytes.NewReader(body))
- if err != nil {
- return nil, fmt.Errorf("creating request: %w", err)
- }
- p.setHeaders(httpReq)
-
- resp, err := p.client.Do(httpReq)
- if err != nil {
- return nil, fmt.Errorf("sending request: %w", err)
- }
- defer resp.Body.Close()
-
- respBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("reading response: %w", err)
- }
-
- if resp.StatusCode != http.StatusOK {
- return nil, &ProviderError{
- StatusCode: resp.StatusCode,
- Body: string(respBody),
- Provider: p.name,
- }
- }
-
- var embResp EmbeddingResponse
- if err := json.Unmarshal(respBody, &embResp); err != nil {
- return nil, fmt.Errorf("unmarshaling response: %w", err)
- }
-
- return &embResp, nil
-}
-
-func (p *OpenAIProvider) setHeaders(req *http.Request) {
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+p.apiKey)
- // Forward request ID if present in context
- if reqID := req.Context().Value("requestID"); reqID != nil {
- if id, ok := reqID.(string); ok && id != "" {
- req.Header.Set("X-Request-ID", id)
- }
- }
-}
-
-// ProviderError represents a non-200 response from a provider.
-type ProviderError struct {
- StatusCode int
- Body string
- Provider string
-}
-
-func (e *ProviderError) Error() string {
- return fmt.Sprintf("provider %s returned %d: %s", e.Provider, e.StatusCode, e.Body)
-}
-
-// IsRetryable returns true if the error is a server-side error worth retrying with another provider.
-func (e *ProviderError) IsRetryable() bool {
- return e.StatusCode >= 500 || e.StatusCode == 429
-}
diff --git a/llm-gateway/internal/provider/provider.go b/llm-gateway/internal/provider/provider.go
deleted file mode 100644
index 3d552ea..0000000
--- a/llm-gateway/internal/provider/provider.go
+++ /dev/null
@@ -1,89 +0,0 @@
-package provider
-
-import (
- "context"
- "io"
-)
-
-// ChatRequest is the OpenAI-compatible chat completion request.
-type ChatRequest struct {
- Model string `json:"model"`
- Messages []Message `json:"messages"`
- Temperature *float64 `json:"temperature,omitempty"`
- MaxTokens *int `json:"max_tokens,omitempty"`
- TopP *float64 `json:"top_p,omitempty"`
- Stream bool `json:"stream"`
- Stop any `json:"stop,omitempty"`
- FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
- PresencePenalty *float64 `json:"presence_penalty,omitempty"`
- N *int `json:"n,omitempty"`
- Tools []any `json:"tools,omitempty"`
- ToolChoice any `json:"tool_choice,omitempty"`
- ResponseFormat any `json:"response_format,omitempty"`
- Extra map[string]any `json:"-"` // pass through unknown fields
-}
-
-type Message struct {
- Role string `json:"role"`
- Content any `json:"content"` // string or []ContentPart
- Name string `json:"name,omitempty"`
- ToolCalls []any `json:"tool_calls,omitempty"`
- ToolCallID string `json:"tool_call_id,omitempty"`
-}
-
-type ChatResponse struct {
- ID string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- Choices []Choice `json:"choices"`
- Usage *Usage `json:"usage,omitempty"`
-}
-
-type Choice struct {
- Index int `json:"index"`
- Message Message `json:"message"`
- FinishReason string `json:"finish_reason"`
-}
-
-type Usage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
-}
-
-// EmbeddingRequest is the OpenAI-compatible embedding request.
-type EmbeddingRequest struct {
- Model string `json:"model"`
- Input any `json:"input"` // string or []string
- EncodingFormat string `json:"encoding_format,omitempty"`
-}
-
-// EmbeddingResponse is the OpenAI-compatible embedding response.
-type EmbeddingResponse struct {
- Object string `json:"object"`
- Data []EmbeddingData `json:"data"`
- Model string `json:"model"`
- Usage *EmbeddingUsage `json:"usage,omitempty"`
-}
-
-// EmbeddingData holds a single embedding vector.
-type EmbeddingData struct {
- Object string `json:"object"`
- Embedding []float64 `json:"embedding"`
- Index int `json:"index"`
-}
-
-// EmbeddingUsage reports token usage for embeddings.
-type EmbeddingUsage struct {
- PromptTokens int `json:"prompt_tokens"`
- TotalTokens int `json:"total_tokens"`
-}
-
-// Provider sends requests to an LLM API.
-type Provider interface {
- Name() string
- ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error)
- ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error)
- Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error)
-}
diff --git a/llm-gateway/internal/provider/registry.go b/llm-gateway/internal/provider/registry.go
deleted file mode 100644
index 2096738..0000000
--- a/llm-gateway/internal/provider/registry.go
+++ /dev/null
@@ -1,214 +0,0 @@
-package provider
-
-import (
- "fmt"
- "sort"
- "sync"
- "time"
-
- "llm-gateway/internal/config"
-)
-
-// ModelTimeouts holds per-model timeout overrides.
-type ModelTimeouts struct {
- RequestTimeout time.Duration
- StreamingTimeout time.Duration
-}
-
-// Route maps a model to a specific provider with pricing.
-type Route struct {
- Provider Provider
- ProviderModel string
- Priority int
- InputPrice float64 // per 1M tokens
- OutputPrice float64 // per 1M tokens
-}
-
-// Registry maps model names to provider routes.
-type Registry struct {
- mu sync.RWMutex
- routes map[string][]Route
- balancers map[string]LoadBalancer
- aliases map[string]string // alias -> canonical name
- order []string // preserves config order (canonical names only)
- timeouts map[string]*ModelTimeouts
-}
-
-func NewRegistry(cfg *config.Config) (*Registry, error) {
- r := &Registry{}
- if err := r.buildFromConfig(cfg); err != nil {
- return nil, err
- }
- return r, nil
-}
-
-func (r *Registry) buildFromConfig(cfg *config.Config) error {
- // Build providers
- providers := make(map[string]Provider)
- for _, pc := range cfg.Providers {
- providers[pc.Name] = NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKey, pc.Timeout)
- }
-
- // Build routes
- routes := make(map[string][]Route)
- balancers := make(map[string]LoadBalancer)
- aliases := make(map[string]string)
- order := make([]string, 0, len(cfg.Models))
- timeouts := make(map[string]*ModelTimeouts)
-
- for _, mc := range cfg.Models {
- var modelRoutes []Route
- for _, rc := range mc.Routes {
- p, ok := providers[rc.Provider]
- if !ok {
- return fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider)
- }
- pc := cfg.ProviderByName(rc.Provider)
- priority := pc.Priority
- modelRoutes = append(modelRoutes, Route{
- Provider: p,
- ProviderModel: rc.Model,
- Priority: priority,
- InputPrice: rc.Pricing.Input,
- OutputPrice: rc.Pricing.Output,
- })
- }
- // Sort by priority (lower = higher priority)
- sort.Slice(modelRoutes, func(i, j int) bool {
- return modelRoutes[i].Priority < modelRoutes[j].Priority
- })
- routes[mc.Name] = modelRoutes
- order = append(order, mc.Name)
-
- // Load balancer
- balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing)
-
- // Per-model timeouts
- if mc.RequestTimeout > 0 || mc.StreamingTimeout > 0 {
- timeouts[mc.Name] = &ModelTimeouts{
- RequestTimeout: mc.RequestTimeout,
- StreamingTimeout: mc.StreamingTimeout,
- }
- }
-
- // Register aliases
- for _, alias := range mc.Aliases {
- aliases[alias] = mc.Name
- }
- }
-
- r.mu.Lock()
- r.routes = routes
- r.balancers = balancers
- r.aliases = aliases
- r.order = order
- r.timeouts = timeouts
- r.mu.Unlock()
-
- return nil
-}
-
-// Reload rebuilds routes from new config. Used for hot-reload.
-func (r *Registry) Reload(cfg *config.Config) error {
- return r.buildFromConfig(cfg)
-}
-
-// Lookup returns the routes for a model name (resolving aliases).
-func (r *Registry) Lookup(model string) ([]Route, bool) {
- r.mu.RLock()
- defer r.mu.RUnlock()
-
- // Resolve alias
- canonical := model
- if alias, ok := r.aliases[model]; ok {
- canonical = alias
- }
-
- routes, ok := r.routes[canonical]
- if !ok {
- return nil, false
- }
-
- // Apply load balancer
- if balancer, ok := r.balancers[canonical]; ok {
- routes = balancer.Reorder(routes)
- }
-
- return routes, true
-}
-
-// ModelNames returns all registered model names in config order (including aliases).
-func (r *Registry) ModelNames() []string {
- r.mu.RLock()
- defer r.mu.RUnlock()
-
- var names []string
- for _, name := range r.order {
- names = append(names, name)
- }
- // Add aliases
- for alias := range r.aliases {
- names = append(names, alias)
- }
- return names
-}
-
-// ModelTimeoutsFor returns per-model timeout overrides, resolving aliases. Returns nil if none set.
-func (r *Registry) ModelTimeoutsFor(model string) *ModelTimeouts {
- r.mu.RLock()
- defer r.mu.RUnlock()
-
- canonical := model
- if alias, ok := r.aliases[model]; ok {
- canonical = alias
- }
- return r.timeouts[canonical]
-}
-
-// RouteInfo exposes route details for dashboard display.
-type RouteInfo struct {
- ProviderName string `json:"provider_name"`
- ProviderModel string `json:"provider_model"`
- Priority int `json:"priority"`
- InputPrice float64 `json:"input_price"`
- OutputPrice float64 `json:"output_price"`
-}
-
-// ModelRouteInfo exposes a model and its routes for dashboard display.
-type ModelRouteInfo struct {
- Name string `json:"name"`
- Aliases []string `json:"aliases,omitempty"`
- Routes []RouteInfo `json:"routes"`
-}
-
-// AllRoutes returns all models and their routes in config order.
-func (r *Registry) AllRoutes() []ModelRouteInfo {
- r.mu.RLock()
- defer r.mu.RUnlock()
-
- // Build reverse alias map
- modelAliases := make(map[string][]string)
- for alias, canonical := range r.aliases {
- modelAliases[canonical] = append(modelAliases[canonical], alias)
- }
-
- results := make([]ModelRouteInfo, 0, len(r.order))
- for _, name := range r.order {
- routes := r.routes[name]
- info := ModelRouteInfo{
- Name: name,
- Aliases: modelAliases[name],
- }
- for _, rt := range routes {
- info.Routes = append(info.Routes, RouteInfo{
- ProviderName: rt.Provider.Name(),
- ProviderModel: rt.ProviderModel,
- Priority: rt.Priority,
- InputPrice: rt.InputPrice,
- OutputPrice: rt.OutputPrice,
- })
- }
- results = append(results, info)
- }
- return results
-}
diff --git a/llm-gateway/internal/provider/registry_test.go b/llm-gateway/internal/provider/registry_test.go
deleted file mode 100644
index 6c2a5cc..0000000
--- a/llm-gateway/internal/provider/registry_test.go
+++ /dev/null
@@ -1,286 +0,0 @@
-package provider
-
-import (
- "context"
- "io"
- "testing"
-
- "llm-gateway/internal/config"
-)
-
-// mockProvider implements the Provider interface for testing.
-type mockProvider struct {
- name string
-}
-
-func (m *mockProvider) Name() string { return m.name }
-
-func (m *mockProvider) ChatCompletion(_ context.Context, _ string, _ *ChatRequest) (*ChatResponse, error) {
- return nil, nil
-}
-
-func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *ChatRequest) (io.ReadCloser, error) {
- return nil, nil
-}
-
-func (m *mockProvider) Embedding(_ context.Context, _ string, _ *EmbeddingRequest) (*EmbeddingResponse, error) {
- return nil, nil
-}
-
-// newTestRegistry builds a Registry directly without going through config parsing.
-func newTestRegistry(models []testModel) *Registry {
- r := &Registry{
- routes: make(map[string][]Route),
- balancers: make(map[string]LoadBalancer),
- aliases: make(map[string]string),
- }
-
- for _, m := range models {
- r.routes[m.name] = m.routes
- r.balancers[m.name] = &FirstBalancer{}
- r.order = append(r.order, m.name)
- for _, alias := range m.aliases {
- r.aliases[alias] = m.name
- }
- }
-
- return r
-}
-
-type testModel struct {
- name string
- aliases []string
- routes []Route
-}
-
-func TestRegistry_Lookup_Canonical(t *testing.T) {
- reg := newTestRegistry([]testModel{
- {
- name: "gpt-4",
- routes: []Route{
- {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
- },
- },
- })
-
- routes, ok := reg.Lookup("gpt-4")
- if !ok {
- t.Fatal("expected Lookup to find gpt-4")
- }
- if len(routes) != 1 {
- t.Fatalf("expected 1 route, got %d", len(routes))
- }
- if routes[0].Provider.Name() != "openai" {
- t.Errorf("expected provider 'openai', got %q", routes[0].Provider.Name())
- }
-}
-
-func TestRegistry_Lookup_Alias(t *testing.T) {
- reg := newTestRegistry([]testModel{
- {
- name: "gpt-4",
- aliases: []string{"gpt4", "gpt-4-latest"},
- routes: []Route{
- {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
- },
- },
- })
-
- tests := []struct {
- name string
- model string
- found bool
- }{
- {"canonical", "gpt-4", true},
- {"alias1", "gpt4", true},
- {"alias2", "gpt-4-latest", true},
- {"unknown", "gpt-5", false},
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- routes, ok := reg.Lookup(tt.model)
- if ok != tt.found {
- t.Fatalf("Lookup(%q) found=%v, want %v", tt.model, ok, tt.found)
- }
- if tt.found && len(routes) != 1 {
- t.Fatalf("expected 1 route, got %d", len(routes))
- }
- })
- }
-}
-
-func TestRegistry_ModelNames_IncludesAliases(t *testing.T) {
- reg := newTestRegistry([]testModel{
- {
- name: "gpt-4",
- aliases: []string{"gpt4"},
- routes: []Route{
- {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
- },
- },
- {
- name: "claude-3",
- routes: []Route{
- {Provider: &mockProvider{name: "anthropic"}, ProviderModel: "claude-3", Priority: 1},
- },
- },
- })
-
- names := reg.ModelNames()
-
- want := map[string]bool{"gpt-4": true, "gpt4": true, "claude-3": true}
- got := make(map[string]bool)
- for _, n := range names {
- got[n] = true
- }
-
- for name := range want {
- if !got[name] {
- t.Errorf("expected %q in ModelNames, not found", name)
- }
- }
-
- if len(names) != len(want) {
- t.Errorf("expected %d names, got %d: %v", len(want), len(names), names)
- }
-}
-
-func TestRegistry_AllRoutes_ShowsAliases(t *testing.T) {
- reg := newTestRegistry([]testModel{
- {
- name: "gpt-4",
- aliases: []string{"gpt4", "gpt-4-latest"},
- routes: []Route{
- {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
- {Provider: &mockProvider{name: "azure"}, ProviderModel: "gpt-4", Priority: 2},
- },
- },
- })
-
- allRoutes := reg.AllRoutes()
- if len(allRoutes) != 1 {
- t.Fatalf("expected 1 model, got %d", len(allRoutes))
- }
-
- m := allRoutes[0]
- if m.Name != "gpt-4" {
- t.Errorf("expected name 'gpt-4', got %q", m.Name)
- }
-
- aliasSet := make(map[string]bool)
- for _, a := range m.Aliases {
- aliasSet[a] = true
- }
- if !aliasSet["gpt4"] || !aliasSet["gpt-4-latest"] {
- t.Errorf("expected aliases [gpt4, gpt-4-latest], got %v", m.Aliases)
- }
-
- if len(m.Routes) != 2 {
- t.Fatalf("expected 2 routes, got %d", len(m.Routes))
- }
- if m.Routes[0].ProviderName != "openai" {
- t.Errorf("expected first route provider 'openai', got %q", m.Routes[0].ProviderName)
- }
- if m.Routes[1].ProviderName != "azure" {
- t.Errorf("expected second route provider 'azure', got %q", m.Routes[1].ProviderName)
- }
-}
-
-func TestRegistry_AllRoutes_ConfigOrder(t *testing.T) {
- reg := newTestRegistry([]testModel{
- {
- name: "model-b",
- routes: []Route{
- {Provider: &mockProvider{name: "prov"}, ProviderModel: "b", Priority: 1},
- },
- },
- {
- name: "model-a",
- routes: []Route{
- {Provider: &mockProvider{name: "prov"}, ProviderModel: "a", Priority: 1},
- },
- },
- })
-
- allRoutes := reg.AllRoutes()
- if len(allRoutes) != 2 {
- t.Fatalf("expected 2 models, got %d", len(allRoutes))
- }
- if allRoutes[0].Name != "model-b" {
- t.Errorf("expected first model 'model-b', got %q", allRoutes[0].Name)
- }
- if allRoutes[1].Name != "model-a" {
- t.Errorf("expected second model 'model-a', got %q", allRoutes[1].Name)
- }
-}
-
-func TestRegistry_PrioritySorting(t *testing.T) {
- reg := newTestRegistry([]testModel{
- {
- name: "multi-provider",
- routes: []Route{
- {Provider: &mockProvider{name: "low-priority"}, ProviderModel: "m", Priority: 3},
- {Provider: &mockProvider{name: "high-priority"}, ProviderModel: "m", Priority: 1},
- {Provider: &mockProvider{name: "mid-priority"}, ProviderModel: "m", Priority: 2},
- },
- },
- })
-
- // Note: routes are stored as given (sorting happens during buildFromConfig).
- // For this test we verify AllRoutes returns them in stored order.
- allRoutes := reg.AllRoutes()
- if len(allRoutes) != 1 {
- t.Fatalf("expected 1 model, got %d", len(allRoutes))
- }
-
- routes := allRoutes[0].Routes
- if len(routes) != 3 {
- t.Fatalf("expected 3 routes, got %d", len(routes))
- }
-
- // Verify the priorities are present
- priorities := make(map[int]bool)
- for _, r := range routes {
- priorities[r.Priority] = true
- }
- for _, p := range []int{1, 2, 3} {
- if !priorities[p] {
- t.Errorf("expected priority %d in routes", p)
- }
- }
-}
-
-func TestRegistry_NewRegistry_UnknownProvider(t *testing.T) {
- cfg := &config.Config{
- Models: []config.ModelConfig{
- {
- Name: "test-model",
- Routes: []config.RouteConfig{
- {Provider: "nonexistent", Model: "m"},
- },
- },
- },
- }
-
- _, err := NewRegistry(cfg)
- if err == nil {
- t.Fatal("expected error for unknown provider, got nil")
- }
-}
-
-func TestRegistry_Lookup_NotFound(t *testing.T) {
- reg := newTestRegistry([]testModel{
- {
- name: "gpt-4",
- routes: []Route{
- {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
- },
- },
- })
-
- _, ok := reg.Lookup("nonexistent")
- if ok {
- t.Fatal("expected Lookup to return false for nonexistent model")
- }
-}
diff --git a/llm-gateway/internal/proxy/auth.go b/llm-gateway/internal/proxy/auth.go
deleted file mode 100644
index ca0944f..0000000
--- a/llm-gateway/internal/proxy/auth.go
+++ /dev/null
@@ -1,43 +0,0 @@
-package proxy
-
-import (
- "net/http"
- "strings"
-
- "llm-gateway/internal/auth"
-)
-
-type AuthMiddleware struct {
- authStore *auth.Store
-}
-
-func NewAuthMiddleware(authStore *auth.Store) *AuthMiddleware {
- return &AuthMiddleware{authStore: authStore}
-}
-
-// Authenticate validates the bearer token against the DB and sets token info in context.
-func (a *AuthMiddleware) Authenticate(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- hdr := r.Header.Get("Authorization")
- if !strings.HasPrefix(hdr, "Bearer ") {
- writeError(w, http.StatusUnauthorized, "missing or invalid Authorization header")
- return
- }
- key := strings.TrimPrefix(hdr, "Bearer ")
-
- token, err := a.authStore.LookupAPIToken(key)
- if err != nil {
- writeError(w, http.StatusUnauthorized, "invalid API key")
- return
- }
-
- // Update last used asynchronously (skip for static tokens)
- if token.ID > 0 {
- go a.authStore.UpdateAPITokenLastUsed(token.ID)
- }
-
- ctx := withTokenName(r.Context(), token.Name)
- ctx = withAPIToken(ctx, token)
- next.ServeHTTP(w, r.WithContext(ctx))
- })
-}
diff --git a/llm-gateway/internal/proxy/concurrency.go b/llm-gateway/internal/proxy/concurrency.go
deleted file mode 100644
index 4f28262..0000000
--- a/llm-gateway/internal/proxy/concurrency.go
+++ /dev/null
@@ -1,51 +0,0 @@
-package proxy
-
-import (
- "net/http"
- "sync"
- "sync/atomic"
-)
-
-// ConcurrencyLimiter enforces per-token concurrent request limits.
-type ConcurrencyLimiter struct {
- mu sync.Mutex
- counters map[string]*atomic.Int64
-}
-
-func NewConcurrencyLimiter() *ConcurrencyLimiter {
- return &ConcurrencyLimiter{
- counters: make(map[string]*atomic.Int64),
- }
-}
-
-func (cl *ConcurrencyLimiter) getCounter(tokenName string) *atomic.Int64 {
- cl.mu.Lock()
- defer cl.mu.Unlock()
- c, ok := cl.counters[tokenName]
- if !ok {
- c = &atomic.Int64{}
- cl.counters[tokenName] = c
- }
- return c
-}
-
-func (cl *ConcurrencyLimiter) Check(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- apiToken := getAPIToken(r.Context())
- if apiToken == nil || apiToken.MaxConcurrent <= 0 {
- next.ServeHTTP(w, r)
- return
- }
-
- counter := cl.getCounter(apiToken.Name)
- current := counter.Add(1)
- defer counter.Add(-1)
-
- if current > int64(apiToken.MaxConcurrent) {
- writeError(w, http.StatusTooManyRequests, "concurrent request limit exceeded")
- return
- }
-
- next.ServeHTTP(w, r)
- })
-}
diff --git a/llm-gateway/internal/proxy/concurrency_test.go b/llm-gateway/internal/proxy/concurrency_test.go
deleted file mode 100644
index fa6ccbf..0000000
--- a/llm-gateway/internal/proxy/concurrency_test.go
+++ /dev/null
@@ -1,317 +0,0 @@
-package proxy
-
-import (
- "net/http"
- "net/http/httptest"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "llm-gateway/internal/auth"
-)
-
-func TestConcurrencyLimiter_AllowsWithinLimit(t *testing.T) {
- tests := []struct {
- name string
- maxConcurrent int
- numRequests int
- wantAllowed int
- }{
- {
- name: "single request within limit",
- maxConcurrent: 5,
- numRequests: 1,
- wantAllowed: 1,
- },
- {
- name: "all requests within limit",
- maxConcurrent: 5,
- numRequests: 5,
- wantAllowed: 5,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- cl := NewConcurrencyLimiter()
-
- token := &auth.APIToken{
- Name: "conc-token",
- MaxConcurrent: tt.maxConcurrent,
- }
-
- var allowed atomic.Int64
- var wg sync.WaitGroup
- // Use a channel to hold all goroutines inside the handler simultaneously.
- gate := make(chan struct{})
-
- handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- allowed.Add(1)
- <-gate // Block until released.
- w.WriteHeader(http.StatusOK)
- }))
-
- for i := 0; i < tt.numRequests; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx := withAPIToken(req.Context(), token)
- req = req.WithContext(ctx)
- handler.ServeHTTP(rec, req)
- }()
- }
-
- // Wait for goroutines to enter the handler.
- time.Sleep(50 * time.Millisecond)
- close(gate)
- wg.Wait()
-
- if int(allowed.Load()) != tt.wantAllowed {
- t.Errorf("allowed = %d, want %d", allowed.Load(), tt.wantAllowed)
- }
- })
- }
-}
-
-func TestConcurrencyLimiter_DeniesOverLimit(t *testing.T) {
- tests := []struct {
- name string
- maxConcurrent int
- numRequests int
- wantDenied int
- }{
- {
- name: "one over limit",
- maxConcurrent: 2,
- numRequests: 3,
- wantDenied: 1,
- },
- {
- name: "many over limit",
- maxConcurrent: 1,
- numRequests: 5,
- wantDenied: 4,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- cl := NewConcurrencyLimiter()
-
- token := &auth.APIToken{
- Name: "conc-token",
- MaxConcurrent: tt.maxConcurrent,
- }
-
- var denied atomic.Int64
- var wg sync.WaitGroup
- gate := make(chan struct{})
-
- handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- <-gate
- w.WriteHeader(http.StatusOK)
- }))
-
- results := make([]int, tt.numRequests)
- for i := 0; i < tt.numRequests; i++ {
- wg.Add(1)
- go func(idx int) {
- defer wg.Done()
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx := withAPIToken(req.Context(), token)
- req = req.WithContext(ctx)
- handler.ServeHTTP(rec, req)
- results[idx] = rec.Code
- if rec.Code == http.StatusTooManyRequests {
- denied.Add(1)
- }
- }(i)
- }
-
- // Wait for goroutines to reach the handler or be rejected.
- time.Sleep(50 * time.Millisecond)
- close(gate)
- wg.Wait()
-
- if int(denied.Load()) != tt.wantDenied {
- t.Errorf("denied = %d, want %d", denied.Load(), tt.wantDenied)
- }
- })
- }
-}
-
-func TestConcurrencyLimiter_CounterDecrementsAfterCompletion(t *testing.T) {
- cl := NewConcurrencyLimiter()
-
- token := &auth.APIToken{
- Name: "decrement-token",
- MaxConcurrent: 1,
- }
-
- handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- }))
-
- // First request should succeed and complete, decrementing the counter.
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx := withAPIToken(req.Context(), token)
- req = req.WithContext(ctx)
- handler.ServeHTTP(rec, req)
-
- if rec.Code != http.StatusOK {
- t.Fatalf("first request: status = %d, want %d", rec.Code, http.StatusOK)
- }
-
- // Counter should have decremented. A second request should also succeed.
- rec2 := httptest.NewRecorder()
- req2 := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx2 := withAPIToken(req2.Context(), token)
- req2 = req2.WithContext(ctx2)
- handler.ServeHTTP(rec2, req2)
-
- if rec2.Code != http.StatusOK {
- t.Errorf("second request after first completed: status = %d, want %d", rec2.Code, http.StatusOK)
- }
-
- // Verify the internal counter is back to 0.
- counter := cl.getCounter(token.Name)
- val := counter.Load()
- if val != 0 {
- t.Errorf("counter = %d, want 0 after all requests completed", val)
- }
-}
-
-func TestConcurrencyLimiter_ZeroMaxConcurrentMeansUnlimited(t *testing.T) {
- tests := []struct {
- name string
- maxConcurrent int
- numRequests int
- }{
- {
- name: "zero allows unlimited concurrent requests",
- maxConcurrent: 0,
- numRequests: 50,
- },
- {
- name: "negative allows unlimited concurrent requests",
- maxConcurrent: -1,
- numRequests: 50,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- cl := NewConcurrencyLimiter()
-
- token := &auth.APIToken{
- Name: "unlimited-token",
- MaxConcurrent: tt.maxConcurrent,
- }
-
- var allowed atomic.Int64
- var wg sync.WaitGroup
- gate := make(chan struct{})
-
- handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- allowed.Add(1)
- <-gate
- w.WriteHeader(http.StatusOK)
- }))
-
- for i := 0; i < tt.numRequests; i++ {
- wg.Add(1)
- go func() {
- defer wg.Done()
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx := withAPIToken(req.Context(), token)
- req = req.WithContext(ctx)
- handler.ServeHTTP(rec, req)
- }()
- }
-
- // Give goroutines time to enter the handler.
- time.Sleep(100 * time.Millisecond)
- close(gate)
- wg.Wait()
-
- if int(allowed.Load()) != tt.numRequests {
- t.Errorf("allowed = %d, want %d (zero/negative maxConcurrent should be unlimited)", allowed.Load(), tt.numRequests)
- }
- })
- }
-}
-
-func TestConcurrencyLimiter_NoToken(t *testing.T) {
- cl := NewConcurrencyLimiter()
-
- handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- }))
-
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- // No API token in context.
- handler.ServeHTTP(rec, req)
-
- if rec.Code != http.StatusOK {
- t.Errorf("status = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
- }
-}
-
-func TestConcurrencyLimiter_PerTokenIsolation(t *testing.T) {
- cl := NewConcurrencyLimiter()
-
- tokenA := &auth.APIToken{
- Name: "token-a",
- MaxConcurrent: 1,
- }
- tokenB := &auth.APIToken{
- Name: "token-b",
- MaxConcurrent: 1,
- }
-
- gateA := make(chan struct{})
- var wg sync.WaitGroup
-
- handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- tok := getAPIToken(r.Context())
- if tok.Name == "token-a" {
- <-gateA // Block token A's request.
- }
- w.WriteHeader(http.StatusOK)
- }))
-
- // Start a request for token A that blocks.
- wg.Add(1)
- go func() {
- defer wg.Done()
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx := withAPIToken(req.Context(), tokenA)
- req = req.WithContext(ctx)
- handler.ServeHTTP(rec, req)
- }()
-
- // Give token A's goroutine time to enter handler.
- time.Sleep(50 * time.Millisecond)
-
- // Token B should not be affected by token A's in-flight request.
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx := withAPIToken(req.Context(), tokenB)
- req = req.WithContext(ctx)
- handler.ServeHTTP(rec, req)
-
- if rec.Code != http.StatusOK {
- t.Errorf("token-b status = %d, want %d (should not be affected by token-a)", rec.Code, http.StatusOK)
- }
-
- close(gateA)
- wg.Wait()
-}
diff --git a/llm-gateway/internal/proxy/dedup.go b/llm-gateway/internal/proxy/dedup.go
deleted file mode 100644
index d55b6c5..0000000
--- a/llm-gateway/internal/proxy/dedup.go
+++ /dev/null
@@ -1,107 +0,0 @@
-package proxy
-
-import (
- "crypto/sha256"
- "encoding/hex"
- "sync"
- "time"
-)
-
-// inflight represents an in-progress deduplicated request.
-type inflight struct {
- done chan struct{}
- result []byte
- statusCode int
- createdAt time.Time
-}
-
-// Deduplicator coalesces identical concurrent non-streaming requests.
-type Deduplicator struct {
- mu sync.Mutex
- flights map[string]*inflight
- window time.Duration
- done chan struct{}
-}
-
-// NewDeduplicator creates a new request deduplicator.
-func NewDeduplicator(window time.Duration) *Deduplicator {
- if window == 0 {
- window = 30 * time.Second
- }
- d := &Deduplicator{
- flights: make(map[string]*inflight),
- window: window,
- done: make(chan struct{}),
- }
- go d.cleanup()
- return d
-}
-
-// DedupKey computes a dedup key from model name and request body.
-func DedupKey(model string, body []byte) string {
- h := sha256.New()
- h.Write([]byte(model))
- h.Write([]byte{0})
- h.Write(body)
- return hex.EncodeToString(h.Sum(nil))
-}
-
-// TryJoin attempts to join an in-flight request. Returns the inflight entry and
-// whether this caller is the leader (true) or a follower (false).
-func (d *Deduplicator) TryJoin(key string) (*inflight, bool) {
- d.mu.Lock()
- defer d.mu.Unlock()
-
- if f, ok := d.flights[key]; ok {
- return f, false // follower
- }
-
- f := &inflight{
- done: make(chan struct{}),
- createdAt: time.Now(),
- }
- d.flights[key] = f
- return f, true // leader
-}
-
-// Complete signals completion of a deduplicated request.
-func (d *Deduplicator) Complete(key string, result []byte, statusCode int) {
- d.mu.Lock()
- f, ok := d.flights[key]
- delete(d.flights, key)
- d.mu.Unlock()
-
- if ok {
- f.result = result
- f.statusCode = statusCode
- close(f.done)
- }
-}
-
-// Close stops the background cleanup goroutine.
-func (d *Deduplicator) Close() {
- close(d.done)
-}
-
-// cleanup periodically removes stale in-flight entries.
-func (d *Deduplicator) cleanup() {
- ticker := time.NewTicker(d.window)
- defer ticker.Stop()
-
- for {
- select {
- case <-d.done:
- return
- case <-ticker.C:
- d.mu.Lock()
- now := time.Now()
- for key, f := range d.flights {
- if now.Sub(f.createdAt) > d.window*2 {
- delete(d.flights, key)
- close(f.done) // unblock any waiting followers
- }
- }
- d.mu.Unlock()
- }
- }
-}
diff --git a/llm-gateway/internal/proxy/dedup_test.go b/llm-gateway/internal/proxy/dedup_test.go
deleted file mode 100644
index 900a53d..0000000
--- a/llm-gateway/internal/proxy/dedup_test.go
+++ /dev/null
@@ -1,74 +0,0 @@
-package proxy
-
-import (
- "sync"
- "testing"
- "time"
-)
-
-func TestDedupKey(t *testing.T) {
- k1 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
- k2 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
- k3 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hello"}]}`))
-
- if k1 != k2 {
- t.Error("identical requests should produce the same key")
- }
- if k1 == k3 {
- t.Error("different requests should produce different keys")
- }
-}
-
-func TestDeduplicator_LeaderFollower(t *testing.T) {
- d := NewDeduplicator(5 * time.Second)
- defer d.Close()
-
- key := DedupKey("gpt-4", []byte(`test`))
-
- // First call is leader
- f1, isLeader := d.TryJoin(key)
- if !isLeader {
- t.Fatal("first caller should be leader")
- }
-
- // Second call with same key is follower
- f2, isLeader := d.TryJoin(key)
- if isLeader {
- t.Fatal("second caller should be follower")
- }
- if f1 != f2 {
- t.Fatal("follower should get same inflight entry")
- }
-
- // Complete the request
- var wg sync.WaitGroup
- wg.Add(1)
- go func() {
- defer wg.Done()
- <-f2.done
- if string(f2.result) != "response" {
- t.Error("follower should receive leader's result")
- }
- if f2.statusCode != 200 {
- t.Error("follower should receive leader's status code")
- }
- }()
-
- d.Complete(key, []byte("response"), 200)
- wg.Wait()
-}
-
-func TestDeduplicator_DifferentKeys(t *testing.T) {
- d := NewDeduplicator(5 * time.Second)
- defer d.Close()
-
- _, isLeader1 := d.TryJoin("key1")
- _, isLeader2 := d.TryJoin("key2")
-
- if !isLeader1 || !isLeader2 {
- t.Error("different keys should both be leaders")
- }
-
- d.Complete("key1", []byte("r1"), 200)
- d.Complete("key2", []byte("r2"), 200)
-}
diff --git a/llm-gateway/internal/proxy/handler.go b/llm-gateway/internal/proxy/handler.go
deleted file mode 100644
index 1b24cc2..0000000
--- a/llm-gateway/internal/proxy/handler.go
+++ /dev/null
@@ -1,514 +0,0 @@
-package proxy
-
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "log"
- "net/http"
- "sort"
- "strings"
- "time"
-
- "github.com/go-chi/chi/v5/middleware"
-
- "llm-gateway/internal/auth"
- "llm-gateway/internal/cache"
- "llm-gateway/internal/config"
- "llm-gateway/internal/metrics"
- "llm-gateway/internal/provider"
- "llm-gateway/internal/storage"
-)
-
-type contextKey string
-
-const tokenNameKey contextKey = "token_name"
-const apiTokenKey contextKey = "api_token"
-
-func withTokenName(ctx context.Context, name string) context.Context {
- return context.WithValue(ctx, tokenNameKey, name)
-}
-
-func getTokenName(ctx context.Context) string {
- name, _ := ctx.Value(tokenNameKey).(string)
- return name
-}
-
-func withAPIToken(ctx context.Context, token *auth.APIToken) context.Context {
- return context.WithValue(ctx, apiTokenKey, token)
-}
-
-func getAPIToken(ctx context.Context) *auth.APIToken {
- t, _ := ctx.Value(apiTokenKey).(*auth.APIToken)
- return t
-}
-
-type Handler struct {
- registry *provider.Registry
- logger *storage.AsyncLogger
- cache *cache.Cache
- metrics *metrics.Metrics
- cfg *config.Config
- healthTracker *provider.HealthTracker
- debugLogger *storage.DebugLogger
- dedup *Deduplicator
-}
-
-func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler {
- return &Handler{
- registry: registry,
- logger: logger,
- cache: c,
- metrics: m,
- cfg: cfg,
- healthTracker: ht,
- }
-}
-
-func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) {
- h.debugLogger = dl
-}
-
-func (h *Handler) SetDeduplicator(d *Deduplicator) {
- h.dedup = d
-}
-
-func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
- body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
- if err != nil {
- writeError(w, http.StatusBadRequest, "failed to read request body")
- return
- }
-
- var req provider.ChatRequest
- if err := json.Unmarshal(body, &req); err != nil {
- writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error())
- return
- }
-
- if req.Model == "" {
- writeError(w, http.StatusBadRequest, "model is required")
- return
- }
-
- routes, ok := h.registry.Lookup(req.Model)
- if !ok {
- writeError(w, http.StatusNotFound, "model not found: "+req.Model)
- return
- }
-
- // Filter healthy routes (circuit breaker)
- routes = h.filterHealthyRoutes(routes)
-
- tokenName := getTokenName(r.Context())
- requestID := middleware.GetReqID(r.Context())
-
- // Check cache for non-streaming requests
- if !req.Stream && h.cache != nil {
- if cached, err := h.cache.Get(r.Context(), req.Model, body); err == nil && cached != nil {
- h.logRequest(requestID, tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true)
- if h.metrics != nil {
- h.metrics.RecordCacheHit()
- }
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("X-Cache", "HIT")
- w.Header().Set("X-Request-ID", requestID)
- w.Write(cached)
- return
- }
- if h.metrics != nil {
- h.metrics.RecordCacheMiss()
- }
- }
-
- // Apply per-model timeout for non-streaming requests
- modelTimeouts := h.registry.ModelTimeoutsFor(req.Model)
-
- if req.Stream {
- h.handleStream(w, r, &req, routes, tokenName, requestID, modelTimeouts)
- return
- }
-
- // Request deduplication for non-streaming requests
- if h.dedup != nil {
- dedupKey := DedupKey(req.Model, body)
- flight, isLeader := h.dedup.TryJoin(dedupKey)
- if !isLeader {
- // Wait for the leader to complete
- select {
- case <-flight.done:
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("X-Request-ID", requestID)
- w.Header().Set("X-Dedup", "HIT")
- w.WriteHeader(flight.statusCode)
- w.Write(flight.result)
- return
- case <-r.Context().Done():
- writeError(w, http.StatusGatewayTimeout, "request cancelled while waiting for dedup")
- return
- }
- }
- // Leader: proceed normally, but capture response for followers
- defer func() {
- // If we haven't completed yet (e.g., panic), clean up
- }()
- h.handleNonStreamDedup(w, r, &req, routes, tokenName, body, requestID, modelTimeouts, dedupKey)
- return
- }
-
- if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
- ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
- defer cancel()
- r = r.WithContext(ctx)
- }
-
- h.handleNonStream(w, r, &req, routes, tokenName, body, requestID)
-}
-
-func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string) {
- var lastErr error
-
- for i, route := range routes {
- // Retry backoff between attempts (not before first attempt)
- if i > 0 {
- backoff := backoffDuration(i, h.cfg.Retry)
- select {
- case <-time.After(backoff):
- case <-r.Context().Done():
- writeError(w, http.StatusGatewayTimeout, "request cancelled")
- return
- }
- }
-
- start := time.Now()
- resp, err := route.Provider.ChatCompletion(r.Context(), route.ProviderModel, req)
- latency := time.Since(start).Milliseconds()
-
- if err != nil {
- var pe *provider.ProviderError
- if errors.As(err, &pe) && !pe.IsRetryable() {
- h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
- h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
- if h.healthTracker != nil {
- h.healthTracker.Record(route.Provider.Name(), latency, err)
- }
- w.Header().Set("X-Request-ID", requestID)
- writeErrorRaw(w, pe.StatusCode, pe.Body)
- return
- }
- lastErr = err
- log.Printf("Provider %s failed for %s: %v", route.Provider.Name(), req.Model, err)
- h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
- h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
- if h.healthTracker != nil {
- h.healthTracker.Record(route.Provider.Name(), latency, err)
- }
- continue
- }
-
- if h.healthTracker != nil {
- h.healthTracker.Record(route.Provider.Name(), latency, nil)
- }
-
- inputTokens, outputTokens := 0, 0
- if resp.Usage != nil {
- inputTokens = resp.Usage.PromptTokens
- outputTokens = resp.Usage.CompletionTokens
- }
- cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
-
- h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
- h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false)
-
- resp.Model = req.Model
-
- respBytes, err := json.Marshal(resp)
- if err != nil {
- writeError(w, http.StatusInternalServerError, "failed to marshal response")
- return
- }
-
- if h.cache != nil {
- h.cache.Set(r.Context(), req.Model, rawBody, respBytes)
- }
-
- // Debug logging
- if h.debugLogger != nil && h.debugLogger.IsEnabled() {
- reqBody := string(rawBody)
- respBody := string(respBytes)
- if h.cfg.Debug.MaxBodyBytes > 0 {
- if len(reqBody) > h.cfg.Debug.MaxBodyBytes {
- reqBody = reqBody[:h.cfg.Debug.MaxBodyBytes]
- }
- if len(respBody) > h.cfg.Debug.MaxBodyBytes {
- respBody = respBody[:h.cfg.Debug.MaxBodyBytes]
- }
- }
- h.debugLogger.Log(storage.DebugLogEntry{
- RequestID: requestID,
- TokenName: tokenName,
- Model: req.Model,
- Provider: route.Provider.Name(),
- RequestBody: reqBody,
- ResponseBody: respBody,
- RequestHeaders: formatHeaders(r.Header),
- ResponseStatus: http.StatusOK,
- })
- }
-
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("X-Cache", "MISS")
- w.Header().Set("X-Request-ID", requestID)
- w.Write(respBytes)
- return
- }
-
- if lastErr != nil {
- w.Header().Set("X-Request-ID", requestID)
- writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
- } else {
- w.Header().Set("X-Request-ID", requestID)
- writeError(w, http.StatusBadGateway, "all providers failed")
- }
-}
-
-// handleNonStreamDedup wraps handleNonStream to capture the response for dedup followers.
-func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
- body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
- if err != nil {
- writeError(w, http.StatusBadRequest, "failed to read request body")
- return
- }
-
- var req provider.EmbeddingRequest
- if err := json.Unmarshal(body, &req); err != nil {
- writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error())
- return
- }
-
- if req.Model == "" {
- writeError(w, http.StatusBadRequest, "model is required")
- return
- }
-
- routes, ok := h.registry.Lookup(req.Model)
- if !ok {
- writeError(w, http.StatusNotFound, "model not found: "+req.Model)
- return
- }
-
- routes = h.filterHealthyRoutes(routes)
- tokenName := getTokenName(r.Context())
- requestID := middleware.GetReqID(r.Context())
-
- var lastErr error
- for i, route := range routes {
- if i > 0 {
- backoff := backoffDuration(i, h.cfg.Retry)
- select {
- case <-time.After(backoff):
- case <-r.Context().Done():
- writeError(w, http.StatusGatewayTimeout, "request cancelled")
- return
- }
- }
-
- start := time.Now()
- resp, err := route.Provider.Embedding(r.Context(), route.ProviderModel, &req)
- latency := time.Since(start).Milliseconds()
-
- if err != nil {
- var pe *provider.ProviderError
- if errors.As(err, &pe) && !pe.IsRetryable() {
- h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
- h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
- if h.healthTracker != nil {
- h.healthTracker.Record(route.Provider.Name(), latency, err)
- }
- w.Header().Set("X-Request-ID", requestID)
- writeErrorRaw(w, pe.StatusCode, pe.Body)
- return
- }
- lastErr = err
- log.Printf("Provider %s embedding failed for %s: %v", route.Provider.Name(), req.Model, err)
- h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
- if h.healthTracker != nil {
- h.healthTracker.Record(route.Provider.Name(), latency, err)
- }
- continue
- }
-
- if h.healthTracker != nil {
- h.healthTracker.Record(route.Provider.Name(), latency, nil)
- }
-
- promptTokens := 0
- if resp.Usage != nil {
- promptTokens = resp.Usage.PromptTokens
- }
- cost := float64(promptTokens) / 1_000_000.0 * route.InputPrice
-
- h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, promptTokens, 0, cost)
- h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, promptTokens, cost, latency, "success", "")
-
- resp.Model = req.Model
-
- respBytes, err := json.Marshal(resp)
- if err != nil {
- writeError(w, http.StatusInternalServerError, "failed to marshal response")
- return
- }
-
- w.Header().Set("Content-Type", "application/json")
- w.Header().Set("X-Request-ID", requestID)
- w.Write(respBytes)
- return
- }
-
- w.Header().Set("X-Request-ID", requestID)
- if lastErr != nil {
- writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
- } else {
- writeError(w, http.StatusBadGateway, "all providers failed")
- }
-}
-
-func (h *Handler) logEmbeddingRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens int, cost float64, latencyMS int64, status, errMsg string) {
- h.logger.Log(storage.RequestLog{
- RequestID: requestID,
- Timestamp: time.Now().Unix(),
- TokenName: tokenName,
- Model: model,
- Provider: providerName,
- ProviderModel: providerModel,
- InputTokens: inputTokens,
- CostUSD: cost,
- LatencyMS: latencyMS,
- Status: status,
- ErrorMessage: errMsg,
- RequestType: "embedding",
- })
-}
-
-func (h *Handler) handleNonStreamDedup(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string, modelTimeouts *provider.ModelTimeouts, dedupKey string) {
- if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
- ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
- defer cancel()
- r = r.WithContext(ctx)
- }
-
- rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
- h.handleNonStream(rec, r, req, routes, tokenName, rawBody, requestID)
- h.dedup.Complete(dedupKey, rec.body, rec.statusCode)
-}
-
-// responseRecorder captures the response for dedup.
-type responseRecorder struct {
- http.ResponseWriter
- statusCode int
- body []byte
-}
-
-func (r *responseRecorder) WriteHeader(code int) {
- r.statusCode = code
- r.ResponseWriter.WriteHeader(code)
-}
-
-func (r *responseRecorder) Write(b []byte) (int, error) {
- r.body = append(r.body, b...)
- return r.ResponseWriter.Write(b)
-}
-
-// filterHealthyRoutes removes providers with open circuit breakers.
-// If all are filtered out, returns original routes as fallback.
-func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route {
- if h.healthTracker == nil {
- return routes
- }
- var healthy []provider.Route
- for _, r := range routes {
- if h.healthTracker.IsAvailable(r.Provider.Name()) {
- healthy = append(healthy, r)
- }
- }
- if len(healthy) == 0 {
- return routes // all-down fallback
- }
- return healthy
-}
-
-// backoffDuration computes exponential backoff for the given attempt.
-func backoffDuration(attempt int, cfg config.RetryConfig) time.Duration {
- d := cfg.InitialBackoff
- for i := 1; i < attempt; i++ {
- d = time.Duration(float64(d) * cfg.Multiplier)
- if d > cfg.MaxBackoff {
- d = cfg.MaxBackoff
- break
- }
- }
- return d
-}
-
-func (h *Handler) logRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) {
- h.logger.Log(storage.RequestLog{
- RequestID: requestID,
- Timestamp: time.Now().Unix(),
- TokenName: tokenName,
- Model: model,
- Provider: providerName,
- ProviderModel: providerModel,
- InputTokens: inputTokens,
- OutputTokens: outputTokens,
- CostUSD: cost,
- LatencyMS: latencyMS,
- Status: status,
- ErrorMessage: errMsg,
- Streaming: streaming,
- Cached: cached,
- })
-}
-
-func computeCost(inputTokens, outputTokens int, inputPrice, outputPrice float64) float64 {
- return (float64(inputTokens) / 1_000_000.0 * inputPrice) + (float64(outputTokens) / 1_000_000.0 * outputPrice)
-}
-
-func writeError(w http.ResponseWriter, code int, msg string) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(code)
- json.NewEncoder(w).Encode(map[string]any{
- "error": map[string]any{
- "message": msg,
- "type": "error",
- "code": code,
- },
- })
-}
-
-func writeErrorRaw(w http.ResponseWriter, code int, body string) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(code)
- w.Write([]byte(body))
-}
-
-// formatHeaders serializes HTTP headers to a readable string, sorted by key.
-// Sensitive headers (Authorization) are redacted.
-func formatHeaders(h http.Header) string {
- keys := make([]string, 0, len(h))
- for k := range h {
- keys = append(keys, k)
- }
- sort.Strings(keys)
-
- var b strings.Builder
- for _, k := range keys {
- val := strings.Join(h[k], ", ")
- if strings.EqualFold(k, "Authorization") {
- val = "[REDACTED]"
- }
- fmt.Fprintf(&b, "%s: %s\n", k, val)
- }
- return b.String()
-}
diff --git a/llm-gateway/internal/proxy/models.go b/llm-gateway/internal/proxy/models.go
deleted file mode 100644
index 184fbcc..0000000
--- a/llm-gateway/internal/proxy/models.go
+++ /dev/null
@@ -1,75 +0,0 @@
-package proxy
-
-import (
- "encoding/json"
- "net/http"
- "time"
-
- "llm-gateway/internal/config"
- "llm-gateway/internal/provider"
-)
-
-type ModelsHandler struct {
- registry *provider.Registry
- healthTracker *provider.HealthTracker
- cfg *config.Config
-}
-
-func NewModelsHandler(registry *provider.Registry, healthTracker *provider.HealthTracker, cfg *config.Config) *ModelsHandler {
- return &ModelsHandler{
- registry: registry,
- healthTracker: healthTracker,
- cfg: cfg,
- }
-}
-
-func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) {
- allRoutes := h.registry.AllRoutes()
- models := make([]map[string]any, 0, len(allRoutes))
-
- for _, m := range allRoutes {
- providers := make([]map[string]any, 0, len(m.Routes))
- for _, rt := range m.Routes {
- healthy := true
- if h.healthTracker != nil {
- healthy = h.healthTracker.IsAvailable(rt.ProviderName)
- }
- providers = append(providers, map[string]any{
- "name": rt.ProviderName,
- "model": rt.ProviderModel,
- "input_price": rt.InputPrice,
- "output_price": rt.OutputPrice,
- "priority": rt.Priority,
- "healthy": healthy,
- })
- }
-
- // Find load balancing strategy from config
- loadBalancing := "first"
- for _, mc := range h.cfg.Models {
- if mc.Name == m.Name {
- if mc.LoadBalancing != "" {
- loadBalancing = mc.LoadBalancing
- }
- break
- }
- }
-
- models = append(models, map[string]any{
- "id": m.Name,
- "object": "model",
- "created": time.Now().Unix(),
- "owned_by": "llm-gateway",
- "providers": providers,
- "provider_count": len(providers),
- "load_balancing": loadBalancing,
- "aliases": m.Aliases,
- })
- }
-
- w.Header().Set("Content-Type", "application/json")
- json.NewEncoder(w).Encode(map[string]any{
- "object": "list",
- "data": models,
- })
-}
diff --git a/llm-gateway/internal/proxy/ratelimit.go b/llm-gateway/internal/proxy/ratelimit.go
deleted file mode 100644
index bc06174..0000000
--- a/llm-gateway/internal/proxy/ratelimit.go
+++ /dev/null
@@ -1,169 +0,0 @@
-package proxy
-
-import (
- "fmt"
- "math"
- "net/http"
- "sync"
- "time"
-
- "llm-gateway/internal/storage"
- "llm-gateway/internal/webhook"
-)
-
-type RateLimiter struct {
- db *storage.DB
- mu sync.Mutex
- buckets map[string]*tokenBucket
- notifier *webhook.Notifier
- budgetNotified sync.Map // tracks which token+budget combos have been notified
-}
-
-type tokenBucket struct {
- tokens float64
- maxTokens float64
- refillRate float64 // tokens per second
- lastRefill time.Time
-}
-
-func NewRateLimiter(db *storage.DB) *RateLimiter {
- return &RateLimiter{
- db: db,
- buckets: make(map[string]*tokenBucket),
- }
-}
-
-// SetNotifier sets the webhook notifier for budget threshold alerts.
-func (rl *RateLimiter) SetNotifier(n *webhook.Notifier) {
- rl.notifier = n
-}
-
-func (rl *RateLimiter) Check(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- apiToken := getAPIToken(r.Context())
- if apiToken == nil {
- next.ServeHTTP(w, r)
- return
- }
-
- tokenName := apiToken.Name
-
- // Check rate limit
- if apiToken.RateLimitRPM > 0 {
- allowed, remaining, resetAt := rl.allow(tokenName, apiToken.RateLimitRPM)
-
- // Set rate limit headers on all responses
- w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", apiToken.RateLimitRPM))
- w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
- w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetAt))
-
- if !allowed {
- retryAfter := resetAt - time.Now().Unix()
- if retryAfter < 1 {
- retryAfter = 1
- }
- w.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter))
- writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
- return
- }
- }
-
- // Check daily budget
- if apiToken.DailyBudgetUSD > 0 {
- spent, err := rl.db.TodaySpend(tokenName)
- if err == nil {
- if spent >= apiToken.DailyBudgetUSD {
- writeError(w, http.StatusTooManyRequests, "daily budget exceeded")
- return
- }
- rl.checkBudgetThreshold(tokenName, "daily", spent, apiToken.DailyBudgetUSD)
- }
- }
-
- // Check monthly budget
- if apiToken.MonthlyBudgetUSD > 0 {
- spent, err := rl.db.MonthSpend(tokenName)
- if err == nil {
- if spent >= apiToken.MonthlyBudgetUSD {
- writeError(w, http.StatusTooManyRequests, "monthly budget exceeded")
- return
- }
- rl.checkBudgetThreshold(tokenName, "monthly", spent, apiToken.MonthlyBudgetUSD)
- }
- }
-
- next.ServeHTTP(w, r)
- })
-}
-
-// checkBudgetThreshold fires a webhook notification when spend reaches 80% of budget.
-func (rl *RateLimiter) checkBudgetThreshold(tokenName, budgetType string, spent, budget float64) {
- if rl.notifier == nil || budget <= 0 {
- return
- }
- if spent/budget < 0.8 {
- return
- }
- key := tokenName + ":" + budgetType
- if _, loaded := rl.budgetNotified.LoadOrStore(key, true); loaded {
- return // already notified
- }
- rl.notifier.Notify(webhook.Event{
- Type: webhook.EventBudgetThreshold,
- Data: map[string]any{
- "token": tokenName,
- "budget_type": budgetType,
- "spent": spent,
- "budget": budget,
- "percent": spent / budget * 100,
- },
- })
-}
-
-func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) {
- rl.mu.Lock()
- defer rl.mu.Unlock()
-
- bucket, ok := rl.buckets[tokenName]
- if !ok {
- bucket = &tokenBucket{
- tokens: float64(rateLimitRPM),
- maxTokens: float64(rateLimitRPM),
- refillRate: float64(rateLimitRPM) / 60.0,
- lastRefill: time.Now(),
- }
- rl.buckets[tokenName] = bucket
- }
-
- now := time.Now()
- elapsed := now.Sub(bucket.lastRefill).Seconds()
- bucket.tokens += elapsed * bucket.refillRate
- if bucket.tokens > bucket.maxTokens {
- bucket.tokens = bucket.maxTokens
- }
- bucket.lastRefill = now
-
- remaining := int(math.Floor(bucket.tokens))
- if remaining < 0 {
- remaining = 0
- }
-
- // Compute reset time: when bucket would be full again
- deficit := bucket.maxTokens - bucket.tokens
- var resetAt int64
- if deficit > 0 && bucket.refillRate > 0 {
- resetAt = now.Add(time.Duration(deficit/bucket.refillRate) * time.Second).Unix()
- } else {
- resetAt = now.Unix()
- }
-
- if bucket.tokens < 1 {
- return false, 0, resetAt
- }
- bucket.tokens--
- remaining = int(math.Floor(bucket.tokens))
- if remaining < 0 {
- remaining = 0
- }
- return true, remaining, resetAt
-}
diff --git a/llm-gateway/internal/proxy/ratelimit_test.go b/llm-gateway/internal/proxy/ratelimit_test.go
deleted file mode 100644
index 20cb311..0000000
--- a/llm-gateway/internal/proxy/ratelimit_test.go
+++ /dev/null
@@ -1,374 +0,0 @@
-package proxy
-
-import (
- "context"
- "database/sql"
- "net/http"
- "net/http/httptest"
- "strconv"
- "testing"
- "time"
-
- _ "modernc.org/sqlite"
-
- "llm-gateway/internal/auth"
- "llm-gateway/internal/storage"
-)
-
-// newTestDB creates an in-memory SQLite database wrapped in storage.DB.
-// It creates the request_logs table needed by TodaySpend.
-func newTestDB(t *testing.T) *storage.DB {
- t.Helper()
- sqlDB, err := sql.Open("sqlite", ":memory:")
- if err != nil {
- t.Fatalf("opening in-memory sqlite: %v", err)
- }
- t.Cleanup(func() { sqlDB.Close() })
-
- // Create the minimal table needed for TodaySpend queries.
- _, err = sqlDB.Exec(`CREATE TABLE IF NOT EXISTS request_logs (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- token_name TEXT,
- cost_usd REAL,
- timestamp INTEGER
- )`)
- if err != nil {
- t.Fatalf("creating request_logs table: %v", err)
- }
- return &storage.DB{DB: sqlDB}
-}
-
-// okHandler is a simple handler that writes 200 OK.
-var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
-})
-
-func TestRateLimiter_Allow(t *testing.T) {
- tests := []struct {
- name string
- rateLimitRPM int
- numRequests int
- wantAllowed int
- wantDenied int
- }{
- {
- name: "allows requests within limit",
- rateLimitRPM: 10,
- numRequests: 5,
- wantAllowed: 5,
- wantDenied: 0,
- },
- {
- name: "denies requests over limit",
- rateLimitRPM: 3,
- numRequests: 6,
- wantAllowed: 3,
- wantDenied: 3,
- },
- {
- name: "allows exactly up to limit",
- rateLimitRPM: 5,
- numRequests: 5,
- wantAllowed: 5,
- wantDenied: 0,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- db := newTestDB(t)
- rl := NewRateLimiter(db)
-
- allowed := 0
- denied := 0
- for i := 0; i < tt.numRequests; i++ {
- ok, _, _ := rl.allow("test-token", tt.rateLimitRPM)
- if ok {
- allowed++
- } else {
- denied++
- }
- }
-
- if allowed != tt.wantAllowed {
- t.Errorf("allowed = %d, want %d", allowed, tt.wantAllowed)
- }
- if denied != tt.wantDenied {
- t.Errorf("denied = %d, want %d", denied, tt.wantDenied)
- }
- })
- }
-}
-
-func TestRateLimiter_TokenRefillsOverTime(t *testing.T) {
- db := newTestDB(t)
- rl := NewRateLimiter(db)
-
- rpm := 60 // 1 token per second refill rate
-
- // Exhaust all tokens.
- for i := 0; i < rpm; i++ {
- ok, _, _ := rl.allow("refill-token", rpm)
- if !ok {
- t.Fatalf("request %d should have been allowed", i)
- }
- }
-
- // Next request should be denied.
- ok, _, _ := rl.allow("refill-token", rpm)
- if ok {
- t.Fatal("request should have been denied after exhausting tokens")
- }
-
- // Manually advance the bucket's lastRefill to simulate time passing.
- rl.mu.Lock()
- bucket := rl.buckets["refill-token"]
- bucket.lastRefill = bucket.lastRefill.Add(-2 * time.Second)
- rl.mu.Unlock()
-
- // After 2 seconds at 1 token/sec, we should have ~2 tokens refilled.
- ok, remaining, _ := rl.allow("refill-token", rpm)
- if !ok {
- t.Fatal("request should have been allowed after token refill")
- }
- // We consumed 1 of the ~2 refilled tokens, so remaining should be >= 0.
- if remaining < 0 {
- t.Errorf("remaining = %d, want >= 0", remaining)
- }
-}
-
-func TestRateLimiter_AllowReturnValues(t *testing.T) {
- tests := []struct {
- name string
- rateLimitRPM int
- numRequests int
- wantLastAllowed bool
- wantLastRemaining int
- }{
- {
- name: "remaining decrements correctly",
- rateLimitRPM: 5,
- numRequests: 1,
- wantLastAllowed: true,
- wantLastRemaining: 4,
- },
- {
- name: "remaining is zero at limit",
- rateLimitRPM: 3,
- numRequests: 3,
- wantLastAllowed: true,
- wantLastRemaining: 0,
- },
- {
- name: "denied returns zero remaining",
- rateLimitRPM: 2,
- numRequests: 3,
- wantLastAllowed: false,
- wantLastRemaining: 0,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- db := newTestDB(t)
- rl := NewRateLimiter(db)
-
- var allowed bool
- var remaining int
- for i := 0; i < tt.numRequests; i++ {
- allowed, remaining, _ = rl.allow("test-token", tt.rateLimitRPM)
- }
-
- if allowed != tt.wantLastAllowed {
- t.Errorf("allowed = %v, want %v", allowed, tt.wantLastAllowed)
- }
- if remaining != tt.wantLastRemaining {
- t.Errorf("remaining = %d, want %d", remaining, tt.wantLastRemaining)
- }
- })
- }
-}
-
-func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) {
- tests := []struct {
- name string
- rateLimitRPM int
- numRequests int
- wantStatusCode int
- wantLimitHeader string
- wantRetryAfter bool
- }{
- {
- name: "sets rate limit headers on allowed request",
- rateLimitRPM: 10,
- numRequests: 1,
- wantStatusCode: http.StatusOK,
- wantLimitHeader: "10",
- wantRetryAfter: false,
- },
- {
- name: "sets Retry-After header on 429",
- rateLimitRPM: 2,
- numRequests: 3,
- wantStatusCode: http.StatusTooManyRequests,
- wantLimitHeader: "2",
- wantRetryAfter: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- db := newTestDB(t)
- rl := NewRateLimiter(db)
-
- token := &auth.APIToken{
- Name: "header-test-token",
- RateLimitRPM: tt.rateLimitRPM,
- }
-
- handler := rl.Check(okHandler)
-
- var rec *httptest.ResponseRecorder
- for i := 0; i < tt.numRequests; i++ {
- rec = httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx := withAPIToken(req.Context(), token)
- req = req.WithContext(ctx)
- handler.ServeHTTP(rec, req)
- }
-
- // Check the last response.
- if rec.Code != tt.wantStatusCode {
- t.Errorf("status code = %d, want %d", rec.Code, tt.wantStatusCode)
- }
-
- // X-RateLimit-Limit header.
- limitHeader := rec.Header().Get("X-RateLimit-Limit")
- if limitHeader != tt.wantLimitHeader {
- t.Errorf("X-RateLimit-Limit = %q, want %q", limitHeader, tt.wantLimitHeader)
- }
-
- // X-RateLimit-Remaining header must be present and numeric.
- remainingHeader := rec.Header().Get("X-RateLimit-Remaining")
- if remainingHeader == "" {
- t.Error("X-RateLimit-Remaining header is missing")
- } else if _, err := strconv.Atoi(remainingHeader); err != nil {
- t.Errorf("X-RateLimit-Remaining = %q, not a valid integer", remainingHeader)
- }
-
- // X-RateLimit-Reset header must be present and numeric.
- resetHeader := rec.Header().Get("X-RateLimit-Reset")
- if resetHeader == "" {
- t.Error("X-RateLimit-Reset header is missing")
- } else if _, err := strconv.ParseInt(resetHeader, 10, 64); err != nil {
- t.Errorf("X-RateLimit-Reset = %q, not a valid integer", resetHeader)
- }
-
- // Retry-After header.
- retryAfter := rec.Header().Get("Retry-After")
- if tt.wantRetryAfter && retryAfter == "" {
- t.Error("Retry-After header is missing on 429 response")
- }
- if !tt.wantRetryAfter && retryAfter != "" {
- t.Errorf("Retry-After header should not be present, got %q", retryAfter)
- }
- })
- }
-}
-
-func TestRateLimiter_CheckMiddleware_NoToken(t *testing.T) {
- db := newTestDB(t)
- rl := NewRateLimiter(db)
-
- handler := rl.Check(okHandler)
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- // No API token in context.
- handler.ServeHTTP(rec, req)
-
- if rec.Code != http.StatusOK {
- t.Errorf("status code = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
- }
-}
-
-func TestRateLimiter_CheckMiddleware_ZeroRPM(t *testing.T) {
- db := newTestDB(t)
- rl := NewRateLimiter(db)
-
- token := &auth.APIToken{
- Name: "unlimited-token",
- RateLimitRPM: 0, // zero means unlimited
- }
-
- handler := rl.Check(okHandler)
-
- for i := 0; i < 100; i++ {
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx := withAPIToken(req.Context(), token)
- req = req.WithContext(ctx)
- handler.ServeHTTP(rec, req)
-
- if rec.Code != http.StatusOK {
- t.Fatalf("request %d: status code = %d, want %d (zero RPM should be unlimited)", i, rec.Code, http.StatusOK)
- }
- }
-}
-
-func TestRateLimiter_PerTokenIsolation(t *testing.T) {
- db := newTestDB(t)
- rl := NewRateLimiter(db)
-
- rpm := 2
-
- // Exhaust token A.
- for i := 0; i < rpm; i++ {
- rl.allow("token-a", rpm)
- }
- ok, _, _ := rl.allow("token-a", rpm)
- if ok {
- t.Fatal("token-a should be rate limited")
- }
-
- // Token B should still have its own bucket.
- ok, _, _ = rl.allow("token-b", rpm)
- if !ok {
- t.Fatal("token-b should not be affected by token-a's rate limit")
- }
-}
-
-func TestRateLimiter_ResetAtIsFuture(t *testing.T) {
- db := newTestDB(t)
- rl := NewRateLimiter(db)
-
- // Consume one token so there's a deficit.
- _, _, resetAt := rl.allow("reset-token", 10)
- now := time.Now().Unix()
-
- if resetAt < now {
- t.Errorf("resetAt = %d, want >= %d (should be now or in the future)", resetAt, now)
- }
-}
-
-func TestRateLimiter_CheckMiddleware_ContextCancelled(t *testing.T) {
- db := newTestDB(t)
- rl := NewRateLimiter(db)
-
- token := &auth.APIToken{
- Name: "ctx-token",
- RateLimitRPM: 10,
- }
-
- handler := rl.Check(okHandler)
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodGet, "/", nil)
- ctx, cancel := context.WithCancel(req.Context())
- ctx = withAPIToken(ctx, token)
- cancel() // Cancel immediately.
- req = req.WithContext(ctx)
-
- // Should still process (rate limiter does not check context cancellation).
- handler.ServeHTTP(rec, req)
- // The handler itself may or may not respect cancelled context;
- // the key point is no panic occurs.
-}
diff --git a/llm-gateway/internal/proxy/stream.go b/llm-gateway/internal/proxy/stream.go
deleted file mode 100644
index 3b5181d..0000000
--- a/llm-gateway/internal/proxy/stream.go
+++ /dev/null
@@ -1,195 +0,0 @@
-package proxy
-
-import (
- "bufio"
- "context"
- "encoding/json"
- "errors"
- "log"
- "net/http"
- "strings"
- "time"
-
- "llm-gateway/internal/provider"
- "llm-gateway/internal/storage"
-)
-
-func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string, modelTimeouts *provider.ModelTimeouts) {
- flusher, ok := w.(http.Flusher)
- if !ok {
- writeError(w, http.StatusInternalServerError, "streaming not supported")
- return
- }
-
- var lastErr error
-
- for i, route := range routes {
- // Retry backoff between attempts
- if i > 0 {
- backoff := backoffDuration(i, h.cfg.Retry)
- select {
- case <-time.After(backoff):
- case <-r.Context().Done():
- writeError(w, http.StatusGatewayTimeout, "request cancelled")
- return
- }
- }
-
- start := time.Now()
- body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req)
-
- if err != nil {
- var pe *provider.ProviderError
- if errors.As(err, &pe) && !pe.IsRetryable() {
- latency := time.Since(start).Milliseconds()
- h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
- h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
- if h.healthTracker != nil {
- h.healthTracker.Record(route.Provider.Name(), latency, err)
- }
- w.Header().Set("X-Request-ID", requestID)
- writeErrorRaw(w, pe.StatusCode, pe.Body)
- return
- }
- lastErr = err
- latency := time.Since(start).Milliseconds()
- log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err)
- h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
- if h.healthTracker != nil {
- h.healthTracker.Record(route.Provider.Name(), latency, err)
- }
- continue
- }
-
- // Apply streaming timeout (per-model override takes precedence)
- streamingTimeout := h.cfg.Server.StreamingTimeout
- if modelTimeouts != nil && modelTimeouts.StreamingTimeout > 0 {
- streamingTimeout = modelTimeouts.StreamingTimeout
- }
- var streamCtx context.Context
- var streamCancel context.CancelFunc
- if streamingTimeout > 0 {
- streamCtx, streamCancel = context.WithTimeout(r.Context(), streamingTimeout)
- } else {
- streamCtx, streamCancel = context.WithCancel(r.Context())
- }
-
- // Stream the response
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("X-Accel-Buffering", "no")
- w.Header().Set("X-Request-ID", requestID)
- w.WriteHeader(http.StatusOK)
-
- inputTokens, outputTokens := 0, 0
- scanner := bufio.NewScanner(body)
- scanner.Buffer(make([]byte, 64*1024), 256*1024)
-
- // Capture streamed lines for debug logging
- debugEnabled := h.debugLogger != nil && h.debugLogger.IsEnabled()
- var debugLines []string
-
- scanDone := make(chan struct{})
- go func() {
- defer close(scanDone)
- for scanner.Scan() {
- select {
- case <-streamCtx.Done():
- return
- default:
- }
-
- line := scanner.Text()
-
- if strings.HasPrefix(line, "data: ") {
- data := strings.TrimPrefix(line, "data: ")
- if data != "[DONE]" {
- var chunk streamChunk
- if json.Unmarshal([]byte(data), &chunk) == nil {
- if chunk.Usage != nil {
- inputTokens = chunk.Usage.PromptTokens
- outputTokens = chunk.Usage.CompletionTokens
- }
- if chunk.Model != "" {
- chunk.Model = req.Model
- if rewritten, err := json.Marshal(chunk); err == nil {
- line = "data: " + string(rewritten)
- }
- }
- }
- }
- }
-
- if debugEnabled {
- debugLines = append(debugLines, line)
- }
-
- w.Write([]byte(line + "\n"))
- flusher.Flush()
- }
- }()
-
- select {
- case <-scanDone:
- // Normal completion
- case <-streamCtx.Done():
- log.Printf("Stream timeout for %s via %s", req.Model, route.Provider.Name())
- }
-
- body.Close()
- streamCancel()
-
- latency := time.Since(start).Milliseconds()
- cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
- h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
- h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
- if h.healthTracker != nil {
- h.healthTracker.Record(route.Provider.Name(), latency, nil)
- }
-
- // Debug logging for streaming requests
- if debugEnabled && len(debugLines) > 0 {
- respBody := strings.Join(debugLines, "\n")
- reqBody, _ := json.Marshal(req)
- reqBodyStr := string(reqBody)
- if h.cfg.Debug.MaxBodyBytes > 0 {
- if len(reqBodyStr) > h.cfg.Debug.MaxBodyBytes {
- reqBodyStr = reqBodyStr[:h.cfg.Debug.MaxBodyBytes]
- }
- if len(respBody) > h.cfg.Debug.MaxBodyBytes {
- respBody = respBody[:h.cfg.Debug.MaxBodyBytes]
- }
- }
- h.debugLogger.Log(storage.DebugLogEntry{
- RequestID: requestID,
- TokenName: tokenName,
- Model: req.Model,
- Provider: route.Provider.Name(),
- RequestBody: reqBodyStr,
- ResponseBody: respBody,
- RequestHeaders: formatHeaders(r.Header),
- ResponseStatus: http.StatusOK,
- })
- }
-
- return
- }
-
- // All providers failed
- w.Header().Set("X-Request-ID", requestID)
- if lastErr != nil {
- writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
- } else {
- writeError(w, http.StatusBadGateway, "all providers failed")
- }
-}
-
-type streamChunk struct {
- ID string `json:"id,omitempty"`
- Object string `json:"object,omitempty"`
- Created int64 `json:"created,omitempty"`
- Model string `json:"model,omitempty"`
- Choices []any `json:"choices,omitempty"`
- Usage *provider.Usage `json:"usage,omitempty"`
-}
diff --git a/llm-gateway/internal/storage/audit.go b/llm-gateway/internal/storage/audit.go
deleted file mode 100644
index 2b41a10..0000000
--- a/llm-gateway/internal/storage/audit.go
+++ /dev/null
@@ -1,105 +0,0 @@
-package storage
-
-import (
- "log"
- "time"
-)
-
-type AuditEntry struct {
- ID int64 `json:"id"`
- Timestamp int64 `json:"timestamp"`
- UserID int64 `json:"user_id"`
- Username string `json:"username"`
- Action string `json:"action"`
- TargetType string `json:"target_type"`
- TargetID string `json:"target_id"`
- Details string `json:"details"`
- IPAddress string `json:"ip_address"`
- RequestID string `json:"request_id"`
-}
-
-type AuditLogger struct {
- db *DB
- OnWrite func()
-}
-
-func NewAuditLogger(db *DB) *AuditLogger {
- return &AuditLogger{db: db}
-}
-
-func (a *AuditLogger) Log(entry AuditEntry) {
- if entry.Timestamp == 0 {
- entry.Timestamp = time.Now().Unix()
- }
- _, err := a.db.Exec(`INSERT INTO audit_log
- (timestamp, user_id, username, action, target_type, target_id, details, ip_address, request_id)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
- entry.Timestamp, entry.UserID, entry.Username, entry.Action,
- entry.TargetType, entry.TargetID, entry.Details, entry.IPAddress, entry.RequestID,
- )
- if err != nil {
- log.Printf("ERROR: audit log: %v", err)
- } else if a.OnWrite != nil {
- a.OnWrite()
- }
-}
-
-type AuditQueryResult struct {
- Entries []AuditEntry `json:"entries"`
- Page int `json:"page"`
- TotalPages int `json:"total_pages"`
- Total int `json:"total"`
-}
-
-func (a *AuditLogger) Query(since int64, action string, page, limit int) *AuditQueryResult {
- if page < 1 {
- page = 1
- }
- if limit <= 0 {
- limit = 50
- }
- offset := (page - 1) * limit
-
- where := "WHERE timestamp >= ?"
- args := []any{since}
-
- if action != "" {
- where += " AND action = ?"
- args = append(args, action)
- }
-
- var total int
- countArgs := make([]any, len(args))
- copy(countArgs, args)
- a.db.QueryRow("SELECT COUNT(*) FROM audit_log "+where, countArgs...).Scan(&total)
-
- totalPages := (total + limit - 1) / limit
- if totalPages < 1 {
- totalPages = 1
- }
-
- query := `SELECT id, timestamp, COALESCE(user_id, 0), username, action,
- COALESCE(target_type, ''), COALESCE(target_id, ''), COALESCE(details, ''),
- COALESCE(ip_address, ''), COALESCE(request_id, '')
- FROM audit_log ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?`
- args = append(args, limit, offset)
-
- rows, err := a.db.Query(query, args...)
- if err != nil {
- return &AuditQueryResult{Entries: []AuditEntry{}, Page: page, TotalPages: totalPages, Total: total}
- }
- defer rows.Close()
-
- var entries []AuditEntry
- for rows.Next() {
- var e AuditEntry
- rows.Scan(&e.ID, &e.Timestamp, &e.UserID, &e.Username, &e.Action,
- &e.TargetType, &e.TargetID, &e.Details, &e.IPAddress, &e.RequestID)
- entries = append(entries, e)
- }
- if entries == nil {
- entries = []AuditEntry{}
- }
-
- return &AuditQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
-}
diff --git a/llm-gateway/internal/storage/db.go b/llm-gateway/internal/storage/db.go
deleted file mode 100644
index f7fc480..0000000
--- a/llm-gateway/internal/storage/db.go
+++ /dev/null
@@ -1,142 +0,0 @@
-package storage
-
-import (
- "database/sql"
- "fmt"
- "log"
- "path/filepath"
- "time"
-
- "github.com/golang-migrate/migrate/v4"
- "github.com/golang-migrate/migrate/v4/database/sqlite"
- "github.com/golang-migrate/migrate/v4/source/iofs"
- _ "modernc.org/sqlite"
-
- "llm-gateway/internal/storage/migrations"
-)
-
-type DB struct {
- *sql.DB
-}
-
-func Open(path string) (*DB, error) {
- dir := filepath.Dir(path)
- if dir != "." && dir != "" {
- // Ensure directory exists — caller should create it if needed
- }
-
- db, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000&_cache_size=-20000")
- if err != nil {
- return nil, fmt.Errorf("opening database: %w", err)
- }
-
- // Performance pragmas
- for _, pragma := range []string{
- "PRAGMA foreign_keys = ON",
- "PRAGMA temp_store = MEMORY",
- "PRAGMA mmap_size = 268435456",
- } {
- if _, err := db.Exec(pragma); err != nil {
- return nil, fmt.Errorf("setting pragma %s: %w", pragma, err)
- }
- }
-
- db.SetMaxOpenConns(1) // SQLite is single-writer
- db.SetMaxIdleConns(1)
-
- if err := runMigrations(db); err != nil {
- return nil, fmt.Errorf("running migrations: %w", err)
- }
-
- return &DB{db}, nil
-}
-
-func runMigrations(db *sql.DB) error {
- sourceDriver, err := iofs.New(migrations.FS, ".")
- if err != nil {
- return fmt.Errorf("creating migration source: %w", err)
- }
-
- dbDriver, err := sqlite.WithInstance(db, &sqlite.Config{})
- if err != nil {
- return fmt.Errorf("creating migration db driver: %w", err)
- }
-
- m, err := migrate.NewWithInstance("iofs", sourceDriver, "sqlite", dbDriver)
- if err != nil {
- return fmt.Errorf("creating migrator: %w", err)
- }
-
- if err := m.Up(); err != nil && err != migrate.ErrNoChange {
- return fmt.Errorf("applying migrations: %w", err)
- }
-
- return nil
-}
-
-// CleanupOldRecords deletes records older than retentionDays.
-func (db *DB) CleanupOldRecords(retentionDays int) error {
- cutoff := time.Now().AddDate(0, 0, -retentionDays).Unix()
- result, err := db.Exec("DELETE FROM request_logs WHERE timestamp < ?", cutoff)
- if err != nil {
- return err
- }
- affected, _ := result.RowsAffected()
- if affected > 0 {
- log.Printf("Cleaned up %d old request log records", affected)
- }
- return nil
-}
-
-// TodaySpend returns the total cost in USD for a given token today.
-func (db *DB) TodaySpend(tokenName string) (float64, error) {
- startOfDay := time.Now().Truncate(24 * time.Hour).Unix()
- var total sql.NullFloat64
- err := db.QueryRow(
- "SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?",
- tokenName, startOfDay,
- ).Scan(&total)
- if err != nil {
- return 0, err
- }
- return total.Float64, nil
-}
-
-// MonthSpend returns the total cost in USD for a given token this month.
-func (db *DB) MonthSpend(tokenName string) (float64, error) {
- now := time.Now()
- startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()).Unix()
- var total sql.NullFloat64
- err := db.QueryRow(
- "SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?",
- tokenName, startOfMonth,
- ).Scan(&total)
- if err != nil {
- return 0, err
- }
- return total.Float64, nil
-}
-
-// TodaySpendAll returns today's spend for all tokens as a map.
-func (db *DB) TodaySpendAll() (map[string]float64, error) {
- startOfDay := time.Now().Truncate(24 * time.Hour).Unix()
- rows, err := db.Query(
- "SELECT token_name, SUM(cost_usd) FROM request_logs WHERE timestamp >= ? GROUP BY token_name",
- startOfDay,
- )
- if err != nil {
- return nil, err
- }
- defer rows.Close()
-
- result := make(map[string]float64)
- for rows.Next() {
- var name string
- var total float64
- if err := rows.Scan(&name, &total); err != nil {
- continue
- }
- result[name] = total
- }
- return result, nil
-}
diff --git a/llm-gateway/internal/storage/debuglog.go b/llm-gateway/internal/storage/debuglog.go
deleted file mode 100644
index 7fe5a83..0000000
--- a/llm-gateway/internal/storage/debuglog.go
+++ /dev/null
@@ -1,253 +0,0 @@
-package storage
-
-import (
- "encoding/json"
- "fmt"
- "log"
- "os"
- "path/filepath"
- "sort"
- "strings"
- "sync/atomic"
- "time"
-)
-
-type DebugLogEntry struct {
- ID int64 `json:"id"`
- RequestID string `json:"request_id"`
- Timestamp int64 `json:"timestamp"`
- TokenName string `json:"token_name"`
- Model string `json:"model"`
- Provider string `json:"provider"`
- RequestBody string `json:"request_body"`
- ResponseBody string `json:"response_body"`
- RequestHeaders string `json:"request_headers"`
- ResponseStatus int `json:"response_status"`
- FilePath string `json:"-"`
-}
-
-// debugFile is the JSON structure written to disk.
-type debugFile struct {
- RequestHeaders string `json:"request_headers"`
- RequestBody string `json:"request_body"`
- ResponseBody string `json:"response_body"`
-}
-
-type DebugLogger struct {
- db *DB
- enabled atomic.Bool
- dataDir string
- OnWrite func()
-}
-
-func NewDebugLogger(db *DB, enabled bool, dataDir string) *DebugLogger {
- dl := &DebugLogger{db: db, dataDir: dataDir}
- dl.enabled.Store(enabled)
- return dl
-}
-
-func (d *DebugLogger) SetEnabled(v bool) {
- d.enabled.Store(v)
-}
-
-func (d *DebugLogger) IsEnabled() bool {
- return d.enabled.Load()
-}
-
-// debugLogDir returns the base directory for debug log files.
-func (d *DebugLogger) debugLogDir() string {
- return filepath.Join(d.dataDir, "debug-logs")
-}
-
-// debugFilePath builds the file path for a debug log entry.
-func (d *DebugLogger) debugFilePath(requestID string, ts time.Time) string {
- date := ts.Format("2006-01-02")
- return filepath.Join(d.debugLogDir(), date, requestID+".json")
-}
-
-func (d *DebugLogger) Log(entry DebugLogEntry) {
- if !d.IsEnabled() {
- return
- }
- if entry.Timestamp == 0 {
- entry.Timestamp = time.Now().Unix()
- }
-
- ts := time.Unix(entry.Timestamp, 0)
- fp := d.debugFilePath(entry.RequestID, ts)
-
- // Write body file
- if err := os.MkdirAll(filepath.Dir(fp), 0755); err != nil {
- log.Printf("ERROR: debug log mkdir: %v", err)
- return
- }
-
- df := debugFile{
- RequestHeaders: entry.RequestHeaders,
- RequestBody: entry.RequestBody,
- ResponseBody: entry.ResponseBody,
- }
- data, err := json.Marshal(df)
- if err != nil {
- log.Printf("ERROR: debug log marshal: %v", err)
- return
- }
- if err := os.WriteFile(fp, data, 0644); err != nil {
- log.Printf("ERROR: debug log write: %v", err)
- return
- }
-
- // Insert metadata into DB (no bodies)
- _, err = d.db.Exec(`INSERT INTO debug_log
- (request_id, timestamp, token_name, model, provider, response_status, file_path)
- VALUES (?, ?, ?, ?, ?, ?, ?)`,
- entry.RequestID, entry.Timestamp, entry.TokenName, entry.Model,
- entry.Provider, entry.ResponseStatus, fp,
- )
- if err != nil {
- log.Printf("ERROR: debug log db insert: %v", err)
- } else if d.OnWrite != nil {
- d.OnWrite()
- }
-}
-
-type DebugLogQueryResult struct {
- Entries []DebugLogEntry `json:"entries"`
- Page int `json:"page"`
- TotalPages int `json:"total_pages"`
- Total int `json:"total"`
-}
-
-// Query returns paginated debug log metadata (no bodies — fast).
-func (d *DebugLogger) Query(page, limit int) *DebugLogQueryResult {
- if page < 1 {
- page = 1
- }
- if limit <= 0 {
- limit = 50
- }
- offset := (page - 1) * limit
-
- var total int
- d.db.QueryRow("SELECT COUNT(*) FROM debug_log").Scan(&total)
-
- totalPages := (total + limit - 1) / limit
- if totalPages < 1 {
- totalPages = 1
- }
-
- rows, err := d.db.Query(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
- COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
- FROM debug_log ORDER BY timestamp DESC LIMIT ? OFFSET ?`, limit, offset)
- if err != nil {
- return &DebugLogQueryResult{Entries: []DebugLogEntry{}, Page: page, TotalPages: totalPages, Total: total}
- }
- defer rows.Close()
-
- var entries []DebugLogEntry
- for rows.Next() {
- var e DebugLogEntry
- rows.Scan(&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
- &e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
- entries = append(entries, e)
- }
- if entries == nil {
- entries = []DebugLogEntry{}
- }
-
- return &DebugLogQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
-}
-
-// QueryFull returns paginated debug log entries including request/response bodies read from files.
-func (d *DebugLogger) QueryFull(page, limit int) *DebugLogQueryResult {
- result := d.Query(page, limit)
- for i := range result.Entries {
- d.populateFromFile(&result.Entries[i])
- }
- return result
-}
-
-// GetByRequestID returns a single debug log entry with bodies read from file.
-func (d *DebugLogger) GetByRequestID(requestID string) *DebugLogEntry {
- var e DebugLogEntry
- err := d.db.QueryRow(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
- COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
- FROM debug_log WHERE request_id = ?`, requestID).Scan(
- &e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
- &e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
- if err != nil {
- return nil
- }
- d.populateFromFile(&e)
- return &e
-}
-
-// populateFromFile reads body data from the debug file on disk.
-// Falls back to DB columns for pre-migration entries that have no file_path.
-func (d *DebugLogger) populateFromFile(e *DebugLogEntry) {
- if e.FilePath == "" {
- // Legacy entry: try reading bodies from DB columns
- d.db.QueryRow(`SELECT COALESCE(request_body, ''), COALESCE(response_body, ''), COALESCE(request_headers, '')
- FROM debug_log WHERE id = ?`, e.ID).Scan(&e.RequestBody, &e.ResponseBody, &e.RequestHeaders)
- return
- }
- data, err := os.ReadFile(e.FilePath)
- if err != nil {
- log.Printf("WARN: debug log read file %s: %v", e.FilePath, err)
- return
- }
- var df debugFile
- if err := json.Unmarshal(data, &df); err != nil {
- log.Printf("WARN: debug log parse file %s: %v", e.FilePath, err)
- return
- }
- e.RequestHeaders = df.RequestHeaders
- e.RequestBody = df.RequestBody
- e.ResponseBody = df.ResponseBody
-}
-
-// Cleanup removes debug log entries and files older than retentionDays.
-func (d *DebugLogger) Cleanup(retentionDays int) error {
- cutoff := time.Now().AddDate(0, 0, -retentionDays)
- cutoffUnix := cutoff.Unix()
-
- // Delete old DB rows
- result, err := d.db.Exec("DELETE FROM debug_log WHERE timestamp < ?", cutoffUnix)
- if err != nil {
- return fmt.Errorf("delete old debug rows: %w", err)
- }
- affected, _ := result.RowsAffected()
- if affected > 0 {
- log.Printf("Cleaned up %d old debug log entries", affected)
- }
-
- // Remove old date directories
- baseDir := d.debugLogDir()
- dirs, err := os.ReadDir(baseDir)
- if err != nil {
- if os.IsNotExist(err) {
- return nil
- }
- return fmt.Errorf("read debug log dir: %w", err)
- }
-
- cutoffDate := cutoff.Format("2006-01-02")
- sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() })
-
- for _, dir := range dirs {
- if !dir.IsDir() {
- continue
- }
- // Date directories are named YYYY-MM-DD; string comparison works
- if strings.Compare(dir.Name(), cutoffDate) < 0 {
- dirPath := filepath.Join(baseDir, dir.Name())
- if err := os.RemoveAll(dirPath); err != nil {
- log.Printf("WARN: failed to remove debug log dir %s: %v", dirPath, err)
- } else {
- log.Printf("Removed old debug log directory: %s", dir.Name())
- }
- }
- }
-
- return nil
-}
diff --git a/llm-gateway/internal/storage/logger.go b/llm-gateway/internal/storage/logger.go
deleted file mode 100644
index ce54675..0000000
--- a/llm-gateway/internal/storage/logger.go
+++ /dev/null
@@ -1,138 +0,0 @@
-package storage
-
-import (
- "log"
- "time"
-)
-
-type RequestLog struct {
- RequestID string
- Timestamp int64
- TokenName string
- Model string
- Provider string
- ProviderModel string
- InputTokens int
- OutputTokens int
- CostUSD float64
- LatencyMS int64
- Status string // success, error, cached
- ErrorMessage string
- Streaming bool
- Cached bool
- RequestType string // "chat" or "embedding"
-}
-
-type AsyncLogger struct {
- db *DB
- ch chan RequestLog
- done chan struct{}
- OnFlush func() // called after successful flush, if set
-}
-
-func NewAsyncLogger(db *DB, bufferSize int) *AsyncLogger {
- if bufferSize == 0 {
- bufferSize = 1000
- }
- l := &AsyncLogger{
- db: db,
- ch: make(chan RequestLog, bufferSize),
- done: make(chan struct{}),
- }
- go l.run()
- return l
-}
-
-func (l *AsyncLogger) Log(r RequestLog) {
- select {
- case l.ch <- r:
- default:
- log.Println("WARNING: request log buffer full, dropping entry")
- }
-}
-
-func (l *AsyncLogger) Close() {
- close(l.ch)
- <-l.done
-}
-
-func (l *AsyncLogger) run() {
- defer close(l.done)
-
- batch := make([]RequestLog, 0, 100)
- ticker := time.NewTicker(1 * time.Second)
- defer ticker.Stop()
-
- for {
- select {
- case r, ok := <-l.ch:
- if !ok {
- // Channel closed, flush remaining
- if len(batch) > 0 {
- l.flush(batch)
- }
- return
- }
- batch = append(batch, r)
- if len(batch) >= 100 {
- l.flush(batch)
- batch = batch[:0]
- }
- case <-ticker.C:
- if len(batch) > 0 {
- l.flush(batch)
- batch = batch[:0]
- }
- }
- }
-}
-
-func (l *AsyncLogger) flush(batch []RequestLog) {
- tx, err := l.db.Begin()
- if err != nil {
- log.Printf("ERROR: starting log transaction: %v", err)
- return
- }
-
- stmt, err := tx.Prepare(`INSERT INTO request_logs
- (request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached, request_type)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
- if err != nil {
- log.Printf("ERROR: preparing log statement: %v", err)
- tx.Rollback()
- return
- }
- defer stmt.Close()
-
- for _, r := range batch {
- streaming := 0
- if r.Streaming {
- streaming = 1
- }
- cached := 0
- if r.Cached {
- cached = 1
- }
- reqType := r.RequestType
- if reqType == "" {
- reqType = "chat"
- }
- _, err := stmt.Exec(
- r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
- r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS,
- r.Status, r.ErrorMessage, streaming, cached, reqType,
- )
- if err != nil {
- log.Printf("ERROR: inserting log: %v", err)
- }
- }
-
- if err := tx.Commit(); err != nil {
- log.Printf("ERROR: committing log batch: %v", err)
- return
- }
-
- if l.OnFlush != nil {
- l.OnFlush()
- }
-}
diff --git a/llm-gateway/internal/storage/migrations/001_init.down.sql b/llm-gateway/internal/storage/migrations/001_init.down.sql
deleted file mode 100644
index bd1fad9..0000000
--- a/llm-gateway/internal/storage/migrations/001_init.down.sql
+++ /dev/null
@@ -1 +0,0 @@
-DROP TABLE IF EXISTS request_logs;
diff --git a/llm-gateway/internal/storage/migrations/001_init.up.sql b/llm-gateway/internal/storage/migrations/001_init.up.sql
deleted file mode 100644
index 610d375..0000000
--- a/llm-gateway/internal/storage/migrations/001_init.up.sql
+++ /dev/null
@@ -1,20 +0,0 @@
-CREATE TABLE IF NOT EXISTS request_logs (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- timestamp INTEGER NOT NULL,
- token_name TEXT NOT NULL,
- model TEXT NOT NULL,
- provider TEXT NOT NULL,
- provider_model TEXT NOT NULL,
- input_tokens INTEGER DEFAULT 0,
- output_tokens INTEGER DEFAULT 0,
- cost_usd REAL DEFAULT 0,
- latency_ms INTEGER DEFAULT 0,
- status TEXT NOT NULL,
- error_message TEXT DEFAULT '',
- streaming INTEGER DEFAULT 0,
- cached INTEGER DEFAULT 0
-);
-
-CREATE INDEX IF NOT EXISTS idx_timestamp ON request_logs(timestamp);
-CREATE INDEX IF NOT EXISTS idx_token ON request_logs(token_name);
-CREATE INDEX IF NOT EXISTS idx_model ON request_logs(model);
diff --git a/llm-gateway/internal/storage/migrations/002_users.down.sql b/llm-gateway/internal/storage/migrations/002_users.down.sql
deleted file mode 100644
index 12f43a1..0000000
--- a/llm-gateway/internal/storage/migrations/002_users.down.sql
+++ /dev/null
@@ -1,3 +0,0 @@
-DROP TABLE IF EXISTS api_tokens;
-DROP TABLE IF EXISTS sessions;
-DROP TABLE IF EXISTS users;
diff --git a/llm-gateway/internal/storage/migrations/002_users.up.sql b/llm-gateway/internal/storage/migrations/002_users.up.sql
deleted file mode 100644
index fb8eb11..0000000
--- a/llm-gateway/internal/storage/migrations/002_users.up.sql
+++ /dev/null
@@ -1,33 +0,0 @@
-CREATE TABLE users (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- username TEXT NOT NULL UNIQUE,
- password_hash TEXT NOT NULL,
- is_admin INTEGER NOT NULL DEFAULT 0,
- totp_secret TEXT DEFAULT '',
- totp_enabled INTEGER NOT NULL DEFAULT 0,
- created_at INTEGER NOT NULL,
- updated_at INTEGER NOT NULL
-);
-
-CREATE TABLE sessions (
- id TEXT PRIMARY KEY,
- user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
- created_at INTEGER NOT NULL,
- expires_at INTEGER NOT NULL
-);
-CREATE INDEX idx_sessions_user ON sessions(user_id);
-CREATE INDEX idx_sessions_expires ON sessions(expires_at);
-
-CREATE TABLE api_tokens (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- name TEXT NOT NULL,
- key_hash TEXT NOT NULL,
- key_prefix TEXT NOT NULL,
- user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
- rate_limit_rpm INTEGER DEFAULT 60,
- daily_budget_usd REAL DEFAULT 0,
- created_at INTEGER NOT NULL,
- last_used_at INTEGER DEFAULT 0
-);
-CREATE UNIQUE INDEX idx_api_tokens_hash ON api_tokens(key_hash);
-CREATE INDEX idx_api_tokens_user ON api_tokens(user_id);
diff --git a/llm-gateway/internal/storage/migrations/003_user_email.down.sql b/llm-gateway/internal/storage/migrations/003_user_email.down.sql
deleted file mode 100644
index e00308d..0000000
--- a/llm-gateway/internal/storage/migrations/003_user_email.down.sql
+++ /dev/null
@@ -1,5 +0,0 @@
--- SQLite doesn't support DROP COLUMN before 3.35.0, so we recreate
-CREATE TABLE users_backup AS SELECT id, username, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users;
-DROP TABLE users;
-ALTER TABLE users_backup RENAME TO users;
-CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);
diff --git a/llm-gateway/internal/storage/migrations/003_user_email.up.sql b/llm-gateway/internal/storage/migrations/003_user_email.up.sql
deleted file mode 100644
index 9468211..0000000
--- a/llm-gateway/internal/storage/migrations/003_user_email.up.sql
+++ /dev/null
@@ -1 +0,0 @@
-ALTER TABLE users ADD COLUMN email TEXT DEFAULT '';
diff --git a/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql b/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql
deleted file mode 100644
index e11bb40..0000000
--- a/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql
+++ /dev/null
@@ -1,4 +0,0 @@
--- SQLite doesn't support DROP COLUMN in older versions, so we recreate the table
-CREATE TABLE api_tokens_backup AS SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens;
-DROP TABLE api_tokens;
-ALTER TABLE api_tokens_backup RENAME TO api_tokens;
diff --git a/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql b/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql
deleted file mode 100644
index ccf0549..0000000
--- a/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql
+++ /dev/null
@@ -1 +0,0 @@
-ALTER TABLE api_tokens ADD COLUMN max_concurrent INTEGER DEFAULT 0;
diff --git a/llm-gateway/internal/storage/migrations/005_request_id.down.sql b/llm-gateway/internal/storage/migrations/005_request_id.down.sql
deleted file mode 100644
index 819b90b..0000000
--- a/llm-gateway/internal/storage/migrations/005_request_id.down.sql
+++ /dev/null
@@ -1 +0,0 @@
-DROP INDEX IF EXISTS idx_request_logs_request_id;
diff --git a/llm-gateway/internal/storage/migrations/005_request_id.up.sql b/llm-gateway/internal/storage/migrations/005_request_id.up.sql
deleted file mode 100644
index ff54384..0000000
--- a/llm-gateway/internal/storage/migrations/005_request_id.up.sql
+++ /dev/null
@@ -1,2 +0,0 @@
-ALTER TABLE request_logs ADD COLUMN request_id TEXT DEFAULT '';
-CREATE INDEX idx_request_logs_request_id ON request_logs(request_id);
diff --git a/llm-gateway/internal/storage/migrations/006_audit_log.down.sql b/llm-gateway/internal/storage/migrations/006_audit_log.down.sql
deleted file mode 100644
index b750c3b..0000000
--- a/llm-gateway/internal/storage/migrations/006_audit_log.down.sql
+++ /dev/null
@@ -1 +0,0 @@
-DROP TABLE IF EXISTS audit_log;
diff --git a/llm-gateway/internal/storage/migrations/006_audit_log.up.sql b/llm-gateway/internal/storage/migrations/006_audit_log.up.sql
deleted file mode 100644
index c2fc48a..0000000
--- a/llm-gateway/internal/storage/migrations/006_audit_log.up.sql
+++ /dev/null
@@ -1,14 +0,0 @@
-CREATE TABLE audit_log (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- timestamp INTEGER NOT NULL,
- user_id INTEGER,
- username TEXT NOT NULL DEFAULT '',
- action TEXT NOT NULL,
- target_type TEXT DEFAULT '',
- target_id TEXT DEFAULT '',
- details TEXT DEFAULT '',
- ip_address TEXT DEFAULT '',
- request_id TEXT DEFAULT ''
-);
-CREATE INDEX idx_audit_timestamp ON audit_log(timestamp);
-CREATE INDEX idx_audit_action ON audit_log(action);
diff --git a/llm-gateway/internal/storage/migrations/007_debug_log.down.sql b/llm-gateway/internal/storage/migrations/007_debug_log.down.sql
deleted file mode 100644
index 41353f5..0000000
--- a/llm-gateway/internal/storage/migrations/007_debug_log.down.sql
+++ /dev/null
@@ -1 +0,0 @@
-DROP TABLE IF EXISTS debug_log;
diff --git a/llm-gateway/internal/storage/migrations/007_debug_log.up.sql b/llm-gateway/internal/storage/migrations/007_debug_log.up.sql
deleted file mode 100644
index 9a8441a..0000000
--- a/llm-gateway/internal/storage/migrations/007_debug_log.up.sql
+++ /dev/null
@@ -1,14 +0,0 @@
-CREATE TABLE debug_log (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- request_id TEXT NOT NULL,
- timestamp INTEGER NOT NULL,
- token_name TEXT DEFAULT '',
- model TEXT DEFAULT '',
- provider TEXT DEFAULT '',
- request_body TEXT DEFAULT '',
- response_body TEXT DEFAULT '',
- request_headers TEXT DEFAULT '',
- response_status INTEGER DEFAULT 0
-);
-CREATE INDEX idx_debug_request_id ON debug_log(request_id);
-CREATE INDEX idx_debug_timestamp ON debug_log(timestamp);
diff --git a/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql b/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql
deleted file mode 100644
index 032a37d..0000000
--- a/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql
+++ /dev/null
@@ -1 +0,0 @@
--- no-op: file_path column is harmless to keep
diff --git a/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql b/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql
deleted file mode 100644
index 7a5bf8b..0000000
--- a/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql
+++ /dev/null
@@ -1 +0,0 @@
-ALTER TABLE debug_log ADD COLUMN file_path TEXT DEFAULT '';
diff --git a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql b/llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql
deleted file mode 100644
index f288a97..0000000
--- a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql
+++ /dev/null
@@ -1 +0,0 @@
-ALTER TABLE api_tokens DROP COLUMN monthly_budget_usd;
diff --git a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql b/llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql
deleted file mode 100644
index 7d9dfbc..0000000
--- a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql
+++ /dev/null
@@ -1 +0,0 @@
-ALTER TABLE api_tokens ADD COLUMN monthly_budget_usd REAL NOT NULL DEFAULT 0;
diff --git a/llm-gateway/internal/storage/migrations/010_request_type.down.sql b/llm-gateway/internal/storage/migrations/010_request_type.down.sql
deleted file mode 100644
index 52e41cb..0000000
--- a/llm-gateway/internal/storage/migrations/010_request_type.down.sql
+++ /dev/null
@@ -1 +0,0 @@
-ALTER TABLE request_logs DROP COLUMN request_type;
diff --git a/llm-gateway/internal/storage/migrations/010_request_type.up.sql b/llm-gateway/internal/storage/migrations/010_request_type.up.sql
deleted file mode 100644
index 20bca14..0000000
--- a/llm-gateway/internal/storage/migrations/010_request_type.up.sql
+++ /dev/null
@@ -1 +0,0 @@
-ALTER TABLE request_logs ADD COLUMN request_type TEXT NOT NULL DEFAULT 'chat';
diff --git a/llm-gateway/internal/storage/migrations/embed.go b/llm-gateway/internal/storage/migrations/embed.go
deleted file mode 100644
index 91cca1c..0000000
--- a/llm-gateway/internal/storage/migrations/embed.go
+++ /dev/null
@@ -1,6 +0,0 @@
-package migrations
-
-import "embed"
-
-//go:embed *.sql
-var FS embed.FS
diff --git a/llm-gateway/internal/webhook/webhook.go b/llm-gateway/internal/webhook/webhook.go
deleted file mode 100644
index 098a0e2..0000000
--- a/llm-gateway/internal/webhook/webhook.go
+++ /dev/null
@@ -1,123 +0,0 @@
-package webhook
-
-import (
- "bytes"
- "crypto/hmac"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "log"
- "net/http"
- "time"
-
- "llm-gateway/internal/config"
-)
-
-// Event types.
-const (
- EventCircuitBreakerOpen = "circuit_breaker.open"
- EventCircuitBreakerClosed = "circuit_breaker.closed"
- EventBudgetThreshold = "budget.threshold"
-)
-
-// Event represents a webhook notification payload.
-type Event struct {
- Type string `json:"type"`
- Timestamp time.Time `json:"timestamp"`
- Data map[string]any `json:"data"`
-}
-
-// Notifier sends webhook notifications.
-type Notifier struct {
- webhooks []config.WebhookConfig
- ch chan Event
- done chan struct{}
- client *http.Client
-}
-
-// NewNotifier creates a webhook notifier from config.
-func NewNotifier(webhooks []config.WebhookConfig) *Notifier {
- n := &Notifier{
- webhooks: webhooks,
- ch: make(chan Event, 100),
- done: make(chan struct{}),
- client: &http.Client{Timeout: 10 * time.Second},
- }
- go n.run()
- return n
-}
-
-// Notify queues an event for delivery (non-blocking).
-func (n *Notifier) Notify(evt Event) {
- if evt.Timestamp.IsZero() {
- evt.Timestamp = time.Now()
- }
- select {
- case n.ch <- evt:
- default:
- log.Printf("WARNING: webhook channel full, dropping event %s", evt.Type)
- }
-}
-
-// Close drains pending events and shuts down.
-func (n *Notifier) Close() {
- close(n.ch)
- <-n.done
-}
-
-func (n *Notifier) run() {
- defer close(n.done)
- for evt := range n.ch {
- for _, wh := range n.webhooks {
- if !n.shouldSend(wh, evt.Type) {
- continue
- }
- n.send(wh, evt)
- }
- }
-}
-
-func (n *Notifier) shouldSend(wh config.WebhookConfig, eventType string) bool {
- if len(wh.Events) == 0 {
- return true // no filter = send all
- }
- for _, e := range wh.Events {
- if e == eventType {
- return true
- }
- }
- return false
-}
-
-func (n *Notifier) send(wh config.WebhookConfig, evt Event) {
- body, err := json.Marshal(evt)
- if err != nil {
- log.Printf("ERROR: webhook marshal: %v", err)
- return
- }
-
- req, err := http.NewRequest(http.MethodPost, wh.URL, bytes.NewReader(body))
- if err != nil {
- log.Printf("ERROR: webhook request: %v", err)
- return
- }
- req.Header.Set("Content-Type", "application/json")
-
- if wh.Secret != "" {
- mac := hmac.New(sha256.New, []byte(wh.Secret))
- mac.Write(body)
- sig := hex.EncodeToString(mac.Sum(nil))
- req.Header.Set("X-Webhook-Signature", "sha256="+sig)
- }
-
- resp, err := n.client.Do(req)
- if err != nil {
- log.Printf("WARNING: webhook delivery to %s failed: %v", wh.URL, err)
- return
- }
- resp.Body.Close()
-
- if resp.StatusCode >= 400 {
- log.Printf("WARNING: webhook %s returned %d", wh.URL, resp.StatusCode)
- }
-}