290 lines
8.2 KiB
Go
290 lines
8.2 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
|
|
"llm-gateway/internal/auth"
|
|
"llm-gateway/internal/cache"
|
|
"llm-gateway/internal/config"
|
|
"llm-gateway/internal/dashboard"
|
|
"llm-gateway/internal/metrics"
|
|
"llm-gateway/internal/pricing"
|
|
"llm-gateway/internal/provider"
|
|
"llm-gateway/internal/proxy"
|
|
"llm-gateway/internal/storage"
|
|
)
|
|
|
|
var version = "dev"
|
|
|
|
func main() {
|
|
configPath := flag.String("config", "configs/config.yaml", "path to config file")
|
|
flag.Parse()
|
|
|
|
log.Printf("llm-gateway %s starting", version)
|
|
|
|
cfg, err := config.Load(*configPath)
|
|
if err != nil {
|
|
log.Fatalf("Failed to load config: %v", err)
|
|
}
|
|
|
|
// Pricing lookup (fetches from URL, refreshes periodically)
|
|
pricingLookup := pricing.NewLookup(cfg.Pricing.URL, cfg.Pricing.RefreshInterval)
|
|
defer pricingLookup.Close()
|
|
|
|
// Auto-fill missing pricing from fetched data
|
|
for i, m := range cfg.Models {
|
|
for j, r := range m.Routes {
|
|
if r.Pricing.Input == 0 && r.Pricing.Output == 0 {
|
|
if pricingLookup.FillMissing(r.Provider, r.Model, &cfg.Models[i].Routes[j].Pricing.Input, &cfg.Models[i].Routes[j].Pricing.Output) {
|
|
log.Printf("Auto-filled pricing for %s via %s: $%.2f/$%.2f per 1M tokens",
|
|
m.Name, r.Provider, cfg.Models[i].Routes[j].Pricing.Input, cfg.Models[i].Routes[j].Pricing.Output)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Database
|
|
db, err := storage.Open(cfg.Database.Path)
|
|
if err != nil {
|
|
log.Fatalf("Failed to open database: %v", err)
|
|
}
|
|
defer db.Close()
|
|
|
|
if err := db.CleanupOldRecords(cfg.Database.RetentionDays); err != nil {
|
|
log.Printf("WARNING: retention cleanup failed: %v", err)
|
|
}
|
|
|
|
asyncLogger := storage.NewAsyncLogger(db, 1000)
|
|
defer asyncLogger.Close()
|
|
|
|
// SSE broker for real-time dashboard updates
|
|
sseBroker := dashboard.NewSSEBroker()
|
|
asyncLogger.OnFlush = sseBroker.Notify
|
|
|
|
// Cache (optional)
|
|
var c *cache.Cache
|
|
if cfg.Cache.Enabled {
|
|
c, err = cache.New(cfg.Cache.Address, cfg.Cache.TTL)
|
|
if err != nil {
|
|
log.Printf("WARNING: cache disabled: %v", err)
|
|
} else {
|
|
log.Printf("Cache connected to %s", cfg.Cache.Address)
|
|
}
|
|
}
|
|
|
|
// Provider registry
|
|
registry, err := provider.NewRegistry(cfg)
|
|
if err != nil {
|
|
log.Fatalf("Failed to build provider registry: %v", err)
|
|
}
|
|
log.Printf("Registered %d models", len(cfg.Models))
|
|
|
|
// Provider health tracker
|
|
healthTracker := provider.NewHealthTracker(5 * time.Minute)
|
|
|
|
// Auth store (static tokens checked in-memory, not seeded to DB)
|
|
var staticTokens []auth.StaticToken
|
|
for _, t := range cfg.Tokens {
|
|
if t.Key != "" {
|
|
staticTokens = append(staticTokens, auth.StaticToken{
|
|
Name: t.Name,
|
|
Key: t.Key,
|
|
RateLimitRPM: t.RateLimitRPM,
|
|
DailyBudgetUSD: t.DailyBudgetUSD,
|
|
})
|
|
log.Printf("Loaded static token: %s", t.Name)
|
|
}
|
|
}
|
|
authStore := auth.NewStore(db.DB, staticTokens)
|
|
authMiddleware := auth.NewMiddleware(authStore)
|
|
authHandlers := auth.NewHandlers(authStore, cfg.Server.SessionSecret)
|
|
|
|
// Seed default admin
|
|
seedDefaultAdmin(cfg, authStore)
|
|
|
|
// Metrics
|
|
m := metrics.New()
|
|
|
|
// Handlers
|
|
proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker)
|
|
modelsHandler := proxy.NewModelsHandler(registry)
|
|
proxyAuth := proxy.NewAuthMiddleware(authStore)
|
|
rateLimiter := proxy.NewRateLimiter(db)
|
|
statsAPI := dashboard.NewStatsAPI(db, authStore)
|
|
statsAPI.SetHealthTracker(healthTracker)
|
|
if c != nil {
|
|
statsAPI.SetCache(c)
|
|
}
|
|
dash := dashboard.NewDashboard(authStore, statsAPI)
|
|
dash.SetRegistry(registry)
|
|
if c != nil {
|
|
dash.SetCache(c)
|
|
}
|
|
|
|
// Router
|
|
r := chi.NewRouter()
|
|
r.Use(middleware.RealIP)
|
|
r.Use(middleware.Recoverer)
|
|
r.Use(middleware.RequestID)
|
|
|
|
// Health & metrics (public)
|
|
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
|
|
if err := db.Ping(); err != nil {
|
|
http.Error(w, "database unhealthy", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
if c != nil {
|
|
if err := c.Ping(r.Context()); err != nil {
|
|
http.Error(w, "cache unhealthy", http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
})
|
|
r.Handle("/metrics", promhttp.Handler())
|
|
|
|
// OpenAI-compatible API (API token auth via Bearer header)
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(proxyAuth.Authenticate)
|
|
r.Use(rateLimiter.Check)
|
|
r.Post("/v1/chat/completions", proxyHandler.ChatCompletions)
|
|
r.Get("/v1/models", modelsHandler.ListModels)
|
|
})
|
|
|
|
// Auth pages (public)
|
|
r.Get("/login", dash.LoginPage)
|
|
r.Get("/setup", dash.SetupPage)
|
|
|
|
// Auth API endpoints (public)
|
|
r.Post("/api/auth/login", authHandlers.Login)
|
|
r.Post("/api/auth/setup", authHandlers.Setup)
|
|
r.Post("/api/auth/login/totp", authHandlers.LoginTOTP)
|
|
|
|
// Root redirect
|
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
|
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
|
})
|
|
|
|
// Authenticated pages and API
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(authMiddleware.RequireAuth)
|
|
|
|
// Dashboard pages (HTMX)
|
|
r.Get("/dashboard", dash.DashboardPage)
|
|
r.Get("/logs", dash.LogsPage)
|
|
r.Get("/models", dash.ModelsPage)
|
|
r.Get("/tokens", dash.TokensPage)
|
|
r.Get("/settings", dash.SettingsPage)
|
|
|
|
// Admin-only pages
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(authMiddleware.RequireAdmin)
|
|
r.Get("/users", dash.UsersPage)
|
|
})
|
|
|
|
// Auth API
|
|
r.Post("/api/auth/logout", authHandlers.Logout)
|
|
r.Get("/api/auth/me", authHandlers.Me)
|
|
r.Put("/api/auth/me/password", authHandlers.ChangePassword)
|
|
r.Put("/api/auth/me/username", authHandlers.ChangeUsername)
|
|
r.Put("/api/auth/me/email", authHandlers.ChangeEmail)
|
|
r.Post("/api/auth/totp/setup", authHandlers.TOTPSetup)
|
|
r.Post("/api/auth/totp/verify", authHandlers.TOTPVerify)
|
|
r.Delete("/api/auth/totp", authHandlers.TOTPDisable)
|
|
|
|
// API token management
|
|
r.Get("/api/tokens", authHandlers.ListTokens)
|
|
r.Post("/api/tokens", authHandlers.CreateToken)
|
|
r.Delete("/api/tokens/{id}", authHandlers.DeleteToken)
|
|
|
|
// SSE events
|
|
r.Get("/api/events", sseBroker.ServeHTTP)
|
|
|
|
// Dashboard stats
|
|
r.Get("/api/stats/summary", statsAPI.Summary)
|
|
r.Get("/api/stats/models", statsAPI.Models)
|
|
r.Get("/api/stats/providers", statsAPI.Providers)
|
|
r.Get("/api/stats/tokens", statsAPI.Tokens)
|
|
r.Get("/api/stats/timeseries", statsAPI.Timeseries)
|
|
r.Get("/api/stats/logs", statsAPI.Logs)
|
|
r.Get("/api/stats/latency", statsAPI.Latency)
|
|
r.Get("/api/stats/cost-breakdown", statsAPI.CostBreakdown)
|
|
r.Get("/api/stats/provider-health", statsAPI.ProviderHealthHandler)
|
|
r.Get("/api/stats/cache", statsAPI.CacheStats)
|
|
|
|
// Admin-only: user management
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(authMiddleware.RequireAdmin)
|
|
r.Get("/api/auth/users", authHandlers.ListUsers)
|
|
r.Post("/api/auth/users", authHandlers.CreateUser)
|
|
r.Delete("/api/auth/users/{id}", authHandlers.DeleteUser)
|
|
})
|
|
})
|
|
|
|
// Periodic session cleanup
|
|
go func() {
|
|
ticker := time.NewTicker(1 * time.Hour)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
if err := authStore.CleanExpiredSessions(); err != nil {
|
|
log.Printf("WARNING: session cleanup failed: %v", err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Server
|
|
srv := &http.Server{
|
|
Addr: cfg.Server.Listen,
|
|
Handler: r,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: cfg.Server.RequestTimeout + 10*time.Second,
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
|
|
// Graceful shutdown
|
|
done := make(chan os.Signal, 1)
|
|
signal.Notify(done, os.Interrupt, syscall.SIGTERM)
|
|
|
|
go func() {
|
|
log.Printf("Listening on %s", cfg.Server.Listen)
|
|
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Fatalf("Server failed: %v", err)
|
|
}
|
|
}()
|
|
|
|
<-done
|
|
log.Println("Shutting down...")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
srv.Shutdown(ctx)
|
|
|
|
log.Println("Stopped")
|
|
}
|
|
|
|
// seedDefaultAdmin creates the default admin user if no users exist.
|
|
func seedDefaultAdmin(cfg *config.Config, authStore *auth.Store) {
|
|
if !authStore.HasAnyUser() {
|
|
da := cfg.Server.DefaultAdmin
|
|
if da.Username != "" && da.Password != "" {
|
|
user, err := authStore.CreateUser(da.Username, da.Password, true)
|
|
if err != nil {
|
|
log.Printf("WARNING: failed to create default admin: %v", err)
|
|
} else {
|
|
log.Printf("Created default admin user: %s (id=%d)", user.Username, user.ID)
|
|
}
|
|
}
|
|
}
|
|
}
|