package main import ( "context" "flag" "log" "net/http" "os" "os/signal" "syscall" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/prometheus/client_golang/prometheus/promhttp" "llm-gateway/internal/auth" "llm-gateway/internal/cache" "llm-gateway/internal/config" "llm-gateway/internal/dashboard" "llm-gateway/internal/metrics" "llm-gateway/internal/pricing" "llm-gateway/internal/provider" "llm-gateway/internal/proxy" "llm-gateway/internal/storage" ) var version = "dev" func main() { configPath := flag.String("config", "configs/config.yaml", "path to config file") flag.Parse() log.Printf("llm-gateway %s starting", version) cfg, err := config.Load(*configPath) if err != nil { log.Fatalf("Failed to load config: %v", err) } // Pricing lookup (fetches from URL, refreshes periodically) pricingLookup := pricing.NewLookup(cfg.Pricing.URL, cfg.Pricing.RefreshInterval) defer pricingLookup.Close() // Auto-fill missing pricing from fetched data for i, m := range cfg.Models { for j, r := range m.Routes { if r.Pricing.Input == 0 && r.Pricing.Output == 0 { if pricingLookup.FillMissing(r.Provider, r.Model, &cfg.Models[i].Routes[j].Pricing.Input, &cfg.Models[i].Routes[j].Pricing.Output) { log.Printf("Auto-filled pricing for %s via %s: $%.2f/$%.2f per 1M tokens", m.Name, r.Provider, cfg.Models[i].Routes[j].Pricing.Input, cfg.Models[i].Routes[j].Pricing.Output) } } } } // Database db, err := storage.Open(cfg.Database.Path) if err != nil { log.Fatalf("Failed to open database: %v", err) } defer db.Close() if err := db.CleanupOldRecords(cfg.Database.RetentionDays); err != nil { log.Printf("WARNING: retention cleanup failed: %v", err) } asyncLogger := storage.NewAsyncLogger(db, 1000) defer asyncLogger.Close() // SSE broker for real-time dashboard updates sseBroker := dashboard.NewSSEBroker() asyncLogger.OnFlush = sseBroker.Notify // Cache (optional) var c *cache.Cache if cfg.Cache.Enabled { c, err = cache.New(cfg.Cache.Address, cfg.Cache.TTL) if err != nil { log.Printf("WARNING: cache disabled: %v", err) } else { log.Printf("Cache connected to %s", cfg.Cache.Address) } } // Provider registry registry, err := provider.NewRegistry(cfg) if err != nil { log.Fatalf("Failed to build provider registry: %v", err) } log.Printf("Registered %d models", len(cfg.Models)) // Auth store (static tokens checked in-memory, not seeded to DB) var staticTokens []auth.StaticToken for _, t := range cfg.Tokens { if t.Key != "" { staticTokens = append(staticTokens, auth.StaticToken{ Name: t.Name, Key: t.Key, RateLimitRPM: t.RateLimitRPM, DailyBudgetUSD: t.DailyBudgetUSD, }) log.Printf("Loaded static token: %s", t.Name) } } authStore := auth.NewStore(db.DB, staticTokens) authMiddleware := auth.NewMiddleware(authStore) authHandlers := auth.NewHandlers(authStore, cfg.Server.SessionSecret) // Seed default admin seedDefaultAdmin(cfg, authStore) // Metrics m := metrics.New() // Handlers proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg) modelsHandler := proxy.NewModelsHandler(registry) proxyAuth := proxy.NewAuthMiddleware(authStore) rateLimiter := proxy.NewRateLimiter(db) statsAPI := dashboard.NewStatsAPI(db, authStore) dash := dashboard.NewDashboard(authStore, statsAPI) // Router r := chi.NewRouter() r.Use(middleware.RealIP) r.Use(middleware.Recoverer) r.Use(middleware.RequestID) // Health & metrics (public) r.Get("/health", func(w http.ResponseWriter, r *http.Request) { if err := db.Ping(); err != nil { http.Error(w, "database unhealthy", http.StatusServiceUnavailable) return } if c != nil { if err := c.Ping(r.Context()); err != nil { http.Error(w, "cache unhealthy", http.StatusServiceUnavailable) return } } w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) }) r.Handle("/metrics", promhttp.Handler()) // OpenAI-compatible API (API token auth via Bearer header) r.Group(func(r chi.Router) { r.Use(proxyAuth.Authenticate) r.Use(rateLimiter.Check) r.Post("/v1/chat/completions", proxyHandler.ChatCompletions) r.Get("/v1/models", modelsHandler.ListModels) }) // Auth pages (public) r.Get("/login", dash.LoginPage) r.Get("/setup", dash.SetupPage) // Auth API endpoints (public) r.Post("/api/auth/login", authHandlers.Login) r.Post("/api/auth/setup", authHandlers.Setup) r.Post("/api/auth/login/totp", authHandlers.LoginTOTP) // Root redirect r.Get("/", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/dashboard", http.StatusFound) }) // Authenticated pages and API r.Group(func(r chi.Router) { r.Use(authMiddleware.RequireAuth) // Dashboard pages (HTMX) r.Get("/dashboard", dash.DashboardPage) r.Get("/tokens", dash.TokensPage) r.Get("/settings", dash.SettingsPage) // Admin-only pages r.Group(func(r chi.Router) { r.Use(authMiddleware.RequireAdmin) r.Get("/users", dash.UsersPage) }) // Auth API r.Post("/api/auth/logout", authHandlers.Logout) r.Get("/api/auth/me", authHandlers.Me) r.Put("/api/auth/me/password", authHandlers.ChangePassword) r.Put("/api/auth/me/username", authHandlers.ChangeUsername) r.Put("/api/auth/me/email", authHandlers.ChangeEmail) r.Post("/api/auth/totp/setup", authHandlers.TOTPSetup) r.Post("/api/auth/totp/verify", authHandlers.TOTPVerify) r.Delete("/api/auth/totp", authHandlers.TOTPDisable) // API token management r.Get("/api/tokens", authHandlers.ListTokens) r.Post("/api/tokens", authHandlers.CreateToken) r.Delete("/api/tokens/{id}", authHandlers.DeleteToken) // SSE events r.Get("/api/events", sseBroker.ServeHTTP) // Dashboard stats r.Get("/api/stats/summary", statsAPI.Summary) r.Get("/api/stats/models", statsAPI.Models) r.Get("/api/stats/providers", statsAPI.Providers) r.Get("/api/stats/tokens", statsAPI.Tokens) r.Get("/api/stats/timeseries", statsAPI.Timeseries) // Admin-only: user management r.Group(func(r chi.Router) { r.Use(authMiddleware.RequireAdmin) r.Get("/api/auth/users", authHandlers.ListUsers) r.Post("/api/auth/users", authHandlers.CreateUser) r.Delete("/api/auth/users/{id}", authHandlers.DeleteUser) }) }) // Periodic session cleanup go func() { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() for range ticker.C { if err := authStore.CleanExpiredSessions(); err != nil { log.Printf("WARNING: session cleanup failed: %v", err) } } }() // Server srv := &http.Server{ Addr: cfg.Server.Listen, Handler: r, ReadTimeout: 30 * time.Second, WriteTimeout: cfg.Server.RequestTimeout + 10*time.Second, IdleTimeout: 120 * time.Second, } // Graceful shutdown done := make(chan os.Signal, 1) signal.Notify(done, os.Interrupt, syscall.SIGTERM) go func() { log.Printf("Listening on %s", cfg.Server.Listen) if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Server failed: %v", err) } }() <-done log.Println("Shutting down...") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() srv.Shutdown(ctx) log.Println("Stopped") } // seedDefaultAdmin creates the default admin user if no users exist. func seedDefaultAdmin(cfg *config.Config, authStore *auth.Store) { if !authStore.HasAnyUser() { da := cfg.Server.DefaultAdmin if da.Username != "" && da.Password != "" { user, err := authStore.CreateUser(da.Username, da.Password, true) if err != nil { log.Printf("WARNING: failed to create default admin: %v", err) } else { log.Printf("Created default admin user: %s (id=%d)", user.Username, user.ID) } } } }