diff --git a/docker-compose.yml b/docker-compose.yml index e45d16a..607fbbb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -81,25 +81,53 @@ services: # retries: 5 # start_period: 30s - # ── LLM API proxy (new-api) ── - new-api: - image: calciumion/new-api:latest + # # ── LLM API proxy (DEPRECATED — replaced by llm-gateway) ── + # new-api: + # image: calciumion/new-api:latest + # ports: + # - "0.0.0.0:4000:3000" + # volumes: + # - new-api-data:/data + # environment: + # - SQL_DSN= + # - TZ=UTC + # - ENABLE_METRIC=true + # - LANG=en_US.UTF-8 + # restart: unless-stopped + # healthcheck: + # test: ["CMD", "wget", "-q", "-O", "/dev/null", "http://localhost:3000/"] + # interval: 15s + # timeout: 5s + # retries: 5 + # start_period: 10s + + # ── LLM API proxy ── + llm-gateway: + build: ./llm-gateway ports: - "0.0.0.0:4000:3000" volumes: - - new-api-data:/data + - llm-gateway-data:/data + - ./llm-gateway/configs/config.yaml:/etc/llm-gateway/config.yaml:ro environment: - - SQL_DSN= - - TZ=UTC - - ENABLE_METRIC=true - - LANG=en_US.UTF-8 + - DASHBOARD_TOKEN=${DASHBOARD_TOKEN} + - OPENWEBUI_API_KEY=${OPENWEBUI_API_KEY} + - PERSONAL_API_KEY=${PERSONAL_API_KEY} + - DEEPINFRA_API_KEY=${DEEPINFRA_API_KEY} + - SILICONFLOW_API_KEY=${SILICONFLOW_API_KEY} + - OPENROUTER_API_KEY=${OPENROUTER_API_KEY} + - GROQ_API_KEY=${GROQ_API_KEY} + - CEREBRAS_API_KEY=${CEREBRAS_API_KEY} + depends_on: + valkey: + condition: service_healthy restart: unless-stopped healthcheck: - test: ["CMD", "wget", "-q", "-O", "/dev/null", "http://localhost:3000/"] + test: ["CMD", "wget", "-q", "-O", "/dev/null", "http://localhost:3000/health"] interval: 15s timeout: 5s retries: 5 - start_period: 10s + start_period: 5s # ── Chat UI ── open-webui: @@ -110,7 +138,7 @@ services: - "0.0.0.0:3000:8080" environment: - OLLAMA_BASE_URL= - - OPENAI_API_BASE_URL=http://new-api:3000/v1 + - OPENAI_API_BASE_URL=http://llm-gateway:3000/v1 - OPENAI_API_KEY=${OPENWEBUI_API_KEY} - ENABLE_RAG_WEB_SEARCH=true - RAG_WEB_SEARCH_ENGINE=searxng @@ -119,7 +147,7 @@ services: - CHROMA_HTTP_PORT=8000 - WEBUI_AUTH=true depends_on: - new-api: + llm-gateway: condition: service_healthy restart: unless-stopped @@ -166,19 +194,19 @@ services: - "127.0.0.1:8428:8428" restart: unless-stopped - # ── Dashboards ── - grafana: - image: grafana/grafana:latest - volumes: - - grafana-data:/var/lib/grafana - ports: - - "0.0.0.0:3001:3000" - environment: - - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD} - - GF_USERS_ALLOW_SIGN_UP=false - depends_on: - - victoriametrics - restart: unless-stopped + # # ── Dashboards (DEPRECATED — replaced by llm-gateway built-in dashboard) ── + # grafana: + # image: grafana/grafana:latest + # volumes: + # - grafana-data:/var/lib/grafana + # ports: + # - "0.0.0.0:3001:3000" + # environment: + # - GF_SECURITY_ADMIN_PASSWORD=${GRAFANA_ADMIN_PASSWORD} + # - GF_USERS_ALLOW_SIGN_UP=false + # depends_on: + # - victoriametrics + # restart: unless-stopped # ── Host system metrics ── node-exporter: @@ -210,6 +238,7 @@ volumes: chromadb-data: litellm-db-data: new-api-data: + llm-gateway-data: open-webui-data: tailscale-state: victoriametrics-data: diff --git a/llm-gateway/.env.example b/llm-gateway/.env.example new file mode 100644 index 0000000..a1acce2 --- /dev/null +++ b/llm-gateway/.env.example @@ -0,0 +1,19 @@ +# LLM Gateway Environment Variables + +# Session secret (required for persistent sessions) +SESSION_SECRET=change-me-to-a-random-string + +# Default admin (created on first run if no users exist) +ADMIN_USERNAME=admin +ADMIN_PASSWORD=change-me-min-8-chars + +# Static API tokens (seeded on startup) +OPENWEBUI_API_KEY=sk-your-openwebui-key +PERSONAL_API_KEY=sk-your-personal-key + +# Provider API keys +DEEPINFRA_API_KEY= +SILICONFLOW_API_KEY= +OPENROUTER_API_KEY= +GROQ_API_KEY= +CEREBRAS_API_KEY= diff --git a/llm-gateway/.gitignore b/llm-gateway/.gitignore new file mode 100644 index 0000000..8bf05e4 --- /dev/null +++ b/llm-gateway/.gitignore @@ -0,0 +1,15 @@ +# Binaries +gateway +llm-gateway + +# Database +*.db +*.db-journal +*.db-wal +*.db-shm + +# Local config +configs/config.local.yaml + +# Environment +.env diff --git a/llm-gateway/Dockerfile b/llm-gateway/Dockerfile new file mode 100644 index 0000000..ff7b982 --- /dev/null +++ b/llm-gateway/Dockerfile @@ -0,0 +1,15 @@ +FROM golang:1.23-alpine AS builder +WORKDIR /src +COPY go.mod go.sum ./ +RUN go mod download +COPY . . +RUN CGO_ENABLED=0 go build -ldflags="-s -w" -o /llm-gateway ./cmd/gateway + +FROM alpine:3.19 +RUN apk add --no-cache ca-certificates tzdata +COPY --from=builder /llm-gateway /usr/local/bin/llm-gateway +RUN mkdir -p /data +VOLUME /data +EXPOSE 3000 +ENTRYPOINT ["llm-gateway"] +CMD ["-config", "/etc/llm-gateway/config.yaml"] diff --git a/llm-gateway/Makefile b/llm-gateway/Makefile new file mode 100644 index 0000000..f042edf --- /dev/null +++ b/llm-gateway/Makefile @@ -0,0 +1,16 @@ +.PHONY: build run clean docker + +BINARY=llm-gateway +VERSION=$(shell git describe --tags --always --dirty 2>/dev/null || echo dev) + +build: + go build -ldflags="-s -w -X main.version=$(VERSION)" -o $(BINARY) ./cmd/gateway + +run: build + ./$(BINARY) -config configs/config.yaml + +clean: + rm -f $(BINARY) + +docker: + docker build -t llm-gateway:latest . diff --git a/llm-gateway/cmd/gateway/main.go b/llm-gateway/cmd/gateway/main.go new file mode 100644 index 0000000..cb33f18 --- /dev/null +++ b/llm-gateway/cmd/gateway/main.go @@ -0,0 +1,281 @@ +package main + +import ( + "context" + "flag" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/prometheus/client_golang/prometheus/promhttp" + + "llm-gateway/internal/auth" + "llm-gateway/internal/cache" + "llm-gateway/internal/config" + "llm-gateway/internal/dashboard" + "llm-gateway/internal/metrics" + "llm-gateway/internal/pricing" + "llm-gateway/internal/provider" + "llm-gateway/internal/proxy" + "llm-gateway/internal/storage" +) + +var version = "dev" + +func main() { + configPath := flag.String("config", "configs/config.yaml", "path to config file") + flag.Parse() + + log.Printf("llm-gateway %s starting", version) + + cfg, err := config.Load(*configPath) + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + // Pricing lookup (fetches from URL, refreshes periodically) + pricingLookup := pricing.NewLookup(cfg.Pricing.URL, cfg.Pricing.RefreshInterval) + defer pricingLookup.Close() + + // Auto-fill missing pricing from fetched data + for i, m := range cfg.Models { + for j, r := range m.Routes { + if r.Pricing.Input == 0 && r.Pricing.Output == 0 { + if pricingLookup.FillMissing(r.Provider, r.Model, &cfg.Models[i].Routes[j].Pricing.Input, &cfg.Models[i].Routes[j].Pricing.Output) { + log.Printf("Auto-filled pricing for %s via %s: $%.2f/$%.2f per 1M tokens", + m.Name, r.Provider, cfg.Models[i].Routes[j].Pricing.Input, cfg.Models[i].Routes[j].Pricing.Output) + } + } + } + } + + // Database + db, err := storage.Open(cfg.Database.Path) + if err != nil { + log.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + if err := db.CleanupOldRecords(cfg.Database.RetentionDays); err != nil { + log.Printf("WARNING: retention cleanup failed: %v", err) + } + + asyncLogger := storage.NewAsyncLogger(db, 1000) + defer asyncLogger.Close() + + // SSE broker for real-time dashboard updates + sseBroker := dashboard.NewSSEBroker() + asyncLogger.OnFlush = sseBroker.Notify + + // Cache (optional) + var c *cache.Cache + if cfg.Cache.Enabled { + c, err = cache.New(cfg.Cache.Address, cfg.Cache.TTL) + if err != nil { + log.Printf("WARNING: cache disabled: %v", err) + } else { + log.Printf("Cache connected to %s", cfg.Cache.Address) + } + } + + // Provider registry + registry, err := provider.NewRegistry(cfg) + if err != nil { + log.Fatalf("Failed to build provider registry: %v", err) + } + log.Printf("Registered %d models", len(cfg.Models)) + + // Auth store + authStore := auth.NewStore(db.DB) + authMiddleware := auth.NewMiddleware(authStore) + authHandlers := auth.NewHandlers(authStore, cfg.Server.SessionSecret) + + // Seed default admin and static tokens + seedAdminAndTokens(cfg, authStore) + + // Metrics + m := metrics.New() + + // Handlers + proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg) + modelsHandler := proxy.NewModelsHandler(registry) + proxyAuth := proxy.NewAuthMiddleware(authStore) + rateLimiter := proxy.NewRateLimiter(db) + statsAPI := dashboard.NewStatsAPI(db, authStore) + dash := dashboard.NewDashboard(authStore, statsAPI) + + // Router + r := chi.NewRouter() + r.Use(middleware.RealIP) + r.Use(middleware.Recoverer) + r.Use(middleware.RequestID) + + // Health & metrics (public) + r.Get("/health", func(w http.ResponseWriter, r *http.Request) { + if err := db.Ping(); err != nil { + http.Error(w, "database unhealthy", http.StatusServiceUnavailable) + return + } + if c != nil { + if err := c.Ping(r.Context()); err != nil { + http.Error(w, "cache unhealthy", http.StatusServiceUnavailable) + return + } + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + r.Handle("/metrics", promhttp.Handler()) + + // OpenAI-compatible API (API token auth via Bearer header) + r.Group(func(r chi.Router) { + r.Use(proxyAuth.Authenticate) + r.Use(rateLimiter.Check) + r.Post("/v1/chat/completions", proxyHandler.ChatCompletions) + r.Get("/v1/models", modelsHandler.ListModels) + }) + + // Auth pages (public) + r.Get("/login", dash.LoginPage) + r.Get("/setup", dash.SetupPage) + + // Auth API endpoints (public) + r.Post("/api/auth/login", authHandlers.Login) + r.Post("/api/auth/setup", authHandlers.Setup) + r.Post("/api/auth/login/totp", authHandlers.LoginTOTP) + + // Root redirect + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/dashboard", http.StatusFound) + }) + + // Authenticated pages and API + r.Group(func(r chi.Router) { + r.Use(authMiddleware.RequireAuth) + + // Dashboard pages (HTMX) + r.Get("/dashboard", dash.DashboardPage) + r.Get("/tokens", dash.TokensPage) + r.Get("/settings", dash.SettingsPage) + + // Admin-only pages + r.Group(func(r chi.Router) { + r.Use(authMiddleware.RequireAdmin) + r.Get("/users", dash.UsersPage) + }) + + // Auth API + r.Post("/api/auth/logout", authHandlers.Logout) + r.Get("/api/auth/me", authHandlers.Me) + r.Put("/api/auth/me/password", authHandlers.ChangePassword) + r.Put("/api/auth/me/username", authHandlers.ChangeUsername) + r.Put("/api/auth/me/email", authHandlers.ChangeEmail) + r.Post("/api/auth/totp/setup", authHandlers.TOTPSetup) + r.Post("/api/auth/totp/verify", authHandlers.TOTPVerify) + r.Delete("/api/auth/totp", authHandlers.TOTPDisable) + + // API token management + r.Get("/api/tokens", authHandlers.ListTokens) + r.Post("/api/tokens", authHandlers.CreateToken) + r.Delete("/api/tokens/{id}", authHandlers.DeleteToken) + + // SSE events + r.Get("/api/events", sseBroker.ServeHTTP) + + // Dashboard stats + r.Get("/api/stats/summary", statsAPI.Summary) + r.Get("/api/stats/models", statsAPI.Models) + r.Get("/api/stats/providers", statsAPI.Providers) + r.Get("/api/stats/tokens", statsAPI.Tokens) + r.Get("/api/stats/timeseries", statsAPI.Timeseries) + + // Admin-only: user management + r.Group(func(r chi.Router) { + r.Use(authMiddleware.RequireAdmin) + r.Get("/api/auth/users", authHandlers.ListUsers) + r.Post("/api/auth/users", authHandlers.CreateUser) + r.Delete("/api/auth/users/{id}", authHandlers.DeleteUser) + }) + }) + + // Periodic session cleanup + go func() { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + for range ticker.C { + if err := authStore.CleanExpiredSessions(); err != nil { + log.Printf("WARNING: session cleanup failed: %v", err) + } + } + }() + + // Server + srv := &http.Server{ + Addr: cfg.Server.Listen, + Handler: r, + ReadTimeout: 30 * time.Second, + WriteTimeout: cfg.Server.RequestTimeout + 10*time.Second, + IdleTimeout: 120 * time.Second, + } + + // Graceful shutdown + done := make(chan os.Signal, 1) + signal.Notify(done, os.Interrupt, syscall.SIGTERM) + + go func() { + log.Printf("Listening on %s", cfg.Server.Listen) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Server failed: %v", err) + } + }() + + <-done + log.Println("Shutting down...") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + srv.Shutdown(ctx) + + log.Println("Stopped") +} + +// seedAdminAndTokens creates the default admin and seeds static tokens from config. +func seedAdminAndTokens(cfg *config.Config, authStore *auth.Store) { + // Seed default admin if no users exist + if !authStore.HasAnyUser() { + da := cfg.Server.DefaultAdmin + if da.Username != "" && da.Password != "" { + user, err := authStore.CreateUser(da.Username, da.Password, true) + if err != nil { + log.Printf("WARNING: failed to create default admin: %v", err) + } else { + log.Printf("Created default admin user: %s (id=%d)", user.Username, user.ID) + } + } + } + + // Seed static tokens from config + if len(cfg.Tokens) > 0 { + admin, err := authStore.GetFirstAdmin() + if err != nil { + log.Printf("WARNING: no admin user found, cannot seed static tokens") + return + } + + for _, t := range cfg.Tokens { + if t.Key == "" { + continue + } + if err := authStore.SeedStaticToken(admin.ID, t.Name, t.Key, t.RateLimitRPM, t.DailyBudgetUSD); err != nil { + log.Printf("WARNING: failed to seed token %q: %v", t.Name, err) + } else { + log.Printf("Seeded static token: %s", t.Name) + } + } + } +} diff --git a/llm-gateway/configs/config.yaml b/llm-gateway/configs/config.yaml new file mode 100644 index 0000000..91f1913 --- /dev/null +++ b/llm-gateway/configs/config.yaml @@ -0,0 +1,140 @@ +server: + listen: "0.0.0.0:3000" + request_timeout: 300s + max_request_body_mb: 10 + session_secret: "${SESSION_SECRET}" + default_admin: + username: "${ADMIN_USERNAME}" + password: "${ADMIN_PASSWORD}" + +tokens: + - name: "open-webui" + key: "${OPENWEBUI_API_KEY}" + rate_limit_rpm: 0 # unlimited + daily_budget_usd: 5.0 + - name: "rayandrew" + key: "${PERSONAL_API_KEY}" + rate_limit_rpm: 0 # unlimited + daily_budget_usd: 10.0 + +pricing_lookup: + # url: "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json" # default + refresh_interval: 6h + +database: + path: "/data/gateway.db" + retention_days: 90 + +cache: + enabled: true + address: "valkey:6379" + ttl: 3600 + +providers: + - name: deepinfra + base_url: "https://api.deepinfra.com/v1/openai" + api_key: "${DEEPINFRA_API_KEY}" + priority: 1 + timeout: 120s + - name: siliconflow + base_url: "https://api.siliconflow.com/v1" + api_key: "${SILICONFLOW_API_KEY}" + priority: 2 + timeout: 120s + - name: openrouter + base_url: "https://openrouter.ai/api/v1" + api_key: "${OPENROUTER_API_KEY}" + priority: 3 + timeout: 120s + - name: groq + base_url: "https://api.groq.com/openai/v1" + api_key: "${GROQ_API_KEY}" + priority: 1 + timeout: 120s + - name: cerebras + base_url: "https://api.cerebras.ai/v1" + api_key: "${CEREBRAS_API_KEY}" + priority: 1 + timeout: 120s + +models: + - name: "deepseek-v3.2" + routes: + - provider: deepinfra + model: "deepseek-ai/DeepSeek-V3.2" + pricing: { input: 0.26, output: 0.38 } + - provider: siliconflow + model: "deepseek-ai/DeepSeek-V3.2" + pricing: { input: 0.27, output: 0.42 } + - provider: openrouter + model: "deepseek/deepseek-chat-v3-0324" + pricing: { input: 0.30, output: 0.88 } + + - name: "llama-3.3-70b" + routes: + - provider: groq + model: "llama-3.3-70b-versatile" + pricing: { input: 0, output: 0 } + - provider: deepinfra + model: "meta-llama/Llama-3.3-70B-Instruct" + pricing: { input: 0.23, output: 0.40 } + + - name: "llama-3.1-8b" + routes: + - provider: groq + model: "llama-3.1-8b-instant" + pricing: { input: 0, output: 0 } + - provider: cerebras + model: "llama-3.1-8b" + pricing: { input: 0, output: 0 } + - provider: deepinfra + model: "meta-llama/Meta-Llama-3.1-8B-Instruct" + pricing: { input: 0.03, output: 0.05 } + + - name: "qwen-2.5-72b" + routes: + - provider: groq + model: "qwen-2.5-72b" + pricing: { input: 0, output: 0 } + - provider: deepinfra + model: "Qwen/Qwen2.5-72B-Instruct" + pricing: { input: 0.23, output: 0.40 } + + - name: "qwen-2.5-coder-32b" + routes: + - provider: groq + model: "qwen-2.5-coder-32b" + pricing: { input: 0, output: 0 } + - provider: deepinfra + model: "Qwen/Qwen2.5-Coder-32B-Instruct" + pricing: { input: 0.07, output: 0.16 } + + - name: "gemma-2-9b" + routes: + - provider: groq + model: "gemma2-9b-it" + pricing: { input: 0, output: 0 } + + - name: "deepseek-r1" + routes: + - provider: deepinfra + model: "deepseek-ai/DeepSeek-R1" + pricing: { input: 0.40, output: 1.60 } + - provider: openrouter + model: "deepseek/deepseek-r1" + pricing: { input: 0.55, output: 2.19 } + + - name: "deepseek-r1-distill-llama-70b" + routes: + - provider: groq + model: "deepseek-r1-distill-llama-70b" + pricing: { input: 0, output: 0 } + - provider: deepinfra + model: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B" + pricing: { input: 0.23, output: 0.69 } + + - name: "deepseek-r1-distill-qwen-32b" + routes: + - provider: deepinfra + model: "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" + pricing: { input: 0.07, output: 0.16 } diff --git a/llm-gateway/go.mod b/llm-gateway/go.mod new file mode 100644 index 0000000..1d54b54 --- /dev/null +++ b/llm-gateway/go.mod @@ -0,0 +1,38 @@ +module llm-gateway + +go 1.24.0 + +require ( + github.com/go-chi/chi/v5 v5.2.5 + github.com/golang-migrate/migrate/v4 v4.19.1 + github.com/pquerna/otp v1.5.0 + github.com/prometheus/client_golang v1.23.2 + github.com/redis/go-redis/v9 v9.17.3 + golang.org/x/crypto v0.48.0 + gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.45.0 +) + +require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/sys v0.41.0 // indirect + google.golang.org/protobuf v1.36.8 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/llm-gateway/go.sum b/llm-gateway/go.sum new file mode 100644 index 0000000..28784f6 --- /dev/null +++ b/llm-gateway/go.sum @@ -0,0 +1,121 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= +github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs= +github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= +github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4= +github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= +golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= +golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= +google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc= +google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.45.0 h1:r51cSGzKpbptxnby+EIIz5fop4VuE4qFoVEjNvWoObs= +modernc.org/sqlite v1.45.0/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/llm-gateway/internal/auth/handlers.go b/llm-gateway/internal/auth/handlers.go new file mode 100644 index 0000000..afbc867 --- /dev/null +++ b/llm-gateway/internal/auth/handlers.go @@ -0,0 +1,714 @@ +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" +) + +type Handlers struct { + store *Store + sessionSecret string + loginLimiter *loginRateLimiter +} + +func NewHandlers(store *Store, sessionSecret string) *Handlers { + return &Handlers{ + store: store, + sessionSecret: sessionSecret, + loginLimiter: newLoginRateLimiter(), + } +} + +// Login brute-force protection +type loginRateLimiter struct { + mu sync.Mutex + attempts map[string][]time.Time +} + +func newLoginRateLimiter() *loginRateLimiter { + return &loginRateLimiter{attempts: make(map[string][]time.Time)} +} + +func (l *loginRateLimiter) allow(ip string) bool { + l.mu.Lock() + defer l.mu.Unlock() + + now := time.Now() + cutoff := now.Add(-1 * time.Minute) + + // Clean old entries + recent := l.attempts[ip][:0] + for _, t := range l.attempts[ip] { + if t.After(cutoff) { + recent = append(recent, t) + } + } + l.attempts[ip] = recent + + if len(recent) >= 5 { + return false + } + l.attempts[ip] = append(l.attempts[ip], now) + return true +} + +func (h *Handlers) Status(w http.ResponseWriter, r *http.Request) { + initialized := h.store.HasAnyUser() + + resp := map[string]any{ + "initialized": initialized, + "logged_in": false, + } + + cookie, err := r.Cookie(sessionCookieName) + if err == nil && cookie.Value != "" { + sess, err := h.store.GetSession(cookie.Value) + if err == nil { + user, err := h.store.GetUserByID(sess.UserID) + if err == nil { + resp["logged_in"] = true + resp["user"] = map[string]any{ + "id": user.ID, + "username": user.Username, + "is_admin": user.IsAdmin, + "totp_enabled": user.TOTPEnabled, + } + } + } + } + + writeJSON(w, resp) +} + +func (h *Handlers) Setup(w http.ResponseWriter, r *http.Request) { + if h.store.HasAnyUser() { + writeError(w, http.StatusBadRequest, "already initialized") + return + } + + var req struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.Username == "" || req.Password == "" { + writeError(w, http.StatusBadRequest, "username and password required") + return + } + if len(req.Password) < 8 { + writeError(w, http.StatusBadRequest, "password must be at least 8 characters") + return + } + + user, err := h.store.CreateUser(req.Username, req.Password, true) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to create user: "+err.Error()) + return + } + + // Auto-login + sessionID, err := h.store.CreateSession(user.ID, 7*24*time.Hour) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to create session") + return + } + + h.setSessionCookie(w, sessionID) + writeJSON(w, map[string]any{ + "user": map[string]any{ + "id": user.ID, + "username": user.Username, + "is_admin": user.IsAdmin, + }, + }) +} + +func (h *Handlers) Login(w http.ResponseWriter, r *http.Request) { + ip := r.RemoteAddr + if fwd := r.Header.Get("X-Real-IP"); fwd != "" { + ip = fwd + } + if !h.loginLimiter.allow(ip) { + writeError(w, http.StatusTooManyRequests, "too many login attempts, try again in a minute") + return + } + + var req struct { + Username string `json:"username"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + + user, err := h.store.GetUserByUsername(req.Username) + if err != nil { + writeError(w, http.StatusUnauthorized, "invalid credentials") + return + } + + if !h.store.CheckPassword(user, req.Password) { + writeError(w, http.StatusUnauthorized, "invalid credentials") + return + } + + if user.TOTPEnabled { + // Set pending cookie for TOTP step + pending := h.signPendingToken(user.ID) + http.SetCookie(w, &http.Cookie{ + Name: "llmgw_pending", + Value: pending, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: 300, // 5 minutes + }) + writeJSON(w, map[string]any{"require_totp": true}) + return + } + + sessionID, err := h.store.CreateSession(user.ID, 7*24*time.Hour) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to create session") + return + } + + h.setSessionCookie(w, sessionID) + writeJSON(w, map[string]any{ + "require_totp": false, + "user": map[string]any{ + "id": user.ID, + "username": user.Username, + "is_admin": user.IsAdmin, + }, + }) +} + +func (h *Handlers) LoginTOTP(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("llmgw_pending") + if err != nil || cookie.Value == "" { + writeError(w, http.StatusBadRequest, "no pending login") + return + } + + userID, err := h.verifyPendingToken(cookie.Value) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid or expired pending login") + return + } + + var req struct { + Code string `json:"code"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + + user, err := h.store.GetUserByID(userID) + if err != nil { + writeError(w, http.StatusBadRequest, "user not found") + return + } + + if !ValidateTOTPCode(user.TOTPSecret, req.Code) { + writeError(w, http.StatusUnauthorized, "invalid TOTP code") + return + } + + // Clear pending cookie + http.SetCookie(w, &http.Cookie{ + Name: "llmgw_pending", + Value: "", + Path: "/", + HttpOnly: true, + MaxAge: -1, + }) + + sessionID, err := h.store.CreateSession(user.ID, 7*24*time.Hour) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to create session") + return + } + + h.setSessionCookie(w, sessionID) + writeJSON(w, map[string]any{ + "user": map[string]any{ + "id": user.ID, + "username": user.Username, + "is_admin": user.IsAdmin, + }, + }) +} + +func (h *Handlers) Logout(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(sessionCookieName) + if err == nil { + h.store.DeleteSession(cookie.Value) + } + + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: "", + Path: "/", + HttpOnly: true, + MaxAge: -1, + }) + + writeJSON(w, map[string]string{"status": "ok"}) +} + +func (h *Handlers) Me(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + writeJSON(w, map[string]any{ + "id": user.ID, + "username": user.Username, + "is_admin": user.IsAdmin, + "totp_enabled": user.TOTPEnabled, + }) +} + +func (h *Handlers) TOTPSetup(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + + key, err := GenerateTOTPKey(user.Username) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to generate TOTP key") + return + } + + if err := h.store.SetTOTPSecret(user.ID, key.Secret()); err != nil { + writeError(w, http.StatusInternalServerError, "failed to save TOTP secret") + return + } + + writeJSON(w, map[string]string{ + "secret": key.Secret(), + "uri": key.URL(), + }) +} + +func (h *Handlers) TOTPVerify(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + + var req struct { + Code string `json:"code"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + + // Re-fetch user to get latest TOTP secret + user, err := h.store.GetUserByID(user.ID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to fetch user") + return + } + + if user.TOTPSecret == "" { + writeError(w, http.StatusBadRequest, "TOTP not set up, call /api/auth/totp/setup first") + return + } + + if !ValidateTOTPCode(user.TOTPSecret, req.Code) { + writeError(w, http.StatusBadRequest, "invalid TOTP code") + return + } + + if err := h.store.EnableTOTP(user.ID); err != nil { + writeError(w, http.StatusInternalServerError, "failed to enable TOTP") + return + } + + writeJSON(w, map[string]string{"status": "totp_enabled"}) +} + +func (h *Handlers) TOTPDisable(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + + if err := h.store.DisableTOTP(user.ID); err != nil { + writeError(w, http.StatusInternalServerError, "failed to disable TOTP") + return + } + + writeJSON(w, map[string]string{"status": "totp_disabled"}) +} + +// User management (admin only) + +func (h *Handlers) ListUsers(w http.ResponseWriter, r *http.Request) { + users, err := h.store.ListUsers() + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to list users") + return + } + + // Strip sensitive fields + type safeUser struct { + ID int64 `json:"id"` + Username string `json:"username"` + IsAdmin bool `json:"is_admin"` + TOTPEnabled bool `json:"totp_enabled"` + CreatedAt int64 `json:"created_at"` + } + result := make([]safeUser, len(users)) + for i, u := range users { + result[i] = safeUser{ + ID: u.ID, + Username: u.Username, + IsAdmin: u.IsAdmin, + TOTPEnabled: u.TOTPEnabled, + CreatedAt: u.CreatedAt, + } + } + writeJSON(w, result) +} + +func (h *Handlers) CreateUser(w http.ResponseWriter, r *http.Request) { + var req struct { + Username string `json:"username"` + Password string `json:"password"` + IsAdmin bool `json:"is_admin"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.Username == "" || req.Password == "" { + writeError(w, http.StatusBadRequest, "username and password required") + return + } + if len(req.Password) < 8 { + writeError(w, http.StatusBadRequest, "password must be at least 8 characters") + return + } + + user, err := h.store.CreateUser(req.Username, req.Password, req.IsAdmin) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to create user: "+err.Error()) + return + } + + writeJSON(w, map[string]any{ + "id": user.ID, + "username": user.Username, + "is_admin": user.IsAdmin, + }) +} + +func (h *Handlers) DeleteUser(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid user ID") + return + } + + // Prevent deleting yourself + user := UserFromContext(r.Context()) + if user != nil && user.ID == id { + writeError(w, http.StatusBadRequest, "cannot delete yourself") + return + } + + if err := h.store.DeleteUser(id); err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + + writeJSON(w, map[string]string{"status": "deleted"}) +} + +// API Token management + +func (h *Handlers) ListTokens(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + + var userID int64 + if !user.IsAdmin { + userID = user.ID + } + // userID=0 means list all (admin) + + tokens, err := h.store.ListAPITokens(userID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to list tokens") + return + } + if tokens == nil { + tokens = []APIToken{} + } + writeJSON(w, tokens) +} + +func (h *Handlers) CreateToken(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + + var req struct { + Name string `json:"name"` + RateLimitRPM int `json:"rate_limit_rpm"` + DailyBudgetUSD float64 `json:"daily_budget_usd"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.Name == "" { + writeError(w, http.StatusBadRequest, "name is required") + return + } + // RateLimitRPM: 0 = unlimited, negative treated as 0 + if req.RateLimitRPM < 0 { + req.RateLimitRPM = 0 + } + + plainKey, token, err := h.store.CreateAPIToken(user.ID, req.Name, req.RateLimitRPM, req.DailyBudgetUSD) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to create token: "+err.Error()) + return + } + + writeJSON(w, map[string]any{ + "key": plainKey, + "token": token, + }) +} + +func (h *Handlers) DeleteToken(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + + idStr := chi.URLParam(r, "id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid token ID") + return + } + + // Non-admin can only delete own tokens + if !user.IsAdmin { + token, err := h.store.GetAPIToken(id) + if err != nil { + writeError(w, http.StatusNotFound, "token not found") + return + } + if token.UserID != user.ID { + writeError(w, http.StatusForbidden, "not your token") + return + } + } + + if err := h.store.DeleteAPIToken(id); err != nil { + writeError(w, http.StatusInternalServerError, "failed to delete token") + return + } + + writeJSON(w, map[string]string{"status": "deleted"}) +} + +// Self-service endpoints + +func (h *Handlers) ChangePassword(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + + var req struct { + CurrentPassword string `json:"current_password"` + NewPassword string `json:"new_password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.NewPassword == "" || len(req.NewPassword) < 8 { + writeError(w, http.StatusBadRequest, "new password must be at least 8 characters") + return + } + + // Re-fetch user to get password hash + user, err := h.store.GetUserByID(user.ID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to fetch user") + return + } + + if !h.store.CheckPassword(user, req.CurrentPassword) { + writeError(w, http.StatusUnauthorized, "current password is incorrect") + return + } + + if err := h.store.UpdatePassword(user.ID, req.NewPassword); err != nil { + writeError(w, http.StatusInternalServerError, "failed to update password") + return + } + + writeJSON(w, map[string]string{"status": "password_updated"}) +} + +func (h *Handlers) ChangeUsername(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + + var req struct { + NewUsername string `json:"new_username"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.NewUsername == "" { + writeError(w, http.StatusBadRequest, "username is required") + return + } + + // Check uniqueness + existing, err := h.store.GetUserByUsername(req.NewUsername) + if err == nil && existing.ID != user.ID { + writeError(w, http.StatusConflict, "username already taken") + return + } + + if err := h.store.UpdateUsername(user.ID, req.NewUsername); err != nil { + writeError(w, http.StatusInternalServerError, "failed to update username") + return + } + + writeJSON(w, map[string]string{"status": "username_updated"}) +} + +func (h *Handlers) ChangeEmail(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil { + writeError(w, http.StatusUnauthorized, "not authenticated") + return + } + + var req struct { + Email string `json:"email"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + + if err := h.store.UpdateEmail(user.ID, req.Email); err != nil { + writeError(w, http.StatusInternalServerError, "failed to update email") + return + } + + writeJSON(w, map[string]string{"status": "email_updated"}) +} + +// Helpers + +func (h *Handlers) setSessionCookie(w http.ResponseWriter, sessionID string) { + http.SetCookie(w, &http.Cookie{ + Name: sessionCookieName, + Value: sessionID, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: sessionTTLDays * 24 * 60 * 60, + }) +} + +func (h *Handlers) signPendingToken(userID int64) string { + data := fmt.Sprintf("%d:%d", userID, time.Now().Unix()) + mac := hmac.New(sha256.New, []byte(h.sessionSecret)) + mac.Write([]byte(data)) + sig := hex.EncodeToString(mac.Sum(nil)) + return data + ":" + sig +} + +func (h *Handlers) verifyPendingToken(token string) (int64, error) { + parts := strings.SplitN(token, ":", 3) + if len(parts) != 3 { + return 0, fmt.Errorf("invalid format") + } + + userID, err := strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid user ID") + } + + ts, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid timestamp") + } + + // Check expiry (5 minutes) + if time.Now().Unix()-ts > 300 { + return 0, fmt.Errorf("expired") + } + + // Verify HMAC + data := parts[0] + ":" + parts[1] + mac := hmac.New(sha256.New, []byte(h.sessionSecret)) + mac.Write([]byte(data)) + expectedSig := hex.EncodeToString(mac.Sum(nil)) + + if !hmac.Equal([]byte(parts[2]), []byte(expectedSig)) { + return 0, fmt.Errorf("invalid signature") + } + + return userID, nil +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, code int, msg string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(map[string]string{"error": msg}) +} diff --git a/llm-gateway/internal/auth/middleware.go b/llm-gateway/internal/auth/middleware.go new file mode 100644 index 0000000..bceadf1 --- /dev/null +++ b/llm-gateway/internal/auth/middleware.go @@ -0,0 +1,83 @@ +package auth + +import ( + "context" + "encoding/json" + "net/http" + "strings" +) + +type contextKey string + +const userContextKey contextKey = "auth_user" + +const ( + sessionCookieName = "llmgw_session" + sessionTTLDays = 7 +) + +type Middleware struct { + store *Store +} + +func NewMiddleware(store *Store) *Middleware { + return &Middleware{store: store} +} + +func UserFromContext(ctx context.Context) *User { + u, _ := ctx.Value(userContextKey).(*User) + return u +} + +func (m *Middleware) RequireAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(sessionCookieName) + if err != nil || cookie.Value == "" { + m.unauthorized(w, r) + return + } + + sess, err := m.store.GetSession(cookie.Value) + if err != nil { + m.unauthorized(w, r) + return + } + + user, err := m.store.GetUserByID(sess.UserID) + if err != nil { + m.unauthorized(w, r) + return + } + + ctx := context.WithValue(r.Context(), userContextKey, user) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func (m *Middleware) RequireAdmin(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := UserFromContext(r.Context()) + if user == nil || !user.IsAdmin { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + json.NewEncoder(w).Encode(map[string]string{"error": "admin access required"}) + return + } + next.ServeHTTP(w, r) + }) +} + +func (m *Middleware) unauthorized(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("HX-Request") == "true" { + w.Header().Set("HX-Redirect", "/login") + w.WriteHeader(http.StatusUnauthorized) + return + } + if strings.HasPrefix(r.URL.Path, "/api/") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]string{"error": "authentication required"}) + return + } + http.Redirect(w, r, "/login", http.StatusFound) +} diff --git a/llm-gateway/internal/auth/store.go b/llm-gateway/internal/auth/store.go new file mode 100644 index 0000000..c829444 --- /dev/null +++ b/llm-gateway/internal/auth/store.go @@ -0,0 +1,367 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "database/sql" + "encoding/hex" + "fmt" + "time" + + "golang.org/x/crypto/bcrypt" +) + +type User struct { + ID int64 `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + PasswordHash string `json:"-"` + IsAdmin bool `json:"is_admin"` + TOTPSecret string `json:"-"` + TOTPEnabled bool `json:"totp_enabled"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type Session struct { + ID string + UserID int64 + CreatedAt int64 + ExpiresAt int64 +} + +type APIToken struct { + ID int64 `json:"id"` + Name string `json:"name"` + KeyPrefix string `json:"key_prefix"` + KeyHash string `json:"-"` + UserID int64 `json:"user_id"` + RateLimitRPM int `json:"rate_limit_rpm"` + DailyBudgetUSD float64 `json:"daily_budget_usd"` + CreatedAt int64 `json:"created_at"` + LastUsedAt int64 `json:"last_used_at"` +} + +type Store struct { + db *sql.DB +} + +func NewStore(db *sql.DB) *Store { + return &Store{db: db} +} + +func (s *Store) HasAnyUser() bool { + var count int + s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) + return count > 0 +} + +func (s *Store) CreateUser(username, password string, isAdmin bool) (*User, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), 12) + if err != nil { + return nil, fmt.Errorf("hashing password: %w", err) + } + + now := time.Now().Unix() + adminInt := 0 + if isAdmin { + adminInt = 1 + } + + result, err := s.db.Exec( + "INSERT INTO users (username, password_hash, is_admin, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", + username, string(hash), adminInt, now, now, + ) + if err != nil { + return nil, fmt.Errorf("creating user: %w", err) + } + + id, _ := result.LastInsertId() + return &User{ + ID: id, + Username: username, + IsAdmin: isAdmin, + CreatedAt: now, + UpdatedAt: now, + }, nil +} + +func (s *Store) GetUserByUsername(username string) (*User, error) { + return s.scanUser(s.db.QueryRow( + "SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users WHERE username = ?", + username, + )) +} + +func (s *Store) GetUserByID(id int64) (*User, error) { + return s.scanUser(s.db.QueryRow( + "SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users WHERE id = ?", + id, + )) +} + +func (s *Store) GetFirstAdmin() (*User, error) { + return s.scanUser(s.db.QueryRow( + "SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users WHERE is_admin = 1 ORDER BY id LIMIT 1", + )) +} + +func (s *Store) scanUser(row *sql.Row) (*User, error) { + var u User + var isAdmin, totpEnabled int + var totpSecret sql.NullString + var email sql.NullString + err := row.Scan(&u.ID, &u.Username, &email, &u.PasswordHash, &isAdmin, &totpSecret, &totpEnabled, &u.CreatedAt, &u.UpdatedAt) + if err != nil { + return nil, err + } + u.Email = email.String + u.IsAdmin = isAdmin == 1 + u.TOTPEnabled = totpEnabled == 1 + u.TOTPSecret = totpSecret.String + return &u, nil +} + +func (s *Store) ListUsers() ([]User, error) { + rows, err := s.db.Query("SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users ORDER BY id") + if err != nil { + return nil, err + } + defer rows.Close() + + var users []User + for rows.Next() { + var u User + var isAdmin, totpEnabled int + var totpSecret sql.NullString + var email sql.NullString + if err := rows.Scan(&u.ID, &u.Username, &email, &u.PasswordHash, &isAdmin, &totpSecret, &totpEnabled, &u.CreatedAt, &u.UpdatedAt); err != nil { + return nil, err + } + u.Email = email.String + u.IsAdmin = isAdmin == 1 + u.TOTPEnabled = totpEnabled == 1 + u.TOTPSecret = totpSecret.String + users = append(users, u) + } + return users, nil +} + +func (s *Store) DeleteUser(id int64) error { + // Prevent deleting the last admin + var adminCount int + s.db.QueryRow("SELECT COUNT(*) FROM users WHERE is_admin = 1").Scan(&adminCount) + + var isAdmin int + s.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", id).Scan(&isAdmin) + if isAdmin == 1 && adminCount <= 1 { + return fmt.Errorf("cannot delete the last admin user") + } + + _, err := s.db.Exec("DELETE FROM users WHERE id = ?", id) + return err +} + +func (s *Store) UpdatePassword(userID int64, newPassword string) error { + hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12) + if err != nil { + return fmt.Errorf("hashing password: %w", err) + } + _, err = s.db.Exec("UPDATE users SET password_hash = ?, updated_at = ? WHERE id = ?", string(hash), time.Now().Unix(), userID) + return err +} + +func (s *Store) CheckPassword(user *User, password string) bool { + return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) == nil +} + +func (s *Store) SetTOTPSecret(userID int64, secret string) error { + _, err := s.db.Exec("UPDATE users SET totp_secret = ?, updated_at = ? WHERE id = ?", secret, time.Now().Unix(), userID) + return err +} + +func (s *Store) EnableTOTP(userID int64) error { + _, err := s.db.Exec("UPDATE users SET totp_enabled = 1, updated_at = ? WHERE id = ?", time.Now().Unix(), userID) + return err +} + +func (s *Store) DisableTOTP(userID int64) error { + _, err := s.db.Exec("UPDATE users SET totp_enabled = 0, totp_secret = '', updated_at = ? WHERE id = ?", time.Now().Unix(), userID) + return err +} + +// Session management + +func (s *Store) CreateSession(userID int64, ttl time.Duration) (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generating session ID: %w", err) + } + id := hex.EncodeToString(b) + now := time.Now().Unix() + expiresAt := time.Now().Add(ttl).Unix() + + _, err := s.db.Exec( + "INSERT INTO sessions (id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)", + id, userID, now, expiresAt, + ) + if err != nil { + return "", fmt.Errorf("creating session: %w", err) + } + return id, nil +} + +func (s *Store) GetSession(sessionID string) (*Session, error) { + var sess Session + err := s.db.QueryRow( + "SELECT id, user_id, created_at, expires_at FROM sessions WHERE id = ? AND expires_at > ?", + sessionID, time.Now().Unix(), + ).Scan(&sess.ID, &sess.UserID, &sess.CreatedAt, &sess.ExpiresAt) + if err != nil { + return nil, err + } + return &sess, nil +} + +func (s *Store) DeleteSession(id string) error { + _, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", id) + return err +} + +func (s *Store) CleanExpiredSessions() error { + _, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now().Unix()) + return err +} + +// API Token management + +func (s *Store) CreateAPIToken(userID int64, name string, rateLimitRPM int, dailyBudgetUSD float64) (string, *APIToken, error) { + // Generate sk- prefixed random key + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", nil, fmt.Errorf("generating token: %w", err) + } + plainKey := "sk-" + hex.EncodeToString(b) + keyPrefix := plainKey[:11] // "sk-" + first 8 hex chars + + hash := sha256.Sum256([]byte(plainKey)) + keyHash := hex.EncodeToString(hash[:]) + + now := time.Now().Unix() + result, err := s.db.Exec( + "INSERT INTO api_tokens (name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + name, keyHash, keyPrefix, userID, rateLimitRPM, dailyBudgetUSD, now, + ) + if err != nil { + return "", nil, fmt.Errorf("creating API token: %w", err) + } + + id, _ := result.LastInsertId() + token := &APIToken{ + ID: id, + Name: name, + KeyPrefix: keyPrefix, + KeyHash: keyHash, + UserID: userID, + RateLimitRPM: rateLimitRPM, + DailyBudgetUSD: dailyBudgetUSD, + CreatedAt: now, + } + return plainKey, token, nil +} + +func (s *Store) LookupAPIToken(key string) (*APIToken, error) { + hash := sha256.Sum256([]byte(key)) + keyHash := hex.EncodeToString(hash[:]) + + var t APIToken + err := s.db.QueryRow( + "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE key_hash = ?", + keyHash, + ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt) + if err != nil { + return nil, err + } + return &t, nil +} + +func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) { + var rows *sql.Rows + var err error + if userID == 0 { + // Admin: list all + rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens ORDER BY id") + } else { + rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID) + } + if err != nil { + return nil, err + } + defer rows.Close() + + var tokens []APIToken + for rows.Next() { + var t APIToken + if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt); err != nil { + return nil, err + } + tokens = append(tokens, t) + } + return tokens, nil +} + +func (s *Store) DeleteAPIToken(id int64) error { + _, err := s.db.Exec("DELETE FROM api_tokens WHERE id = ?", id) + return err +} + +func (s *Store) GetAPIToken(id int64) (*APIToken, error) { + var t APIToken + err := s.db.QueryRow( + "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens WHERE id = ?", + id, + ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.CreatedAt, &t.LastUsedAt) + if err != nil { + return nil, err + } + return &t, nil +} + +func (s *Store) UpdateAPITokenLastUsed(id int64) { + s.db.Exec("UPDATE api_tokens SET last_used_at = ? WHERE id = ?", time.Now().Unix(), id) +} + +// SeedStaticToken creates a token by name if it doesn't already exist (idempotent). +func (s *Store) SeedStaticToken(userID int64, name, plainKey string, rateLimitRPM int, dailyBudgetUSD float64) error { + // Check if token with this name already exists + var count int + s.db.QueryRow("SELECT COUNT(*) FROM api_tokens WHERE name = ?", name).Scan(&count) + if count > 0 { + return nil // already seeded + } + + keyPrefix := plainKey + if len(keyPrefix) > 11 { + keyPrefix = keyPrefix[:11] + } + + hash := sha256.Sum256([]byte(plainKey)) + keyHash := hex.EncodeToString(hash[:]) + + now := time.Now().Unix() + _, err := s.db.Exec( + "INSERT INTO api_tokens (name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)", + name, keyHash, keyPrefix, userID, rateLimitRPM, dailyBudgetUSD, now, + ) + return err +} + +func (s *Store) UpdateUsername(userID int64, newUsername string) error { + _, err := s.db.Exec("UPDATE users SET username = ?, updated_at = ? WHERE id = ?", newUsername, time.Now().Unix(), userID) + return err +} + +func (s *Store) UpdateEmail(userID int64, email string) error { + _, err := s.db.Exec("UPDATE users SET email = ?, updated_at = ? WHERE id = ?", email, time.Now().Unix(), userID) + return err +} diff --git a/llm-gateway/internal/auth/totp.go b/llm-gateway/internal/auth/totp.go new file mode 100644 index 0000000..b3092ff --- /dev/null +++ b/llm-gateway/internal/auth/totp.go @@ -0,0 +1,17 @@ +package auth + +import ( + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" +) + +func GenerateTOTPKey(username string) (*otp.Key, error) { + return totp.Generate(totp.GenerateOpts{ + Issuer: "LLM Gateway", + AccountName: username, + }) +} + +func ValidateTOTPCode(secret, code string) bool { + return totp.Validate(code, secret) +} diff --git a/llm-gateway/internal/cache/cache.go b/llm-gateway/internal/cache/cache.go new file mode 100644 index 0000000..9d9df2a --- /dev/null +++ b/llm-gateway/internal/cache/cache.go @@ -0,0 +1,64 @@ +package cache + +import ( + "context" + "crypto/sha256" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +type Cache struct { + client *redis.Client + ttl time.Duration +} + +func New(addr string, ttlSeconds int) (*Cache, error) { + client := redis.NewClient(&redis.Options{ + Addr: addr, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("connecting to Valkey: %w", err) + } + + ttl := time.Duration(ttlSeconds) * time.Second + if ttl == 0 { + ttl = 1 * time.Hour + } + + return &Cache{client: client, ttl: ttl}, nil +} + +func (c *Cache) Get(ctx context.Context, model string, requestBody []byte) ([]byte, error) { + key := c.cacheKey(model, requestBody) + data, err := c.client.Get(ctx, key).Bytes() + if err == redis.Nil { + return nil, nil + } + return data, err +} + +func (c *Cache) Set(ctx context.Context, model string, requestBody, responseBody []byte) error { + key := c.cacheKey(model, requestBody) + return c.client.Set(ctx, key, responseBody, c.ttl).Err() +} + +func (c *Cache) Ping(ctx context.Context) error { + return c.client.Ping(ctx).Err() +} + +func (c *Cache) Close() error { + return c.client.Close() +} + +func (c *Cache) cacheKey(model string, requestBody []byte) string { + h := sha256.New() + h.Write([]byte(model)) + h.Write(requestBody) + return fmt.Sprintf("llm-gw:%x", h.Sum(nil)) +} diff --git a/llm-gateway/internal/config/config.go b/llm-gateway/internal/config/config.go new file mode 100644 index 0000000..2aaec2c --- /dev/null +++ b/llm-gateway/internal/config/config.go @@ -0,0 +1,198 @@ +package config + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "log" + "os" + "time" + + "gopkg.in/yaml.v3" +) + +type Config struct { + Server ServerConfig `yaml:"server"` + Database DatabaseConfig `yaml:"database"` + Cache CacheConfig `yaml:"cache"` + Pricing PricingLookupConfig `yaml:"pricing_lookup"` + Providers []ProviderConfig `yaml:"providers"` + Models []ModelConfig `yaml:"models"` + Tokens []TokenConfig `yaml:"tokens"` +} + +type PricingLookupConfig struct { + URL string `yaml:"url"` + RefreshInterval time.Duration `yaml:"refresh_interval"` +} + +type DefaultAdminConfig struct { + Username string `yaml:"username"` + Password string `yaml:"password"` +} + +type TokenConfig struct { + Name string `yaml:"name"` + Key string `yaml:"key"` + RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited + DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited +} + +type ServerConfig struct { + Listen string `yaml:"listen"` + RequestTimeout time.Duration `yaml:"request_timeout"` + MaxRequestBodyMB int `yaml:"max_request_body_mb"` + SessionSecret string `yaml:"session_secret"` + DefaultAdmin DefaultAdminConfig `yaml:"default_admin"` +} + +type DatabaseConfig struct { + Path string `yaml:"path"` + RetentionDays int `yaml:"retention_days"` +} + +type CacheConfig struct { + Enabled bool `yaml:"enabled"` + Address string `yaml:"address"` + TTL int `yaml:"ttl"` // seconds +} + +type ProviderConfig struct { + Name string `yaml:"name"` + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + Priority int `yaml:"priority"` + Timeout time.Duration `yaml:"timeout"` +} + +type ModelConfig struct { + Name string `yaml:"name"` + Routes []RouteConfig `yaml:"routes"` +} + +type RouteConfig struct { + Provider string `yaml:"provider"` + Model string `yaml:"model"` + Pricing PricingConfig `yaml:"pricing"` +} + +type PricingConfig struct { + Input float64 `yaml:"input"` // cost per 1M tokens + Output float64 `yaml:"output"` // cost per 1M tokens +} + +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading config: %w", err) + } + + // Expand environment variables + expanded := os.ExpandEnv(string(data)) + + var cfg Config + if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil { + return nil, fmt.Errorf("parsing config: %w", err) + } + + if err := cfg.validate(); err != nil { + return nil, fmt.Errorf("validating config: %w", err) + } + + return &cfg, nil +} + +func (c *Config) validate() error { + if c.Server.Listen == "" { + c.Server.Listen = "0.0.0.0:3000" + } + if c.Server.RequestTimeout == 0 { + c.Server.RequestTimeout = 300 * time.Second + } + if c.Server.MaxRequestBodyMB == 0 { + c.Server.MaxRequestBodyMB = 10 + } + if c.Server.SessionSecret == "" { + b := make([]byte, 32) + rand.Read(b) + c.Server.SessionSecret = hex.EncodeToString(b) + log.Println("WARNING: no session_secret configured, generated random one (sessions won't survive restart)") + } + if c.Database.Path == "" { + c.Database.Path = "gateway.db" + } + if c.Database.RetentionDays == 0 { + c.Database.RetentionDays = 90 + } + if c.Pricing.RefreshInterval == 0 { + c.Pricing.RefreshInterval = 6 * time.Hour + } + + if len(c.Providers) == 0 { + return fmt.Errorf("at least one provider is required") + } + providerNames := make(map[string]bool) + for i, p := range c.Providers { + if p.Name == "" || p.BaseURL == "" || p.APIKey == "" { + return fmt.Errorf("provider %d: name, base_url, and api_key are required", i) + } + if providerNames[p.Name] { + return fmt.Errorf("duplicate provider name: %s", p.Name) + } + providerNames[p.Name] = true + if c.Providers[i].Timeout == 0 { + c.Providers[i].Timeout = 120 * time.Second + } + if c.Providers[i].Priority == 0 { + c.Providers[i].Priority = 1 + } + } + + if len(c.Models) == 0 { + return fmt.Errorf("at least one model is required") + } + modelNames := make(map[string]bool) + for i, m := range c.Models { + if m.Name == "" { + return fmt.Errorf("model %d: name is required", i) + } + if modelNames[m.Name] { + return fmt.Errorf("duplicate model name: %s", m.Name) + } + modelNames[m.Name] = true + if len(m.Routes) == 0 { + return fmt.Errorf("model %s: at least one route is required", m.Name) + } + for j, r := range m.Routes { + if r.Provider == "" || r.Model == "" { + return fmt.Errorf("model %s route %d: provider and model are required", m.Name, j) + } + if !providerNames[r.Provider] { + return fmt.Errorf("model %s route %d: unknown provider %s", m.Name, j, r.Provider) + } + } + } + + // Validate tokens (optional section) + for i, t := range c.Tokens { + if t.Key == "" { + log.Printf("WARNING: token %d (%s) has empty key, skipping", i, t.Name) + continue + } + if t.Name == "" { + c.Tokens[i].Name = fmt.Sprintf("token-%d", i) + } + } + + return nil +} + +// ProviderByName returns the provider config by name. +func (c *Config) ProviderByName(name string) *ProviderConfig { + for i := range c.Providers { + if c.Providers[i].Name == name { + return &c.Providers[i] + } + } + return nil +} diff --git a/llm-gateway/internal/dashboard/api.go b/llm-gateway/internal/dashboard/api.go new file mode 100644 index 0000000..452ef27 --- /dev/null +++ b/llm-gateway/internal/dashboard/api.go @@ -0,0 +1,308 @@ +package dashboard + +import ( + "encoding/json" + "net/http" + "time" + + "llm-gateway/internal/auth" + "llm-gateway/internal/storage" +) + +// Exported types for template rendering and JSON API. + +type Period struct { + Requests int `json:"requests"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CostUSD float64 `json:"cost_usd"` + Errors int `json:"errors"` + CachedHits int `json:"cached_hits"` +} + +type SummaryResult struct { + Today *Period `json:"today"` + Week *Period `json:"week"` + Month *Period `json:"month"` +} + +type ModelStats struct { + Model string `json:"model"` + Requests int `json:"requests"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CostUSD float64 `json:"cost_usd"` + AvgLatencyMS float64 `json:"avg_latency_ms"` +} + +type ProviderStats struct { + Provider string `json:"provider"` + Requests int `json:"requests"` + Successes int `json:"successes"` + Errors int `json:"errors"` + AvgLatencyMS float64 `json:"avg_latency_ms"` + CostUSD float64 `json:"cost_usd"` +} + +type TokenUsageStats struct { + TokenName string `json:"token_name"` + Requests int `json:"requests"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CostUSD float64 `json:"cost_usd"` +} + +type StatsAPI struct { + db *storage.DB + authStore *auth.Store +} + +func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI { + return &StatsAPI{db: db, authStore: authStore} +} + +// TokenNamesForUser returns the token names that belong to the user. +// Admins get nil (no filter), non-admins get their token names. +func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string { + if user == nil || user.IsAdmin { + return nil + } + tokens, err := s.authStore.ListAPITokens(user.ID) + if err != nil { + return []string{"__none__"} + } + names := make([]string, len(tokens)) + for i, t := range tokens { + names[i] = t.Name + } + if len(names) == 0 { + return []string{"__none__"} + } + return names +} + +// tokenNamesForUser returns token names from request context (for HTTP handlers). +func (s *StatsAPI) tokenNamesForUser(r *http.Request) []string { + user := auth.UserFromContext(r.Context()) + return s.TokenNamesForUser(user) +} + +func buildTokenFilter(tokenNames []string) (string, []any) { + if tokenNames == nil { + return "", nil + } + placeholders := "" + args := make([]any, len(tokenNames)) + for i, n := range tokenNames { + if i > 0 { + placeholders += "," + } + placeholders += "?" + args[i] = n + } + return " AND token_name IN (" + placeholders + ")", args +} + +// Data-fetching methods (used by both JSON handlers and template handlers). + +func (s *StatsAPI) GetSummary(tokenNames []string) *SummaryResult { + now := time.Now() + todayStart := now.Truncate(24 * time.Hour).Unix() + weekStart := now.AddDate(0, 0, -7).Unix() + monthStart := now.AddDate(0, -1, 0).Unix() + + tokenFilter, filterArgs := buildTokenFilter(tokenNames) + + result := &SummaryResult{ + Today: &Period{}, + Week: &Period{}, + Month: &Period{}, + } + + periods := map[string]struct { + since int64 + period *Period + }{ + "today": {todayStart, result.Today}, + "week": {weekStart, result.Week}, + "month": {monthStart, result.Month}, + } + + for _, p := range periods { + args := append([]any{p.since}, filterArgs...) + row := s.db.QueryRow(`SELECT + COUNT(*), + COALESCE(SUM(input_tokens), 0), + COALESCE(SUM(output_tokens), 0), + COALESCE(SUM(cost_usd), 0), + COALESCE(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END), 0) + FROM request_logs WHERE timestamp >= ?`+tokenFilter, args...) + row.Scan(&p.period.Requests, &p.period.InputTokens, &p.period.OutputTokens, &p.period.CostUSD, &p.period.Errors, &p.period.CachedHits) + } + + return result +} + +func (s *StatsAPI) GetModels(tokenNames []string) []ModelStats { + since := time.Now().AddDate(0, 0, -30).Unix() + tokenFilter, filterArgs := buildTokenFilter(tokenNames) + + args := append([]any{since}, filterArgs...) + rows, err := s.db.Query(`SELECT + model, + COUNT(*) as requests, + COALESCE(SUM(input_tokens), 0) as input_tokens, + COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cost_usd), 0) as cost, + COALESCE(AVG(latency_ms), 0) as avg_latency + FROM request_logs WHERE timestamp >= ?`+tokenFilter+` + GROUP BY model ORDER BY requests DESC`, args...) + if err != nil { + return nil + } + defer rows.Close() + + var results []ModelStats + for rows.Next() { + var m ModelStats + rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS) + results = append(results, m) + } + return results +} + +func (s *StatsAPI) GetProviders(tokenNames []string) []ProviderStats { + since := time.Now().AddDate(0, 0, -30).Unix() + tokenFilter, filterArgs := buildTokenFilter(tokenNames) + + args := append([]any{since}, filterArgs...) + rows, err := s.db.Query(`SELECT + provider, + COUNT(*) as requests, + COALESCE(SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END), 0) as successes, + COALESCE(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END), 0) as errors, + COALESCE(AVG(latency_ms), 0) as avg_latency, + COALESCE(SUM(cost_usd), 0) as cost + FROM request_logs WHERE timestamp >= ?`+tokenFilter+` + GROUP BY provider ORDER BY requests DESC`, args...) + if err != nil { + return nil + } + defer rows.Close() + + var results []ProviderStats + for rows.Next() { + var p ProviderStats + rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD) + results = append(results, p) + } + return results +} + +func (s *StatsAPI) GetTokenUsage(tokenNames []string) []TokenUsageStats { + since := time.Now().AddDate(0, 0, -30).Unix() + tokenFilter, filterArgs := buildTokenFilter(tokenNames) + + args := append([]any{since}, filterArgs...) + rows, err := s.db.Query(`SELECT + token_name, + COUNT(*) as requests, + COALESCE(SUM(input_tokens), 0) as input_tokens, + COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cost_usd), 0) as cost + FROM request_logs WHERE timestamp >= ?`+tokenFilter+` + GROUP BY token_name ORDER BY requests DESC`, args...) + if err != nil { + return nil + } + defer rows.Close() + + var results []TokenUsageStats + for rows.Next() { + var t TokenUsageStats + rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD) + results = append(results, t) + } + return results +} + +// JSON HTTP handlers (thin wrappers). + +func (s *StatsAPI) Summary(w http.ResponseWriter, r *http.Request) { + tokenNames := s.tokenNamesForUser(r) + result := s.GetSummary(tokenNames) + writeJSON(w, result) +} + +func (s *StatsAPI) Models(w http.ResponseWriter, r *http.Request) { + tokenNames := s.tokenNamesForUser(r) + results := s.GetModels(tokenNames) + writeJSON(w, results) +} + +func (s *StatsAPI) Providers(w http.ResponseWriter, r *http.Request) { + tokenNames := s.tokenNamesForUser(r) + results := s.GetProviders(tokenNames) + writeJSON(w, results) +} + +func (s *StatsAPI) Tokens(w http.ResponseWriter, r *http.Request) { + tokenNames := s.tokenNamesForUser(r) + results := s.GetTokenUsage(tokenNames) + writeJSON(w, results) +} + +func (s *StatsAPI) Timeseries(w http.ResponseWriter, r *http.Request) { + period := r.URL.Query().Get("period") + var since int64 + var groupFmt string + switch period { + case "7d": + since = time.Now().AddDate(0, 0, -7).Unix() + groupFmt = "%Y-%m-%d" + case "30d": + since = time.Now().AddDate(0, -1, 0).Unix() + groupFmt = "%Y-%m-%d" + default: + since = time.Now().Add(-24 * time.Hour).Unix() + groupFmt = "%Y-%m-%d %H:00" + } + + tokenNames := s.tokenNamesForUser(r) + tokenFilter, filterArgs := buildTokenFilter(tokenNames) + + args := append([]any{since}, filterArgs...) + rows, err := s.db.Query(`SELECT + strftime('`+groupFmt+`', timestamp, 'unixepoch') as bucket, + COUNT(*) as requests, + COALESCE(SUM(cost_usd), 0) as cost, + COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens + FROM request_logs WHERE timestamp >= ?`+tokenFilter+` + GROUP BY bucket ORDER BY bucket`, args...) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer rows.Close() + + type point struct { + Bucket string `json:"bucket"` + Requests int `json:"requests"` + CostUSD float64 `json:"cost_usd"` + TotalTokens int `json:"total_tokens"` + } + + var results []point + for rows.Next() { + var p point + rows.Scan(&p.Bucket, &p.Requests, &p.CostUSD, &p.TotalTokens) + results = append(results, p) + } + writeJSON(w, results) +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(v) +} diff --git a/llm-gateway/internal/dashboard/handler.go b/llm-gateway/internal/dashboard/handler.go new file mode 100644 index 0000000..6b246cd --- /dev/null +++ b/llm-gateway/internal/dashboard/handler.go @@ -0,0 +1,192 @@ +package dashboard + +import ( + "embed" + "fmt" + "html/template" + "net/http" + "time" + + "llm-gateway/internal/auth" +) + +//go:embed templates/*.html templates/partials/*.html +var templateFiles embed.FS + +var templateFuncs = template.FuncMap{ + "formatTime": func(ts int64) string { + if ts == 0 { + return "never" + } + return time.Unix(ts, 0).Format("2006-01-02") + }, + "addInt": func(a, b int) int { + return a + b + }, + "formatCost": func(v float64) string { + if v == 0 { + return "$0.00" + } + if v < 0.01 { + return fmt.Sprintf("$%.6f", v) + } + return fmt.Sprintf("$%.4f", v) + }, +} + +// PageData is the common data passed to all templates. +type PageData struct { + ActivePage string + User *auth.User + // Page-specific data + Summary *SummaryResult + Models []ModelStats + Providers []ProviderStats + TokenStats []TokenUsageStats + Tokens []auth.APIToken + Users []auth.User +} + +// Dashboard serves the HTMX-based dashboard pages. +type Dashboard struct { + templates *template.Template + authStore *auth.Store + statsAPI *StatsAPI +} + +// NewDashboard creates a new Dashboard handler. +func NewDashboard(authStore *auth.Store, statsAPI *StatsAPI) *Dashboard { + tmpl := template.Must( + template.New("").Funcs(templateFuncs).ParseFS(templateFiles, + "templates/*.html", + "templates/partials/*.html", + ), + ) + + return &Dashboard{ + templates: tmpl, + authStore: authStore, + statsAPI: statsAPI, + } +} + +// LoginPage serves the login page. +func (d *Dashboard) LoginPage(w http.ResponseWriter, r *http.Request) { + if !d.authStore.HasAnyUser() { + http.Redirect(w, r, "/setup", http.StatusFound) + return + } + if user := d.getSessionUser(r); user != nil { + http.Redirect(w, r, "/dashboard", http.StatusFound) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + d.templates.ExecuteTemplate(w, "login", nil) +} + +// SetupPage serves the initial setup page. +func (d *Dashboard) SetupPage(w http.ResponseWriter, r *http.Request) { + if d.authStore.HasAnyUser() { + http.Redirect(w, r, "/login", http.StatusFound) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + d.templates.ExecuteTemplate(w, "setup", nil) +} + +// DashboardPage serves the main dashboard view. +func (d *Dashboard) DashboardPage(w http.ResponseWriter, r *http.Request) { + user := auth.UserFromContext(r.Context()) + tokenNames := d.statsAPI.TokenNamesForUser(user) + + data := PageData{ + ActivePage: "dashboard", + User: user, + Summary: d.statsAPI.GetSummary(tokenNames), + Models: d.statsAPI.GetModels(tokenNames), + Providers: d.statsAPI.GetProviders(tokenNames), + TokenStats: d.statsAPI.GetTokenUsage(tokenNames), + } + + d.renderDashboardPage(w, r, "partials/dashboard.html", data) +} + +// TokensPage serves the tokens management view. +func (d *Dashboard) TokensPage(w http.ResponseWriter, r *http.Request) { + user := auth.UserFromContext(r.Context()) + + var userID int64 + if !user.IsAdmin { + userID = user.ID + } + + tokens, _ := d.authStore.ListAPITokens(userID) + if tokens == nil { + tokens = []auth.APIToken{} + } + + d.renderDashboardPage(w, r, "partials/tokens.html", PageData{ + ActivePage: "tokens", + User: user, + Tokens: tokens, + }) +} + +// UsersPage serves the user management view (admin only). +func (d *Dashboard) UsersPage(w http.ResponseWriter, r *http.Request) { + user := auth.UserFromContext(r.Context()) + users, _ := d.authStore.ListUsers() + + d.renderDashboardPage(w, r, "partials/users.html", PageData{ + ActivePage: "users", + User: user, + Users: users, + }) +} + +// SettingsPage serves the settings view. +func (d *Dashboard) SettingsPage(w http.ResponseWriter, r *http.Request) { + user := auth.UserFromContext(r.Context()) + user, _ = d.authStore.GetUserByID(user.ID) + + d.renderDashboardPage(w, r, "partials/settings.html", PageData{ + ActivePage: "settings", + User: user, + }) +} + +// renderDashboardPage renders either the full layout or just the content partial. +func (d *Dashboard) renderDashboardPage(w http.ResponseWriter, r *http.Request, partialFile string, data PageData) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + + if r.Header.Get("HX-Request") == "true" { + tmpl := template.Must( + template.New("").Funcs(templateFuncs).ParseFS(templateFiles, "templates/"+partialFile), + ) + tmpl.ExecuteTemplate(w, "content", data) + } else { + tmpl := template.Must( + template.New("").Funcs(templateFuncs).ParseFS(templateFiles, + "templates/layout.html", + "templates/"+partialFile, + ), + ) + tmpl.ExecuteTemplate(w, "layout", data) + } +} + +func (d *Dashboard) getSessionUser(r *http.Request) *auth.User { + cookie, err := r.Cookie("llmgw_session") + if err != nil || cookie.Value == "" { + return nil + } + sess, err := d.authStore.GetSession(cookie.Value) + if err != nil { + return nil + } + user, err := d.authStore.GetUserByID(sess.UserID) + if err != nil { + return nil + } + return user +} diff --git a/llm-gateway/internal/dashboard/sse.go b/llm-gateway/internal/dashboard/sse.go new file mode 100644 index 0000000..aafd88d --- /dev/null +++ b/llm-gateway/internal/dashboard/sse.go @@ -0,0 +1,73 @@ +package dashboard + +import ( + "fmt" + "net/http" + "sync" +) + +// SSEBroker manages Server-Sent Events connections. +type SSEBroker struct { + mu sync.RWMutex + clients map[chan struct{}]struct{} +} + +// NewSSEBroker creates a new SSE broker. +func NewSSEBroker() *SSEBroker { + return &SSEBroker{ + clients: make(map[chan struct{}]struct{}), + } +} + +// Notify sends a refresh signal to all connected SSE clients. +func (b *SSEBroker) Notify() { + b.mu.RLock() + defer b.mu.RUnlock() + for ch := range b.clients { + select { + case ch <- struct{}{}: + default: + // Client not ready, skip + } + } +} + +// ServeHTTP handles SSE connections. +func (b *SSEBroker) ServeHTTP(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming not supported", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + ch := make(chan struct{}, 1) + b.mu.Lock() + b.clients[ch] = struct{}{} + b.mu.Unlock() + + defer func() { + b.mu.Lock() + delete(b.clients, ch) + b.mu.Unlock() + }() + + // Send initial connection event + fmt.Fprintf(w, "event: connected\ndata: ok\n\n") + flusher.Flush() + + ctx := r.Context() + for { + select { + case <-ctx.Done(): + return + case <-ch: + fmt.Fprintf(w, "event: refresh\ndata: updated\n\n") + flusher.Flush() + } + } +} diff --git a/llm-gateway/internal/dashboard/templates/layout.html b/llm-gateway/internal/dashboard/templates/layout.html new file mode 100644 index 0000000..d80d41f --- /dev/null +++ b/llm-gateway/internal/dashboard/templates/layout.html @@ -0,0 +1,121 @@ +{{define "layout"}} + + + + + +LLM Gateway + + + + + + + + +
+
+ {{template "content" .}} +
+
+ + + +{{end}} diff --git a/llm-gateway/internal/dashboard/templates/login.html b/llm-gateway/internal/dashboard/templates/login.html new file mode 100644 index 0000000..52097b8 --- /dev/null +++ b/llm-gateway/internal/dashboard/templates/login.html @@ -0,0 +1,89 @@ +{{define "login"}} + + + + + +Login - LLM Gateway + + + + +
+

LLM Gateway

+
+
+
+ + +
+
+ + +
+ +
+ +
+ + + +{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/dashboard.html b/llm-gateway/internal/dashboard/templates/partials/dashboard.html new file mode 100644 index 0000000..df9e0ec --- /dev/null +++ b/llm-gateway/internal/dashboard/templates/partials/dashboard.html @@ -0,0 +1,128 @@ +{{define "content"}} +
+ + +
+ {{with .Summary.Today}} +
Requests Today
{{.Requests}}
+
Cost Today
{{formatCost .CostUSD}}
+
Tokens Today
{{addInt .InputTokens .OutputTokens}}
{{.InputTokens}} in / {{.OutputTokens}} out
+
Errors Today
{{.Errors}}
+
Cache Hits
{{.CachedHits}}
+ {{end}} + {{with .Summary.Week}} +
Cost (7d)
{{formatCost .CostUSD}}
+ {{end}} +
+ +
+ + + +
+
+

Requests & Cost

+ +
+ +{{if .Models}} +
+

Models

+ + + + {{range .Models}} + + + + + + + + {{end}} + +
ModelRequestsTokens (in/out)CostAvg Latency
{{.Model}}{{.Requests}}{{.InputTokens}} / {{.OutputTokens}}{{formatCost .CostUSD}}{{printf "%.0f" .AvgLatencyMS}}ms
+
+{{end}} + +{{if .Providers}} +
+

Providers

+ + + + {{range .Providers}} + + + + + + + + + {{end}} + +
ProviderRequestsSuccessErrorsAvg LatencyCost
{{.Provider}}{{.Requests}}{{.Successes}}{{.Errors}}{{printf "%.0f" .AvgLatencyMS}}ms{{formatCost .CostUSD}}
+
+{{end}} + +{{if .TokenStats}} +
+

API Token Usage

+ + + + {{range .TokenStats}} + + + + + + + {{end}} + +
TokenRequestsTokens (in/out)Cost
{{.TokenName}}{{.Requests}}{{.InputTokens}} / {{.OutputTokens}}{{formatCost .CostUSD}}
+
+{{end}} + + +
+{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/settings.html b/llm-gateway/internal/dashboard/templates/partials/settings.html new file mode 100644 index 0000000..faf009b --- /dev/null +++ b/llm-gateway/internal/dashboard/templates/partials/settings.html @@ -0,0 +1,171 @@ +{{define "content"}} + + +
+

Profile

+
+
+
+ +
+ + +
+
+
+
+
+ +
+ + +
+
+
+
+ +
+

Change Password

+
+
+
+ + +
+
+ + +
+
+ + +
+ +
+
+ +
+

Two-Factor Authentication

+
+ {{if .User.TOTPEnabled}} +

Two-factor authentication is enabled.

+ + {{else}} +

Two-factor authentication is not enabled.

+ + {{end}} +
+ +
+ + + +{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/tokens.html b/llm-gateway/internal/dashboard/templates/partials/tokens.html new file mode 100644 index 0000000..550616a --- /dev/null +++ b/llm-gateway/internal/dashboard/templates/partials/tokens.html @@ -0,0 +1,104 @@ +{{define "content"}} + + + + +
+ + + + {{range .Tokens}} + + + + + + + + + + {{else}} + + {{end}} + +
NamePrefixRate LimitBudgetCreatedLast Used
{{.Name}}{{.KeyPrefix}}...{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}{{else}}unlimited{{end}}{{formatTime .CreatedAt}}{{if gt .LastUsedAt 0}}{{formatTime .LastUsedAt}}{{else}}never{{end}}
No API tokens yet. Create one to get started.
+
+ + + + + +{{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/users.html b/llm-gateway/internal/dashboard/templates/partials/users.html new file mode 100644 index 0000000..df56a4f --- /dev/null +++ b/llm-gateway/internal/dashboard/templates/partials/users.html @@ -0,0 +1,93 @@ +{{define "content"}} + + +
+ + + + {{range .Users}} + + + + + + + + + {{end}} + +
IDUsernameRole2FACreated
{{.ID}}{{.Username}}{{if .IsAdmin}}Admin{{else}}User{{end}}{{if .TOTPEnabled}}Enabled{{else}}Off{{end}}{{formatTime .CreatedAt}}{{if ne .ID $.User.ID}}{{end}}
+
+ + + + + +{{end}} diff --git a/llm-gateway/internal/dashboard/templates/setup.html b/llm-gateway/internal/dashboard/templates/setup.html new file mode 100644 index 0000000..ceb7339 --- /dev/null +++ b/llm-gateway/internal/dashboard/templates/setup.html @@ -0,0 +1,72 @@ +{{define "setup"}} + + + + + +Setup - LLM Gateway + + + +
+

LLM Gateway Setup

+

Create the first admin account

+
+
+
+ + +
+
+ + +
+
+ + +
+ +
+
+ + + +{{end}} diff --git a/llm-gateway/internal/metrics/prometheus.go b/llm-gateway/internal/metrics/prometheus.go new file mode 100644 index 0000000..7e8b635 --- /dev/null +++ b/llm-gateway/internal/metrics/prometheus.go @@ -0,0 +1,53 @@ +package metrics + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +type Metrics struct { + requestsTotal *prometheus.CounterVec + requestDuration *prometheus.HistogramVec + tokensTotal *prometheus.CounterVec + costTotal *prometheus.CounterVec +} + +func New() *Metrics { + return &Metrics{ + requestsTotal: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "llm_gateway_requests_total", + Help: "Total number of LLM requests", + }, []string{"model", "provider", "token_name", "status"}), + + requestDuration: promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "llm_gateway_request_duration_ms", + Help: "Request duration in milliseconds", + Buckets: []float64{100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000}, + }, []string{"model", "provider"}), + + tokensTotal: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "llm_gateway_tokens_total", + Help: "Total tokens processed", + }, []string{"model", "provider", "type"}), + + costTotal: promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "llm_gateway_cost_usd_total", + Help: "Total cost in USD", + }, []string{"model", "provider", "token_name"}), + } +} + +func (m *Metrics) RecordRequest(model, providerName, tokenName, status string, latencyMS int64, inputTokens, outputTokens int, cost float64) { + m.requestsTotal.WithLabelValues(model, providerName, tokenName, status).Inc() + m.requestDuration.WithLabelValues(model, providerName).Observe(float64(latencyMS)) + + if inputTokens > 0 { + m.tokensTotal.WithLabelValues(model, providerName, "input").Add(float64(inputTokens)) + } + if outputTokens > 0 { + m.tokensTotal.WithLabelValues(model, providerName, "output").Add(float64(outputTokens)) + } + if cost > 0 { + m.costTotal.WithLabelValues(model, providerName, tokenName).Add(cost) + } +} diff --git a/llm-gateway/internal/pricing/pricing.go b/llm-gateway/internal/pricing/pricing.go new file mode 100644 index 0000000..69b01d8 --- /dev/null +++ b/llm-gateway/internal/pricing/pricing.go @@ -0,0 +1,191 @@ +package pricing + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "sync" + "time" +) + +const defaultPricesURL = "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json" + +// Provider represents a provider entry in genai_prices.json. +type Provider struct { + ID string `json:"id"` + Models []Model `json:"models"` +} + +// Model represents a model entry with pricing. +type Model struct { + ID string `json:"id"` + Prices json.RawMessage `json:"prices"` +} + +// Lookup provides pricing data fetched from genai-prices. +type Lookup struct { + mu sync.RWMutex + prices map[string][2]float64 + url string + stopCh chan struct{} +} + +// NewLookup creates a Lookup that fetches pricing data immediately and refreshes every interval. +// If url is empty, uses the default genai-prices URL. +// Returns a usable Lookup even if the initial fetch fails (prices will be empty until next refresh). +func NewLookup(url string, interval time.Duration) *Lookup { + if url == "" { + url = defaultPricesURL + } + l := &Lookup{ + prices: make(map[string][2]float64), + url: url, + stopCh: make(chan struct{}), + } + + // Initial fetch + l.refresh() + + // Background refresh + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + l.refresh() + case <-l.stopCh: + return + } + } + }() + + return l +} + +// Close stops the background refresh goroutine. +func (l *Lookup) Close() { + close(l.stopCh) +} + +// Get returns (inputPer1M, outputPer1M) for a provider:model pair. +// Returns (0, 0) if not found. +func (l *Lookup) Get(provider, model string) (float64, float64) { + if l == nil { + return 0, 0 + } + l.mu.RLock() + defer l.mu.RUnlock() + key := fmt.Sprintf("%s:%s", provider, model) + if p, ok := l.prices[key]; ok { + return p[0], p[1] + } + return 0, 0 +} + +// FillMissing fills in zero-value pricing from the lookup data. +// Returns the number of prices filled. +func (l *Lookup) FillMissing(provider, model string, input, output *float64) bool { + if l == nil || (*input > 0 && *output > 0) { + return false + } + i, o := l.Get(provider, model) + if i == 0 && o == 0 { + return false + } + if *input == 0 { + *input = i + } + if *output == 0 { + *output = o + } + return true +} + +func (l *Lookup) refresh() { + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Get(l.url) + if err != nil { + log.Printf("WARNING: failed to fetch pricing data: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Printf("WARNING: pricing data fetch returned %d", resp.StatusCode) + return + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("WARNING: failed to read pricing data: %v", err) + return + } + + var providers []Provider + if err := json.Unmarshal(body, &providers); err != nil { + log.Printf("WARNING: failed to parse pricing data: %v", err) + return + } + + prices := make(map[string][2]float64) + for _, p := range providers { + for _, m := range p.Models { + input, output := parsePrices(m.Prices) + if input > 0 || output > 0 { + key := fmt.Sprintf("%s:%s", p.ID, m.ID) + prices[key] = [2]float64{input, output} + } + } + } + + l.mu.Lock() + l.prices = prices + l.mu.Unlock() + + log.Printf("Loaded pricing data: %d model prices from genai-prices", len(prices)) +} + +// parsePrices handles the different shapes of the "prices" field: +// - object: {"input_mtok": 0.5, "output_mtok": 1.0} +// - array: [{"prices": {"input_mtok": 0.5, ...}}, ...] (time-of-day; use first entry) +func parsePrices(raw json.RawMessage) (input, output float64) { + if len(raw) == 0 { + return 0, 0 + } + + // Try as object first (most common) + var obj map[string]any + if json.Unmarshal(raw, &obj) == nil { + return extractPrice(obj, "input_mtok"), extractPrice(obj, "output_mtok") + } + + // Try as array (time-of-day pricing) — use first entry + var arr []struct { + Prices map[string]any `json:"prices"` + } + if json.Unmarshal(raw, &arr) == nil && len(arr) > 0 { + return extractPrice(arr[0].Prices, "input_mtok"), extractPrice(arr[0].Prices, "output_mtok") + } + + return 0, 0 +} + +// extractPrice handles both simple float and tiered pricing (uses base price). +func extractPrice(prices map[string]any, key string) float64 { + v, ok := prices[key] + if !ok { + return 0 + } + switch val := v.(type) { + case float64: + return val + case map[string]any: + if base, ok := val["base"].(float64); ok { + return base + } + } + return 0 +} diff --git a/llm-gateway/internal/provider/openai.go b/llm-gateway/internal/provider/openai.go new file mode 100644 index 0000000..1278eea --- /dev/null +++ b/llm-gateway/internal/provider/openai.go @@ -0,0 +1,130 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// OpenAIProvider is a generic OpenAI-compatible HTTP client. +type OpenAIProvider struct { + name string + baseURL string + apiKey string + client *http.Client +} + +func NewOpenAIProvider(name, baseURL, apiKey string, timeout time.Duration) *OpenAIProvider { + return &OpenAIProvider{ + name: name, + baseURL: baseURL, + apiKey: apiKey, + client: &http.Client{ + Timeout: timeout, + }, + } +} + +func (p *OpenAIProvider) Name() string { return p.name } + +func (p *OpenAIProvider) ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error) { + reqCopy := *req + reqCopy.Model = model + reqCopy.Stream = false + + body, err := json.Marshal(reqCopy) + if err != nil { + return nil, fmt.Errorf("marshaling request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + p.setHeaders(httpReq) + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("sending request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, &ProviderError{ + StatusCode: resp.StatusCode, + Body: string(respBody), + Provider: p.name, + } + } + + var chatResp ChatResponse + if err := json.Unmarshal(respBody, &chatResp); err != nil { + return nil, fmt.Errorf("unmarshaling response: %w", err) + } + + return &chatResp, nil +} + +func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error) { + reqCopy := *req + reqCopy.Model = model + reqCopy.Stream = true + + body, err := json.Marshal(reqCopy) + if err != nil { + return nil, fmt.Errorf("marshaling request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + p.setHeaders(httpReq) + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("sending request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + return nil, &ProviderError{ + StatusCode: resp.StatusCode, + Body: string(respBody), + Provider: p.name, + } + } + + return resp.Body, nil +} + +func (p *OpenAIProvider) setHeaders(req *http.Request) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+p.apiKey) +} + +// ProviderError represents a non-200 response from a provider. +type ProviderError struct { + StatusCode int + Body string + Provider string +} + +func (e *ProviderError) Error() string { + return fmt.Sprintf("provider %s returned %d: %s", e.Provider, e.StatusCode, e.Body) +} + +// IsRetryable returns true if the error is a server-side error worth retrying with another provider. +func (e *ProviderError) IsRetryable() bool { + return e.StatusCode >= 500 || e.StatusCode == 429 +} diff --git a/llm-gateway/internal/provider/provider.go b/llm-gateway/internal/provider/provider.go new file mode 100644 index 0000000..5686865 --- /dev/null +++ b/llm-gateway/internal/provider/provider.go @@ -0,0 +1,60 @@ +package provider + +import ( + "context" + "io" +) + +// ChatRequest is the OpenAI-compatible chat completion request. +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stream bool `json:"stream"` + Stop any `json:"stop,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + N *int `json:"n,omitempty"` + Tools []any `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + Extra map[string]any `json:"-"` // pass through unknown fields +} + +type Message struct { + Role string `json:"role"` + Content any `json:"content"` // string or []ContentPart + Name string `json:"name,omitempty"` + ToolCalls []any `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +type Choice struct { + Index int `json:"index"` + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Provider sends requests to an LLM API. +type Provider interface { + Name() string + ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error) + ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error) +} diff --git a/llm-gateway/internal/provider/registry.go b/llm-gateway/internal/provider/registry.go new file mode 100644 index 0000000..1a15819 --- /dev/null +++ b/llm-gateway/internal/provider/registry.go @@ -0,0 +1,74 @@ +package provider + +import ( + "fmt" + "sort" + + "llm-gateway/internal/config" +) + +// Route maps a model to a specific provider with pricing. +type Route struct { + Provider Provider + ProviderModel string + Priority int + InputPrice float64 // per 1M tokens + OutputPrice float64 // per 1M tokens +} + +// Registry maps model names to provider routes. +type Registry struct { + routes map[string][]Route +} + +func NewRegistry(cfg *config.Config) (*Registry, error) { + // Build providers + providers := make(map[string]Provider) + for _, pc := range cfg.Providers { + providers[pc.Name] = NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKey, pc.Timeout) + } + + // Build routes + routes := make(map[string][]Route) + for _, mc := range cfg.Models { + var modelRoutes []Route + for _, rc := range mc.Routes { + p, ok := providers[rc.Provider] + if !ok { + return nil, fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider) + } + pc := cfg.ProviderByName(rc.Provider) + priority := pc.Priority + modelRoutes = append(modelRoutes, Route{ + Provider: p, + ProviderModel: rc.Model, + Priority: priority, + InputPrice: rc.Pricing.Input, + OutputPrice: rc.Pricing.Output, + }) + } + // Sort by priority (lower = higher priority) + sort.Slice(modelRoutes, func(i, j int) bool { + return modelRoutes[i].Priority < modelRoutes[j].Priority + }) + routes[mc.Name] = modelRoutes + } + + return &Registry{routes: routes}, nil +} + +// Lookup returns the routes for a model name. +func (r *Registry) Lookup(model string) ([]Route, bool) { + routes, ok := r.routes[model] + return routes, ok +} + +// ModelNames returns all registered model names. +func (r *Registry) ModelNames() []string { + names := make([]string, 0, len(r.routes)) + for name := range r.routes { + names = append(names, name) + } + sort.Strings(names) + return names +} diff --git a/llm-gateway/internal/proxy/auth.go b/llm-gateway/internal/proxy/auth.go new file mode 100644 index 0000000..b885550 --- /dev/null +++ b/llm-gateway/internal/proxy/auth.go @@ -0,0 +1,41 @@ +package proxy + +import ( + "net/http" + "strings" + + "llm-gateway/internal/auth" +) + +type AuthMiddleware struct { + authStore *auth.Store +} + +func NewAuthMiddleware(authStore *auth.Store) *AuthMiddleware { + return &AuthMiddleware{authStore: authStore} +} + +// Authenticate validates the bearer token against the DB and sets token info in context. +func (a *AuthMiddleware) Authenticate(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hdr := r.Header.Get("Authorization") + if !strings.HasPrefix(hdr, "Bearer ") { + writeError(w, http.StatusUnauthorized, "missing or invalid Authorization header") + return + } + key := strings.TrimPrefix(hdr, "Bearer ") + + token, err := a.authStore.LookupAPIToken(key) + if err != nil { + writeError(w, http.StatusUnauthorized, "invalid API key") + return + } + + // Update last used asynchronously + go a.authStore.UpdateAPITokenLastUsed(token.ID) + + ctx := withTokenName(r.Context(), token.Name) + ctx = withAPIToken(ctx, token) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} diff --git a/llm-gateway/internal/proxy/handler.go b/llm-gateway/internal/proxy/handler.go new file mode 100644 index 0000000..99eb125 --- /dev/null +++ b/llm-gateway/internal/proxy/handler.go @@ -0,0 +1,207 @@ +package proxy + +import ( + "context" + "encoding/json" + "errors" + "io" + "log" + "net/http" + "time" + + "llm-gateway/internal/auth" + "llm-gateway/internal/cache" + "llm-gateway/internal/config" + "llm-gateway/internal/metrics" + "llm-gateway/internal/provider" + "llm-gateway/internal/storage" +) + +type contextKey string + +const tokenNameKey contextKey = "token_name" +const apiTokenKey contextKey = "api_token" + +func withTokenName(ctx context.Context, name string) context.Context { + return context.WithValue(ctx, tokenNameKey, name) +} + +func getTokenName(ctx context.Context) string { + name, _ := ctx.Value(tokenNameKey).(string) + return name +} + +func withAPIToken(ctx context.Context, token *auth.APIToken) context.Context { + return context.WithValue(ctx, apiTokenKey, token) +} + +func getAPIToken(ctx context.Context) *auth.APIToken { + t, _ := ctx.Value(apiTokenKey).(*auth.APIToken) + return t +} + +type Handler struct { + registry *provider.Registry + logger *storage.AsyncLogger + cache *cache.Cache + metrics *metrics.Metrics + cfg *config.Config +} + +func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config) *Handler { + return &Handler{ + registry: registry, + logger: logger, + cache: c, + metrics: m, + cfg: cfg, + } +} + +func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20)) + if err != nil { + writeError(w, http.StatusBadRequest, "failed to read request body") + return + } + + var req provider.ChatRequest + if err := json.Unmarshal(body, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error()) + return + } + + if req.Model == "" { + writeError(w, http.StatusBadRequest, "model is required") + return + } + + routes, ok := h.registry.Lookup(req.Model) + if !ok { + writeError(w, http.StatusNotFound, "model not found: "+req.Model) + return + } + + tokenName := getTokenName(r.Context()) + + // Check cache for non-streaming requests + if !req.Stream && h.cache != nil { + if cached, err := h.cache.Get(r.Context(), req.Model, body); err == nil && cached != nil { + h.logRequest(tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true) + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Cache", "HIT") + w.Write(cached) + return + } + } + + if req.Stream { + h.handleStream(w, r, &req, routes, tokenName) + return + } + + h.handleNonStream(w, r, &req, routes, tokenName, body) +} + +func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte) { + var lastErr error + + for _, route := range routes { + start := time.Now() + resp, err := route.Provider.ChatCompletion(r.Context(), route.ProviderModel, req) + latency := time.Since(start).Milliseconds() + + if err != nil { + var pe *provider.ProviderError + if errors.As(err, &pe) && !pe.IsRetryable() { + // Client error — don't retry + h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) + h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) + writeErrorRaw(w, pe.StatusCode, pe.Body) + return + } + lastErr = err + log.Printf("Provider %s failed for %s: %v", route.Provider.Name(), req.Model, err) + h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) + h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) + continue + } + + // Compute cost + inputTokens, outputTokens := 0, 0 + if resp.Usage != nil { + inputTokens = resp.Usage.PromptTokens + outputTokens = resp.Usage.CompletionTokens + } + cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice) + + h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost) + h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false) + + // Override model name in response to match the requested model + resp.Model = req.Model + + respBytes, err := json.Marshal(resp) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to marshal response") + return + } + + // Cache the response + if h.cache != nil { + h.cache.Set(r.Context(), req.Model, rawBody, respBytes) + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Cache", "MISS") + w.Write(respBytes) + return + } + + // All providers failed + if lastErr != nil { + writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error()) + } else { + writeError(w, http.StatusBadGateway, "all providers failed") + } +} + +func (h *Handler) logRequest(tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) { + h.logger.Log(storage.RequestLog{ + Timestamp: time.Now().Unix(), + TokenName: tokenName, + Model: model, + Provider: providerName, + ProviderModel: providerModel, + InputTokens: inputTokens, + OutputTokens: outputTokens, + CostUSD: cost, + LatencyMS: latencyMS, + Status: status, + ErrorMessage: errMsg, + Streaming: streaming, + Cached: cached, + }) +} + +func computeCost(inputTokens, outputTokens int, inputPrice, outputPrice float64) float64 { + return (float64(inputTokens) / 1_000_000.0 * inputPrice) + (float64(outputTokens) / 1_000_000.0 * outputPrice) +} + +func writeError(w http.ResponseWriter, code int, msg string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": msg, + "type": "error", + "code": code, + }, + }) +} + +func writeErrorRaw(w http.ResponseWriter, code int, body string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + w.Write([]byte(body)) +} diff --git a/llm-gateway/internal/proxy/models.go b/llm-gateway/internal/proxy/models.go new file mode 100644 index 0000000..56e1aba --- /dev/null +++ b/llm-gateway/internal/proxy/models.go @@ -0,0 +1,36 @@ +package proxy + +import ( + "encoding/json" + "net/http" + "time" + + "llm-gateway/internal/provider" +) + +type ModelsHandler struct { + registry *provider.Registry +} + +func NewModelsHandler(registry *provider.Registry) *ModelsHandler { + return &ModelsHandler{registry: registry} +} + +func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) { + names := h.registry.ModelNames() + models := make([]map[string]any, len(names)) + for i, name := range names { + models[i] = map[string]any{ + "id": name, + "object": "model", + "created": time.Now().Unix(), + "owned_by": "llm-gateway", + } + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "object": "list", + "data": models, + }) +} diff --git a/llm-gateway/internal/proxy/ratelimit.go b/llm-gateway/internal/proxy/ratelimit.go new file mode 100644 index 0000000..2278ea1 --- /dev/null +++ b/llm-gateway/internal/proxy/ratelimit.go @@ -0,0 +1,90 @@ +package proxy + +import ( + "net/http" + "sync" + "time" + + "llm-gateway/internal/storage" +) + +type RateLimiter struct { + db *storage.DB + mu sync.Mutex + buckets map[string]*tokenBucket +} + +type tokenBucket struct { + tokens float64 + maxTokens float64 + refillRate float64 // tokens per second + lastRefill time.Time +} + +func NewRateLimiter(db *storage.DB) *RateLimiter { + return &RateLimiter{ + db: db, + buckets: make(map[string]*tokenBucket), + } +} + +func (rl *RateLimiter) Check(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiToken := getAPIToken(r.Context()) + if apiToken == nil { + next.ServeHTTP(w, r) + return + } + + tokenName := apiToken.Name + + // Check rate limit + if apiToken.RateLimitRPM > 0 { + if !rl.allow(tokenName, apiToken.RateLimitRPM) { + writeError(w, http.StatusTooManyRequests, "rate limit exceeded") + return + } + } + + // Check daily budget + if apiToken.DailyBudgetUSD > 0 { + spent, err := rl.db.TodaySpend(tokenName) + if err == nil && spent >= apiToken.DailyBudgetUSD { + writeError(w, http.StatusTooManyRequests, "daily budget exceeded") + return + } + } + + next.ServeHTTP(w, r) + }) +} + +func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + bucket, ok := rl.buckets[tokenName] + if !ok { + bucket = &tokenBucket{ + tokens: float64(rateLimitRPM), + maxTokens: float64(rateLimitRPM), + refillRate: float64(rateLimitRPM) / 60.0, + lastRefill: time.Now(), + } + rl.buckets[tokenName] = bucket + } + + now := time.Now() + elapsed := now.Sub(bucket.lastRefill).Seconds() + bucket.tokens += elapsed * bucket.refillRate + if bucket.tokens > bucket.maxTokens { + bucket.tokens = bucket.maxTokens + } + bucket.lastRefill = now + + if bucket.tokens < 1 { + return false + } + bucket.tokens-- + return true +} diff --git a/llm-gateway/internal/proxy/stream.go b/llm-gateway/internal/proxy/stream.go new file mode 100644 index 0000000..39fb596 --- /dev/null +++ b/llm-gateway/internal/proxy/stream.go @@ -0,0 +1,105 @@ +package proxy + +import ( + "bufio" + "encoding/json" + "errors" + "log" + "net/http" + "strings" + "time" + + "llm-gateway/internal/provider" +) + +func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string) { + flusher, ok := w.(http.Flusher) + if !ok { + writeError(w, http.StatusInternalServerError, "streaming not supported") + return + } + + var lastErr error + + for _, route := range routes { + start := time.Now() + body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req) + + if err != nil { + var pe *provider.ProviderError + if errors.As(err, &pe) && !pe.IsRetryable() { + latency := time.Since(start).Milliseconds() + h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) + h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false) + writeErrorRaw(w, pe.StatusCode, pe.Body) + return + } + lastErr = err + log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err) + h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, time.Since(start).Milliseconds(), "error", err.Error(), true, false) + continue + } + + // Stream the response + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + w.WriteHeader(http.StatusOK) + + inputTokens, outputTokens := 0, 0 + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 64*1024), 256*1024) + + for scanner.Scan() { + line := scanner.Text() + + // Parse usage from the final chunk if available + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + if data != "[DONE]" { + var chunk streamChunk + if json.Unmarshal([]byte(data), &chunk) == nil { + if chunk.Usage != nil { + inputTokens = chunk.Usage.PromptTokens + outputTokens = chunk.Usage.CompletionTokens + } + // Override model name in chunk + if chunk.Model != "" { + chunk.Model = req.Model + if rewritten, err := json.Marshal(chunk); err == nil { + line = "data: " + string(rewritten) + } + } + } + } + } + + w.Write([]byte(line + "\n")) + flusher.Flush() + } + body.Close() + + latency := time.Since(start).Milliseconds() + cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice) + h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost) + h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false) + return + } + + // All providers failed + if lastErr != nil { + writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error()) + } else { + writeError(w, http.StatusBadGateway, "all providers failed") + } +} + +type streamChunk struct { + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` + Created int64 `json:"created,omitempty"` + Model string `json:"model,omitempty"` + Choices []any `json:"choices,omitempty"` + Usage *provider.Usage `json:"usage,omitempty"` +} diff --git a/llm-gateway/internal/storage/db.go b/llm-gateway/internal/storage/db.go new file mode 100644 index 0000000..60f0396 --- /dev/null +++ b/llm-gateway/internal/storage/db.go @@ -0,0 +1,103 @@ +package storage + +import ( + "database/sql" + "fmt" + "log" + "path/filepath" + "time" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/sqlite" + "github.com/golang-migrate/migrate/v4/source/iofs" + _ "modernc.org/sqlite" + + "llm-gateway/internal/storage/migrations" +) + +type DB struct { + *sql.DB +} + +func Open(path string) (*DB, error) { + dir := filepath.Dir(path) + if dir != "." && dir != "" { + // Ensure directory exists — caller should create it if needed + } + + db, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000&_cache_size=-20000") + if err != nil { + return nil, fmt.Errorf("opening database: %w", err) + } + + // Performance pragmas + for _, pragma := range []string{ + "PRAGMA foreign_keys = ON", + "PRAGMA temp_store = MEMORY", + "PRAGMA mmap_size = 268435456", + } { + if _, err := db.Exec(pragma); err != nil { + return nil, fmt.Errorf("setting pragma %s: %w", pragma, err) + } + } + + db.SetMaxOpenConns(1) // SQLite is single-writer + db.SetMaxIdleConns(1) + + if err := runMigrations(db); err != nil { + return nil, fmt.Errorf("running migrations: %w", err) + } + + return &DB{db}, nil +} + +func runMigrations(db *sql.DB) error { + sourceDriver, err := iofs.New(migrations.FS, ".") + if err != nil { + return fmt.Errorf("creating migration source: %w", err) + } + + dbDriver, err := sqlite.WithInstance(db, &sqlite.Config{}) + if err != nil { + return fmt.Errorf("creating migration db driver: %w", err) + } + + m, err := migrate.NewWithInstance("iofs", sourceDriver, "sqlite", dbDriver) + if err != nil { + return fmt.Errorf("creating migrator: %w", err) + } + + if err := m.Up(); err != nil && err != migrate.ErrNoChange { + return fmt.Errorf("applying migrations: %w", err) + } + + return nil +} + +// CleanupOldRecords deletes records older than retentionDays. +func (db *DB) CleanupOldRecords(retentionDays int) error { + cutoff := time.Now().AddDate(0, 0, -retentionDays).Unix() + result, err := db.Exec("DELETE FROM request_logs WHERE timestamp < ?", cutoff) + if err != nil { + return err + } + affected, _ := result.RowsAffected() + if affected > 0 { + log.Printf("Cleaned up %d old request log records", affected) + } + return nil +} + +// TodaySpend returns the total cost in USD for a given token today. +func (db *DB) TodaySpend(tokenName string) (float64, error) { + startOfDay := time.Now().Truncate(24 * time.Hour).Unix() + var total sql.NullFloat64 + err := db.QueryRow( + "SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?", + tokenName, startOfDay, + ).Scan(&total) + if err != nil { + return 0, err + } + return total.Float64, nil +} diff --git a/llm-gateway/internal/storage/logger.go b/llm-gateway/internal/storage/logger.go new file mode 100644 index 0000000..ad2e829 --- /dev/null +++ b/llm-gateway/internal/storage/logger.go @@ -0,0 +1,132 @@ +package storage + +import ( + "log" + "time" +) + +type RequestLog struct { + Timestamp int64 + TokenName string + Model string + Provider string + ProviderModel string + InputTokens int + OutputTokens int + CostUSD float64 + LatencyMS int64 + Status string // success, error, cached + ErrorMessage string + Streaming bool + Cached bool +} + +type AsyncLogger struct { + db *DB + ch chan RequestLog + done chan struct{} + OnFlush func() // called after successful flush, if set +} + +func NewAsyncLogger(db *DB, bufferSize int) *AsyncLogger { + if bufferSize == 0 { + bufferSize = 1000 + } + l := &AsyncLogger{ + db: db, + ch: make(chan RequestLog, bufferSize), + done: make(chan struct{}), + } + go l.run() + return l +} + +func (l *AsyncLogger) Log(r RequestLog) { + select { + case l.ch <- r: + default: + log.Println("WARNING: request log buffer full, dropping entry") + } +} + +func (l *AsyncLogger) Close() { + close(l.ch) + <-l.done +} + +func (l *AsyncLogger) run() { + defer close(l.done) + + batch := make([]RequestLog, 0, 100) + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case r, ok := <-l.ch: + if !ok { + // Channel closed, flush remaining + if len(batch) > 0 { + l.flush(batch) + } + return + } + batch = append(batch, r) + if len(batch) >= 100 { + l.flush(batch) + batch = batch[:0] + } + case <-ticker.C: + if len(batch) > 0 { + l.flush(batch) + batch = batch[:0] + } + } + } +} + +func (l *AsyncLogger) flush(batch []RequestLog) { + tx, err := l.db.Begin() + if err != nil { + log.Printf("ERROR: starting log transaction: %v", err) + return + } + + stmt, err := tx.Prepare(`INSERT INTO request_logs + (timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + if err != nil { + log.Printf("ERROR: preparing log statement: %v", err) + tx.Rollback() + return + } + defer stmt.Close() + + for _, r := range batch { + streaming := 0 + if r.Streaming { + streaming = 1 + } + cached := 0 + if r.Cached { + cached = 1 + } + _, err := stmt.Exec( + r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel, + r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS, + r.Status, r.ErrorMessage, streaming, cached, + ) + if err != nil { + log.Printf("ERROR: inserting log: %v", err) + } + } + + if err := tx.Commit(); err != nil { + log.Printf("ERROR: committing log batch: %v", err) + return + } + + if l.OnFlush != nil { + l.OnFlush() + } +} diff --git a/llm-gateway/internal/storage/migrations/001_init.down.sql b/llm-gateway/internal/storage/migrations/001_init.down.sql new file mode 100644 index 0000000..bd1fad9 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/001_init.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS request_logs; diff --git a/llm-gateway/internal/storage/migrations/001_init.up.sql b/llm-gateway/internal/storage/migrations/001_init.up.sql new file mode 100644 index 0000000..610d375 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/001_init.up.sql @@ -0,0 +1,20 @@ +CREATE TABLE IF NOT EXISTS request_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp INTEGER NOT NULL, + token_name TEXT NOT NULL, + model TEXT NOT NULL, + provider TEXT NOT NULL, + provider_model TEXT NOT NULL, + input_tokens INTEGER DEFAULT 0, + output_tokens INTEGER DEFAULT 0, + cost_usd REAL DEFAULT 0, + latency_ms INTEGER DEFAULT 0, + status TEXT NOT NULL, + error_message TEXT DEFAULT '', + streaming INTEGER DEFAULT 0, + cached INTEGER DEFAULT 0 +); + +CREATE INDEX IF NOT EXISTS idx_timestamp ON request_logs(timestamp); +CREATE INDEX IF NOT EXISTS idx_token ON request_logs(token_name); +CREATE INDEX IF NOT EXISTS idx_model ON request_logs(model); diff --git a/llm-gateway/internal/storage/migrations/002_users.down.sql b/llm-gateway/internal/storage/migrations/002_users.down.sql new file mode 100644 index 0000000..12f43a1 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/002_users.down.sql @@ -0,0 +1,3 @@ +DROP TABLE IF EXISTS api_tokens; +DROP TABLE IF EXISTS sessions; +DROP TABLE IF EXISTS users; diff --git a/llm-gateway/internal/storage/migrations/002_users.up.sql b/llm-gateway/internal/storage/migrations/002_users.up.sql new file mode 100644 index 0000000..fb8eb11 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/002_users.up.sql @@ -0,0 +1,33 @@ +CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + is_admin INTEGER NOT NULL DEFAULT 0, + totp_secret TEXT DEFAULT '', + totp_enabled INTEGER NOT NULL DEFAULT 0, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL +); + +CREATE TABLE sessions ( + id TEXT PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + created_at INTEGER NOT NULL, + expires_at INTEGER NOT NULL +); +CREATE INDEX idx_sessions_user ON sessions(user_id); +CREATE INDEX idx_sessions_expires ON sessions(expires_at); + +CREATE TABLE api_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + key_hash TEXT NOT NULL, + key_prefix TEXT NOT NULL, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + rate_limit_rpm INTEGER DEFAULT 60, + daily_budget_usd REAL DEFAULT 0, + created_at INTEGER NOT NULL, + last_used_at INTEGER DEFAULT 0 +); +CREATE UNIQUE INDEX idx_api_tokens_hash ON api_tokens(key_hash); +CREATE INDEX idx_api_tokens_user ON api_tokens(user_id); diff --git a/llm-gateway/internal/storage/migrations/003_user_email.down.sql b/llm-gateway/internal/storage/migrations/003_user_email.down.sql new file mode 100644 index 0000000..e00308d --- /dev/null +++ b/llm-gateway/internal/storage/migrations/003_user_email.down.sql @@ -0,0 +1,5 @@ +-- SQLite doesn't support DROP COLUMN before 3.35.0, so we recreate +CREATE TABLE users_backup AS SELECT id, username, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users; +DROP TABLE users; +ALTER TABLE users_backup RENAME TO users; +CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username); diff --git a/llm-gateway/internal/storage/migrations/003_user_email.up.sql b/llm-gateway/internal/storage/migrations/003_user_email.up.sql new file mode 100644 index 0000000..9468211 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/003_user_email.up.sql @@ -0,0 +1 @@ +ALTER TABLE users ADD COLUMN email TEXT DEFAULT ''; diff --git a/llm-gateway/internal/storage/migrations/embed.go b/llm-gateway/internal/storage/migrations/embed.go new file mode 100644 index 0000000..91cca1c --- /dev/null +++ b/llm-gateway/internal/storage/migrations/embed.go @@ -0,0 +1,6 @@ +package migrations + +import "embed" + +//go:embed *.sql +var FS embed.FS diff --git a/monitoring/prometheus.yml b/monitoring/prometheus.yml index a992d6a..37588dc 100644 --- a/monitoring/prometheus.yml +++ b/monitoring/prometheus.yml @@ -2,9 +2,9 @@ global: scrape_interval: 30s scrape_configs: - - job_name: 'new-api' + - job_name: 'llm-gateway' static_configs: - - targets: ['new-api:3000'] + - targets: ['llm-gateway:3000'] - job_name: 'node' static_configs: