From 291b8f4863bbab365c92196a3f6d076aa960dc1f Mon Sep 17 00:00:00 2001 From: Ray Andrew Date: Sun, 15 Feb 2026 04:52:09 -0600 Subject: [PATCH] feat(gateway): add monthly budget support and webhook notifications for circuit breaker and budget events --- llm-gateway/cmd/gateway/main.go | 68 +++++-- llm-gateway/go.mod | 2 +- llm-gateway/internal/auth/store.go | 72 +++---- llm-gateway/internal/auth/store_test.go | 1 + llm-gateway/internal/config/config.go | 77 +++++-- llm-gateway/internal/config/config_test.go | 1 - llm-gateway/internal/dashboard/api.go | 29 +++ llm-gateway/internal/dashboard/handler.go | 8 +- .../templates/partials/models-page.html | 16 +- .../templates/partials/settings.html | 24 +++ .../dashboard/templates/partials/tokens.html | 10 +- llm-gateway/internal/provider/health.go | 24 ++- llm-gateway/internal/provider/health_test.go | 14 +- llm-gateway/internal/provider/openai.go | 42 ++++ llm-gateway/internal/provider/provider.go | 29 +++ llm-gateway/internal/provider/registry.go | 30 +++ .../internal/provider/registry_test.go | 4 + llm-gateway/internal/proxy/dedup.go | 107 ++++++++++ llm-gateway/internal/proxy/dedup_test.go | 74 +++++++ llm-gateway/internal/proxy/handler.go | 190 +++++++++++++++++- llm-gateway/internal/proxy/models.go | 61 +++++- llm-gateway/internal/proxy/ratelimit.go | 59 +++++- llm-gateway/internal/proxy/ratelimit_test.go | 22 +- llm-gateway/internal/proxy/stream.go | 12 +- llm-gateway/internal/storage/db.go | 15 ++ llm-gateway/internal/storage/logger.go | 11 +- .../009_token_monthly_budget.down.sql | 1 + .../009_token_monthly_budget.up.sql | 1 + .../migrations/010_request_type.down.sql | 1 + .../migrations/010_request_type.up.sql | 1 + llm-gateway/internal/webhook/webhook.go | 123 ++++++++++++ 31 files changed, 1005 insertions(+), 124 deletions(-) create mode 100644 llm-gateway/internal/proxy/dedup.go create mode 100644 llm-gateway/internal/proxy/dedup_test.go create mode 100644 llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql create mode 100644 llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql create mode 100644 llm-gateway/internal/storage/migrations/010_request_type.down.sql create mode 100644 llm-gateway/internal/storage/migrations/010_request_type.up.sql create mode 100644 llm-gateway/internal/webhook/webhook.go diff --git a/llm-gateway/cmd/gateway/main.go b/llm-gateway/cmd/gateway/main.go index 8bd10b2..836781a 100644 --- a/llm-gateway/cmd/gateway/main.go +++ b/llm-gateway/cmd/gateway/main.go @@ -25,6 +25,7 @@ import ( "llm-gateway/internal/provider" "llm-gateway/internal/proxy" "llm-gateway/internal/storage" + "llm-gateway/internal/webhook" ) var version = "dev" @@ -95,16 +96,41 @@ func main() { // Provider health tracker healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker) + // Webhook notifier + var notifier *webhook.Notifier + if len(cfg.Webhooks) > 0 { + notifier = webhook.NewNotifier(cfg.Webhooks) + defer notifier.Close() + log.Printf("Webhooks configured: %d endpoints", len(cfg.Webhooks)) + + // Wire health tracker state changes to webhook + healthTracker.OnStateChange = func(providerName string, from, to provider.CircuitState) { + eventType := webhook.EventCircuitBreakerOpen + if to == provider.CircuitClosed { + eventType = webhook.EventCircuitBreakerClosed + } + notifier.Notify(webhook.Event{ + Type: eventType, + Data: map[string]any{ + "provider": providerName, + "from": from.String(), + "to": to.String(), + }, + }) + } + } + // Auth store (static tokens checked in-memory, not seeded to DB) var staticTokens []auth.StaticToken for _, t := range cfg.Tokens { if t.Key != "" { staticTokens = append(staticTokens, auth.StaticToken{ - Name: t.Name, - Key: t.Key, - RateLimitRPM: t.RateLimitRPM, - DailyBudgetUSD: t.DailyBudgetUSD, - MaxConcurrent: t.MaxConcurrent, + Name: t.Name, + Key: t.Key, + RateLimitRPM: t.RateLimitRPM, + DailyBudgetUSD: t.DailyBudgetUSD, + MonthlyBudgetUSD: t.MonthlyBudgetUSD, + MaxConcurrent: t.MaxConcurrent, }) log.Printf("Loaded static token: %s", t.Name) } @@ -133,14 +159,27 @@ func main() { // Handlers proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker) proxyHandler.SetDebugLogger(debugLogger) - modelsHandler := proxy.NewModelsHandler(registry) + + // Request deduplication + if cfg.Dedup.Enabled { + dedup := proxy.NewDeduplicator(cfg.Dedup.Window) + defer dedup.Close() + proxyHandler.SetDeduplicator(dedup) + log.Printf("Request deduplication enabled (window: %v)", cfg.Dedup.Window) + } + + modelsHandler := proxy.NewModelsHandler(registry, healthTracker, cfg) proxyAuth := proxy.NewAuthMiddleware(authStore) rateLimiter := proxy.NewRateLimiter(db) + if notifier != nil { + rateLimiter.SetNotifier(notifier) + } concurrencyLimiter := proxy.NewConcurrencyLimiter() statsAPI := dashboard.NewStatsAPI(db, authStore) statsAPI.SetHealthTracker(healthTracker) statsAPI.SetAuditLogger(auditLogger) statsAPI.SetDebugLogger(debugLogger) + statsAPI.SetConfigPath(*configPath) if c != nil { statsAPI.SetCache(c) } @@ -196,6 +235,7 @@ func main() { r.Use(rateLimiter.Check) r.Use(concurrencyLimiter.Check) r.Post("/v1/chat/completions", proxyHandler.ChatCompletions) + r.Post("/v1/embeddings", proxyHandler.Embeddings) r.Get("/v1/models", modelsHandler.ListModels) }) @@ -266,7 +306,7 @@ func main() { r.Get("/api/export/logs", exportHandler.ExportLogs) r.Get("/api/export/stats", exportHandler.ExportStats) - // Admin-only: user management, audit, debug + // Admin-only: user management, audit, debug, config validation r.Group(func(r chi.Router) { r.Use(authMiddleware.RequireAdmin) r.Get("/api/auth/users", authHandlers.ListUsers) @@ -276,6 +316,9 @@ func main() { // Audit log r.Get("/api/stats/audit", statsAPI.AuditLogs) + // Config validation + r.Get("/api/config/validate", statsAPI.ValidateConfig) + // Debug logging r.Post("/api/debug/toggle", statsAPI.DebugToggle) r.Get("/api/debug/status", statsAPI.DebugStatus) @@ -332,11 +375,12 @@ func main() { for _, t := range newCfg.Tokens { if t.Key != "" { newStaticTokens = append(newStaticTokens, auth.StaticToken{ - Name: t.Name, - Key: t.Key, - RateLimitRPM: t.RateLimitRPM, - DailyBudgetUSD: t.DailyBudgetUSD, - MaxConcurrent: t.MaxConcurrent, + Name: t.Name, + Key: t.Key, + RateLimitRPM: t.RateLimitRPM, + DailyBudgetUSD: t.DailyBudgetUSD, + MonthlyBudgetUSD: t.MonthlyBudgetUSD, + MaxConcurrent: t.MaxConcurrent, }) } } diff --git a/llm-gateway/go.mod b/llm-gateway/go.mod index 231f0a7..08c7f3b 100644 --- a/llm-gateway/go.mod +++ b/llm-gateway/go.mod @@ -4,6 +4,7 @@ go 1.24.0 require ( github.com/go-chi/chi/v5 v5.2.5 + github.com/go-chi/cors v1.2.2 github.com/golang-migrate/migrate/v4 v4.19.1 github.com/pquerna/otp v1.5.0 github.com/prometheus/client_golang v1.23.2 @@ -19,7 +20,6 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/go-chi/cors v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/llm-gateway/internal/auth/store.go b/llm-gateway/internal/auth/store.go index cfeec0c..4f62ff7 100644 --- a/llm-gateway/internal/auth/store.go +++ b/llm-gateway/internal/auth/store.go @@ -31,25 +31,27 @@ type Session struct { } type APIToken struct { - ID int64 `json:"id"` - Name string `json:"name"` - KeyPrefix string `json:"key_prefix"` - KeyHash string `json:"-"` - UserID int64 `json:"user_id"` - RateLimitRPM int `json:"rate_limit_rpm"` - DailyBudgetUSD float64 `json:"daily_budget_usd"` - MaxConcurrent int `json:"max_concurrent"` - CreatedAt int64 `json:"created_at"` - LastUsedAt int64 `json:"last_used_at"` + ID int64 `json:"id"` + Name string `json:"name"` + KeyPrefix string `json:"key_prefix"` + KeyHash string `json:"-"` + UserID int64 `json:"user_id"` + RateLimitRPM int `json:"rate_limit_rpm"` + DailyBudgetUSD float64 `json:"daily_budget_usd"` + MonthlyBudgetUSD float64 `json:"monthly_budget_usd"` + MaxConcurrent int `json:"max_concurrent"` + CreatedAt int64 `json:"created_at"` + LastUsedAt int64 `json:"last_used_at"` } // StaticToken represents a token defined in config (checked in-memory, never stored in DB). type StaticToken struct { - Name string - Key string - RateLimitRPM int - DailyBudgetUSD float64 - MaxConcurrent int + Name string + Key string + RateLimitRPM int + DailyBudgetUSD float64 + MonthlyBudgetUSD float64 + MaxConcurrent int } type Store struct { @@ -289,12 +291,13 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) { prefix = prefix[:11] } return &APIToken{ - ID: -1, // sentinel: static token - Name: st.Name, - KeyPrefix: prefix, - RateLimitRPM: st.RateLimitRPM, - DailyBudgetUSD: st.DailyBudgetUSD, - MaxConcurrent: st.MaxConcurrent, + ID: -1, // sentinel: static token + Name: st.Name, + KeyPrefix: prefix, + RateLimitRPM: st.RateLimitRPM, + DailyBudgetUSD: st.DailyBudgetUSD, + MonthlyBudgetUSD: st.MonthlyBudgetUSD, + MaxConcurrent: st.MaxConcurrent, }, nil } } @@ -305,9 +308,9 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) { var t APIToken err := s.db.QueryRow( - "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?", + "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?", keyHash, - ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) + ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) if err != nil { return nil, err } @@ -323,12 +326,13 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) { prefix = prefix[:11] } tokens = append(tokens, APIToken{ - ID: -1, // sentinel: static token - Name: st.Name, - KeyPrefix: prefix, - RateLimitRPM: st.RateLimitRPM, - DailyBudgetUSD: st.DailyBudgetUSD, - MaxConcurrent: st.MaxConcurrent, + ID: -1, // sentinel: static token + Name: st.Name, + KeyPrefix: prefix, + RateLimitRPM: st.RateLimitRPM, + DailyBudgetUSD: st.DailyBudgetUSD, + MonthlyBudgetUSD: st.MonthlyBudgetUSD, + MaxConcurrent: st.MaxConcurrent, }) } @@ -336,9 +340,9 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) { var rows *sql.Rows var err error if userID == 0 { - rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id") + rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id") } else { - rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID) + rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID) } if err != nil { return tokens, nil @@ -347,7 +351,7 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) { for rows.Next() { var t APIToken - if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil { + if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil { return tokens, nil } tokens = append(tokens, t) @@ -363,9 +367,9 @@ func (s *Store) DeleteAPIToken(id int64) error { func (s *Store) GetAPIToken(id int64) (*APIToken, error) { var t APIToken err := s.db.QueryRow( - "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?", + "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?", id, - ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) + ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) if err != nil { return nil, err } diff --git a/llm-gateway/internal/auth/store_test.go b/llm-gateway/internal/auth/store_test.go index fc4678f..d566224 100644 --- a/llm-gateway/internal/auth/store_test.go +++ b/llm-gateway/internal/auth/store_test.go @@ -42,6 +42,7 @@ func setupTestDB(t *testing.T) *sql.DB { user_id INTEGER NOT NULL, rate_limit_rpm INTEGER DEFAULT 0, daily_budget_usd REAL DEFAULT 0, + monthly_budget_usd REAL DEFAULT 0, max_concurrent INTEGER DEFAULT 0, created_at INTEGER NOT NULL, last_used_at INTEGER DEFAULT 0 diff --git a/llm-gateway/internal/config/config.go b/llm-gateway/internal/config/config.go index 3800eea..ffa934b 100644 --- a/llm-gateway/internal/config/config.go +++ b/llm-gateway/internal/config/config.go @@ -20,11 +20,24 @@ type Config struct { Retry RetryConfig `yaml:"retry"` Debug DebugConfig `yaml:"debug"` CORS CORSConfig `yaml:"cors"` + Dedup DedupConfig `yaml:"dedup"` + Webhooks []WebhookConfig `yaml:"webhooks"` Providers []ProviderConfig `yaml:"providers"` Models []ModelConfig `yaml:"models"` Tokens []TokenConfig `yaml:"tokens"` } +type DedupConfig struct { + Enabled bool `yaml:"enabled"` + Window time.Duration `yaml:"window"` // max time to wait for dedup result +} + +type WebhookConfig struct { + URL string `yaml:"url"` + Events []string `yaml:"events"` // event types to send + Secret string `yaml:"secret"` // optional HMAC secret +} + type PricingLookupConfig struct { URL string `yaml:"url"` RefreshInterval time.Duration `yaml:"refresh_interval"` @@ -36,20 +49,21 @@ type DefaultAdminConfig struct { } type TokenConfig struct { - Name string `yaml:"name"` - Key string `yaml:"key"` - RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited - DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited - MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited + Name string `yaml:"name"` + Key string `yaml:"key"` + RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited + DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited + MonthlyBudgetUSD float64 `yaml:"monthly_budget_usd"` // 0 = unlimited + MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited } type ServerConfig struct { - Listen string `yaml:"listen"` - RequestTimeout time.Duration `yaml:"request_timeout"` - StreamingTimeout time.Duration `yaml:"streaming_timeout"` - MaxRequestBodyMB int `yaml:"max_request_body_mb"` - SessionSecret string `yaml:"session_secret"` - DefaultAdmin DefaultAdminConfig `yaml:"default_admin"` + Listen string `yaml:"listen"` + RequestTimeout time.Duration `yaml:"request_timeout"` + StreamingTimeout time.Duration `yaml:"streaming_timeout"` + MaxRequestBodyMB int `yaml:"max_request_body_mb"` + SessionSecret string `yaml:"session_secret"` + DefaultAdmin DefaultAdminConfig `yaml:"default_admin"` } type CircuitBreakerConfig struct { @@ -66,10 +80,10 @@ type RetryConfig struct { } type DebugConfig struct { - Enabled bool `yaml:"enabled"` - MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies) - RetentionDays int `yaml:"retention_days"` - DataDir string `yaml:"data_dir"` + Enabled bool `yaml:"enabled"` + MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies) + RetentionDays int `yaml:"retention_days"` + DataDir string `yaml:"data_dir"` } type CORSConfig struct { @@ -100,10 +114,12 @@ type ProviderConfig struct { } type ModelConfig struct { - Name string `yaml:"name"` - Aliases []string `yaml:"aliases"` - Routes []RouteConfig `yaml:"routes"` - LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost + Name string `yaml:"name"` + Aliases []string `yaml:"aliases"` + Routes []RouteConfig `yaml:"routes"` + LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost + RequestTimeout time.Duration `yaml:"request_timeout"` // per-model override; 0 = use server default + StreamingTimeout time.Duration `yaml:"streaming_timeout"` // per-model override; 0 = use server default } type RouteConfig struct { @@ -131,14 +147,15 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("parsing config: %w", err) } - if err := cfg.validate(); err != nil { + if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validating config: %w", err) } return &cfg, nil } -func (c *Config) validate() error { +// Validate checks the config for correctness and applies defaults. +func (c *Config) Validate() error { if c.Server.Listen == "" { c.Server.Listen = "0.0.0.0:3000" } @@ -201,6 +218,11 @@ func (c *Config) validate() error { c.CORS.MaxAge = 300 } + // Dedup defaults + if c.Dedup.Window == 0 { + c.Dedup.Window = 30 * time.Second + } + if len(c.Providers) == 0 { return fmt.Errorf("at least one provider is required") } @@ -266,6 +288,19 @@ func (c *Config) validate() error { return nil } +// ValidateBytes parses raw YAML and returns a list of validation errors. +func ValidateBytes(data []byte) []string { + expanded := os.ExpandEnv(string(data)) + var cfg Config + if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil { + return []string{"parse error: " + err.Error()} + } + if err := cfg.Validate(); err != nil { + return []string{err.Error()} + } + return nil +} + // ProviderByName returns the provider config by name. func (c *Config) ProviderByName(name string) *ProviderConfig { for i := range c.Providers { diff --git a/llm-gateway/internal/config/config_test.go b/llm-gateway/internal/config/config_test.go index d40617a..3a65f7a 100644 --- a/llm-gateway/internal/config/config_test.go +++ b/llm-gateway/internal/config/config_test.go @@ -735,4 +735,3 @@ models: t.Errorf("error = %q, want to contain api_key validation message", err.Error()) } } - diff --git a/llm-gateway/internal/dashboard/api.go b/llm-gateway/internal/dashboard/api.go index 346d703..e1f4639 100644 --- a/llm-gateway/internal/dashboard/api.go +++ b/llm-gateway/internal/dashboard/api.go @@ -3,6 +3,7 @@ package dashboard import ( "encoding/json" "net/http" + "os" "sort" "strconv" "time" @@ -11,6 +12,7 @@ import ( "llm-gateway/internal/auth" "llm-gateway/internal/cache" + "llm-gateway/internal/config" "llm-gateway/internal/provider" "llm-gateway/internal/storage" ) @@ -109,6 +111,7 @@ type StatsAPI struct { cache *cache.Cache auditLogger *storage.AuditLogger debugLogger *storage.DebugLogger + configPath string } func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI { @@ -135,6 +138,11 @@ func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) { s.debugLogger = dl } +// SetConfigPath sets the config file path for validation. +func (s *StatsAPI) SetConfigPath(path string) { + s.configPath = path +} + // TokenNamesForUser returns the token names that belong to the user. // Admins get nil (no filter), non-admins get their token names. func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string { @@ -712,6 +720,27 @@ func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) { writeJSON(w, entry) } +// ValidateConfig validates the config file at the stored path. +func (s *StatsAPI) ValidateConfig(w http.ResponseWriter, r *http.Request) { + if s.configPath == "" { + w.WriteHeader(http.StatusInternalServerError) + writeJSON(w, map[string]any{"valid": false, "errors": []string{"config path not set"}}) + return + } + data, err := os.ReadFile(s.configPath) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + writeJSON(w, map[string]any{"valid": false, "errors": []string{"failed to read config: " + err.Error()}}) + return + } + errs := config.ValidateBytes(data) + if len(errs) > 0 { + writeJSON(w, map[string]any{"valid": false, "errors": errs}) + return + } + writeJSON(w, map[string]any{"valid": true, "errors": []string{}}) +} + func writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(v) diff --git a/llm-gateway/internal/dashboard/handler.go b/llm-gateway/internal/dashboard/handler.go index b637ccd..e672637 100644 --- a/llm-gateway/internal/dashboard/handler.go +++ b/llm-gateway/internal/dashboard/handler.go @@ -127,9 +127,9 @@ type PageData struct { // Models routing page data ModelRoutes []provider.ModelRouteInfo // Audit page data - AuditResult *storage.AuditQueryResult + AuditResult *storage.AuditQueryResult AuditFilterActions []string - FilterAction string + FilterAction string // Debug page data DebugResult *storage.DebugLogQueryResult DebugEnabled bool @@ -275,6 +275,10 @@ func (d *Dashboard) ModelsPage(w http.ResponseWriter, r *http.Request) { data.ModelRoutes = d.registry.AllRoutes() } + if d.statsAPI.healthTracker != nil { + data.ProviderHealth = d.statsAPI.healthTracker.Status() + } + d.renderDashboardPage(w, r, "partials/models-page.html", data) } diff --git a/llm-gateway/internal/dashboard/templates/partials/models-page.html b/llm-gateway/internal/dashboard/templates/partials/models-page.html index c81109e..6d4a60b 100644 --- a/llm-gateway/internal/dashboard/templates/partials/models-page.html +++ b/llm-gateway/internal/dashboard/templates/partials/models-page.html @@ -6,7 +6,7 @@ {{if .ModelRoutes}} {{range .ModelRoutes}}
-

{{.Name}}

+

{{.Name}}{{if .Aliases}} aliases: {{range $i, $a := .Aliases}}{{if $i}}, {{end}}{{$a}}{{end}}{{end}}

@@ -15,9 +15,11 @@ + + {{$health := $.ProviderHealth}} {{range .Routes}} @@ -25,6 +27,18 @@ + {{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/settings.html b/llm-gateway/internal/dashboard/templates/partials/settings.html index faf009b..2befd7b 100644 --- a/llm-gateway/internal/dashboard/templates/partials/settings.html +++ b/llm-gateway/internal/dashboard/templates/partials/settings.html @@ -70,6 +70,15 @@ +{{if .User.IsAdmin}} +
+

Config Validation

+

Validate the current gateway configuration file for errors.

+ +
+
+{{end}} + {{end}} diff --git a/llm-gateway/internal/dashboard/templates/partials/tokens.html b/llm-gateway/internal/dashboard/templates/partials/tokens.html index 7f330d2..58f85bd 100644 --- a/llm-gateway/internal/dashboard/templates/partials/tokens.html +++ b/llm-gateway/internal/dashboard/templates/partials/tokens.html @@ -19,7 +19,10 @@ - + - +
Priority Input Price (per 1M) Output Price (per 1M)Health
{{.ProviderName}}{{.Priority}} {{formatPrice .InputPrice}} {{formatPrice .OutputPrice}} + {{$pname := .ProviderName}} + {{range $health}} + {{if eq .Provider $pname}} + {{if eq .Status "healthy"}}healthy + {{else if eq .Status "degraded"}}degraded + {{else}}down + {{end}} + {{end}} + {{end}} + {{if not $health}}-{{end}} +
{{.Name}} {{.KeyPrefix}}... {{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}{{else}}unlimited{{end}} + {{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}} + {{if gt .MonthlyBudgetUSD 0.0}}
${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}} +
{{$spend := index $.TokenSpend .Name}} {{if gt .DailyBudgetUSD 0.0}} @@ -49,7 +52,10 @@ {{.Name}} {{.KeyPrefix}}... {{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}{{else}}unlimited{{end}} + {{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}} + {{if gt .MonthlyBudgetUSD 0.0}}
${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}} +
{{$spend := index $.TokenSpend .Name}} {{if gt .DailyBudgetUSD 0.0}} diff --git a/llm-gateway/internal/provider/health.go b/llm-gateway/internal/provider/health.go index 91cb965..ad3638a 100644 --- a/llm-gateway/internal/provider/health.go +++ b/llm-gateway/internal/provider/health.go @@ -46,8 +46,8 @@ type HealthEvent struct { // ProviderHealth is the computed health status for a provider. type ProviderHealth struct { - Provider string `json:"provider"` - Status string `json:"status"` // healthy, degraded, down + Provider string `json:"provider"` + Status string `json:"status"` // healthy, degraded, down ErrorRate float64 `json:"error_rate"` AvgLatency float64 `json:"avg_latency_ms"` Total int `json:"total"` @@ -57,11 +57,12 @@ type ProviderHealth struct { // HealthTracker tracks per-provider health using a sliding window. type HealthTracker struct { - mu sync.RWMutex - windows map[string][]HealthEvent - windowDu time.Duration - circuits map[string]*ProviderCircuit - cbConfig config.CircuitBreakerConfig + mu sync.RWMutex + windows map[string][]HealthEvent + windowDu time.Duration + circuits map[string]*ProviderCircuit + cbConfig config.CircuitBreakerConfig + OnStateChange func(provider string, from, to CircuitState) } // NewHealthTracker creates a health tracker with the given window duration. @@ -135,6 +136,8 @@ func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) { h.circuits[providerName] = circuit } + prevState := circuit.State + switch circuit.State { case CircuitClosed: // Check if error threshold exceeded @@ -164,6 +167,13 @@ func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) { circuit.OpenedAt = time.Now() } } + + if circuit.State != prevState && h.OnStateChange != nil { + cb := h.OnStateChange + from, to := prevState, circuit.State + // Call outside lock to avoid deadlocks + go cb(providerName, from, to) + } } // errorRateUnlocked computes error rate within window. Must be called with lock held. diff --git a/llm-gateway/internal/provider/health_test.go b/llm-gateway/internal/provider/health_test.go index 6d021b1..8dda99e 100644 --- a/llm-gateway/internal/provider/health_test.go +++ b/llm-gateway/internal/provider/health_test.go @@ -47,13 +47,13 @@ func TestHealthTracker_Record(t *testing.T) { func TestHealthTracker_Status(t *testing.T) { tests := []struct { - name string - successCount int - errorCount int - wantStatus string - wantErrorRate float64 - wantTotal int - wantErrors int + name string + successCount int + errorCount int + wantStatus string + wantErrorRate float64 + wantTotal int + wantErrors int }{ { name: "healthy - no errors", diff --git a/llm-gateway/internal/provider/openai.go b/llm-gateway/internal/provider/openai.go index 0e434f3..9c096b3 100644 --- a/llm-gateway/internal/provider/openai.go +++ b/llm-gateway/internal/provider/openai.go @@ -108,6 +108,48 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string, return resp.Body, nil } +func (p *OpenAIProvider) Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error) { + reqCopy := *req + reqCopy.Model = model + + body, err := json.Marshal(reqCopy) + if err != nil { + return nil, fmt.Errorf("marshaling request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/embeddings", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + p.setHeaders(httpReq) + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("sending request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, &ProviderError{ + StatusCode: resp.StatusCode, + Body: string(respBody), + Provider: p.name, + } + } + + var embResp EmbeddingResponse + if err := json.Unmarshal(respBody, &embResp); err != nil { + return nil, fmt.Errorf("unmarshaling response: %w", err) + } + + return &embResp, nil +} + func (p *OpenAIProvider) setHeaders(req *http.Request) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+p.apiKey) diff --git a/llm-gateway/internal/provider/provider.go b/llm-gateway/internal/provider/provider.go index 5686865..3d552ea 100644 --- a/llm-gateway/internal/provider/provider.go +++ b/llm-gateway/internal/provider/provider.go @@ -52,9 +52,38 @@ type Usage struct { TotalTokens int `json:"total_tokens"` } +// EmbeddingRequest is the OpenAI-compatible embedding request. +type EmbeddingRequest struct { + Model string `json:"model"` + Input any `json:"input"` // string or []string + EncodingFormat string `json:"encoding_format,omitempty"` +} + +// EmbeddingResponse is the OpenAI-compatible embedding response. +type EmbeddingResponse struct { + Object string `json:"object"` + Data []EmbeddingData `json:"data"` + Model string `json:"model"` + Usage *EmbeddingUsage `json:"usage,omitempty"` +} + +// EmbeddingData holds a single embedding vector. +type EmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +// EmbeddingUsage reports token usage for embeddings. +type EmbeddingUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + // Provider sends requests to an LLM API. type Provider interface { Name() string ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error) ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error) + Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error) } diff --git a/llm-gateway/internal/provider/registry.go b/llm-gateway/internal/provider/registry.go index 5a4dced..2096738 100644 --- a/llm-gateway/internal/provider/registry.go +++ b/llm-gateway/internal/provider/registry.go @@ -4,10 +4,17 @@ import ( "fmt" "sort" "sync" + "time" "llm-gateway/internal/config" ) +// ModelTimeouts holds per-model timeout overrides. +type ModelTimeouts struct { + RequestTimeout time.Duration + StreamingTimeout time.Duration +} + // Route maps a model to a specific provider with pricing. type Route struct { Provider Provider @@ -24,6 +31,7 @@ type Registry struct { balancers map[string]LoadBalancer aliases map[string]string // alias -> canonical name order []string // preserves config order (canonical names only) + timeouts map[string]*ModelTimeouts } func NewRegistry(cfg *config.Config) (*Registry, error) { @@ -46,6 +54,7 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error { balancers := make(map[string]LoadBalancer) aliases := make(map[string]string) order := make([]string, 0, len(cfg.Models)) + timeouts := make(map[string]*ModelTimeouts) for _, mc := range cfg.Models { var modelRoutes []Route @@ -74,6 +83,14 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error { // Load balancer balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing) + // Per-model timeouts + if mc.RequestTimeout > 0 || mc.StreamingTimeout > 0 { + timeouts[mc.Name] = &ModelTimeouts{ + RequestTimeout: mc.RequestTimeout, + StreamingTimeout: mc.StreamingTimeout, + } + } + // Register aliases for _, alias := range mc.Aliases { aliases[alias] = mc.Name @@ -85,6 +102,7 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error { r.balancers = balancers r.aliases = aliases r.order = order + r.timeouts = timeouts r.mu.Unlock() return nil @@ -135,6 +153,18 @@ func (r *Registry) ModelNames() []string { return names } +// ModelTimeoutsFor returns per-model timeout overrides, resolving aliases. Returns nil if none set. +func (r *Registry) ModelTimeoutsFor(model string) *ModelTimeouts { + r.mu.RLock() + defer r.mu.RUnlock() + + canonical := model + if alias, ok := r.aliases[model]; ok { + canonical = alias + } + return r.timeouts[canonical] +} + // RouteInfo exposes route details for dashboard display. type RouteInfo struct { ProviderName string `json:"provider_name"` diff --git a/llm-gateway/internal/provider/registry_test.go b/llm-gateway/internal/provider/registry_test.go index b04c45a..6c2a5cc 100644 --- a/llm-gateway/internal/provider/registry_test.go +++ b/llm-gateway/internal/provider/registry_test.go @@ -23,6 +23,10 @@ func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *Chat return nil, nil } +func (m *mockProvider) Embedding(_ context.Context, _ string, _ *EmbeddingRequest) (*EmbeddingResponse, error) { + return nil, nil +} + // newTestRegistry builds a Registry directly without going through config parsing. func newTestRegistry(models []testModel) *Registry { r := &Registry{ diff --git a/llm-gateway/internal/proxy/dedup.go b/llm-gateway/internal/proxy/dedup.go new file mode 100644 index 0000000..d55b6c5 --- /dev/null +++ b/llm-gateway/internal/proxy/dedup.go @@ -0,0 +1,107 @@ +package proxy + +import ( + "crypto/sha256" + "encoding/hex" + "sync" + "time" +) + +// inflight represents an in-progress deduplicated request. +type inflight struct { + done chan struct{} + result []byte + statusCode int + createdAt time.Time +} + +// Deduplicator coalesces identical concurrent non-streaming requests. +type Deduplicator struct { + mu sync.Mutex + flights map[string]*inflight + window time.Duration + done chan struct{} +} + +// NewDeduplicator creates a new request deduplicator. +func NewDeduplicator(window time.Duration) *Deduplicator { + if window == 0 { + window = 30 * time.Second + } + d := &Deduplicator{ + flights: make(map[string]*inflight), + window: window, + done: make(chan struct{}), + } + go d.cleanup() + return d +} + +// DedupKey computes a dedup key from model name and request body. +func DedupKey(model string, body []byte) string { + h := sha256.New() + h.Write([]byte(model)) + h.Write([]byte{0}) + h.Write(body) + return hex.EncodeToString(h.Sum(nil)) +} + +// TryJoin attempts to join an in-flight request. Returns the inflight entry and +// whether this caller is the leader (true) or a follower (false). +func (d *Deduplicator) TryJoin(key string) (*inflight, bool) { + d.mu.Lock() + defer d.mu.Unlock() + + if f, ok := d.flights[key]; ok { + return f, false // follower + } + + f := &inflight{ + done: make(chan struct{}), + createdAt: time.Now(), + } + d.flights[key] = f + return f, true // leader +} + +// Complete signals completion of a deduplicated request. +func (d *Deduplicator) Complete(key string, result []byte, statusCode int) { + d.mu.Lock() + f, ok := d.flights[key] + delete(d.flights, key) + d.mu.Unlock() + + if ok { + f.result = result + f.statusCode = statusCode + close(f.done) + } +} + +// Close stops the background cleanup goroutine. +func (d *Deduplicator) Close() { + close(d.done) +} + +// cleanup periodically removes stale in-flight entries. +func (d *Deduplicator) cleanup() { + ticker := time.NewTicker(d.window) + defer ticker.Stop() + + for { + select { + case <-d.done: + return + case <-ticker.C: + d.mu.Lock() + now := time.Now() + for key, f := range d.flights { + if now.Sub(f.createdAt) > d.window*2 { + delete(d.flights, key) + close(f.done) // unblock any waiting followers + } + } + d.mu.Unlock() + } + } +} diff --git a/llm-gateway/internal/proxy/dedup_test.go b/llm-gateway/internal/proxy/dedup_test.go new file mode 100644 index 0000000..900a53d --- /dev/null +++ b/llm-gateway/internal/proxy/dedup_test.go @@ -0,0 +1,74 @@ +package proxy + +import ( + "sync" + "testing" + "time" +) + +func TestDedupKey(t *testing.T) { + k1 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`)) + k2 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`)) + k3 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hello"}]}`)) + + if k1 != k2 { + t.Error("identical requests should produce the same key") + } + if k1 == k3 { + t.Error("different requests should produce different keys") + } +} + +func TestDeduplicator_LeaderFollower(t *testing.T) { + d := NewDeduplicator(5 * time.Second) + defer d.Close() + + key := DedupKey("gpt-4", []byte(`test`)) + + // First call is leader + f1, isLeader := d.TryJoin(key) + if !isLeader { + t.Fatal("first caller should be leader") + } + + // Second call with same key is follower + f2, isLeader := d.TryJoin(key) + if isLeader { + t.Fatal("second caller should be follower") + } + if f1 != f2 { + t.Fatal("follower should get same inflight entry") + } + + // Complete the request + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + <-f2.done + if string(f2.result) != "response" { + t.Error("follower should receive leader's result") + } + if f2.statusCode != 200 { + t.Error("follower should receive leader's status code") + } + }() + + d.Complete(key, []byte("response"), 200) + wg.Wait() +} + +func TestDeduplicator_DifferentKeys(t *testing.T) { + d := NewDeduplicator(5 * time.Second) + defer d.Close() + + _, isLeader1 := d.TryJoin("key1") + _, isLeader2 := d.TryJoin("key2") + + if !isLeader1 || !isLeader2 { + t.Error("different keys should both be leaders") + } + + d.Complete("key1", []byte("r1"), 200) + d.Complete("key2", []byte("r2"), 200) +} diff --git a/llm-gateway/internal/proxy/handler.go b/llm-gateway/internal/proxy/handler.go index cb6c619..1b24cc2 100644 --- a/llm-gateway/internal/proxy/handler.go +++ b/llm-gateway/internal/proxy/handler.go @@ -53,6 +53,7 @@ type Handler struct { cfg *config.Config healthTracker *provider.HealthTracker debugLogger *storage.DebugLogger + dedup *Deduplicator } func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler { @@ -70,6 +71,10 @@ func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) { h.debugLogger = dl } +func (h *Handler) SetDeduplicator(d *Deduplicator) { + h.dedup = d +} + func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20)) if err != nil { @@ -118,11 +123,47 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { } } + // Apply per-model timeout for non-streaming requests + modelTimeouts := h.registry.ModelTimeoutsFor(req.Model) + if req.Stream { - h.handleStream(w, r, &req, routes, tokenName, requestID) + h.handleStream(w, r, &req, routes, tokenName, requestID, modelTimeouts) return } + // Request deduplication for non-streaming requests + if h.dedup != nil { + dedupKey := DedupKey(req.Model, body) + flight, isLeader := h.dedup.TryJoin(dedupKey) + if !isLeader { + // Wait for the leader to complete + select { + case <-flight.done: + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Request-ID", requestID) + w.Header().Set("X-Dedup", "HIT") + w.WriteHeader(flight.statusCode) + w.Write(flight.result) + return + case <-r.Context().Done(): + writeError(w, http.StatusGatewayTimeout, "request cancelled while waiting for dedup") + return + } + } + // Leader: proceed normally, but capture response for followers + defer func() { + // If we haven't completed yet (e.g., panic), clean up + }() + h.handleNonStreamDedup(w, r, &req, routes, tokenName, body, requestID, modelTimeouts, dedupKey) + return + } + + if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 { + ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout) + defer cancel() + r = r.WithContext(ctx) + } + h.handleNonStream(w, r, &req, routes, tokenName, body, requestID) } @@ -233,6 +274,153 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p } } +// handleNonStreamDedup wraps handleNonStream to capture the response for dedup followers. +func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20)) + if err != nil { + writeError(w, http.StatusBadRequest, "failed to read request body") + return + } + + var req provider.EmbeddingRequest + if err := json.Unmarshal(body, &req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error()) + return + } + + if req.Model == "" { + writeError(w, http.StatusBadRequest, "model is required") + return + } + + routes, ok := h.registry.Lookup(req.Model) + if !ok { + writeError(w, http.StatusNotFound, "model not found: "+req.Model) + return + } + + routes = h.filterHealthyRoutes(routes) + tokenName := getTokenName(r.Context()) + requestID := middleware.GetReqID(r.Context()) + + var lastErr error + for i, route := range routes { + if i > 0 { + backoff := backoffDuration(i, h.cfg.Retry) + select { + case <-time.After(backoff): + case <-r.Context().Done(): + writeError(w, http.StatusGatewayTimeout, "request cancelled") + return + } + } + + start := time.Now() + resp, err := route.Provider.Embedding(r.Context(), route.ProviderModel, &req) + latency := time.Since(start).Milliseconds() + + if err != nil { + var pe *provider.ProviderError + if errors.As(err, &pe) && !pe.IsRetryable() { + h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) + h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error()) + if h.healthTracker != nil { + h.healthTracker.Record(route.Provider.Name(), latency, err) + } + w.Header().Set("X-Request-ID", requestID) + writeErrorRaw(w, pe.StatusCode, pe.Body) + return + } + lastErr = err + log.Printf("Provider %s embedding failed for %s: %v", route.Provider.Name(), req.Model, err) + h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error()) + if h.healthTracker != nil { + h.healthTracker.Record(route.Provider.Name(), latency, err) + } + continue + } + + if h.healthTracker != nil { + h.healthTracker.Record(route.Provider.Name(), latency, nil) + } + + promptTokens := 0 + if resp.Usage != nil { + promptTokens = resp.Usage.PromptTokens + } + cost := float64(promptTokens) / 1_000_000.0 * route.InputPrice + + h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, promptTokens, 0, cost) + h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, promptTokens, cost, latency, "success", "") + + resp.Model = req.Model + + respBytes, err := json.Marshal(resp) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to marshal response") + return + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Request-ID", requestID) + w.Write(respBytes) + return + } + + w.Header().Set("X-Request-ID", requestID) + if lastErr != nil { + writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error()) + } else { + writeError(w, http.StatusBadGateway, "all providers failed") + } +} + +func (h *Handler) logEmbeddingRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens int, cost float64, latencyMS int64, status, errMsg string) { + h.logger.Log(storage.RequestLog{ + RequestID: requestID, + Timestamp: time.Now().Unix(), + TokenName: tokenName, + Model: model, + Provider: providerName, + ProviderModel: providerModel, + InputTokens: inputTokens, + CostUSD: cost, + LatencyMS: latencyMS, + Status: status, + ErrorMessage: errMsg, + RequestType: "embedding", + }) +} + +func (h *Handler) handleNonStreamDedup(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string, modelTimeouts *provider.ModelTimeouts, dedupKey string) { + if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 { + ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout) + defer cancel() + r = r.WithContext(ctx) + } + + rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK} + h.handleNonStream(rec, r, req, routes, tokenName, rawBody, requestID) + h.dedup.Complete(dedupKey, rec.body, rec.statusCode) +} + +// responseRecorder captures the response for dedup. +type responseRecorder struct { + http.ResponseWriter + statusCode int + body []byte +} + +func (r *responseRecorder) WriteHeader(code int) { + r.statusCode = code + r.ResponseWriter.WriteHeader(code) +} + +func (r *responseRecorder) Write(b []byte) (int, error) { + r.body = append(r.body, b...) + return r.ResponseWriter.Write(b) +} + // filterHealthyRoutes removes providers with open circuit breakers. // If all are filtered out, returns original routes as fallback. func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route { diff --git a/llm-gateway/internal/proxy/models.go b/llm-gateway/internal/proxy/models.go index 56e1aba..184fbcc 100644 --- a/llm-gateway/internal/proxy/models.go +++ b/llm-gateway/internal/proxy/models.go @@ -5,27 +5,66 @@ import ( "net/http" "time" + "llm-gateway/internal/config" "llm-gateway/internal/provider" ) type ModelsHandler struct { - registry *provider.Registry + registry *provider.Registry + healthTracker *provider.HealthTracker + cfg *config.Config } -func NewModelsHandler(registry *provider.Registry) *ModelsHandler { - return &ModelsHandler{registry: registry} +func NewModelsHandler(registry *provider.Registry, healthTracker *provider.HealthTracker, cfg *config.Config) *ModelsHandler { + return &ModelsHandler{ + registry: registry, + healthTracker: healthTracker, + cfg: cfg, + } } func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) { - names := h.registry.ModelNames() - models := make([]map[string]any, len(names)) - for i, name := range names { - models[i] = map[string]any{ - "id": name, - "object": "model", - "created": time.Now().Unix(), - "owned_by": "llm-gateway", + allRoutes := h.registry.AllRoutes() + models := make([]map[string]any, 0, len(allRoutes)) + + for _, m := range allRoutes { + providers := make([]map[string]any, 0, len(m.Routes)) + for _, rt := range m.Routes { + healthy := true + if h.healthTracker != nil { + healthy = h.healthTracker.IsAvailable(rt.ProviderName) + } + providers = append(providers, map[string]any{ + "name": rt.ProviderName, + "model": rt.ProviderModel, + "input_price": rt.InputPrice, + "output_price": rt.OutputPrice, + "priority": rt.Priority, + "healthy": healthy, + }) } + + // Find load balancing strategy from config + loadBalancing := "first" + for _, mc := range h.cfg.Models { + if mc.Name == m.Name { + if mc.LoadBalancing != "" { + loadBalancing = mc.LoadBalancing + } + break + } + } + + models = append(models, map[string]any{ + "id": m.Name, + "object": "model", + "created": time.Now().Unix(), + "owned_by": "llm-gateway", + "providers": providers, + "provider_count": len(providers), + "load_balancing": loadBalancing, + "aliases": m.Aliases, + }) } w.Header().Set("Content-Type", "application/json") diff --git a/llm-gateway/internal/proxy/ratelimit.go b/llm-gateway/internal/proxy/ratelimit.go index 240bbe2..bc06174 100644 --- a/llm-gateway/internal/proxy/ratelimit.go +++ b/llm-gateway/internal/proxy/ratelimit.go @@ -8,12 +8,15 @@ import ( "time" "llm-gateway/internal/storage" + "llm-gateway/internal/webhook" ) type RateLimiter struct { - db *storage.DB - mu sync.Mutex - buckets map[string]*tokenBucket + db *storage.DB + mu sync.Mutex + buckets map[string]*tokenBucket + notifier *webhook.Notifier + budgetNotified sync.Map // tracks which token+budget combos have been notified } type tokenBucket struct { @@ -30,6 +33,11 @@ func NewRateLimiter(db *storage.DB) *RateLimiter { } } +// SetNotifier sets the webhook notifier for budget threshold alerts. +func (rl *RateLimiter) SetNotifier(n *webhook.Notifier) { + rl.notifier = n +} + func (rl *RateLimiter) Check(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { apiToken := getAPIToken(r.Context()) @@ -63,9 +71,24 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler { // Check daily budget if apiToken.DailyBudgetUSD > 0 { spent, err := rl.db.TodaySpend(tokenName) - if err == nil && spent >= apiToken.DailyBudgetUSD { - writeError(w, http.StatusTooManyRequests, "daily budget exceeded") - return + if err == nil { + if spent >= apiToken.DailyBudgetUSD { + writeError(w, http.StatusTooManyRequests, "daily budget exceeded") + return + } + rl.checkBudgetThreshold(tokenName, "daily", spent, apiToken.DailyBudgetUSD) + } + } + + // Check monthly budget + if apiToken.MonthlyBudgetUSD > 0 { + spent, err := rl.db.MonthSpend(tokenName) + if err == nil { + if spent >= apiToken.MonthlyBudgetUSD { + writeError(w, http.StatusTooManyRequests, "monthly budget exceeded") + return + } + rl.checkBudgetThreshold(tokenName, "monthly", spent, apiToken.MonthlyBudgetUSD) } } @@ -73,6 +96,30 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler { }) } +// checkBudgetThreshold fires a webhook notification when spend reaches 80% of budget. +func (rl *RateLimiter) checkBudgetThreshold(tokenName, budgetType string, spent, budget float64) { + if rl.notifier == nil || budget <= 0 { + return + } + if spent/budget < 0.8 { + return + } + key := tokenName + ":" + budgetType + if _, loaded := rl.budgetNotified.LoadOrStore(key, true); loaded { + return // already notified + } + rl.notifier.Notify(webhook.Event{ + Type: webhook.EventBudgetThreshold, + Data: map[string]any{ + "token": tokenName, + "budget_type": budgetType, + "spent": spent, + "budget": budget, + "percent": spent / budget * 100, + }, + }) +} + func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) { rl.mu.Lock() defer rl.mu.Unlock() diff --git a/llm-gateway/internal/proxy/ratelimit_test.go b/llm-gateway/internal/proxy/ratelimit_test.go index 8d1afb7..20cb311 100644 --- a/llm-gateway/internal/proxy/ratelimit_test.go +++ b/llm-gateway/internal/proxy/ratelimit_test.go @@ -45,11 +45,11 @@ var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { func TestRateLimiter_Allow(t *testing.T) { tests := []struct { - name string - rateLimitRPM int - numRequests int - wantAllowed int - wantDenied int + name string + rateLimitRPM int + numRequests int + wantAllowed int + wantDenied int }{ { name: "allows requests within limit", @@ -191,12 +191,12 @@ func TestRateLimiter_AllowReturnValues(t *testing.T) { func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) { tests := []struct { - name string - rateLimitRPM int - numRequests int - wantStatusCode int - wantLimitHeader string - wantRetryAfter bool + name string + rateLimitRPM int + numRequests int + wantStatusCode int + wantLimitHeader string + wantRetryAfter bool }{ { name: "sets rate limit headers on allowed request", diff --git a/llm-gateway/internal/proxy/stream.go b/llm-gateway/internal/proxy/stream.go index c1304d5..74067c3 100644 --- a/llm-gateway/internal/proxy/stream.go +++ b/llm-gateway/internal/proxy/stream.go @@ -13,7 +13,7 @@ import ( "llm-gateway/internal/provider" ) -func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) { +func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string, modelTimeouts *provider.ModelTimeouts) { flusher, ok := w.(http.Flusher) if !ok { writeError(w, http.StatusInternalServerError, "streaming not supported") @@ -60,11 +60,15 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov continue } - // Apply streaming timeout + // Apply streaming timeout (per-model override takes precedence) + streamingTimeout := h.cfg.Server.StreamingTimeout + if modelTimeouts != nil && modelTimeouts.StreamingTimeout > 0 { + streamingTimeout = modelTimeouts.StreamingTimeout + } var streamCtx context.Context var streamCancel context.CancelFunc - if h.cfg.Server.StreamingTimeout > 0 { - streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout) + if streamingTimeout > 0 { + streamCtx, streamCancel = context.WithTimeout(r.Context(), streamingTimeout) } else { streamCtx, streamCancel = context.WithCancel(r.Context()) } diff --git a/llm-gateway/internal/storage/db.go b/llm-gateway/internal/storage/db.go index 60d705a..f7fc480 100644 --- a/llm-gateway/internal/storage/db.go +++ b/llm-gateway/internal/storage/db.go @@ -102,6 +102,21 @@ func (db *DB) TodaySpend(tokenName string) (float64, error) { return total.Float64, nil } +// MonthSpend returns the total cost in USD for a given token this month. +func (db *DB) MonthSpend(tokenName string) (float64, error) { + now := time.Now() + startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()).Unix() + var total sql.NullFloat64 + err := db.QueryRow( + "SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?", + tokenName, startOfMonth, + ).Scan(&total) + if err != nil { + return 0, err + } + return total.Float64, nil +} + // TodaySpendAll returns today's spend for all tokens as a map. func (db *DB) TodaySpendAll() (map[string]float64, error) { startOfDay := time.Now().Truncate(24 * time.Hour).Unix() diff --git a/llm-gateway/internal/storage/logger.go b/llm-gateway/internal/storage/logger.go index d832dd4..ce54675 100644 --- a/llm-gateway/internal/storage/logger.go +++ b/llm-gateway/internal/storage/logger.go @@ -20,6 +20,7 @@ type RequestLog struct { ErrorMessage string Streaming bool Cached bool + RequestType string // "chat" or "embedding" } type AsyncLogger struct { @@ -94,8 +95,8 @@ func (l *AsyncLogger) flush(batch []RequestLog) { } stmt, err := tx.Prepare(`INSERT INTO request_logs - (request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) + (request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached, request_type) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) if err != nil { log.Printf("ERROR: preparing log statement: %v", err) tx.Rollback() @@ -112,10 +113,14 @@ func (l *AsyncLogger) flush(batch []RequestLog) { if r.Cached { cached = 1 } + reqType := r.RequestType + if reqType == "" { + reqType = "chat" + } _, err := stmt.Exec( r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel, r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS, - r.Status, r.ErrorMessage, streaming, cached, + r.Status, r.ErrorMessage, streaming, cached, reqType, ) if err != nil { log.Printf("ERROR: inserting log: %v", err) diff --git a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql b/llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql new file mode 100644 index 0000000..f288a97 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/009_token_monthly_budget.down.sql @@ -0,0 +1 @@ +ALTER TABLE api_tokens DROP COLUMN monthly_budget_usd; diff --git a/llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql b/llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql new file mode 100644 index 0000000..7d9dfbc --- /dev/null +++ b/llm-gateway/internal/storage/migrations/009_token_monthly_budget.up.sql @@ -0,0 +1 @@ +ALTER TABLE api_tokens ADD COLUMN monthly_budget_usd REAL NOT NULL DEFAULT 0; diff --git a/llm-gateway/internal/storage/migrations/010_request_type.down.sql b/llm-gateway/internal/storage/migrations/010_request_type.down.sql new file mode 100644 index 0000000..52e41cb --- /dev/null +++ b/llm-gateway/internal/storage/migrations/010_request_type.down.sql @@ -0,0 +1 @@ +ALTER TABLE request_logs DROP COLUMN request_type; diff --git a/llm-gateway/internal/storage/migrations/010_request_type.up.sql b/llm-gateway/internal/storage/migrations/010_request_type.up.sql new file mode 100644 index 0000000..20bca14 --- /dev/null +++ b/llm-gateway/internal/storage/migrations/010_request_type.up.sql @@ -0,0 +1 @@ +ALTER TABLE request_logs ADD COLUMN request_type TEXT NOT NULL DEFAULT 'chat'; diff --git a/llm-gateway/internal/webhook/webhook.go b/llm-gateway/internal/webhook/webhook.go new file mode 100644 index 0000000..098a0e2 --- /dev/null +++ b/llm-gateway/internal/webhook/webhook.go @@ -0,0 +1,123 @@ +package webhook + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "log" + "net/http" + "time" + + "llm-gateway/internal/config" +) + +// Event types. +const ( + EventCircuitBreakerOpen = "circuit_breaker.open" + EventCircuitBreakerClosed = "circuit_breaker.closed" + EventBudgetThreshold = "budget.threshold" +) + +// Event represents a webhook notification payload. +type Event struct { + Type string `json:"type"` + Timestamp time.Time `json:"timestamp"` + Data map[string]any `json:"data"` +} + +// Notifier sends webhook notifications. +type Notifier struct { + webhooks []config.WebhookConfig + ch chan Event + done chan struct{} + client *http.Client +} + +// NewNotifier creates a webhook notifier from config. +func NewNotifier(webhooks []config.WebhookConfig) *Notifier { + n := &Notifier{ + webhooks: webhooks, + ch: make(chan Event, 100), + done: make(chan struct{}), + client: &http.Client{Timeout: 10 * time.Second}, + } + go n.run() + return n +} + +// Notify queues an event for delivery (non-blocking). +func (n *Notifier) Notify(evt Event) { + if evt.Timestamp.IsZero() { + evt.Timestamp = time.Now() + } + select { + case n.ch <- evt: + default: + log.Printf("WARNING: webhook channel full, dropping event %s", evt.Type) + } +} + +// Close drains pending events and shuts down. +func (n *Notifier) Close() { + close(n.ch) + <-n.done +} + +func (n *Notifier) run() { + defer close(n.done) + for evt := range n.ch { + for _, wh := range n.webhooks { + if !n.shouldSend(wh, evt.Type) { + continue + } + n.send(wh, evt) + } + } +} + +func (n *Notifier) shouldSend(wh config.WebhookConfig, eventType string) bool { + if len(wh.Events) == 0 { + return true // no filter = send all + } + for _, e := range wh.Events { + if e == eventType { + return true + } + } + return false +} + +func (n *Notifier) send(wh config.WebhookConfig, evt Event) { + body, err := json.Marshal(evt) + if err != nil { + log.Printf("ERROR: webhook marshal: %v", err) + return + } + + req, err := http.NewRequest(http.MethodPost, wh.URL, bytes.NewReader(body)) + if err != nil { + log.Printf("ERROR: webhook request: %v", err) + return + } + req.Header.Set("Content-Type", "application/json") + + if wh.Secret != "" { + mac := hmac.New(sha256.New, []byte(wh.Secret)) + mac.Write(body) + sig := hex.EncodeToString(mac.Sum(nil)) + req.Header.Set("X-Webhook-Signature", "sha256="+sig) + } + + resp, err := n.client.Do(req) + if err != nil { + log.Printf("WARNING: webhook delivery to %s failed: %v", wh.URL, err) + return + } + resp.Body.Close() + + if resp.StatusCode >= 400 { + log.Printf("WARNING: webhook %s returned %d", wh.URL, resp.StatusCode) + } +}