feat(gateway): add monthly budget support and webhook notifications for circuit breaker and budget events
This commit is contained in:
parent
28a694744d
commit
291b8f4863
31 changed files with 1005 additions and 124 deletions
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"llm-gateway/internal/provider"
|
"llm-gateway/internal/provider"
|
||||||
"llm-gateway/internal/proxy"
|
"llm-gateway/internal/proxy"
|
||||||
"llm-gateway/internal/storage"
|
"llm-gateway/internal/storage"
|
||||||
|
"llm-gateway/internal/webhook"
|
||||||
)
|
)
|
||||||
|
|
||||||
var version = "dev"
|
var version = "dev"
|
||||||
|
|
@ -95,16 +96,41 @@ func main() {
|
||||||
// Provider health tracker
|
// Provider health tracker
|
||||||
healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker)
|
healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker)
|
||||||
|
|
||||||
|
// Webhook notifier
|
||||||
|
var notifier *webhook.Notifier
|
||||||
|
if len(cfg.Webhooks) > 0 {
|
||||||
|
notifier = webhook.NewNotifier(cfg.Webhooks)
|
||||||
|
defer notifier.Close()
|
||||||
|
log.Printf("Webhooks configured: %d endpoints", len(cfg.Webhooks))
|
||||||
|
|
||||||
|
// Wire health tracker state changes to webhook
|
||||||
|
healthTracker.OnStateChange = func(providerName string, from, to provider.CircuitState) {
|
||||||
|
eventType := webhook.EventCircuitBreakerOpen
|
||||||
|
if to == provider.CircuitClosed {
|
||||||
|
eventType = webhook.EventCircuitBreakerClosed
|
||||||
|
}
|
||||||
|
notifier.Notify(webhook.Event{
|
||||||
|
Type: eventType,
|
||||||
|
Data: map[string]any{
|
||||||
|
"provider": providerName,
|
||||||
|
"from": from.String(),
|
||||||
|
"to": to.String(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Auth store (static tokens checked in-memory, not seeded to DB)
|
// Auth store (static tokens checked in-memory, not seeded to DB)
|
||||||
var staticTokens []auth.StaticToken
|
var staticTokens []auth.StaticToken
|
||||||
for _, t := range cfg.Tokens {
|
for _, t := range cfg.Tokens {
|
||||||
if t.Key != "" {
|
if t.Key != "" {
|
||||||
staticTokens = append(staticTokens, auth.StaticToken{
|
staticTokens = append(staticTokens, auth.StaticToken{
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Key: t.Key,
|
Key: t.Key,
|
||||||
RateLimitRPM: t.RateLimitRPM,
|
RateLimitRPM: t.RateLimitRPM,
|
||||||
DailyBudgetUSD: t.DailyBudgetUSD,
|
DailyBudgetUSD: t.DailyBudgetUSD,
|
||||||
MaxConcurrent: t.MaxConcurrent,
|
MonthlyBudgetUSD: t.MonthlyBudgetUSD,
|
||||||
|
MaxConcurrent: t.MaxConcurrent,
|
||||||
})
|
})
|
||||||
log.Printf("Loaded static token: %s", t.Name)
|
log.Printf("Loaded static token: %s", t.Name)
|
||||||
}
|
}
|
||||||
|
|
@ -133,14 +159,27 @@ func main() {
|
||||||
// Handlers
|
// Handlers
|
||||||
proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker)
|
proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker)
|
||||||
proxyHandler.SetDebugLogger(debugLogger)
|
proxyHandler.SetDebugLogger(debugLogger)
|
||||||
modelsHandler := proxy.NewModelsHandler(registry)
|
|
||||||
|
// Request deduplication
|
||||||
|
if cfg.Dedup.Enabled {
|
||||||
|
dedup := proxy.NewDeduplicator(cfg.Dedup.Window)
|
||||||
|
defer dedup.Close()
|
||||||
|
proxyHandler.SetDeduplicator(dedup)
|
||||||
|
log.Printf("Request deduplication enabled (window: %v)", cfg.Dedup.Window)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsHandler := proxy.NewModelsHandler(registry, healthTracker, cfg)
|
||||||
proxyAuth := proxy.NewAuthMiddleware(authStore)
|
proxyAuth := proxy.NewAuthMiddleware(authStore)
|
||||||
rateLimiter := proxy.NewRateLimiter(db)
|
rateLimiter := proxy.NewRateLimiter(db)
|
||||||
|
if notifier != nil {
|
||||||
|
rateLimiter.SetNotifier(notifier)
|
||||||
|
}
|
||||||
concurrencyLimiter := proxy.NewConcurrencyLimiter()
|
concurrencyLimiter := proxy.NewConcurrencyLimiter()
|
||||||
statsAPI := dashboard.NewStatsAPI(db, authStore)
|
statsAPI := dashboard.NewStatsAPI(db, authStore)
|
||||||
statsAPI.SetHealthTracker(healthTracker)
|
statsAPI.SetHealthTracker(healthTracker)
|
||||||
statsAPI.SetAuditLogger(auditLogger)
|
statsAPI.SetAuditLogger(auditLogger)
|
||||||
statsAPI.SetDebugLogger(debugLogger)
|
statsAPI.SetDebugLogger(debugLogger)
|
||||||
|
statsAPI.SetConfigPath(*configPath)
|
||||||
if c != nil {
|
if c != nil {
|
||||||
statsAPI.SetCache(c)
|
statsAPI.SetCache(c)
|
||||||
}
|
}
|
||||||
|
|
@ -196,6 +235,7 @@ func main() {
|
||||||
r.Use(rateLimiter.Check)
|
r.Use(rateLimiter.Check)
|
||||||
r.Use(concurrencyLimiter.Check)
|
r.Use(concurrencyLimiter.Check)
|
||||||
r.Post("/v1/chat/completions", proxyHandler.ChatCompletions)
|
r.Post("/v1/chat/completions", proxyHandler.ChatCompletions)
|
||||||
|
r.Post("/v1/embeddings", proxyHandler.Embeddings)
|
||||||
r.Get("/v1/models", modelsHandler.ListModels)
|
r.Get("/v1/models", modelsHandler.ListModels)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -266,7 +306,7 @@ func main() {
|
||||||
r.Get("/api/export/logs", exportHandler.ExportLogs)
|
r.Get("/api/export/logs", exportHandler.ExportLogs)
|
||||||
r.Get("/api/export/stats", exportHandler.ExportStats)
|
r.Get("/api/export/stats", exportHandler.ExportStats)
|
||||||
|
|
||||||
// Admin-only: user management, audit, debug
|
// Admin-only: user management, audit, debug, config validation
|
||||||
r.Group(func(r chi.Router) {
|
r.Group(func(r chi.Router) {
|
||||||
r.Use(authMiddleware.RequireAdmin)
|
r.Use(authMiddleware.RequireAdmin)
|
||||||
r.Get("/api/auth/users", authHandlers.ListUsers)
|
r.Get("/api/auth/users", authHandlers.ListUsers)
|
||||||
|
|
@ -276,6 +316,9 @@ func main() {
|
||||||
// Audit log
|
// Audit log
|
||||||
r.Get("/api/stats/audit", statsAPI.AuditLogs)
|
r.Get("/api/stats/audit", statsAPI.AuditLogs)
|
||||||
|
|
||||||
|
// Config validation
|
||||||
|
r.Get("/api/config/validate", statsAPI.ValidateConfig)
|
||||||
|
|
||||||
// Debug logging
|
// Debug logging
|
||||||
r.Post("/api/debug/toggle", statsAPI.DebugToggle)
|
r.Post("/api/debug/toggle", statsAPI.DebugToggle)
|
||||||
r.Get("/api/debug/status", statsAPI.DebugStatus)
|
r.Get("/api/debug/status", statsAPI.DebugStatus)
|
||||||
|
|
@ -332,11 +375,12 @@ func main() {
|
||||||
for _, t := range newCfg.Tokens {
|
for _, t := range newCfg.Tokens {
|
||||||
if t.Key != "" {
|
if t.Key != "" {
|
||||||
newStaticTokens = append(newStaticTokens, auth.StaticToken{
|
newStaticTokens = append(newStaticTokens, auth.StaticToken{
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Key: t.Key,
|
Key: t.Key,
|
||||||
RateLimitRPM: t.RateLimitRPM,
|
RateLimitRPM: t.RateLimitRPM,
|
||||||
DailyBudgetUSD: t.DailyBudgetUSD,
|
DailyBudgetUSD: t.DailyBudgetUSD,
|
||||||
MaxConcurrent: t.MaxConcurrent,
|
MonthlyBudgetUSD: t.MonthlyBudgetUSD,
|
||||||
|
MaxConcurrent: t.MaxConcurrent,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ go 1.24.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/go-chi/chi/v5 v5.2.5
|
github.com/go-chi/chi/v5 v5.2.5
|
||||||
|
github.com/go-chi/cors v1.2.2
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1
|
github.com/golang-migrate/migrate/v4 v4.19.1
|
||||||
github.com/pquerna/otp v1.5.0
|
github.com/pquerna/otp v1.5.0
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
|
|
@ -19,7 +20,6 @@ require (
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/go-chi/cors v1.2.2 // indirect
|
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
|
|
||||||
|
|
@ -31,25 +31,27 @@ type Session struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type APIToken struct {
|
type APIToken struct {
|
||||||
ID int64 `json:"id"`
|
ID int64 `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
KeyPrefix string `json:"key_prefix"`
|
KeyPrefix string `json:"key_prefix"`
|
||||||
KeyHash string `json:"-"`
|
KeyHash string `json:"-"`
|
||||||
UserID int64 `json:"user_id"`
|
UserID int64 `json:"user_id"`
|
||||||
RateLimitRPM int `json:"rate_limit_rpm"`
|
RateLimitRPM int `json:"rate_limit_rpm"`
|
||||||
DailyBudgetUSD float64 `json:"daily_budget_usd"`
|
DailyBudgetUSD float64 `json:"daily_budget_usd"`
|
||||||
MaxConcurrent int `json:"max_concurrent"`
|
MonthlyBudgetUSD float64 `json:"monthly_budget_usd"`
|
||||||
CreatedAt int64 `json:"created_at"`
|
MaxConcurrent int `json:"max_concurrent"`
|
||||||
LastUsedAt int64 `json:"last_used_at"`
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
LastUsedAt int64 `json:"last_used_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// StaticToken represents a token defined in config (checked in-memory, never stored in DB).
|
// StaticToken represents a token defined in config (checked in-memory, never stored in DB).
|
||||||
type StaticToken struct {
|
type StaticToken struct {
|
||||||
Name string
|
Name string
|
||||||
Key string
|
Key string
|
||||||
RateLimitRPM int
|
RateLimitRPM int
|
||||||
DailyBudgetUSD float64
|
DailyBudgetUSD float64
|
||||||
MaxConcurrent int
|
MonthlyBudgetUSD float64
|
||||||
|
MaxConcurrent int
|
||||||
}
|
}
|
||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
|
|
@ -289,12 +291,13 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
|
||||||
prefix = prefix[:11]
|
prefix = prefix[:11]
|
||||||
}
|
}
|
||||||
return &APIToken{
|
return &APIToken{
|
||||||
ID: -1, // sentinel: static token
|
ID: -1, // sentinel: static token
|
||||||
Name: st.Name,
|
Name: st.Name,
|
||||||
KeyPrefix: prefix,
|
KeyPrefix: prefix,
|
||||||
RateLimitRPM: st.RateLimitRPM,
|
RateLimitRPM: st.RateLimitRPM,
|
||||||
DailyBudgetUSD: st.DailyBudgetUSD,
|
DailyBudgetUSD: st.DailyBudgetUSD,
|
||||||
MaxConcurrent: st.MaxConcurrent,
|
MonthlyBudgetUSD: st.MonthlyBudgetUSD,
|
||||||
|
MaxConcurrent: st.MaxConcurrent,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -305,9 +308,9 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
|
||||||
|
|
||||||
var t APIToken
|
var t APIToken
|
||||||
err := s.db.QueryRow(
|
err := s.db.QueryRow(
|
||||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
|
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
|
||||||
keyHash,
|
keyHash,
|
||||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -323,12 +326,13 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
|
||||||
prefix = prefix[:11]
|
prefix = prefix[:11]
|
||||||
}
|
}
|
||||||
tokens = append(tokens, APIToken{
|
tokens = append(tokens, APIToken{
|
||||||
ID: -1, // sentinel: static token
|
ID: -1, // sentinel: static token
|
||||||
Name: st.Name,
|
Name: st.Name,
|
||||||
KeyPrefix: prefix,
|
KeyPrefix: prefix,
|
||||||
RateLimitRPM: st.RateLimitRPM,
|
RateLimitRPM: st.RateLimitRPM,
|
||||||
DailyBudgetUSD: st.DailyBudgetUSD,
|
DailyBudgetUSD: st.DailyBudgetUSD,
|
||||||
MaxConcurrent: st.MaxConcurrent,
|
MonthlyBudgetUSD: st.MonthlyBudgetUSD,
|
||||||
|
MaxConcurrent: st.MaxConcurrent,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -336,9 +340,9 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
if userID == 0 {
|
if userID == 0 {
|
||||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id")
|
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id")
|
||||||
} else {
|
} else {
|
||||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
|
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return tokens, nil
|
return tokens, nil
|
||||||
|
|
@ -347,7 +351,7 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var t APIToken
|
var t APIToken
|
||||||
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil {
|
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil {
|
||||||
return tokens, nil
|
return tokens, nil
|
||||||
}
|
}
|
||||||
tokens = append(tokens, t)
|
tokens = append(tokens, t)
|
||||||
|
|
@ -363,9 +367,9 @@ func (s *Store) DeleteAPIToken(id int64) error {
|
||||||
func (s *Store) GetAPIToken(id int64) (*APIToken, error) {
|
func (s *Store) GetAPIToken(id int64) (*APIToken, error) {
|
||||||
var t APIToken
|
var t APIToken
|
||||||
err := s.db.QueryRow(
|
err := s.db.QueryRow(
|
||||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?",
|
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?",
|
||||||
id,
|
id,
|
||||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ func setupTestDB(t *testing.T) *sql.DB {
|
||||||
user_id INTEGER NOT NULL,
|
user_id INTEGER NOT NULL,
|
||||||
rate_limit_rpm INTEGER DEFAULT 0,
|
rate_limit_rpm INTEGER DEFAULT 0,
|
||||||
daily_budget_usd REAL DEFAULT 0,
|
daily_budget_usd REAL DEFAULT 0,
|
||||||
|
monthly_budget_usd REAL DEFAULT 0,
|
||||||
max_concurrent INTEGER DEFAULT 0,
|
max_concurrent INTEGER DEFAULT 0,
|
||||||
created_at INTEGER NOT NULL,
|
created_at INTEGER NOT NULL,
|
||||||
last_used_at INTEGER DEFAULT 0
|
last_used_at INTEGER DEFAULT 0
|
||||||
|
|
|
||||||
|
|
@ -20,11 +20,24 @@ type Config struct {
|
||||||
Retry RetryConfig `yaml:"retry"`
|
Retry RetryConfig `yaml:"retry"`
|
||||||
Debug DebugConfig `yaml:"debug"`
|
Debug DebugConfig `yaml:"debug"`
|
||||||
CORS CORSConfig `yaml:"cors"`
|
CORS CORSConfig `yaml:"cors"`
|
||||||
|
Dedup DedupConfig `yaml:"dedup"`
|
||||||
|
Webhooks []WebhookConfig `yaml:"webhooks"`
|
||||||
Providers []ProviderConfig `yaml:"providers"`
|
Providers []ProviderConfig `yaml:"providers"`
|
||||||
Models []ModelConfig `yaml:"models"`
|
Models []ModelConfig `yaml:"models"`
|
||||||
Tokens []TokenConfig `yaml:"tokens"`
|
Tokens []TokenConfig `yaml:"tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DedupConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
Window time.Duration `yaml:"window"` // max time to wait for dedup result
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebhookConfig struct {
|
||||||
|
URL string `yaml:"url"`
|
||||||
|
Events []string `yaml:"events"` // event types to send
|
||||||
|
Secret string `yaml:"secret"` // optional HMAC secret
|
||||||
|
}
|
||||||
|
|
||||||
type PricingLookupConfig struct {
|
type PricingLookupConfig struct {
|
||||||
URL string `yaml:"url"`
|
URL string `yaml:"url"`
|
||||||
RefreshInterval time.Duration `yaml:"refresh_interval"`
|
RefreshInterval time.Duration `yaml:"refresh_interval"`
|
||||||
|
|
@ -36,20 +49,21 @@ type DefaultAdminConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenConfig struct {
|
type TokenConfig struct {
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
Key string `yaml:"key"`
|
Key string `yaml:"key"`
|
||||||
RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited
|
RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited
|
||||||
DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited
|
DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited
|
||||||
MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited
|
MonthlyBudgetUSD float64 `yaml:"monthly_budget_usd"` // 0 = unlimited
|
||||||
|
MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Listen string `yaml:"listen"`
|
Listen string `yaml:"listen"`
|
||||||
RequestTimeout time.Duration `yaml:"request_timeout"`
|
RequestTimeout time.Duration `yaml:"request_timeout"`
|
||||||
StreamingTimeout time.Duration `yaml:"streaming_timeout"`
|
StreamingTimeout time.Duration `yaml:"streaming_timeout"`
|
||||||
MaxRequestBodyMB int `yaml:"max_request_body_mb"`
|
MaxRequestBodyMB int `yaml:"max_request_body_mb"`
|
||||||
SessionSecret string `yaml:"session_secret"`
|
SessionSecret string `yaml:"session_secret"`
|
||||||
DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
|
DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CircuitBreakerConfig struct {
|
type CircuitBreakerConfig struct {
|
||||||
|
|
@ -66,10 +80,10 @@ type RetryConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DebugConfig struct {
|
type DebugConfig struct {
|
||||||
Enabled bool `yaml:"enabled"`
|
Enabled bool `yaml:"enabled"`
|
||||||
MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies)
|
MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies)
|
||||||
RetentionDays int `yaml:"retention_days"`
|
RetentionDays int `yaml:"retention_days"`
|
||||||
DataDir string `yaml:"data_dir"`
|
DataDir string `yaml:"data_dir"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CORSConfig struct {
|
type CORSConfig struct {
|
||||||
|
|
@ -100,10 +114,12 @@ type ProviderConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelConfig struct {
|
type ModelConfig struct {
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
Aliases []string `yaml:"aliases"`
|
Aliases []string `yaml:"aliases"`
|
||||||
Routes []RouteConfig `yaml:"routes"`
|
Routes []RouteConfig `yaml:"routes"`
|
||||||
LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost
|
LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost
|
||||||
|
RequestTimeout time.Duration `yaml:"request_timeout"` // per-model override; 0 = use server default
|
||||||
|
StreamingTimeout time.Duration `yaml:"streaming_timeout"` // per-model override; 0 = use server default
|
||||||
}
|
}
|
||||||
|
|
||||||
type RouteConfig struct {
|
type RouteConfig struct {
|
||||||
|
|
@ -131,14 +147,15 @@ func Load(path string) (*Config, error) {
|
||||||
return nil, fmt.Errorf("parsing config: %w", err)
|
return nil, fmt.Errorf("parsing config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := cfg.validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
return nil, fmt.Errorf("validating config: %w", err)
|
return nil, fmt.Errorf("validating config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) validate() error {
|
// Validate checks the config for correctness and applies defaults.
|
||||||
|
func (c *Config) Validate() error {
|
||||||
if c.Server.Listen == "" {
|
if c.Server.Listen == "" {
|
||||||
c.Server.Listen = "0.0.0.0:3000"
|
c.Server.Listen = "0.0.0.0:3000"
|
||||||
}
|
}
|
||||||
|
|
@ -201,6 +218,11 @@ func (c *Config) validate() error {
|
||||||
c.CORS.MaxAge = 300
|
c.CORS.MaxAge = 300
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dedup defaults
|
||||||
|
if c.Dedup.Window == 0 {
|
||||||
|
c.Dedup.Window = 30 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
if len(c.Providers) == 0 {
|
if len(c.Providers) == 0 {
|
||||||
return fmt.Errorf("at least one provider is required")
|
return fmt.Errorf("at least one provider is required")
|
||||||
}
|
}
|
||||||
|
|
@ -266,6 +288,19 @@ func (c *Config) validate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateBytes parses raw YAML and returns a list of validation errors.
|
||||||
|
func ValidateBytes(data []byte) []string {
|
||||||
|
expanded := os.ExpandEnv(string(data))
|
||||||
|
var cfg Config
|
||||||
|
if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil {
|
||||||
|
return []string{"parse error: " + err.Error()}
|
||||||
|
}
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
return []string{err.Error()}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// ProviderByName returns the provider config by name.
|
// ProviderByName returns the provider config by name.
|
||||||
func (c *Config) ProviderByName(name string) *ProviderConfig {
|
func (c *Config) ProviderByName(name string) *ProviderConfig {
|
||||||
for i := range c.Providers {
|
for i := range c.Providers {
|
||||||
|
|
|
||||||
|
|
@ -735,4 +735,3 @@ models:
|
||||||
t.Errorf("error = %q, want to contain api_key validation message", err.Error())
|
t.Errorf("error = %q, want to contain api_key validation message", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package dashboard
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -11,6 +12,7 @@ import (
|
||||||
|
|
||||||
"llm-gateway/internal/auth"
|
"llm-gateway/internal/auth"
|
||||||
"llm-gateway/internal/cache"
|
"llm-gateway/internal/cache"
|
||||||
|
"llm-gateway/internal/config"
|
||||||
"llm-gateway/internal/provider"
|
"llm-gateway/internal/provider"
|
||||||
"llm-gateway/internal/storage"
|
"llm-gateway/internal/storage"
|
||||||
)
|
)
|
||||||
|
|
@ -109,6 +111,7 @@ type StatsAPI struct {
|
||||||
cache *cache.Cache
|
cache *cache.Cache
|
||||||
auditLogger *storage.AuditLogger
|
auditLogger *storage.AuditLogger
|
||||||
debugLogger *storage.DebugLogger
|
debugLogger *storage.DebugLogger
|
||||||
|
configPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI {
|
func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI {
|
||||||
|
|
@ -135,6 +138,11 @@ func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) {
|
||||||
s.debugLogger = dl
|
s.debugLogger = dl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetConfigPath sets the config file path for validation.
|
||||||
|
func (s *StatsAPI) SetConfigPath(path string) {
|
||||||
|
s.configPath = path
|
||||||
|
}
|
||||||
|
|
||||||
// TokenNamesForUser returns the token names that belong to the user.
|
// TokenNamesForUser returns the token names that belong to the user.
|
||||||
// Admins get nil (no filter), non-admins get their token names.
|
// Admins get nil (no filter), non-admins get their token names.
|
||||||
func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string {
|
func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string {
|
||||||
|
|
@ -712,6 +720,27 @@ func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) {
|
||||||
writeJSON(w, entry)
|
writeJSON(w, entry)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateConfig validates the config file at the stored path.
|
||||||
|
func (s *StatsAPI) ValidateConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if s.configPath == "" {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
writeJSON(w, map[string]any{"valid": false, "errors": []string{"config path not set"}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := os.ReadFile(s.configPath)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
writeJSON(w, map[string]any{"valid": false, "errors": []string{"failed to read config: " + err.Error()}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errs := config.ValidateBytes(data)
|
||||||
|
if len(errs) > 0 {
|
||||||
|
writeJSON(w, map[string]any{"valid": false, "errors": errs})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, map[string]any{"valid": true, "errors": []string{}})
|
||||||
|
}
|
||||||
|
|
||||||
func writeJSON(w http.ResponseWriter, v any) {
|
func writeJSON(w http.ResponseWriter, v any) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(v)
|
json.NewEncoder(w).Encode(v)
|
||||||
|
|
|
||||||
|
|
@ -127,9 +127,9 @@ type PageData struct {
|
||||||
// Models routing page data
|
// Models routing page data
|
||||||
ModelRoutes []provider.ModelRouteInfo
|
ModelRoutes []provider.ModelRouteInfo
|
||||||
// Audit page data
|
// Audit page data
|
||||||
AuditResult *storage.AuditQueryResult
|
AuditResult *storage.AuditQueryResult
|
||||||
AuditFilterActions []string
|
AuditFilterActions []string
|
||||||
FilterAction string
|
FilterAction string
|
||||||
// Debug page data
|
// Debug page data
|
||||||
DebugResult *storage.DebugLogQueryResult
|
DebugResult *storage.DebugLogQueryResult
|
||||||
DebugEnabled bool
|
DebugEnabled bool
|
||||||
|
|
@ -275,6 +275,10 @@ func (d *Dashboard) ModelsPage(w http.ResponseWriter, r *http.Request) {
|
||||||
data.ModelRoutes = d.registry.AllRoutes()
|
data.ModelRoutes = d.registry.AllRoutes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if d.statsAPI.healthTracker != nil {
|
||||||
|
data.ProviderHealth = d.statsAPI.healthTracker.Status()
|
||||||
|
}
|
||||||
|
|
||||||
d.renderDashboardPage(w, r, "partials/models-page.html", data)
|
d.renderDashboardPage(w, r, "partials/models-page.html", data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
{{if .ModelRoutes}}
|
{{if .ModelRoutes}}
|
||||||
{{range .ModelRoutes}}
|
{{range .ModelRoutes}}
|
||||||
<div class="section">
|
<div class="section">
|
||||||
<h2>{{.Name}}</h2>
|
<h2>{{.Name}}{{if .Aliases}} <span style="font-size:0.75rem;color:var(--text-muted);font-weight:400;">aliases: {{range $i, $a := .Aliases}}{{if $i}}, {{end}}{{$a}}{{end}}</span>{{end}}</h2>
|
||||||
<table>
|
<table>
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
|
|
@ -15,9 +15,11 @@
|
||||||
<th>Priority</th>
|
<th>Priority</th>
|
||||||
<th>Input Price (per 1M)</th>
|
<th>Input Price (per 1M)</th>
|
||||||
<th>Output Price (per 1M)</th>
|
<th>Output Price (per 1M)</th>
|
||||||
|
<th>Health</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
|
{{$health := $.ProviderHealth}}
|
||||||
{{range .Routes}}
|
{{range .Routes}}
|
||||||
<tr>
|
<tr>
|
||||||
<td>{{.ProviderName}}</td>
|
<td>{{.ProviderName}}</td>
|
||||||
|
|
@ -25,6 +27,18 @@
|
||||||
<td><span class="badge badge-priority">{{.Priority}}</span></td>
|
<td><span class="badge badge-priority">{{.Priority}}</span></td>
|
||||||
<td>{{formatPrice .InputPrice}}</td>
|
<td>{{formatPrice .InputPrice}}</td>
|
||||||
<td>{{formatPrice .OutputPrice}}</td>
|
<td>{{formatPrice .OutputPrice}}</td>
|
||||||
|
<td>
|
||||||
|
{{$pname := .ProviderName}}
|
||||||
|
{{range $health}}
|
||||||
|
{{if eq .Provider $pname}}
|
||||||
|
{{if eq .Status "healthy"}}<span class="badge" style="background:#166534;color:#4ade80;">healthy</span>
|
||||||
|
{{else if eq .Status "degraded"}}<span class="badge" style="background:#92400e;color:#fbbf24;">degraded</span>
|
||||||
|
{{else}}<span class="badge" style="background:#991b1b;color:#f87171;">down</span>
|
||||||
|
{{end}}
|
||||||
|
{{end}}
|
||||||
|
{{end}}
|
||||||
|
{{if not $health}}<span style="color:var(--text-muted);">-</span>{{end}}
|
||||||
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
{{end}}
|
{{end}}
|
||||||
</tbody>
|
</tbody>
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,15 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{{if .User.IsAdmin}}
|
||||||
|
<div class="section">
|
||||||
|
<h2>Config Validation</h2>
|
||||||
|
<p style="color:#94a3b8;font-size:0.85rem;margin-bottom:12px;">Validate the current gateway configuration file for errors.</p>
|
||||||
|
<button class="btn btn-sm btn-primary" onclick="validateConfig()">Validate Config</button>
|
||||||
|
<div id="config-validation-result" style="margin-top:12px;"></div>
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
<script src="https://cdn.jsdelivr.net/npm/qrious@4.0.2/dist/qrious.min.js"></script>
|
<script src="https://cdn.jsdelivr.net/npm/qrious@4.0.2/dist/qrious.min.js"></script>
|
||||||
<script>
|
<script>
|
||||||
function showMsg(id, msg, isError) {
|
function showMsg(id, msg, isError) {
|
||||||
|
|
@ -167,5 +176,20 @@ async function disableTOTP() {
|
||||||
htmx.ajax('GET', '/settings', {target: '#content', swap: 'innerHTML'});
|
htmx.ajax('GET', '/settings', {target: '#content', swap: 'innerHTML'});
|
||||||
} catch (e) { alert(e.message); }
|
} catch (e) { alert(e.message); }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function validateConfig() {
|
||||||
|
var el = document.getElementById('config-validation-result');
|
||||||
|
el.innerHTML = '<span style="color:#94a3b8;">Validating...</span>';
|
||||||
|
try {
|
||||||
|
var resp = await fetch('/api/config/validate', { credentials: 'same-origin' });
|
||||||
|
var data = await resp.json();
|
||||||
|
if (data.valid) {
|
||||||
|
el.innerHTML = '<div class="success-msg">Configuration is valid.</div>';
|
||||||
|
} else {
|
||||||
|
var errs = (data.errors||[]).map(function(e) { return '<li>' + e + '</li>'; }).join('');
|
||||||
|
el.innerHTML = '<div class="error-msg">Configuration errors:<ul style="margin:4px 0 0 16px;">' + errs + '</ul></div>';
|
||||||
|
}
|
||||||
|
} catch (e) { el.innerHTML = '<div class="error-msg">' + e.message + '</div>'; }
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
{{end}}
|
{{end}}
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,10 @@
|
||||||
<td>{{.Name}}</td>
|
<td>{{.Name}}</td>
|
||||||
<td><code>{{.KeyPrefix}}...</code></td>
|
<td><code>{{.KeyPrefix}}...</code></td>
|
||||||
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
|
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
|
||||||
<td>{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}{{else}}unlimited{{end}}</td>
|
<td>
|
||||||
|
{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
|
||||||
|
{{if gt .MonthlyBudgetUSD 0.0}}<br>${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
|
||||||
|
</td>
|
||||||
<td>
|
<td>
|
||||||
{{$spend := index $.TokenSpend .Name}}
|
{{$spend := index $.TokenSpend .Name}}
|
||||||
{{if gt .DailyBudgetUSD 0.0}}
|
{{if gt .DailyBudgetUSD 0.0}}
|
||||||
|
|
@ -49,7 +52,10 @@
|
||||||
<td>{{.Name}}</td>
|
<td>{{.Name}}</td>
|
||||||
<td><code>{{.KeyPrefix}}...</code></td>
|
<td><code>{{.KeyPrefix}}...</code></td>
|
||||||
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
|
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
|
||||||
<td>{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}{{else}}unlimited{{end}}</td>
|
<td>
|
||||||
|
{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
|
||||||
|
{{if gt .MonthlyBudgetUSD 0.0}}<br>${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
|
||||||
|
</td>
|
||||||
<td>
|
<td>
|
||||||
{{$spend := index $.TokenSpend .Name}}
|
{{$spend := index $.TokenSpend .Name}}
|
||||||
{{if gt .DailyBudgetUSD 0.0}}
|
{{if gt .DailyBudgetUSD 0.0}}
|
||||||
|
|
|
||||||
|
|
@ -46,8 +46,8 @@ type HealthEvent struct {
|
||||||
|
|
||||||
// ProviderHealth is the computed health status for a provider.
|
// ProviderHealth is the computed health status for a provider.
|
||||||
type ProviderHealth struct {
|
type ProviderHealth struct {
|
||||||
Provider string `json:"provider"`
|
Provider string `json:"provider"`
|
||||||
Status string `json:"status"` // healthy, degraded, down
|
Status string `json:"status"` // healthy, degraded, down
|
||||||
ErrorRate float64 `json:"error_rate"`
|
ErrorRate float64 `json:"error_rate"`
|
||||||
AvgLatency float64 `json:"avg_latency_ms"`
|
AvgLatency float64 `json:"avg_latency_ms"`
|
||||||
Total int `json:"total"`
|
Total int `json:"total"`
|
||||||
|
|
@ -57,11 +57,12 @@ type ProviderHealth struct {
|
||||||
|
|
||||||
// HealthTracker tracks per-provider health using a sliding window.
|
// HealthTracker tracks per-provider health using a sliding window.
|
||||||
type HealthTracker struct {
|
type HealthTracker struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
windows map[string][]HealthEvent
|
windows map[string][]HealthEvent
|
||||||
windowDu time.Duration
|
windowDu time.Duration
|
||||||
circuits map[string]*ProviderCircuit
|
circuits map[string]*ProviderCircuit
|
||||||
cbConfig config.CircuitBreakerConfig
|
cbConfig config.CircuitBreakerConfig
|
||||||
|
OnStateChange func(provider string, from, to CircuitState)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHealthTracker creates a health tracker with the given window duration.
|
// NewHealthTracker creates a health tracker with the given window duration.
|
||||||
|
|
@ -135,6 +136,8 @@ func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
|
||||||
h.circuits[providerName] = circuit
|
h.circuits[providerName] = circuit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prevState := circuit.State
|
||||||
|
|
||||||
switch circuit.State {
|
switch circuit.State {
|
||||||
case CircuitClosed:
|
case CircuitClosed:
|
||||||
// Check if error threshold exceeded
|
// Check if error threshold exceeded
|
||||||
|
|
@ -164,6 +167,13 @@ func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
|
||||||
circuit.OpenedAt = time.Now()
|
circuit.OpenedAt = time.Now()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if circuit.State != prevState && h.OnStateChange != nil {
|
||||||
|
cb := h.OnStateChange
|
||||||
|
from, to := prevState, circuit.State
|
||||||
|
// Call outside lock to avoid deadlocks
|
||||||
|
go cb(providerName, from, to)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// errorRateUnlocked computes error rate within window. Must be called with lock held.
|
// errorRateUnlocked computes error rate within window. Must be called with lock held.
|
||||||
|
|
|
||||||
|
|
@ -47,13 +47,13 @@ func TestHealthTracker_Record(t *testing.T) {
|
||||||
|
|
||||||
func TestHealthTracker_Status(t *testing.T) {
|
func TestHealthTracker_Status(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
successCount int
|
successCount int
|
||||||
errorCount int
|
errorCount int
|
||||||
wantStatus string
|
wantStatus string
|
||||||
wantErrorRate float64
|
wantErrorRate float64
|
||||||
wantTotal int
|
wantTotal int
|
||||||
wantErrors int
|
wantErrors int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "healthy - no errors",
|
name: "healthy - no errors",
|
||||||
|
|
|
||||||
|
|
@ -108,6 +108,48 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string,
|
||||||
return resp.Body, nil
|
return resp.Body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||||
|
reqCopy := *req
|
||||||
|
reqCopy.Model = model
|
||||||
|
|
||||||
|
body, err := json.Marshal(reqCopy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("marshaling request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/embeddings", bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
p.setHeaders(httpReq)
|
||||||
|
|
||||||
|
resp, err := p.client.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("sending request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, &ProviderError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Body: string(respBody),
|
||||||
|
Provider: p.name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var embResp EmbeddingResponse
|
||||||
|
if err := json.Unmarshal(respBody, &embResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshaling response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &embResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *OpenAIProvider) setHeaders(req *http.Request) {
|
func (p *OpenAIProvider) setHeaders(req *http.Request) {
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||||
|
|
|
||||||
|
|
@ -52,9 +52,38 @@ type Usage struct {
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmbeddingRequest is the OpenAI-compatible embedding request.
|
||||||
|
type EmbeddingRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Input any `json:"input"` // string or []string
|
||||||
|
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingResponse is the OpenAI-compatible embedding response.
|
||||||
|
type EmbeddingResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []EmbeddingData `json:"data"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Usage *EmbeddingUsage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingData holds a single embedding vector.
|
||||||
|
type EmbeddingData struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Embedding []float64 `json:"embedding"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmbeddingUsage reports token usage for embeddings.
|
||||||
|
type EmbeddingUsage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
// Provider sends requests to an LLM API.
|
// Provider sends requests to an LLM API.
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
Name() string
|
Name() string
|
||||||
ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error)
|
ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error)
|
||||||
ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error)
|
ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error)
|
||||||
|
Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,17 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"llm-gateway/internal/config"
|
"llm-gateway/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ModelTimeouts holds per-model timeout overrides.
|
||||||
|
type ModelTimeouts struct {
|
||||||
|
RequestTimeout time.Duration
|
||||||
|
StreamingTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
// Route maps a model to a specific provider with pricing.
|
// Route maps a model to a specific provider with pricing.
|
||||||
type Route struct {
|
type Route struct {
|
||||||
Provider Provider
|
Provider Provider
|
||||||
|
|
@ -24,6 +31,7 @@ type Registry struct {
|
||||||
balancers map[string]LoadBalancer
|
balancers map[string]LoadBalancer
|
||||||
aliases map[string]string // alias -> canonical name
|
aliases map[string]string // alias -> canonical name
|
||||||
order []string // preserves config order (canonical names only)
|
order []string // preserves config order (canonical names only)
|
||||||
|
timeouts map[string]*ModelTimeouts
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRegistry(cfg *config.Config) (*Registry, error) {
|
func NewRegistry(cfg *config.Config) (*Registry, error) {
|
||||||
|
|
@ -46,6 +54,7 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
|
||||||
balancers := make(map[string]LoadBalancer)
|
balancers := make(map[string]LoadBalancer)
|
||||||
aliases := make(map[string]string)
|
aliases := make(map[string]string)
|
||||||
order := make([]string, 0, len(cfg.Models))
|
order := make([]string, 0, len(cfg.Models))
|
||||||
|
timeouts := make(map[string]*ModelTimeouts)
|
||||||
|
|
||||||
for _, mc := range cfg.Models {
|
for _, mc := range cfg.Models {
|
||||||
var modelRoutes []Route
|
var modelRoutes []Route
|
||||||
|
|
@ -74,6 +83,14 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
|
||||||
// Load balancer
|
// Load balancer
|
||||||
balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing)
|
balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing)
|
||||||
|
|
||||||
|
// Per-model timeouts
|
||||||
|
if mc.RequestTimeout > 0 || mc.StreamingTimeout > 0 {
|
||||||
|
timeouts[mc.Name] = &ModelTimeouts{
|
||||||
|
RequestTimeout: mc.RequestTimeout,
|
||||||
|
StreamingTimeout: mc.StreamingTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Register aliases
|
// Register aliases
|
||||||
for _, alias := range mc.Aliases {
|
for _, alias := range mc.Aliases {
|
||||||
aliases[alias] = mc.Name
|
aliases[alias] = mc.Name
|
||||||
|
|
@ -85,6 +102,7 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
|
||||||
r.balancers = balancers
|
r.balancers = balancers
|
||||||
r.aliases = aliases
|
r.aliases = aliases
|
||||||
r.order = order
|
r.order = order
|
||||||
|
r.timeouts = timeouts
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -135,6 +153,18 @@ func (r *Registry) ModelNames() []string {
|
||||||
return names
|
return names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelTimeoutsFor returns per-model timeout overrides, resolving aliases. Returns nil if none set.
|
||||||
|
func (r *Registry) ModelTimeoutsFor(model string) *ModelTimeouts {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
canonical := model
|
||||||
|
if alias, ok := r.aliases[model]; ok {
|
||||||
|
canonical = alias
|
||||||
|
}
|
||||||
|
return r.timeouts[canonical]
|
||||||
|
}
|
||||||
|
|
||||||
// RouteInfo exposes route details for dashboard display.
|
// RouteInfo exposes route details for dashboard display.
|
||||||
type RouteInfo struct {
|
type RouteInfo struct {
|
||||||
ProviderName string `json:"provider_name"`
|
ProviderName string `json:"provider_name"`
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,10 @@ func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *Chat
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockProvider) Embedding(_ context.Context, _ string, _ *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// newTestRegistry builds a Registry directly without going through config parsing.
|
// newTestRegistry builds a Registry directly without going through config parsing.
|
||||||
func newTestRegistry(models []testModel) *Registry {
|
func newTestRegistry(models []testModel) *Registry {
|
||||||
r := &Registry{
|
r := &Registry{
|
||||||
|
|
|
||||||
107
llm-gateway/internal/proxy/dedup.go
Normal file
107
llm-gateway/internal/proxy/dedup.go
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// inflight represents an in-progress deduplicated request.
|
||||||
|
type inflight struct {
|
||||||
|
done chan struct{}
|
||||||
|
result []byte
|
||||||
|
statusCode int
|
||||||
|
createdAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deduplicator coalesces identical concurrent non-streaming requests.
|
||||||
|
type Deduplicator struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
flights map[string]*inflight
|
||||||
|
window time.Duration
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDeduplicator creates a new request deduplicator.
|
||||||
|
func NewDeduplicator(window time.Duration) *Deduplicator {
|
||||||
|
if window == 0 {
|
||||||
|
window = 30 * time.Second
|
||||||
|
}
|
||||||
|
d := &Deduplicator{
|
||||||
|
flights: make(map[string]*inflight),
|
||||||
|
window: window,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go d.cleanup()
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// DedupKey computes a dedup key from model name and request body.
|
||||||
|
func DedupKey(model string, body []byte) string {
|
||||||
|
h := sha256.New()
|
||||||
|
h.Write([]byte(model))
|
||||||
|
h.Write([]byte{0})
|
||||||
|
h.Write(body)
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryJoin attempts to join an in-flight request. Returns the inflight entry and
|
||||||
|
// whether this caller is the leader (true) or a follower (false).
|
||||||
|
func (d *Deduplicator) TryJoin(key string) (*inflight, bool) {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
if f, ok := d.flights[key]; ok {
|
||||||
|
return f, false // follower
|
||||||
|
}
|
||||||
|
|
||||||
|
f := &inflight{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
createdAt: time.Now(),
|
||||||
|
}
|
||||||
|
d.flights[key] = f
|
||||||
|
return f, true // leader
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete signals completion of a deduplicated request.
|
||||||
|
func (d *Deduplicator) Complete(key string, result []byte, statusCode int) {
|
||||||
|
d.mu.Lock()
|
||||||
|
f, ok := d.flights[key]
|
||||||
|
delete(d.flights, key)
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
f.result = result
|
||||||
|
f.statusCode = statusCode
|
||||||
|
close(f.done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the background cleanup goroutine.
|
||||||
|
func (d *Deduplicator) Close() {
|
||||||
|
close(d.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup periodically removes stale in-flight entries.
|
||||||
|
func (d *Deduplicator) cleanup() {
|
||||||
|
ticker := time.NewTicker(d.window)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-d.done:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
d.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
for key, f := range d.flights {
|
||||||
|
if now.Sub(f.createdAt) > d.window*2 {
|
||||||
|
delete(d.flights, key)
|
||||||
|
close(f.done) // unblock any waiting followers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
74
llm-gateway/internal/proxy/dedup_test.go
Normal file
74
llm-gateway/internal/proxy/dedup_test.go
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDedupKey(t *testing.T) {
|
||||||
|
k1 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
|
||||||
|
k2 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
|
||||||
|
k3 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hello"}]}`))
|
||||||
|
|
||||||
|
if k1 != k2 {
|
||||||
|
t.Error("identical requests should produce the same key")
|
||||||
|
}
|
||||||
|
if k1 == k3 {
|
||||||
|
t.Error("different requests should produce different keys")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeduplicator_LeaderFollower(t *testing.T) {
|
||||||
|
d := NewDeduplicator(5 * time.Second)
|
||||||
|
defer d.Close()
|
||||||
|
|
||||||
|
key := DedupKey("gpt-4", []byte(`test`))
|
||||||
|
|
||||||
|
// First call is leader
|
||||||
|
f1, isLeader := d.TryJoin(key)
|
||||||
|
if !isLeader {
|
||||||
|
t.Fatal("first caller should be leader")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second call with same key is follower
|
||||||
|
f2, isLeader := d.TryJoin(key)
|
||||||
|
if isLeader {
|
||||||
|
t.Fatal("second caller should be follower")
|
||||||
|
}
|
||||||
|
if f1 != f2 {
|
||||||
|
t.Fatal("follower should get same inflight entry")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete the request
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-f2.done
|
||||||
|
if string(f2.result) != "response" {
|
||||||
|
t.Error("follower should receive leader's result")
|
||||||
|
}
|
||||||
|
if f2.statusCode != 200 {
|
||||||
|
t.Error("follower should receive leader's status code")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
d.Complete(key, []byte("response"), 200)
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeduplicator_DifferentKeys(t *testing.T) {
|
||||||
|
d := NewDeduplicator(5 * time.Second)
|
||||||
|
defer d.Close()
|
||||||
|
|
||||||
|
_, isLeader1 := d.TryJoin("key1")
|
||||||
|
_, isLeader2 := d.TryJoin("key2")
|
||||||
|
|
||||||
|
if !isLeader1 || !isLeader2 {
|
||||||
|
t.Error("different keys should both be leaders")
|
||||||
|
}
|
||||||
|
|
||||||
|
d.Complete("key1", []byte("r1"), 200)
|
||||||
|
d.Complete("key2", []byte("r2"), 200)
|
||||||
|
}
|
||||||
|
|
@ -53,6 +53,7 @@ type Handler struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
healthTracker *provider.HealthTracker
|
healthTracker *provider.HealthTracker
|
||||||
debugLogger *storage.DebugLogger
|
debugLogger *storage.DebugLogger
|
||||||
|
dedup *Deduplicator
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler {
|
func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler {
|
||||||
|
|
@ -70,6 +71,10 @@ func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) {
|
||||||
h.debugLogger = dl
|
h.debugLogger = dl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) SetDeduplicator(d *Deduplicator) {
|
||||||
|
h.dedup = d
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
|
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -118,11 +123,47 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply per-model timeout for non-streaming requests
|
||||||
|
modelTimeouts := h.registry.ModelTimeoutsFor(req.Model)
|
||||||
|
|
||||||
if req.Stream {
|
if req.Stream {
|
||||||
h.handleStream(w, r, &req, routes, tokenName, requestID)
|
h.handleStream(w, r, &req, routes, tokenName, requestID, modelTimeouts)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Request deduplication for non-streaming requests
|
||||||
|
if h.dedup != nil {
|
||||||
|
dedupKey := DedupKey(req.Model, body)
|
||||||
|
flight, isLeader := h.dedup.TryJoin(dedupKey)
|
||||||
|
if !isLeader {
|
||||||
|
// Wait for the leader to complete
|
||||||
|
select {
|
||||||
|
case <-flight.done:
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Request-ID", requestID)
|
||||||
|
w.Header().Set("X-Dedup", "HIT")
|
||||||
|
w.WriteHeader(flight.statusCode)
|
||||||
|
w.Write(flight.result)
|
||||||
|
return
|
||||||
|
case <-r.Context().Done():
|
||||||
|
writeError(w, http.StatusGatewayTimeout, "request cancelled while waiting for dedup")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Leader: proceed normally, but capture response for followers
|
||||||
|
defer func() {
|
||||||
|
// If we haven't completed yet (e.g., panic), clean up
|
||||||
|
}()
|
||||||
|
h.handleNonStreamDedup(w, r, &req, routes, tokenName, body, requestID, modelTimeouts, dedupKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
h.handleNonStream(w, r, &req, routes, tokenName, body, requestID)
|
h.handleNonStream(w, r, &req, routes, tokenName, body, requestID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -233,6 +274,153 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleNonStreamDedup wraps handleNonStream to capture the response for dedup followers.
|
||||||
|
func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req provider.EmbeddingRequest
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Model == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "model is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
routes, ok := h.registry.Lookup(req.Model)
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusNotFound, "model not found: "+req.Model)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
routes = h.filterHealthyRoutes(routes)
|
||||||
|
tokenName := getTokenName(r.Context())
|
||||||
|
requestID := middleware.GetReqID(r.Context())
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for i, route := range routes {
|
||||||
|
if i > 0 {
|
||||||
|
backoff := backoffDuration(i, h.cfg.Retry)
|
||||||
|
select {
|
||||||
|
case <-time.After(backoff):
|
||||||
|
case <-r.Context().Done():
|
||||||
|
writeError(w, http.StatusGatewayTimeout, "request cancelled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := route.Provider.Embedding(r.Context(), route.ProviderModel, &req)
|
||||||
|
latency := time.Since(start).Milliseconds()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
var pe *provider.ProviderError
|
||||||
|
if errors.As(err, &pe) && !pe.IsRetryable() {
|
||||||
|
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
||||||
|
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
|
||||||
|
if h.healthTracker != nil {
|
||||||
|
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
||||||
|
}
|
||||||
|
w.Header().Set("X-Request-ID", requestID)
|
||||||
|
writeErrorRaw(w, pe.StatusCode, pe.Body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
log.Printf("Provider %s embedding failed for %s: %v", route.Provider.Name(), req.Model, err)
|
||||||
|
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
|
||||||
|
if h.healthTracker != nil {
|
||||||
|
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.healthTracker != nil {
|
||||||
|
h.healthTracker.Record(route.Provider.Name(), latency, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
promptTokens := 0
|
||||||
|
if resp.Usage != nil {
|
||||||
|
promptTokens = resp.Usage.PromptTokens
|
||||||
|
}
|
||||||
|
cost := float64(promptTokens) / 1_000_000.0 * route.InputPrice
|
||||||
|
|
||||||
|
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, promptTokens, 0, cost)
|
||||||
|
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, promptTokens, cost, latency, "success", "")
|
||||||
|
|
||||||
|
resp.Model = req.Model
|
||||||
|
|
||||||
|
respBytes, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, "failed to marshal response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Request-ID", requestID)
|
||||||
|
w.Write(respBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("X-Request-ID", requestID)
|
||||||
|
if lastErr != nil {
|
||||||
|
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusBadGateway, "all providers failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) logEmbeddingRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens int, cost float64, latencyMS int64, status, errMsg string) {
|
||||||
|
h.logger.Log(storage.RequestLog{
|
||||||
|
RequestID: requestID,
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
TokenName: tokenName,
|
||||||
|
Model: model,
|
||||||
|
Provider: providerName,
|
||||||
|
ProviderModel: providerModel,
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
CostUSD: cost,
|
||||||
|
LatencyMS: latencyMS,
|
||||||
|
Status: status,
|
||||||
|
ErrorMessage: errMsg,
|
||||||
|
RequestType: "embedding",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleNonStreamDedup(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string, modelTimeouts *provider.ModelTimeouts, dedupKey string) {
|
||||||
|
if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
|
||||||
|
h.handleNonStream(rec, r, req, routes, tokenName, rawBody, requestID)
|
||||||
|
h.dedup.Complete(dedupKey, rec.body, rec.statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseRecorder captures the response for dedup.
|
||||||
|
type responseRecorder struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
statusCode int
|
||||||
|
body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseRecorder) WriteHeader(code int) {
|
||||||
|
r.statusCode = code
|
||||||
|
r.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *responseRecorder) Write(b []byte) (int, error) {
|
||||||
|
r.body = append(r.body, b...)
|
||||||
|
return r.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
// filterHealthyRoutes removes providers with open circuit breakers.
|
// filterHealthyRoutes removes providers with open circuit breakers.
|
||||||
// If all are filtered out, returns original routes as fallback.
|
// If all are filtered out, returns original routes as fallback.
|
||||||
func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route {
|
func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route {
|
||||||
|
|
|
||||||
|
|
@ -5,27 +5,66 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"llm-gateway/internal/config"
|
||||||
"llm-gateway/internal/provider"
|
"llm-gateway/internal/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelsHandler struct {
|
type ModelsHandler struct {
|
||||||
registry *provider.Registry
|
registry *provider.Registry
|
||||||
|
healthTracker *provider.HealthTracker
|
||||||
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewModelsHandler(registry *provider.Registry) *ModelsHandler {
|
func NewModelsHandler(registry *provider.Registry, healthTracker *provider.HealthTracker, cfg *config.Config) *ModelsHandler {
|
||||||
return &ModelsHandler{registry: registry}
|
return &ModelsHandler{
|
||||||
|
registry: registry,
|
||||||
|
healthTracker: healthTracker,
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) {
|
func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) {
|
||||||
names := h.registry.ModelNames()
|
allRoutes := h.registry.AllRoutes()
|
||||||
models := make([]map[string]any, len(names))
|
models := make([]map[string]any, 0, len(allRoutes))
|
||||||
for i, name := range names {
|
|
||||||
models[i] = map[string]any{
|
for _, m := range allRoutes {
|
||||||
"id": name,
|
providers := make([]map[string]any, 0, len(m.Routes))
|
||||||
"object": "model",
|
for _, rt := range m.Routes {
|
||||||
"created": time.Now().Unix(),
|
healthy := true
|
||||||
"owned_by": "llm-gateway",
|
if h.healthTracker != nil {
|
||||||
|
healthy = h.healthTracker.IsAvailable(rt.ProviderName)
|
||||||
|
}
|
||||||
|
providers = append(providers, map[string]any{
|
||||||
|
"name": rt.ProviderName,
|
||||||
|
"model": rt.ProviderModel,
|
||||||
|
"input_price": rt.InputPrice,
|
||||||
|
"output_price": rt.OutputPrice,
|
||||||
|
"priority": rt.Priority,
|
||||||
|
"healthy": healthy,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Find load balancing strategy from config
|
||||||
|
loadBalancing := "first"
|
||||||
|
for _, mc := range h.cfg.Models {
|
||||||
|
if mc.Name == m.Name {
|
||||||
|
if mc.LoadBalancing != "" {
|
||||||
|
loadBalancing = mc.LoadBalancing
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
models = append(models, map[string]any{
|
||||||
|
"id": m.Name,
|
||||||
|
"object": "model",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"owned_by": "llm-gateway",
|
||||||
|
"providers": providers,
|
||||||
|
"provider_count": len(providers),
|
||||||
|
"load_balancing": loadBalancing,
|
||||||
|
"aliases": m.Aliases,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
|
||||||
|
|
@ -8,12 +8,15 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"llm-gateway/internal/storage"
|
"llm-gateway/internal/storage"
|
||||||
|
"llm-gateway/internal/webhook"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RateLimiter struct {
|
type RateLimiter struct {
|
||||||
db *storage.DB
|
db *storage.DB
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
buckets map[string]*tokenBucket
|
buckets map[string]*tokenBucket
|
||||||
|
notifier *webhook.Notifier
|
||||||
|
budgetNotified sync.Map // tracks which token+budget combos have been notified
|
||||||
}
|
}
|
||||||
|
|
||||||
type tokenBucket struct {
|
type tokenBucket struct {
|
||||||
|
|
@ -30,6 +33,11 @@ func NewRateLimiter(db *storage.DB) *RateLimiter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNotifier sets the webhook notifier for budget threshold alerts.
|
||||||
|
func (rl *RateLimiter) SetNotifier(n *webhook.Notifier) {
|
||||||
|
rl.notifier = n
|
||||||
|
}
|
||||||
|
|
||||||
func (rl *RateLimiter) Check(next http.Handler) http.Handler {
|
func (rl *RateLimiter) Check(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
apiToken := getAPIToken(r.Context())
|
apiToken := getAPIToken(r.Context())
|
||||||
|
|
@ -63,9 +71,24 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
|
||||||
// Check daily budget
|
// Check daily budget
|
||||||
if apiToken.DailyBudgetUSD > 0 {
|
if apiToken.DailyBudgetUSD > 0 {
|
||||||
spent, err := rl.db.TodaySpend(tokenName)
|
spent, err := rl.db.TodaySpend(tokenName)
|
||||||
if err == nil && spent >= apiToken.DailyBudgetUSD {
|
if err == nil {
|
||||||
writeError(w, http.StatusTooManyRequests, "daily budget exceeded")
|
if spent >= apiToken.DailyBudgetUSD {
|
||||||
return
|
writeError(w, http.StatusTooManyRequests, "daily budget exceeded")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rl.checkBudgetThreshold(tokenName, "daily", spent, apiToken.DailyBudgetUSD)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check monthly budget
|
||||||
|
if apiToken.MonthlyBudgetUSD > 0 {
|
||||||
|
spent, err := rl.db.MonthSpend(tokenName)
|
||||||
|
if err == nil {
|
||||||
|
if spent >= apiToken.MonthlyBudgetUSD {
|
||||||
|
writeError(w, http.StatusTooManyRequests, "monthly budget exceeded")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rl.checkBudgetThreshold(tokenName, "monthly", spent, apiToken.MonthlyBudgetUSD)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -73,6 +96,30 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkBudgetThreshold fires a webhook notification when spend reaches 80% of budget.
|
||||||
|
func (rl *RateLimiter) checkBudgetThreshold(tokenName, budgetType string, spent, budget float64) {
|
||||||
|
if rl.notifier == nil || budget <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if spent/budget < 0.8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := tokenName + ":" + budgetType
|
||||||
|
if _, loaded := rl.budgetNotified.LoadOrStore(key, true); loaded {
|
||||||
|
return // already notified
|
||||||
|
}
|
||||||
|
rl.notifier.Notify(webhook.Event{
|
||||||
|
Type: webhook.EventBudgetThreshold,
|
||||||
|
Data: map[string]any{
|
||||||
|
"token": tokenName,
|
||||||
|
"budget_type": budgetType,
|
||||||
|
"spent": spent,
|
||||||
|
"budget": budget,
|
||||||
|
"percent": spent / budget * 100,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) {
|
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) {
|
||||||
rl.mu.Lock()
|
rl.mu.Lock()
|
||||||
defer rl.mu.Unlock()
|
defer rl.mu.Unlock()
|
||||||
|
|
|
||||||
|
|
@ -45,11 +45,11 @@ var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
func TestRateLimiter_Allow(t *testing.T) {
|
func TestRateLimiter_Allow(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
rateLimitRPM int
|
rateLimitRPM int
|
||||||
numRequests int
|
numRequests int
|
||||||
wantAllowed int
|
wantAllowed int
|
||||||
wantDenied int
|
wantDenied int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "allows requests within limit",
|
name: "allows requests within limit",
|
||||||
|
|
@ -191,12 +191,12 @@ func TestRateLimiter_AllowReturnValues(t *testing.T) {
|
||||||
|
|
||||||
func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) {
|
func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
rateLimitRPM int
|
rateLimitRPM int
|
||||||
numRequests int
|
numRequests int
|
||||||
wantStatusCode int
|
wantStatusCode int
|
||||||
wantLimitHeader string
|
wantLimitHeader string
|
||||||
wantRetryAfter bool
|
wantRetryAfter bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "sets rate limit headers on allowed request",
|
name: "sets rate limit headers on allowed request",
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import (
|
||||||
"llm-gateway/internal/provider"
|
"llm-gateway/internal/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) {
|
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string, modelTimeouts *provider.ModelTimeouts) {
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
writeError(w, http.StatusInternalServerError, "streaming not supported")
|
writeError(w, http.StatusInternalServerError, "streaming not supported")
|
||||||
|
|
@ -60,11 +60,15 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply streaming timeout
|
// Apply streaming timeout (per-model override takes precedence)
|
||||||
|
streamingTimeout := h.cfg.Server.StreamingTimeout
|
||||||
|
if modelTimeouts != nil && modelTimeouts.StreamingTimeout > 0 {
|
||||||
|
streamingTimeout = modelTimeouts.StreamingTimeout
|
||||||
|
}
|
||||||
var streamCtx context.Context
|
var streamCtx context.Context
|
||||||
var streamCancel context.CancelFunc
|
var streamCancel context.CancelFunc
|
||||||
if h.cfg.Server.StreamingTimeout > 0 {
|
if streamingTimeout > 0 {
|
||||||
streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout)
|
streamCtx, streamCancel = context.WithTimeout(r.Context(), streamingTimeout)
|
||||||
} else {
|
} else {
|
||||||
streamCtx, streamCancel = context.WithCancel(r.Context())
|
streamCtx, streamCancel = context.WithCancel(r.Context())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,21 @@ func (db *DB) TodaySpend(tokenName string) (float64, error) {
|
||||||
return total.Float64, nil
|
return total.Float64, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MonthSpend returns the total cost in USD for a given token this month.
|
||||||
|
func (db *DB) MonthSpend(tokenName string) (float64, error) {
|
||||||
|
now := time.Now()
|
||||||
|
startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()).Unix()
|
||||||
|
var total sql.NullFloat64
|
||||||
|
err := db.QueryRow(
|
||||||
|
"SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?",
|
||||||
|
tokenName, startOfMonth,
|
||||||
|
).Scan(&total)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return total.Float64, nil
|
||||||
|
}
|
||||||
|
|
||||||
// TodaySpendAll returns today's spend for all tokens as a map.
|
// TodaySpendAll returns today's spend for all tokens as a map.
|
||||||
func (db *DB) TodaySpendAll() (map[string]float64, error) {
|
func (db *DB) TodaySpendAll() (map[string]float64, error) {
|
||||||
startOfDay := time.Now().Truncate(24 * time.Hour).Unix()
|
startOfDay := time.Now().Truncate(24 * time.Hour).Unix()
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ type RequestLog struct {
|
||||||
ErrorMessage string
|
ErrorMessage string
|
||||||
Streaming bool
|
Streaming bool
|
||||||
Cached bool
|
Cached bool
|
||||||
|
RequestType string // "chat" or "embedding"
|
||||||
}
|
}
|
||||||
|
|
||||||
type AsyncLogger struct {
|
type AsyncLogger struct {
|
||||||
|
|
@ -94,8 +95,8 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt, err := tx.Prepare(`INSERT INTO request_logs
|
stmt, err := tx.Prepare(`INSERT INTO request_logs
|
||||||
(request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached)
|
(request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached, request_type)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("ERROR: preparing log statement: %v", err)
|
log.Printf("ERROR: preparing log statement: %v", err)
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
|
|
@ -112,10 +113,14 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
|
||||||
if r.Cached {
|
if r.Cached {
|
||||||
cached = 1
|
cached = 1
|
||||||
}
|
}
|
||||||
|
reqType := r.RequestType
|
||||||
|
if reqType == "" {
|
||||||
|
reqType = "chat"
|
||||||
|
}
|
||||||
_, err := stmt.Exec(
|
_, err := stmt.Exec(
|
||||||
r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
|
r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
|
||||||
r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS,
|
r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS,
|
||||||
r.Status, r.ErrorMessage, streaming, cached,
|
r.Status, r.ErrorMessage, streaming, cached, reqType,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("ERROR: inserting log: %v", err)
|
log.Printf("ERROR: inserting log: %v", err)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
ALTER TABLE api_tokens DROP COLUMN monthly_budget_usd;
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
ALTER TABLE api_tokens ADD COLUMN monthly_budget_usd REAL NOT NULL DEFAULT 0;
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
ALTER TABLE request_logs DROP COLUMN request_type;
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
ALTER TABLE request_logs ADD COLUMN request_type TEXT NOT NULL DEFAULT 'chat';
|
||||||
123
llm-gateway/internal/webhook/webhook.go
Normal file
123
llm-gateway/internal/webhook/webhook.go
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
package webhook
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"llm-gateway/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event types.
|
||||||
|
const (
|
||||||
|
EventCircuitBreakerOpen = "circuit_breaker.open"
|
||||||
|
EventCircuitBreakerClosed = "circuit_breaker.closed"
|
||||||
|
EventBudgetThreshold = "budget.threshold"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event represents a webhook notification payload.
|
||||||
|
type Event struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Data map[string]any `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notifier sends webhook notifications.
|
||||||
|
type Notifier struct {
|
||||||
|
webhooks []config.WebhookConfig
|
||||||
|
ch chan Event
|
||||||
|
done chan struct{}
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNotifier creates a webhook notifier from config.
|
||||||
|
func NewNotifier(webhooks []config.WebhookConfig) *Notifier {
|
||||||
|
n := &Notifier{
|
||||||
|
webhooks: webhooks,
|
||||||
|
ch: make(chan Event, 100),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
client: &http.Client{Timeout: 10 * time.Second},
|
||||||
|
}
|
||||||
|
go n.run()
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify queues an event for delivery (non-blocking).
|
||||||
|
func (n *Notifier) Notify(evt Event) {
|
||||||
|
if evt.Timestamp.IsZero() {
|
||||||
|
evt.Timestamp = time.Now()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case n.ch <- evt:
|
||||||
|
default:
|
||||||
|
log.Printf("WARNING: webhook channel full, dropping event %s", evt.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close drains pending events and shuts down.
|
||||||
|
func (n *Notifier) Close() {
|
||||||
|
close(n.ch)
|
||||||
|
<-n.done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) run() {
|
||||||
|
defer close(n.done)
|
||||||
|
for evt := range n.ch {
|
||||||
|
for _, wh := range n.webhooks {
|
||||||
|
if !n.shouldSend(wh, evt.Type) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
n.send(wh, evt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) shouldSend(wh config.WebhookConfig, eventType string) bool {
|
||||||
|
if len(wh.Events) == 0 {
|
||||||
|
return true // no filter = send all
|
||||||
|
}
|
||||||
|
for _, e := range wh.Events {
|
||||||
|
if e == eventType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) send(wh config.WebhookConfig, evt Event) {
|
||||||
|
body, err := json.Marshal(evt)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("ERROR: webhook marshal: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodPost, wh.URL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("ERROR: webhook request: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if wh.Secret != "" {
|
||||||
|
mac := hmac.New(sha256.New, []byte(wh.Secret))
|
||||||
|
mac.Write(body)
|
||||||
|
sig := hex.EncodeToString(mac.Sum(nil))
|
||||||
|
req.Header.Set("X-Webhook-Signature", "sha256="+sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := n.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("WARNING: webhook delivery to %s failed: %v", wh.URL, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
log.Printf("WARNING: webhook %s returned %d", wh.URL, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue