feat(gateway): add monthly budget support and webhook notifications for circuit breaker and budget events

This commit is contained in:
Ray Andrew 2026-02-15 04:52:09 -06:00
parent 28a694744d
commit 291b8f4863
Signed by: rayandrew
SSH key fingerprint: SHA256:EUCV+qCSqkap8rR+p+zGjxHfKI06G0GJKgo1DIOniQY
31 changed files with 1005 additions and 124 deletions

View file

@ -25,6 +25,7 @@ import (
"llm-gateway/internal/provider" "llm-gateway/internal/provider"
"llm-gateway/internal/proxy" "llm-gateway/internal/proxy"
"llm-gateway/internal/storage" "llm-gateway/internal/storage"
"llm-gateway/internal/webhook"
) )
var version = "dev" var version = "dev"
@ -95,16 +96,41 @@ func main() {
// Provider health tracker // Provider health tracker
healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker) healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker)
// Webhook notifier
var notifier *webhook.Notifier
if len(cfg.Webhooks) > 0 {
notifier = webhook.NewNotifier(cfg.Webhooks)
defer notifier.Close()
log.Printf("Webhooks configured: %d endpoints", len(cfg.Webhooks))
// Wire health tracker state changes to webhook
healthTracker.OnStateChange = func(providerName string, from, to provider.CircuitState) {
eventType := webhook.EventCircuitBreakerOpen
if to == provider.CircuitClosed {
eventType = webhook.EventCircuitBreakerClosed
}
notifier.Notify(webhook.Event{
Type: eventType,
Data: map[string]any{
"provider": providerName,
"from": from.String(),
"to": to.String(),
},
})
}
}
// Auth store (static tokens checked in-memory, not seeded to DB) // Auth store (static tokens checked in-memory, not seeded to DB)
var staticTokens []auth.StaticToken var staticTokens []auth.StaticToken
for _, t := range cfg.Tokens { for _, t := range cfg.Tokens {
if t.Key != "" { if t.Key != "" {
staticTokens = append(staticTokens, auth.StaticToken{ staticTokens = append(staticTokens, auth.StaticToken{
Name: t.Name, Name: t.Name,
Key: t.Key, Key: t.Key,
RateLimitRPM: t.RateLimitRPM, RateLimitRPM: t.RateLimitRPM,
DailyBudgetUSD: t.DailyBudgetUSD, DailyBudgetUSD: t.DailyBudgetUSD,
MaxConcurrent: t.MaxConcurrent, MonthlyBudgetUSD: t.MonthlyBudgetUSD,
MaxConcurrent: t.MaxConcurrent,
}) })
log.Printf("Loaded static token: %s", t.Name) log.Printf("Loaded static token: %s", t.Name)
} }
@ -133,14 +159,27 @@ func main() {
// Handlers // Handlers
proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker) proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker)
proxyHandler.SetDebugLogger(debugLogger) proxyHandler.SetDebugLogger(debugLogger)
modelsHandler := proxy.NewModelsHandler(registry)
// Request deduplication
if cfg.Dedup.Enabled {
dedup := proxy.NewDeduplicator(cfg.Dedup.Window)
defer dedup.Close()
proxyHandler.SetDeduplicator(dedup)
log.Printf("Request deduplication enabled (window: %v)", cfg.Dedup.Window)
}
modelsHandler := proxy.NewModelsHandler(registry, healthTracker, cfg)
proxyAuth := proxy.NewAuthMiddleware(authStore) proxyAuth := proxy.NewAuthMiddleware(authStore)
rateLimiter := proxy.NewRateLimiter(db) rateLimiter := proxy.NewRateLimiter(db)
if notifier != nil {
rateLimiter.SetNotifier(notifier)
}
concurrencyLimiter := proxy.NewConcurrencyLimiter() concurrencyLimiter := proxy.NewConcurrencyLimiter()
statsAPI := dashboard.NewStatsAPI(db, authStore) statsAPI := dashboard.NewStatsAPI(db, authStore)
statsAPI.SetHealthTracker(healthTracker) statsAPI.SetHealthTracker(healthTracker)
statsAPI.SetAuditLogger(auditLogger) statsAPI.SetAuditLogger(auditLogger)
statsAPI.SetDebugLogger(debugLogger) statsAPI.SetDebugLogger(debugLogger)
statsAPI.SetConfigPath(*configPath)
if c != nil { if c != nil {
statsAPI.SetCache(c) statsAPI.SetCache(c)
} }
@ -196,6 +235,7 @@ func main() {
r.Use(rateLimiter.Check) r.Use(rateLimiter.Check)
r.Use(concurrencyLimiter.Check) r.Use(concurrencyLimiter.Check)
r.Post("/v1/chat/completions", proxyHandler.ChatCompletions) r.Post("/v1/chat/completions", proxyHandler.ChatCompletions)
r.Post("/v1/embeddings", proxyHandler.Embeddings)
r.Get("/v1/models", modelsHandler.ListModels) r.Get("/v1/models", modelsHandler.ListModels)
}) })
@ -266,7 +306,7 @@ func main() {
r.Get("/api/export/logs", exportHandler.ExportLogs) r.Get("/api/export/logs", exportHandler.ExportLogs)
r.Get("/api/export/stats", exportHandler.ExportStats) r.Get("/api/export/stats", exportHandler.ExportStats)
// Admin-only: user management, audit, debug // Admin-only: user management, audit, debug, config validation
r.Group(func(r chi.Router) { r.Group(func(r chi.Router) {
r.Use(authMiddleware.RequireAdmin) r.Use(authMiddleware.RequireAdmin)
r.Get("/api/auth/users", authHandlers.ListUsers) r.Get("/api/auth/users", authHandlers.ListUsers)
@ -276,6 +316,9 @@ func main() {
// Audit log // Audit log
r.Get("/api/stats/audit", statsAPI.AuditLogs) r.Get("/api/stats/audit", statsAPI.AuditLogs)
// Config validation
r.Get("/api/config/validate", statsAPI.ValidateConfig)
// Debug logging // Debug logging
r.Post("/api/debug/toggle", statsAPI.DebugToggle) r.Post("/api/debug/toggle", statsAPI.DebugToggle)
r.Get("/api/debug/status", statsAPI.DebugStatus) r.Get("/api/debug/status", statsAPI.DebugStatus)
@ -332,11 +375,12 @@ func main() {
for _, t := range newCfg.Tokens { for _, t := range newCfg.Tokens {
if t.Key != "" { if t.Key != "" {
newStaticTokens = append(newStaticTokens, auth.StaticToken{ newStaticTokens = append(newStaticTokens, auth.StaticToken{
Name: t.Name, Name: t.Name,
Key: t.Key, Key: t.Key,
RateLimitRPM: t.RateLimitRPM, RateLimitRPM: t.RateLimitRPM,
DailyBudgetUSD: t.DailyBudgetUSD, DailyBudgetUSD: t.DailyBudgetUSD,
MaxConcurrent: t.MaxConcurrent, MonthlyBudgetUSD: t.MonthlyBudgetUSD,
MaxConcurrent: t.MaxConcurrent,
}) })
} }
} }

View file

@ -4,6 +4,7 @@ go 1.24.0
require ( require (
github.com/go-chi/chi/v5 v5.2.5 github.com/go-chi/chi/v5 v5.2.5
github.com/go-chi/cors v1.2.2
github.com/golang-migrate/migrate/v4 v4.19.1 github.com/golang-migrate/migrate/v4 v4.19.1
github.com/pquerna/otp v1.5.0 github.com/pquerna/otp v1.5.0
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
@ -19,7 +20,6 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/go-chi/cors v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect

View file

@ -31,25 +31,27 @@ type Session struct {
} }
type APIToken struct { type APIToken struct {
ID int64 `json:"id"` ID int64 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
KeyPrefix string `json:"key_prefix"` KeyPrefix string `json:"key_prefix"`
KeyHash string `json:"-"` KeyHash string `json:"-"`
UserID int64 `json:"user_id"` UserID int64 `json:"user_id"`
RateLimitRPM int `json:"rate_limit_rpm"` RateLimitRPM int `json:"rate_limit_rpm"`
DailyBudgetUSD float64 `json:"daily_budget_usd"` DailyBudgetUSD float64 `json:"daily_budget_usd"`
MaxConcurrent int `json:"max_concurrent"` MonthlyBudgetUSD float64 `json:"monthly_budget_usd"`
CreatedAt int64 `json:"created_at"` MaxConcurrent int `json:"max_concurrent"`
LastUsedAt int64 `json:"last_used_at"` CreatedAt int64 `json:"created_at"`
LastUsedAt int64 `json:"last_used_at"`
} }
// StaticToken represents a token defined in config (checked in-memory, never stored in DB). // StaticToken represents a token defined in config (checked in-memory, never stored in DB).
type StaticToken struct { type StaticToken struct {
Name string Name string
Key string Key string
RateLimitRPM int RateLimitRPM int
DailyBudgetUSD float64 DailyBudgetUSD float64
MaxConcurrent int MonthlyBudgetUSD float64
MaxConcurrent int
} }
type Store struct { type Store struct {
@ -289,12 +291,13 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
prefix = prefix[:11] prefix = prefix[:11]
} }
return &APIToken{ return &APIToken{
ID: -1, // sentinel: static token ID: -1, // sentinel: static token
Name: st.Name, Name: st.Name,
KeyPrefix: prefix, KeyPrefix: prefix,
RateLimitRPM: st.RateLimitRPM, RateLimitRPM: st.RateLimitRPM,
DailyBudgetUSD: st.DailyBudgetUSD, DailyBudgetUSD: st.DailyBudgetUSD,
MaxConcurrent: st.MaxConcurrent, MonthlyBudgetUSD: st.MonthlyBudgetUSD,
MaxConcurrent: st.MaxConcurrent,
}, nil }, nil
} }
} }
@ -305,9 +308,9 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
var t APIToken var t APIToken
err := s.db.QueryRow( err := s.db.QueryRow(
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?", "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
keyHash, keyHash,
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -323,12 +326,13 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
prefix = prefix[:11] prefix = prefix[:11]
} }
tokens = append(tokens, APIToken{ tokens = append(tokens, APIToken{
ID: -1, // sentinel: static token ID: -1, // sentinel: static token
Name: st.Name, Name: st.Name,
KeyPrefix: prefix, KeyPrefix: prefix,
RateLimitRPM: st.RateLimitRPM, RateLimitRPM: st.RateLimitRPM,
DailyBudgetUSD: st.DailyBudgetUSD, DailyBudgetUSD: st.DailyBudgetUSD,
MaxConcurrent: st.MaxConcurrent, MonthlyBudgetUSD: st.MonthlyBudgetUSD,
MaxConcurrent: st.MaxConcurrent,
}) })
} }
@ -336,9 +340,9 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if userID == 0 { if userID == 0 {
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id") rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id")
} else { } else {
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID) rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
} }
if err != nil { if err != nil {
return tokens, nil return tokens, nil
@ -347,7 +351,7 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
for rows.Next() { for rows.Next() {
var t APIToken var t APIToken
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil { if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil {
return tokens, nil return tokens, nil
} }
tokens = append(tokens, t) tokens = append(tokens, t)
@ -363,9 +367,9 @@ func (s *Store) DeleteAPIToken(id int64) error {
func (s *Store) GetAPIToken(id int64) (*APIToken, error) { func (s *Store) GetAPIToken(id int64) (*APIToken, error) {
var t APIToken var t APIToken
err := s.db.QueryRow( err := s.db.QueryRow(
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?", "SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?",
id, id,
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt) ).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -42,6 +42,7 @@ func setupTestDB(t *testing.T) *sql.DB {
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
rate_limit_rpm INTEGER DEFAULT 0, rate_limit_rpm INTEGER DEFAULT 0,
daily_budget_usd REAL DEFAULT 0, daily_budget_usd REAL DEFAULT 0,
monthly_budget_usd REAL DEFAULT 0,
max_concurrent INTEGER DEFAULT 0, max_concurrent INTEGER DEFAULT 0,
created_at INTEGER NOT NULL, created_at INTEGER NOT NULL,
last_used_at INTEGER DEFAULT 0 last_used_at INTEGER DEFAULT 0

View file

@ -20,11 +20,24 @@ type Config struct {
Retry RetryConfig `yaml:"retry"` Retry RetryConfig `yaml:"retry"`
Debug DebugConfig `yaml:"debug"` Debug DebugConfig `yaml:"debug"`
CORS CORSConfig `yaml:"cors"` CORS CORSConfig `yaml:"cors"`
Dedup DedupConfig `yaml:"dedup"`
Webhooks []WebhookConfig `yaml:"webhooks"`
Providers []ProviderConfig `yaml:"providers"` Providers []ProviderConfig `yaml:"providers"`
Models []ModelConfig `yaml:"models"` Models []ModelConfig `yaml:"models"`
Tokens []TokenConfig `yaml:"tokens"` Tokens []TokenConfig `yaml:"tokens"`
} }
type DedupConfig struct {
Enabled bool `yaml:"enabled"`
Window time.Duration `yaml:"window"` // max time to wait for dedup result
}
type WebhookConfig struct {
URL string `yaml:"url"`
Events []string `yaml:"events"` // event types to send
Secret string `yaml:"secret"` // optional HMAC secret
}
type PricingLookupConfig struct { type PricingLookupConfig struct {
URL string `yaml:"url"` URL string `yaml:"url"`
RefreshInterval time.Duration `yaml:"refresh_interval"` RefreshInterval time.Duration `yaml:"refresh_interval"`
@ -36,20 +49,21 @@ type DefaultAdminConfig struct {
} }
type TokenConfig struct { type TokenConfig struct {
Name string `yaml:"name"` Name string `yaml:"name"`
Key string `yaml:"key"` Key string `yaml:"key"`
RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited
DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited
MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited MonthlyBudgetUSD float64 `yaml:"monthly_budget_usd"` // 0 = unlimited
MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited
} }
type ServerConfig struct { type ServerConfig struct {
Listen string `yaml:"listen"` Listen string `yaml:"listen"`
RequestTimeout time.Duration `yaml:"request_timeout"` RequestTimeout time.Duration `yaml:"request_timeout"`
StreamingTimeout time.Duration `yaml:"streaming_timeout"` StreamingTimeout time.Duration `yaml:"streaming_timeout"`
MaxRequestBodyMB int `yaml:"max_request_body_mb"` MaxRequestBodyMB int `yaml:"max_request_body_mb"`
SessionSecret string `yaml:"session_secret"` SessionSecret string `yaml:"session_secret"`
DefaultAdmin DefaultAdminConfig `yaml:"default_admin"` DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
} }
type CircuitBreakerConfig struct { type CircuitBreakerConfig struct {
@ -66,10 +80,10 @@ type RetryConfig struct {
} }
type DebugConfig struct { type DebugConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies) MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies)
RetentionDays int `yaml:"retention_days"` RetentionDays int `yaml:"retention_days"`
DataDir string `yaml:"data_dir"` DataDir string `yaml:"data_dir"`
} }
type CORSConfig struct { type CORSConfig struct {
@ -100,10 +114,12 @@ type ProviderConfig struct {
} }
type ModelConfig struct { type ModelConfig struct {
Name string `yaml:"name"` Name string `yaml:"name"`
Aliases []string `yaml:"aliases"` Aliases []string `yaml:"aliases"`
Routes []RouteConfig `yaml:"routes"` Routes []RouteConfig `yaml:"routes"`
LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost
RequestTimeout time.Duration `yaml:"request_timeout"` // per-model override; 0 = use server default
StreamingTimeout time.Duration `yaml:"streaming_timeout"` // per-model override; 0 = use server default
} }
type RouteConfig struct { type RouteConfig struct {
@ -131,14 +147,15 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("parsing config: %w", err) return nil, fmt.Errorf("parsing config: %w", err)
} }
if err := cfg.validate(); err != nil { if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("validating config: %w", err) return nil, fmt.Errorf("validating config: %w", err)
} }
return &cfg, nil return &cfg, nil
} }
func (c *Config) validate() error { // Validate checks the config for correctness and applies defaults.
func (c *Config) Validate() error {
if c.Server.Listen == "" { if c.Server.Listen == "" {
c.Server.Listen = "0.0.0.0:3000" c.Server.Listen = "0.0.0.0:3000"
} }
@ -201,6 +218,11 @@ func (c *Config) validate() error {
c.CORS.MaxAge = 300 c.CORS.MaxAge = 300
} }
// Dedup defaults
if c.Dedup.Window == 0 {
c.Dedup.Window = 30 * time.Second
}
if len(c.Providers) == 0 { if len(c.Providers) == 0 {
return fmt.Errorf("at least one provider is required") return fmt.Errorf("at least one provider is required")
} }
@ -266,6 +288,19 @@ func (c *Config) validate() error {
return nil return nil
} }
// ValidateBytes parses raw YAML and returns a list of validation errors.
func ValidateBytes(data []byte) []string {
expanded := os.ExpandEnv(string(data))
var cfg Config
if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil {
return []string{"parse error: " + err.Error()}
}
if err := cfg.Validate(); err != nil {
return []string{err.Error()}
}
return nil
}
// ProviderByName returns the provider config by name. // ProviderByName returns the provider config by name.
func (c *Config) ProviderByName(name string) *ProviderConfig { func (c *Config) ProviderByName(name string) *ProviderConfig {
for i := range c.Providers { for i := range c.Providers {

View file

@ -735,4 +735,3 @@ models:
t.Errorf("error = %q, want to contain api_key validation message", err.Error()) t.Errorf("error = %q, want to contain api_key validation message", err.Error())
} }
} }

View file

@ -3,6 +3,7 @@ package dashboard
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"os"
"sort" "sort"
"strconv" "strconv"
"time" "time"
@ -11,6 +12,7 @@ import (
"llm-gateway/internal/auth" "llm-gateway/internal/auth"
"llm-gateway/internal/cache" "llm-gateway/internal/cache"
"llm-gateway/internal/config"
"llm-gateway/internal/provider" "llm-gateway/internal/provider"
"llm-gateway/internal/storage" "llm-gateway/internal/storage"
) )
@ -109,6 +111,7 @@ type StatsAPI struct {
cache *cache.Cache cache *cache.Cache
auditLogger *storage.AuditLogger auditLogger *storage.AuditLogger
debugLogger *storage.DebugLogger debugLogger *storage.DebugLogger
configPath string
} }
func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI { func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI {
@ -135,6 +138,11 @@ func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) {
s.debugLogger = dl s.debugLogger = dl
} }
// SetConfigPath sets the config file path for validation.
func (s *StatsAPI) SetConfigPath(path string) {
s.configPath = path
}
// TokenNamesForUser returns the token names that belong to the user. // TokenNamesForUser returns the token names that belong to the user.
// Admins get nil (no filter), non-admins get their token names. // Admins get nil (no filter), non-admins get their token names.
func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string { func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string {
@ -712,6 +720,27 @@ func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) {
writeJSON(w, entry) writeJSON(w, entry)
} }
// ValidateConfig validates the config file at the stored path.
func (s *StatsAPI) ValidateConfig(w http.ResponseWriter, r *http.Request) {
if s.configPath == "" {
w.WriteHeader(http.StatusInternalServerError)
writeJSON(w, map[string]any{"valid": false, "errors": []string{"config path not set"}})
return
}
data, err := os.ReadFile(s.configPath)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
writeJSON(w, map[string]any{"valid": false, "errors": []string{"failed to read config: " + err.Error()}})
return
}
errs := config.ValidateBytes(data)
if len(errs) > 0 {
writeJSON(w, map[string]any{"valid": false, "errors": errs})
return
}
writeJSON(w, map[string]any{"valid": true, "errors": []string{}})
}
func writeJSON(w http.ResponseWriter, v any) { func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(v) json.NewEncoder(w).Encode(v)

View file

@ -127,9 +127,9 @@ type PageData struct {
// Models routing page data // Models routing page data
ModelRoutes []provider.ModelRouteInfo ModelRoutes []provider.ModelRouteInfo
// Audit page data // Audit page data
AuditResult *storage.AuditQueryResult AuditResult *storage.AuditQueryResult
AuditFilterActions []string AuditFilterActions []string
FilterAction string FilterAction string
// Debug page data // Debug page data
DebugResult *storage.DebugLogQueryResult DebugResult *storage.DebugLogQueryResult
DebugEnabled bool DebugEnabled bool
@ -275,6 +275,10 @@ func (d *Dashboard) ModelsPage(w http.ResponseWriter, r *http.Request) {
data.ModelRoutes = d.registry.AllRoutes() data.ModelRoutes = d.registry.AllRoutes()
} }
if d.statsAPI.healthTracker != nil {
data.ProviderHealth = d.statsAPI.healthTracker.Status()
}
d.renderDashboardPage(w, r, "partials/models-page.html", data) d.renderDashboardPage(w, r, "partials/models-page.html", data)
} }

View file

@ -6,7 +6,7 @@
{{if .ModelRoutes}} {{if .ModelRoutes}}
{{range .ModelRoutes}} {{range .ModelRoutes}}
<div class="section"> <div class="section">
<h2>{{.Name}}</h2> <h2>{{.Name}}{{if .Aliases}} <span style="font-size:0.75rem;color:var(--text-muted);font-weight:400;">aliases: {{range $i, $a := .Aliases}}{{if $i}}, {{end}}{{$a}}{{end}}</span>{{end}}</h2>
<table> <table>
<thead> <thead>
<tr> <tr>
@ -15,9 +15,11 @@
<th>Priority</th> <th>Priority</th>
<th>Input Price (per 1M)</th> <th>Input Price (per 1M)</th>
<th>Output Price (per 1M)</th> <th>Output Price (per 1M)</th>
<th>Health</th>
</tr> </tr>
</thead> </thead>
<tbody> <tbody>
{{$health := $.ProviderHealth}}
{{range .Routes}} {{range .Routes}}
<tr> <tr>
<td>{{.ProviderName}}</td> <td>{{.ProviderName}}</td>
@ -25,6 +27,18 @@
<td><span class="badge badge-priority">{{.Priority}}</span></td> <td><span class="badge badge-priority">{{.Priority}}</span></td>
<td>{{formatPrice .InputPrice}}</td> <td>{{formatPrice .InputPrice}}</td>
<td>{{formatPrice .OutputPrice}}</td> <td>{{formatPrice .OutputPrice}}</td>
<td>
{{$pname := .ProviderName}}
{{range $health}}
{{if eq .Provider $pname}}
{{if eq .Status "healthy"}}<span class="badge" style="background:#166534;color:#4ade80;">healthy</span>
{{else if eq .Status "degraded"}}<span class="badge" style="background:#92400e;color:#fbbf24;">degraded</span>
{{else}}<span class="badge" style="background:#991b1b;color:#f87171;">down</span>
{{end}}
{{end}}
{{end}}
{{if not $health}}<span style="color:var(--text-muted);">-</span>{{end}}
</td>
</tr> </tr>
{{end}} {{end}}
</tbody> </tbody>

View file

@ -70,6 +70,15 @@
</div> </div>
</div> </div>
{{if .User.IsAdmin}}
<div class="section">
<h2>Config Validation</h2>
<p style="color:#94a3b8;font-size:0.85rem;margin-bottom:12px;">Validate the current gateway configuration file for errors.</p>
<button class="btn btn-sm btn-primary" onclick="validateConfig()">Validate Config</button>
<div id="config-validation-result" style="margin-top:12px;"></div>
</div>
{{end}}
<script src="https://cdn.jsdelivr.net/npm/qrious@4.0.2/dist/qrious.min.js"></script> <script src="https://cdn.jsdelivr.net/npm/qrious@4.0.2/dist/qrious.min.js"></script>
<script> <script>
function showMsg(id, msg, isError) { function showMsg(id, msg, isError) {
@ -167,5 +176,20 @@ async function disableTOTP() {
htmx.ajax('GET', '/settings', {target: '#content', swap: 'innerHTML'}); htmx.ajax('GET', '/settings', {target: '#content', swap: 'innerHTML'});
} catch (e) { alert(e.message); } } catch (e) { alert(e.message); }
} }
async function validateConfig() {
var el = document.getElementById('config-validation-result');
el.innerHTML = '<span style="color:#94a3b8;">Validating...</span>';
try {
var resp = await fetch('/api/config/validate', { credentials: 'same-origin' });
var data = await resp.json();
if (data.valid) {
el.innerHTML = '<div class="success-msg">Configuration is valid.</div>';
} else {
var errs = (data.errors||[]).map(function(e) { return '<li>' + e + '</li>'; }).join('');
el.innerHTML = '<div class="error-msg">Configuration errors:<ul style="margin:4px 0 0 16px;">' + errs + '</ul></div>';
}
} catch (e) { el.innerHTML = '<div class="error-msg">' + e.message + '</div>'; }
}
</script> </script>
{{end}} {{end}}

View file

@ -19,7 +19,10 @@
<td>{{.Name}}</td> <td>{{.Name}}</td>
<td><code>{{.KeyPrefix}}...</code></td> <td><code>{{.KeyPrefix}}...</code></td>
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td> <td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
<td>{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}{{else}}unlimited{{end}}</td> <td>
{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
{{if gt .MonthlyBudgetUSD 0.0}}<br>${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
</td>
<td> <td>
{{$spend := index $.TokenSpend .Name}} {{$spend := index $.TokenSpend .Name}}
{{if gt .DailyBudgetUSD 0.0}} {{if gt .DailyBudgetUSD 0.0}}
@ -49,7 +52,10 @@
<td>{{.Name}}</td> <td>{{.Name}}</td>
<td><code>{{.KeyPrefix}}...</code></td> <td><code>{{.KeyPrefix}}...</code></td>
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td> <td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
<td>{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}{{else}}unlimited{{end}}</td> <td>
{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
{{if gt .MonthlyBudgetUSD 0.0}}<br>${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
</td>
<td> <td>
{{$spend := index $.TokenSpend .Name}} {{$spend := index $.TokenSpend .Name}}
{{if gt .DailyBudgetUSD 0.0}} {{if gt .DailyBudgetUSD 0.0}}

View file

@ -46,8 +46,8 @@ type HealthEvent struct {
// ProviderHealth is the computed health status for a provider. // ProviderHealth is the computed health status for a provider.
type ProviderHealth struct { type ProviderHealth struct {
Provider string `json:"provider"` Provider string `json:"provider"`
Status string `json:"status"` // healthy, degraded, down Status string `json:"status"` // healthy, degraded, down
ErrorRate float64 `json:"error_rate"` ErrorRate float64 `json:"error_rate"`
AvgLatency float64 `json:"avg_latency_ms"` AvgLatency float64 `json:"avg_latency_ms"`
Total int `json:"total"` Total int `json:"total"`
@ -57,11 +57,12 @@ type ProviderHealth struct {
// HealthTracker tracks per-provider health using a sliding window. // HealthTracker tracks per-provider health using a sliding window.
type HealthTracker struct { type HealthTracker struct {
mu sync.RWMutex mu sync.RWMutex
windows map[string][]HealthEvent windows map[string][]HealthEvent
windowDu time.Duration windowDu time.Duration
circuits map[string]*ProviderCircuit circuits map[string]*ProviderCircuit
cbConfig config.CircuitBreakerConfig cbConfig config.CircuitBreakerConfig
OnStateChange func(provider string, from, to CircuitState)
} }
// NewHealthTracker creates a health tracker with the given window duration. // NewHealthTracker creates a health tracker with the given window duration.
@ -135,6 +136,8 @@ func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
h.circuits[providerName] = circuit h.circuits[providerName] = circuit
} }
prevState := circuit.State
switch circuit.State { switch circuit.State {
case CircuitClosed: case CircuitClosed:
// Check if error threshold exceeded // Check if error threshold exceeded
@ -164,6 +167,13 @@ func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
circuit.OpenedAt = time.Now() circuit.OpenedAt = time.Now()
} }
} }
if circuit.State != prevState && h.OnStateChange != nil {
cb := h.OnStateChange
from, to := prevState, circuit.State
// Call outside lock to avoid deadlocks
go cb(providerName, from, to)
}
} }
// errorRateUnlocked computes error rate within window. Must be called with lock held. // errorRateUnlocked computes error rate within window. Must be called with lock held.

View file

@ -47,13 +47,13 @@ func TestHealthTracker_Record(t *testing.T) {
func TestHealthTracker_Status(t *testing.T) { func TestHealthTracker_Status(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
successCount int successCount int
errorCount int errorCount int
wantStatus string wantStatus string
wantErrorRate float64 wantErrorRate float64
wantTotal int wantTotal int
wantErrors int wantErrors int
}{ }{
{ {
name: "healthy - no errors", name: "healthy - no errors",

View file

@ -108,6 +108,48 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string,
return resp.Body, nil return resp.Body, nil
} }
func (p *OpenAIProvider) Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error) {
reqCopy := *req
reqCopy.Model = model
body, err := json.Marshal(reqCopy)
if err != nil {
return nil, fmt.Errorf("marshaling request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/embeddings", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}
p.setHeaders(httpReq)
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("sending request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, &ProviderError{
StatusCode: resp.StatusCode,
Body: string(respBody),
Provider: p.name,
}
}
var embResp EmbeddingResponse
if err := json.Unmarshal(respBody, &embResp); err != nil {
return nil, fmt.Errorf("unmarshaling response: %w", err)
}
return &embResp, nil
}
func (p *OpenAIProvider) setHeaders(req *http.Request) { func (p *OpenAIProvider) setHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey) req.Header.Set("Authorization", "Bearer "+p.apiKey)

View file

@ -52,9 +52,38 @@ type Usage struct {
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
} }
// EmbeddingRequest is the OpenAI-compatible embedding request.
type EmbeddingRequest struct {
Model string `json:"model"`
Input any `json:"input"` // string or []string
EncodingFormat string `json:"encoding_format,omitempty"`
}
// EmbeddingResponse is the OpenAI-compatible embedding response.
type EmbeddingResponse struct {
Object string `json:"object"`
Data []EmbeddingData `json:"data"`
Model string `json:"model"`
Usage *EmbeddingUsage `json:"usage,omitempty"`
}
// EmbeddingData holds a single embedding vector.
type EmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
// EmbeddingUsage reports token usage for embeddings.
type EmbeddingUsage struct {
PromptTokens int `json:"prompt_tokens"`
TotalTokens int `json:"total_tokens"`
}
// Provider sends requests to an LLM API. // Provider sends requests to an LLM API.
type Provider interface { type Provider interface {
Name() string Name() string
ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error) ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error)
ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error) ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error)
Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error)
} }

View file

@ -4,10 +4,17 @@ import (
"fmt" "fmt"
"sort" "sort"
"sync" "sync"
"time"
"llm-gateway/internal/config" "llm-gateway/internal/config"
) )
// ModelTimeouts holds per-model timeout overrides.
type ModelTimeouts struct {
RequestTimeout time.Duration
StreamingTimeout time.Duration
}
// Route maps a model to a specific provider with pricing. // Route maps a model to a specific provider with pricing.
type Route struct { type Route struct {
Provider Provider Provider Provider
@ -24,6 +31,7 @@ type Registry struct {
balancers map[string]LoadBalancer balancers map[string]LoadBalancer
aliases map[string]string // alias -> canonical name aliases map[string]string // alias -> canonical name
order []string // preserves config order (canonical names only) order []string // preserves config order (canonical names only)
timeouts map[string]*ModelTimeouts
} }
func NewRegistry(cfg *config.Config) (*Registry, error) { func NewRegistry(cfg *config.Config) (*Registry, error) {
@ -46,6 +54,7 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
balancers := make(map[string]LoadBalancer) balancers := make(map[string]LoadBalancer)
aliases := make(map[string]string) aliases := make(map[string]string)
order := make([]string, 0, len(cfg.Models)) order := make([]string, 0, len(cfg.Models))
timeouts := make(map[string]*ModelTimeouts)
for _, mc := range cfg.Models { for _, mc := range cfg.Models {
var modelRoutes []Route var modelRoutes []Route
@ -74,6 +83,14 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
// Load balancer // Load balancer
balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing) balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing)
// Per-model timeouts
if mc.RequestTimeout > 0 || mc.StreamingTimeout > 0 {
timeouts[mc.Name] = &ModelTimeouts{
RequestTimeout: mc.RequestTimeout,
StreamingTimeout: mc.StreamingTimeout,
}
}
// Register aliases // Register aliases
for _, alias := range mc.Aliases { for _, alias := range mc.Aliases {
aliases[alias] = mc.Name aliases[alias] = mc.Name
@ -85,6 +102,7 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
r.balancers = balancers r.balancers = balancers
r.aliases = aliases r.aliases = aliases
r.order = order r.order = order
r.timeouts = timeouts
r.mu.Unlock() r.mu.Unlock()
return nil return nil
@ -135,6 +153,18 @@ func (r *Registry) ModelNames() []string {
return names return names
} }
// ModelTimeoutsFor returns per-model timeout overrides, resolving aliases. Returns nil if none set.
func (r *Registry) ModelTimeoutsFor(model string) *ModelTimeouts {
r.mu.RLock()
defer r.mu.RUnlock()
canonical := model
if alias, ok := r.aliases[model]; ok {
canonical = alias
}
return r.timeouts[canonical]
}
// RouteInfo exposes route details for dashboard display. // RouteInfo exposes route details for dashboard display.
type RouteInfo struct { type RouteInfo struct {
ProviderName string `json:"provider_name"` ProviderName string `json:"provider_name"`

View file

@ -23,6 +23,10 @@ func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *Chat
return nil, nil return nil, nil
} }
func (m *mockProvider) Embedding(_ context.Context, _ string, _ *EmbeddingRequest) (*EmbeddingResponse, error) {
return nil, nil
}
// newTestRegistry builds a Registry directly without going through config parsing. // newTestRegistry builds a Registry directly without going through config parsing.
func newTestRegistry(models []testModel) *Registry { func newTestRegistry(models []testModel) *Registry {
r := &Registry{ r := &Registry{

View file

@ -0,0 +1,107 @@
package proxy
import (
"crypto/sha256"
"encoding/hex"
"sync"
"time"
)
// inflight represents an in-progress deduplicated request.
type inflight struct {
done chan struct{}
result []byte
statusCode int
createdAt time.Time
}
// Deduplicator coalesces identical concurrent non-streaming requests.
type Deduplicator struct {
mu sync.Mutex
flights map[string]*inflight
window time.Duration
done chan struct{}
}
// NewDeduplicator creates a new request deduplicator.
func NewDeduplicator(window time.Duration) *Deduplicator {
if window == 0 {
window = 30 * time.Second
}
d := &Deduplicator{
flights: make(map[string]*inflight),
window: window,
done: make(chan struct{}),
}
go d.cleanup()
return d
}
// DedupKey computes a dedup key from model name and request body.
func DedupKey(model string, body []byte) string {
h := sha256.New()
h.Write([]byte(model))
h.Write([]byte{0})
h.Write(body)
return hex.EncodeToString(h.Sum(nil))
}
// TryJoin attempts to join an in-flight request. Returns the inflight entry and
// whether this caller is the leader (true) or a follower (false).
func (d *Deduplicator) TryJoin(key string) (*inflight, bool) {
d.mu.Lock()
defer d.mu.Unlock()
if f, ok := d.flights[key]; ok {
return f, false // follower
}
f := &inflight{
done: make(chan struct{}),
createdAt: time.Now(),
}
d.flights[key] = f
return f, true // leader
}
// Complete signals completion of a deduplicated request.
func (d *Deduplicator) Complete(key string, result []byte, statusCode int) {
d.mu.Lock()
f, ok := d.flights[key]
delete(d.flights, key)
d.mu.Unlock()
if ok {
f.result = result
f.statusCode = statusCode
close(f.done)
}
}
// Close stops the background cleanup goroutine.
func (d *Deduplicator) Close() {
close(d.done)
}
// cleanup periodically removes stale in-flight entries.
func (d *Deduplicator) cleanup() {
ticker := time.NewTicker(d.window)
defer ticker.Stop()
for {
select {
case <-d.done:
return
case <-ticker.C:
d.mu.Lock()
now := time.Now()
for key, f := range d.flights {
if now.Sub(f.createdAt) > d.window*2 {
delete(d.flights, key)
close(f.done) // unblock any waiting followers
}
}
d.mu.Unlock()
}
}
}

View file

@ -0,0 +1,74 @@
package proxy
import (
"sync"
"testing"
"time"
)
func TestDedupKey(t *testing.T) {
k1 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
k2 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
k3 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hello"}]}`))
if k1 != k2 {
t.Error("identical requests should produce the same key")
}
if k1 == k3 {
t.Error("different requests should produce different keys")
}
}
func TestDeduplicator_LeaderFollower(t *testing.T) {
d := NewDeduplicator(5 * time.Second)
defer d.Close()
key := DedupKey("gpt-4", []byte(`test`))
// First call is leader
f1, isLeader := d.TryJoin(key)
if !isLeader {
t.Fatal("first caller should be leader")
}
// Second call with same key is follower
f2, isLeader := d.TryJoin(key)
if isLeader {
t.Fatal("second caller should be follower")
}
if f1 != f2 {
t.Fatal("follower should get same inflight entry")
}
// Complete the request
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
<-f2.done
if string(f2.result) != "response" {
t.Error("follower should receive leader's result")
}
if f2.statusCode != 200 {
t.Error("follower should receive leader's status code")
}
}()
d.Complete(key, []byte("response"), 200)
wg.Wait()
}
func TestDeduplicator_DifferentKeys(t *testing.T) {
d := NewDeduplicator(5 * time.Second)
defer d.Close()
_, isLeader1 := d.TryJoin("key1")
_, isLeader2 := d.TryJoin("key2")
if !isLeader1 || !isLeader2 {
t.Error("different keys should both be leaders")
}
d.Complete("key1", []byte("r1"), 200)
d.Complete("key2", []byte("r2"), 200)
}

View file

@ -53,6 +53,7 @@ type Handler struct {
cfg *config.Config cfg *config.Config
healthTracker *provider.HealthTracker healthTracker *provider.HealthTracker
debugLogger *storage.DebugLogger debugLogger *storage.DebugLogger
dedup *Deduplicator
} }
func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler { func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler {
@ -70,6 +71,10 @@ func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) {
h.debugLogger = dl h.debugLogger = dl
} }
func (h *Handler) SetDeduplicator(d *Deduplicator) {
h.dedup = d
}
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20)) body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
if err != nil { if err != nil {
@ -118,11 +123,47 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
} }
} }
// Apply per-model timeout for non-streaming requests
modelTimeouts := h.registry.ModelTimeoutsFor(req.Model)
if req.Stream { if req.Stream {
h.handleStream(w, r, &req, routes, tokenName, requestID) h.handleStream(w, r, &req, routes, tokenName, requestID, modelTimeouts)
return return
} }
// Request deduplication for non-streaming requests
if h.dedup != nil {
dedupKey := DedupKey(req.Model, body)
flight, isLeader := h.dedup.TryJoin(dedupKey)
if !isLeader {
// Wait for the leader to complete
select {
case <-flight.done:
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Request-ID", requestID)
w.Header().Set("X-Dedup", "HIT")
w.WriteHeader(flight.statusCode)
w.Write(flight.result)
return
case <-r.Context().Done():
writeError(w, http.StatusGatewayTimeout, "request cancelled while waiting for dedup")
return
}
}
// Leader: proceed normally, but capture response for followers
defer func() {
// If we haven't completed yet (e.g., panic), clean up
}()
h.handleNonStreamDedup(w, r, &req, routes, tokenName, body, requestID, modelTimeouts, dedupKey)
return
}
if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
defer cancel()
r = r.WithContext(ctx)
}
h.handleNonStream(w, r, &req, routes, tokenName, body, requestID) h.handleNonStream(w, r, &req, routes, tokenName, body, requestID)
} }
@ -233,6 +274,153 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
} }
} }
// handleNonStreamDedup wraps handleNonStream to capture the response for dedup followers.
func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
if err != nil {
writeError(w, http.StatusBadRequest, "failed to read request body")
return
}
var req provider.EmbeddingRequest
if err := json.Unmarshal(body, &req); err != nil {
writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error())
return
}
if req.Model == "" {
writeError(w, http.StatusBadRequest, "model is required")
return
}
routes, ok := h.registry.Lookup(req.Model)
if !ok {
writeError(w, http.StatusNotFound, "model not found: "+req.Model)
return
}
routes = h.filterHealthyRoutes(routes)
tokenName := getTokenName(r.Context())
requestID := middleware.GetReqID(r.Context())
var lastErr error
for i, route := range routes {
if i > 0 {
backoff := backoffDuration(i, h.cfg.Retry)
select {
case <-time.After(backoff):
case <-r.Context().Done():
writeError(w, http.StatusGatewayTimeout, "request cancelled")
return
}
}
start := time.Now()
resp, err := route.Provider.Embedding(r.Context(), route.ProviderModel, &req)
latency := time.Since(start).Milliseconds()
if err != nil {
var pe *provider.ProviderError
if errors.As(err, &pe) && !pe.IsRetryable() {
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
w.Header().Set("X-Request-ID", requestID)
writeErrorRaw(w, pe.StatusCode, pe.Body)
return
}
lastErr = err
log.Printf("Provider %s embedding failed for %s: %v", route.Provider.Name(), req.Model, err)
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
continue
}
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, nil)
}
promptTokens := 0
if resp.Usage != nil {
promptTokens = resp.Usage.PromptTokens
}
cost := float64(promptTokens) / 1_000_000.0 * route.InputPrice
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, promptTokens, 0, cost)
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, promptTokens, cost, latency, "success", "")
resp.Model = req.Model
respBytes, err := json.Marshal(resp)
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to marshal response")
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Request-ID", requestID)
w.Write(respBytes)
return
}
w.Header().Set("X-Request-ID", requestID)
if lastErr != nil {
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
} else {
writeError(w, http.StatusBadGateway, "all providers failed")
}
}
func (h *Handler) logEmbeddingRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens int, cost float64, latencyMS int64, status, errMsg string) {
h.logger.Log(storage.RequestLog{
RequestID: requestID,
Timestamp: time.Now().Unix(),
TokenName: tokenName,
Model: model,
Provider: providerName,
ProviderModel: providerModel,
InputTokens: inputTokens,
CostUSD: cost,
LatencyMS: latencyMS,
Status: status,
ErrorMessage: errMsg,
RequestType: "embedding",
})
}
func (h *Handler) handleNonStreamDedup(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string, modelTimeouts *provider.ModelTimeouts, dedupKey string) {
if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
defer cancel()
r = r.WithContext(ctx)
}
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
h.handleNonStream(rec, r, req, routes, tokenName, rawBody, requestID)
h.dedup.Complete(dedupKey, rec.body, rec.statusCode)
}
// responseRecorder captures the response for dedup.
type responseRecorder struct {
http.ResponseWriter
statusCode int
body []byte
}
func (r *responseRecorder) WriteHeader(code int) {
r.statusCode = code
r.ResponseWriter.WriteHeader(code)
}
func (r *responseRecorder) Write(b []byte) (int, error) {
r.body = append(r.body, b...)
return r.ResponseWriter.Write(b)
}
// filterHealthyRoutes removes providers with open circuit breakers. // filterHealthyRoutes removes providers with open circuit breakers.
// If all are filtered out, returns original routes as fallback. // If all are filtered out, returns original routes as fallback.
func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route { func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route {

View file

@ -5,27 +5,66 @@ import (
"net/http" "net/http"
"time" "time"
"llm-gateway/internal/config"
"llm-gateway/internal/provider" "llm-gateway/internal/provider"
) )
type ModelsHandler struct { type ModelsHandler struct {
registry *provider.Registry registry *provider.Registry
healthTracker *provider.HealthTracker
cfg *config.Config
} }
func NewModelsHandler(registry *provider.Registry) *ModelsHandler { func NewModelsHandler(registry *provider.Registry, healthTracker *provider.HealthTracker, cfg *config.Config) *ModelsHandler {
return &ModelsHandler{registry: registry} return &ModelsHandler{
registry: registry,
healthTracker: healthTracker,
cfg: cfg,
}
} }
func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) { func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) {
names := h.registry.ModelNames() allRoutes := h.registry.AllRoutes()
models := make([]map[string]any, len(names)) models := make([]map[string]any, 0, len(allRoutes))
for i, name := range names {
models[i] = map[string]any{ for _, m := range allRoutes {
"id": name, providers := make([]map[string]any, 0, len(m.Routes))
"object": "model", for _, rt := range m.Routes {
"created": time.Now().Unix(), healthy := true
"owned_by": "llm-gateway", if h.healthTracker != nil {
healthy = h.healthTracker.IsAvailable(rt.ProviderName)
}
providers = append(providers, map[string]any{
"name": rt.ProviderName,
"model": rt.ProviderModel,
"input_price": rt.InputPrice,
"output_price": rt.OutputPrice,
"priority": rt.Priority,
"healthy": healthy,
})
} }
// Find load balancing strategy from config
loadBalancing := "first"
for _, mc := range h.cfg.Models {
if mc.Name == m.Name {
if mc.LoadBalancing != "" {
loadBalancing = mc.LoadBalancing
}
break
}
}
models = append(models, map[string]any{
"id": m.Name,
"object": "model",
"created": time.Now().Unix(),
"owned_by": "llm-gateway",
"providers": providers,
"provider_count": len(providers),
"load_balancing": loadBalancing,
"aliases": m.Aliases,
})
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")

View file

@ -8,12 +8,15 @@ import (
"time" "time"
"llm-gateway/internal/storage" "llm-gateway/internal/storage"
"llm-gateway/internal/webhook"
) )
type RateLimiter struct { type RateLimiter struct {
db *storage.DB db *storage.DB
mu sync.Mutex mu sync.Mutex
buckets map[string]*tokenBucket buckets map[string]*tokenBucket
notifier *webhook.Notifier
budgetNotified sync.Map // tracks which token+budget combos have been notified
} }
type tokenBucket struct { type tokenBucket struct {
@ -30,6 +33,11 @@ func NewRateLimiter(db *storage.DB) *RateLimiter {
} }
} }
// SetNotifier sets the webhook notifier for budget threshold alerts.
func (rl *RateLimiter) SetNotifier(n *webhook.Notifier) {
rl.notifier = n
}
func (rl *RateLimiter) Check(next http.Handler) http.Handler { func (rl *RateLimiter) Check(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiToken := getAPIToken(r.Context()) apiToken := getAPIToken(r.Context())
@ -63,9 +71,24 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
// Check daily budget // Check daily budget
if apiToken.DailyBudgetUSD > 0 { if apiToken.DailyBudgetUSD > 0 {
spent, err := rl.db.TodaySpend(tokenName) spent, err := rl.db.TodaySpend(tokenName)
if err == nil && spent >= apiToken.DailyBudgetUSD { if err == nil {
writeError(w, http.StatusTooManyRequests, "daily budget exceeded") if spent >= apiToken.DailyBudgetUSD {
return writeError(w, http.StatusTooManyRequests, "daily budget exceeded")
return
}
rl.checkBudgetThreshold(tokenName, "daily", spent, apiToken.DailyBudgetUSD)
}
}
// Check monthly budget
if apiToken.MonthlyBudgetUSD > 0 {
spent, err := rl.db.MonthSpend(tokenName)
if err == nil {
if spent >= apiToken.MonthlyBudgetUSD {
writeError(w, http.StatusTooManyRequests, "monthly budget exceeded")
return
}
rl.checkBudgetThreshold(tokenName, "monthly", spent, apiToken.MonthlyBudgetUSD)
} }
} }
@ -73,6 +96,30 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
}) })
} }
// checkBudgetThreshold fires a webhook notification when spend reaches 80% of budget.
func (rl *RateLimiter) checkBudgetThreshold(tokenName, budgetType string, spent, budget float64) {
if rl.notifier == nil || budget <= 0 {
return
}
if spent/budget < 0.8 {
return
}
key := tokenName + ":" + budgetType
if _, loaded := rl.budgetNotified.LoadOrStore(key, true); loaded {
return // already notified
}
rl.notifier.Notify(webhook.Event{
Type: webhook.EventBudgetThreshold,
Data: map[string]any{
"token": tokenName,
"budget_type": budgetType,
"spent": spent,
"budget": budget,
"percent": spent / budget * 100,
},
})
}
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) { func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) {
rl.mu.Lock() rl.mu.Lock()
defer rl.mu.Unlock() defer rl.mu.Unlock()

View file

@ -45,11 +45,11 @@ var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
func TestRateLimiter_Allow(t *testing.T) { func TestRateLimiter_Allow(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
rateLimitRPM int rateLimitRPM int
numRequests int numRequests int
wantAllowed int wantAllowed int
wantDenied int wantDenied int
}{ }{
{ {
name: "allows requests within limit", name: "allows requests within limit",
@ -191,12 +191,12 @@ func TestRateLimiter_AllowReturnValues(t *testing.T) {
func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) { func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
rateLimitRPM int rateLimitRPM int
numRequests int numRequests int
wantStatusCode int wantStatusCode int
wantLimitHeader string wantLimitHeader string
wantRetryAfter bool wantRetryAfter bool
}{ }{
{ {
name: "sets rate limit headers on allowed request", name: "sets rate limit headers on allowed request",

View file

@ -13,7 +13,7 @@ import (
"llm-gateway/internal/provider" "llm-gateway/internal/provider"
) )
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) { func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string, modelTimeouts *provider.ModelTimeouts) {
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
if !ok { if !ok {
writeError(w, http.StatusInternalServerError, "streaming not supported") writeError(w, http.StatusInternalServerError, "streaming not supported")
@ -60,11 +60,15 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
continue continue
} }
// Apply streaming timeout // Apply streaming timeout (per-model override takes precedence)
streamingTimeout := h.cfg.Server.StreamingTimeout
if modelTimeouts != nil && modelTimeouts.StreamingTimeout > 0 {
streamingTimeout = modelTimeouts.StreamingTimeout
}
var streamCtx context.Context var streamCtx context.Context
var streamCancel context.CancelFunc var streamCancel context.CancelFunc
if h.cfg.Server.StreamingTimeout > 0 { if streamingTimeout > 0 {
streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout) streamCtx, streamCancel = context.WithTimeout(r.Context(), streamingTimeout)
} else { } else {
streamCtx, streamCancel = context.WithCancel(r.Context()) streamCtx, streamCancel = context.WithCancel(r.Context())
} }

View file

@ -102,6 +102,21 @@ func (db *DB) TodaySpend(tokenName string) (float64, error) {
return total.Float64, nil return total.Float64, nil
} }
// MonthSpend returns the total cost in USD for a given token this month.
func (db *DB) MonthSpend(tokenName string) (float64, error) {
now := time.Now()
startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()).Unix()
var total sql.NullFloat64
err := db.QueryRow(
"SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?",
tokenName, startOfMonth,
).Scan(&total)
if err != nil {
return 0, err
}
return total.Float64, nil
}
// TodaySpendAll returns today's spend for all tokens as a map. // TodaySpendAll returns today's spend for all tokens as a map.
func (db *DB) TodaySpendAll() (map[string]float64, error) { func (db *DB) TodaySpendAll() (map[string]float64, error) {
startOfDay := time.Now().Truncate(24 * time.Hour).Unix() startOfDay := time.Now().Truncate(24 * time.Hour).Unix()

View file

@ -20,6 +20,7 @@ type RequestLog struct {
ErrorMessage string ErrorMessage string
Streaming bool Streaming bool
Cached bool Cached bool
RequestType string // "chat" or "embedding"
} }
type AsyncLogger struct { type AsyncLogger struct {
@ -94,8 +95,8 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
} }
stmt, err := tx.Prepare(`INSERT INTO request_logs stmt, err := tx.Prepare(`INSERT INTO request_logs
(request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached) (request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached, request_type)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil { if err != nil {
log.Printf("ERROR: preparing log statement: %v", err) log.Printf("ERROR: preparing log statement: %v", err)
tx.Rollback() tx.Rollback()
@ -112,10 +113,14 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
if r.Cached { if r.Cached {
cached = 1 cached = 1
} }
reqType := r.RequestType
if reqType == "" {
reqType = "chat"
}
_, err := stmt.Exec( _, err := stmt.Exec(
r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel, r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS, r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS,
r.Status, r.ErrorMessage, streaming, cached, r.Status, r.ErrorMessage, streaming, cached, reqType,
) )
if err != nil { if err != nil {
log.Printf("ERROR: inserting log: %v", err) log.Printf("ERROR: inserting log: %v", err)

View file

@ -0,0 +1 @@
ALTER TABLE api_tokens DROP COLUMN monthly_budget_usd;

View file

@ -0,0 +1 @@
ALTER TABLE api_tokens ADD COLUMN monthly_budget_usd REAL NOT NULL DEFAULT 0;

View file

@ -0,0 +1 @@
ALTER TABLE request_logs DROP COLUMN request_type;

View file

@ -0,0 +1 @@
ALTER TABLE request_logs ADD COLUMN request_type TEXT NOT NULL DEFAULT 'chat';

View file

@ -0,0 +1,123 @@
package webhook
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"log"
"net/http"
"time"
"llm-gateway/internal/config"
)
// Event types.
const (
EventCircuitBreakerOpen = "circuit_breaker.open"
EventCircuitBreakerClosed = "circuit_breaker.closed"
EventBudgetThreshold = "budget.threshold"
)
// Event represents a webhook notification payload.
type Event struct {
Type string `json:"type"`
Timestamp time.Time `json:"timestamp"`
Data map[string]any `json:"data"`
}
// Notifier sends webhook notifications.
type Notifier struct {
webhooks []config.WebhookConfig
ch chan Event
done chan struct{}
client *http.Client
}
// NewNotifier creates a webhook notifier from config.
func NewNotifier(webhooks []config.WebhookConfig) *Notifier {
n := &Notifier{
webhooks: webhooks,
ch: make(chan Event, 100),
done: make(chan struct{}),
client: &http.Client{Timeout: 10 * time.Second},
}
go n.run()
return n
}
// Notify queues an event for delivery (non-blocking).
func (n *Notifier) Notify(evt Event) {
if evt.Timestamp.IsZero() {
evt.Timestamp = time.Now()
}
select {
case n.ch <- evt:
default:
log.Printf("WARNING: webhook channel full, dropping event %s", evt.Type)
}
}
// Close drains pending events and shuts down.
func (n *Notifier) Close() {
close(n.ch)
<-n.done
}
func (n *Notifier) run() {
defer close(n.done)
for evt := range n.ch {
for _, wh := range n.webhooks {
if !n.shouldSend(wh, evt.Type) {
continue
}
n.send(wh, evt)
}
}
}
func (n *Notifier) shouldSend(wh config.WebhookConfig, eventType string) bool {
if len(wh.Events) == 0 {
return true // no filter = send all
}
for _, e := range wh.Events {
if e == eventType {
return true
}
}
return false
}
func (n *Notifier) send(wh config.WebhookConfig, evt Event) {
body, err := json.Marshal(evt)
if err != nil {
log.Printf("ERROR: webhook marshal: %v", err)
return
}
req, err := http.NewRequest(http.MethodPost, wh.URL, bytes.NewReader(body))
if err != nil {
log.Printf("ERROR: webhook request: %v", err)
return
}
req.Header.Set("Content-Type", "application/json")
if wh.Secret != "" {
mac := hmac.New(sha256.New, []byte(wh.Secret))
mac.Write(body)
sig := hex.EncodeToString(mac.Sum(nil))
req.Header.Set("X-Webhook-Signature", "sha256="+sig)
}
resp, err := n.client.Do(req)
if err != nil {
log.Printf("WARNING: webhook delivery to %s failed: %v", wh.URL, err)
return
}
resp.Body.Close()
if resp.StatusCode >= 400 {
log.Printf("WARNING: webhook %s returned %d", wh.URL, resp.StatusCode)
}
}