feat(gateway): add monthly budget support and webhook notifications for circuit breaker and budget events
This commit is contained in:
parent
28a694744d
commit
291b8f4863
31 changed files with 1005 additions and 124 deletions
|
|
@ -25,6 +25,7 @@ import (
|
|||
"llm-gateway/internal/provider"
|
||||
"llm-gateway/internal/proxy"
|
||||
"llm-gateway/internal/storage"
|
||||
"llm-gateway/internal/webhook"
|
||||
)
|
||||
|
||||
var version = "dev"
|
||||
|
|
@ -95,6 +96,30 @@ func main() {
|
|||
// Provider health tracker
|
||||
healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker)
|
||||
|
||||
// Webhook notifier
|
||||
var notifier *webhook.Notifier
|
||||
if len(cfg.Webhooks) > 0 {
|
||||
notifier = webhook.NewNotifier(cfg.Webhooks)
|
||||
defer notifier.Close()
|
||||
log.Printf("Webhooks configured: %d endpoints", len(cfg.Webhooks))
|
||||
|
||||
// Wire health tracker state changes to webhook
|
||||
healthTracker.OnStateChange = func(providerName string, from, to provider.CircuitState) {
|
||||
eventType := webhook.EventCircuitBreakerOpen
|
||||
if to == provider.CircuitClosed {
|
||||
eventType = webhook.EventCircuitBreakerClosed
|
||||
}
|
||||
notifier.Notify(webhook.Event{
|
||||
Type: eventType,
|
||||
Data: map[string]any{
|
||||
"provider": providerName,
|
||||
"from": from.String(),
|
||||
"to": to.String(),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Auth store (static tokens checked in-memory, not seeded to DB)
|
||||
var staticTokens []auth.StaticToken
|
||||
for _, t := range cfg.Tokens {
|
||||
|
|
@ -104,6 +129,7 @@ func main() {
|
|||
Key: t.Key,
|
||||
RateLimitRPM: t.RateLimitRPM,
|
||||
DailyBudgetUSD: t.DailyBudgetUSD,
|
||||
MonthlyBudgetUSD: t.MonthlyBudgetUSD,
|
||||
MaxConcurrent: t.MaxConcurrent,
|
||||
})
|
||||
log.Printf("Loaded static token: %s", t.Name)
|
||||
|
|
@ -133,14 +159,27 @@ func main() {
|
|||
// Handlers
|
||||
proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker)
|
||||
proxyHandler.SetDebugLogger(debugLogger)
|
||||
modelsHandler := proxy.NewModelsHandler(registry)
|
||||
|
||||
// Request deduplication
|
||||
if cfg.Dedup.Enabled {
|
||||
dedup := proxy.NewDeduplicator(cfg.Dedup.Window)
|
||||
defer dedup.Close()
|
||||
proxyHandler.SetDeduplicator(dedup)
|
||||
log.Printf("Request deduplication enabled (window: %v)", cfg.Dedup.Window)
|
||||
}
|
||||
|
||||
modelsHandler := proxy.NewModelsHandler(registry, healthTracker, cfg)
|
||||
proxyAuth := proxy.NewAuthMiddleware(authStore)
|
||||
rateLimiter := proxy.NewRateLimiter(db)
|
||||
if notifier != nil {
|
||||
rateLimiter.SetNotifier(notifier)
|
||||
}
|
||||
concurrencyLimiter := proxy.NewConcurrencyLimiter()
|
||||
statsAPI := dashboard.NewStatsAPI(db, authStore)
|
||||
statsAPI.SetHealthTracker(healthTracker)
|
||||
statsAPI.SetAuditLogger(auditLogger)
|
||||
statsAPI.SetDebugLogger(debugLogger)
|
||||
statsAPI.SetConfigPath(*configPath)
|
||||
if c != nil {
|
||||
statsAPI.SetCache(c)
|
||||
}
|
||||
|
|
@ -196,6 +235,7 @@ func main() {
|
|||
r.Use(rateLimiter.Check)
|
||||
r.Use(concurrencyLimiter.Check)
|
||||
r.Post("/v1/chat/completions", proxyHandler.ChatCompletions)
|
||||
r.Post("/v1/embeddings", proxyHandler.Embeddings)
|
||||
r.Get("/v1/models", modelsHandler.ListModels)
|
||||
})
|
||||
|
||||
|
|
@ -266,7 +306,7 @@ func main() {
|
|||
r.Get("/api/export/logs", exportHandler.ExportLogs)
|
||||
r.Get("/api/export/stats", exportHandler.ExportStats)
|
||||
|
||||
// Admin-only: user management, audit, debug
|
||||
// Admin-only: user management, audit, debug, config validation
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(authMiddleware.RequireAdmin)
|
||||
r.Get("/api/auth/users", authHandlers.ListUsers)
|
||||
|
|
@ -276,6 +316,9 @@ func main() {
|
|||
// Audit log
|
||||
r.Get("/api/stats/audit", statsAPI.AuditLogs)
|
||||
|
||||
// Config validation
|
||||
r.Get("/api/config/validate", statsAPI.ValidateConfig)
|
||||
|
||||
// Debug logging
|
||||
r.Post("/api/debug/toggle", statsAPI.DebugToggle)
|
||||
r.Get("/api/debug/status", statsAPI.DebugStatus)
|
||||
|
|
@ -336,6 +379,7 @@ func main() {
|
|||
Key: t.Key,
|
||||
RateLimitRPM: t.RateLimitRPM,
|
||||
DailyBudgetUSD: t.DailyBudgetUSD,
|
||||
MonthlyBudgetUSD: t.MonthlyBudgetUSD,
|
||||
MaxConcurrent: t.MaxConcurrent,
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ go 1.24.0
|
|||
|
||||
require (
|
||||
github.com/go-chi/chi/v5 v5.2.5
|
||||
github.com/go-chi/cors v1.2.2
|
||||
github.com/golang-migrate/migrate/v4 v4.19.1
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
|
|
@ -19,7 +20,6 @@ require (
|
|||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/go-chi/cors v1.2.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ type APIToken struct {
|
|||
UserID int64 `json:"user_id"`
|
||||
RateLimitRPM int `json:"rate_limit_rpm"`
|
||||
DailyBudgetUSD float64 `json:"daily_budget_usd"`
|
||||
MonthlyBudgetUSD float64 `json:"monthly_budget_usd"`
|
||||
MaxConcurrent int `json:"max_concurrent"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
LastUsedAt int64 `json:"last_used_at"`
|
||||
|
|
@ -49,6 +50,7 @@ type StaticToken struct {
|
|||
Key string
|
||||
RateLimitRPM int
|
||||
DailyBudgetUSD float64
|
||||
MonthlyBudgetUSD float64
|
||||
MaxConcurrent int
|
||||
}
|
||||
|
||||
|
|
@ -294,6 +296,7 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
|
|||
KeyPrefix: prefix,
|
||||
RateLimitRPM: st.RateLimitRPM,
|
||||
DailyBudgetUSD: st.DailyBudgetUSD,
|
||||
MonthlyBudgetUSD: st.MonthlyBudgetUSD,
|
||||
MaxConcurrent: st.MaxConcurrent,
|
||||
}, nil
|
||||
}
|
||||
|
|
@ -305,9 +308,9 @@ func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
|
|||
|
||||
var t APIToken
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
|
||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
|
||||
keyHash,
|
||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -328,6 +331,7 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
|
|||
KeyPrefix: prefix,
|
||||
RateLimitRPM: st.RateLimitRPM,
|
||||
DailyBudgetUSD: st.DailyBudgetUSD,
|
||||
MonthlyBudgetUSD: st.MonthlyBudgetUSD,
|
||||
MaxConcurrent: st.MaxConcurrent,
|
||||
})
|
||||
}
|
||||
|
|
@ -336,9 +340,9 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
|
|||
var rows *sql.Rows
|
||||
var err error
|
||||
if userID == 0 {
|
||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id")
|
||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id")
|
||||
} else {
|
||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
|
||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
|
||||
}
|
||||
if err != nil {
|
||||
return tokens, nil
|
||||
|
|
@ -347,7 +351,7 @@ func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
|
|||
|
||||
for rows.Next() {
|
||||
var t APIToken
|
||||
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil {
|
||||
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil {
|
||||
return tokens, nil
|
||||
}
|
||||
tokens = append(tokens, t)
|
||||
|
|
@ -363,9 +367,9 @@ func (s *Store) DeleteAPIToken(id int64) error {
|
|||
func (s *Store) GetAPIToken(id int64) (*APIToken, error) {
|
||||
var t APIToken
|
||||
err := s.db.QueryRow(
|
||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?",
|
||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?",
|
||||
id,
|
||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ func setupTestDB(t *testing.T) *sql.DB {
|
|||
user_id INTEGER NOT NULL,
|
||||
rate_limit_rpm INTEGER DEFAULT 0,
|
||||
daily_budget_usd REAL DEFAULT 0,
|
||||
monthly_budget_usd REAL DEFAULT 0,
|
||||
max_concurrent INTEGER DEFAULT 0,
|
||||
created_at INTEGER NOT NULL,
|
||||
last_used_at INTEGER DEFAULT 0
|
||||
|
|
|
|||
|
|
@ -20,11 +20,24 @@ type Config struct {
|
|||
Retry RetryConfig `yaml:"retry"`
|
||||
Debug DebugConfig `yaml:"debug"`
|
||||
CORS CORSConfig `yaml:"cors"`
|
||||
Dedup DedupConfig `yaml:"dedup"`
|
||||
Webhooks []WebhookConfig `yaml:"webhooks"`
|
||||
Providers []ProviderConfig `yaml:"providers"`
|
||||
Models []ModelConfig `yaml:"models"`
|
||||
Tokens []TokenConfig `yaml:"tokens"`
|
||||
}
|
||||
|
||||
type DedupConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
Window time.Duration `yaml:"window"` // max time to wait for dedup result
|
||||
}
|
||||
|
||||
type WebhookConfig struct {
|
||||
URL string `yaml:"url"`
|
||||
Events []string `yaml:"events"` // event types to send
|
||||
Secret string `yaml:"secret"` // optional HMAC secret
|
||||
}
|
||||
|
||||
type PricingLookupConfig struct {
|
||||
URL string `yaml:"url"`
|
||||
RefreshInterval time.Duration `yaml:"refresh_interval"`
|
||||
|
|
@ -40,6 +53,7 @@ type TokenConfig struct {
|
|||
Key string `yaml:"key"`
|
||||
RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited
|
||||
DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited
|
||||
MonthlyBudgetUSD float64 `yaml:"monthly_budget_usd"` // 0 = unlimited
|
||||
MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited
|
||||
}
|
||||
|
||||
|
|
@ -104,6 +118,8 @@ type ModelConfig struct {
|
|||
Aliases []string `yaml:"aliases"`
|
||||
Routes []RouteConfig `yaml:"routes"`
|
||||
LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost
|
||||
RequestTimeout time.Duration `yaml:"request_timeout"` // per-model override; 0 = use server default
|
||||
StreamingTimeout time.Duration `yaml:"streaming_timeout"` // per-model override; 0 = use server default
|
||||
}
|
||||
|
||||
type RouteConfig struct {
|
||||
|
|
@ -131,14 +147,15 @@ func Load(path string) (*Config, error) {
|
|||
return nil, fmt.Errorf("parsing config: %w", err)
|
||||
}
|
||||
|
||||
if err := cfg.validate(); err != nil {
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("validating config: %w", err)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) validate() error {
|
||||
// Validate checks the config for correctness and applies defaults.
|
||||
func (c *Config) Validate() error {
|
||||
if c.Server.Listen == "" {
|
||||
c.Server.Listen = "0.0.0.0:3000"
|
||||
}
|
||||
|
|
@ -201,6 +218,11 @@ func (c *Config) validate() error {
|
|||
c.CORS.MaxAge = 300
|
||||
}
|
||||
|
||||
// Dedup defaults
|
||||
if c.Dedup.Window == 0 {
|
||||
c.Dedup.Window = 30 * time.Second
|
||||
}
|
||||
|
||||
if len(c.Providers) == 0 {
|
||||
return fmt.Errorf("at least one provider is required")
|
||||
}
|
||||
|
|
@ -266,6 +288,19 @@ func (c *Config) validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ValidateBytes parses raw YAML and returns a list of validation errors.
|
||||
func ValidateBytes(data []byte) []string {
|
||||
expanded := os.ExpandEnv(string(data))
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil {
|
||||
return []string{"parse error: " + err.Error()}
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return []string{err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProviderByName returns the provider config by name.
|
||||
func (c *Config) ProviderByName(name string) *ProviderConfig {
|
||||
for i := range c.Providers {
|
||||
|
|
|
|||
|
|
@ -735,4 +735,3 @@ models:
|
|||
t.Errorf("error = %q, want to contain api_key validation message", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package dashboard
|
|||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
|
@ -11,6 +12,7 @@ import (
|
|||
|
||||
"llm-gateway/internal/auth"
|
||||
"llm-gateway/internal/cache"
|
||||
"llm-gateway/internal/config"
|
||||
"llm-gateway/internal/provider"
|
||||
"llm-gateway/internal/storage"
|
||||
)
|
||||
|
|
@ -109,6 +111,7 @@ type StatsAPI struct {
|
|||
cache *cache.Cache
|
||||
auditLogger *storage.AuditLogger
|
||||
debugLogger *storage.DebugLogger
|
||||
configPath string
|
||||
}
|
||||
|
||||
func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI {
|
||||
|
|
@ -135,6 +138,11 @@ func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) {
|
|||
s.debugLogger = dl
|
||||
}
|
||||
|
||||
// SetConfigPath sets the config file path for validation.
|
||||
func (s *StatsAPI) SetConfigPath(path string) {
|
||||
s.configPath = path
|
||||
}
|
||||
|
||||
// TokenNamesForUser returns the token names that belong to the user.
|
||||
// Admins get nil (no filter), non-admins get their token names.
|
||||
func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string {
|
||||
|
|
@ -712,6 +720,27 @@ func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) {
|
|||
writeJSON(w, entry)
|
||||
}
|
||||
|
||||
// ValidateConfig validates the config file at the stored path.
|
||||
func (s *StatsAPI) ValidateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
if s.configPath == "" {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
writeJSON(w, map[string]any{"valid": false, "errors": []string{"config path not set"}})
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(s.configPath)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
writeJSON(w, map[string]any{"valid": false, "errors": []string{"failed to read config: " + err.Error()}})
|
||||
return
|
||||
}
|
||||
errs := config.ValidateBytes(data)
|
||||
if len(errs) > 0 {
|
||||
writeJSON(w, map[string]any{"valid": false, "errors": errs})
|
||||
return
|
||||
}
|
||||
writeJSON(w, map[string]any{"valid": true, "errors": []string{}})
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(v)
|
||||
|
|
|
|||
|
|
@ -275,6 +275,10 @@ func (d *Dashboard) ModelsPage(w http.ResponseWriter, r *http.Request) {
|
|||
data.ModelRoutes = d.registry.AllRoutes()
|
||||
}
|
||||
|
||||
if d.statsAPI.healthTracker != nil {
|
||||
data.ProviderHealth = d.statsAPI.healthTracker.Status()
|
||||
}
|
||||
|
||||
d.renderDashboardPage(w, r, "partials/models-page.html", data)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
{{if .ModelRoutes}}
|
||||
{{range .ModelRoutes}}
|
||||
<div class="section">
|
||||
<h2>{{.Name}}</h2>
|
||||
<h2>{{.Name}}{{if .Aliases}} <span style="font-size:0.75rem;color:var(--text-muted);font-weight:400;">aliases: {{range $i, $a := .Aliases}}{{if $i}}, {{end}}{{$a}}{{end}}</span>{{end}}</h2>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
|
|
@ -15,9 +15,11 @@
|
|||
<th>Priority</th>
|
||||
<th>Input Price (per 1M)</th>
|
||||
<th>Output Price (per 1M)</th>
|
||||
<th>Health</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{$health := $.ProviderHealth}}
|
||||
{{range .Routes}}
|
||||
<tr>
|
||||
<td>{{.ProviderName}}</td>
|
||||
|
|
@ -25,6 +27,18 @@
|
|||
<td><span class="badge badge-priority">{{.Priority}}</span></td>
|
||||
<td>{{formatPrice .InputPrice}}</td>
|
||||
<td>{{formatPrice .OutputPrice}}</td>
|
||||
<td>
|
||||
{{$pname := .ProviderName}}
|
||||
{{range $health}}
|
||||
{{if eq .Provider $pname}}
|
||||
{{if eq .Status "healthy"}}<span class="badge" style="background:#166534;color:#4ade80;">healthy</span>
|
||||
{{else if eq .Status "degraded"}}<span class="badge" style="background:#92400e;color:#fbbf24;">degraded</span>
|
||||
{{else}}<span class="badge" style="background:#991b1b;color:#f87171;">down</span>
|
||||
{{end}}
|
||||
{{end}}
|
||||
{{end}}
|
||||
{{if not $health}}<span style="color:var(--text-muted);">-</span>{{end}}
|
||||
</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
|
|
|
|||
|
|
@ -70,6 +70,15 @@
|
|||
</div>
|
||||
</div>
|
||||
|
||||
{{if .User.IsAdmin}}
|
||||
<div class="section">
|
||||
<h2>Config Validation</h2>
|
||||
<p style="color:#94a3b8;font-size:0.85rem;margin-bottom:12px;">Validate the current gateway configuration file for errors.</p>
|
||||
<button class="btn btn-sm btn-primary" onclick="validateConfig()">Validate Config</button>
|
||||
<div id="config-validation-result" style="margin-top:12px;"></div>
|
||||
</div>
|
||||
{{end}}
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/qrious@4.0.2/dist/qrious.min.js"></script>
|
||||
<script>
|
||||
function showMsg(id, msg, isError) {
|
||||
|
|
@ -167,5 +176,20 @@ async function disableTOTP() {
|
|||
htmx.ajax('GET', '/settings', {target: '#content', swap: 'innerHTML'});
|
||||
} catch (e) { alert(e.message); }
|
||||
}
|
||||
|
||||
async function validateConfig() {
|
||||
var el = document.getElementById('config-validation-result');
|
||||
el.innerHTML = '<span style="color:#94a3b8;">Validating...</span>';
|
||||
try {
|
||||
var resp = await fetch('/api/config/validate', { credentials: 'same-origin' });
|
||||
var data = await resp.json();
|
||||
if (data.valid) {
|
||||
el.innerHTML = '<div class="success-msg">Configuration is valid.</div>';
|
||||
} else {
|
||||
var errs = (data.errors||[]).map(function(e) { return '<li>' + e + '</li>'; }).join('');
|
||||
el.innerHTML = '<div class="error-msg">Configuration errors:<ul style="margin:4px 0 0 16px;">' + errs + '</ul></div>';
|
||||
}
|
||||
} catch (e) { el.innerHTML = '<div class="error-msg">' + e.message + '</div>'; }
|
||||
}
|
||||
</script>
|
||||
{{end}}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,10 @@
|
|||
<td>{{.Name}}</td>
|
||||
<td><code>{{.KeyPrefix}}...</code></td>
|
||||
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
|
||||
<td>{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}{{else}}unlimited{{end}}</td>
|
||||
<td>
|
||||
{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
|
||||
{{if gt .MonthlyBudgetUSD 0.0}}<br>${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
|
||||
</td>
|
||||
<td>
|
||||
{{$spend := index $.TokenSpend .Name}}
|
||||
{{if gt .DailyBudgetUSD 0.0}}
|
||||
|
|
@ -49,7 +52,10 @@
|
|||
<td>{{.Name}}</td>
|
||||
<td><code>{{.KeyPrefix}}...</code></td>
|
||||
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
|
||||
<td>{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}{{else}}unlimited{{end}}</td>
|
||||
<td>
|
||||
{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
|
||||
{{if gt .MonthlyBudgetUSD 0.0}}<br>${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
|
||||
</td>
|
||||
<td>
|
||||
{{$spend := index $.TokenSpend .Name}}
|
||||
{{if gt .DailyBudgetUSD 0.0}}
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ type HealthTracker struct {
|
|||
windowDu time.Duration
|
||||
circuits map[string]*ProviderCircuit
|
||||
cbConfig config.CircuitBreakerConfig
|
||||
OnStateChange func(provider string, from, to CircuitState)
|
||||
}
|
||||
|
||||
// NewHealthTracker creates a health tracker with the given window duration.
|
||||
|
|
@ -135,6 +136,8 @@ func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
|
|||
h.circuits[providerName] = circuit
|
||||
}
|
||||
|
||||
prevState := circuit.State
|
||||
|
||||
switch circuit.State {
|
||||
case CircuitClosed:
|
||||
// Check if error threshold exceeded
|
||||
|
|
@ -164,6 +167,13 @@ func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
|
|||
circuit.OpenedAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
if circuit.State != prevState && h.OnStateChange != nil {
|
||||
cb := h.OnStateChange
|
||||
from, to := prevState, circuit.State
|
||||
// Call outside lock to avoid deadlocks
|
||||
go cb(providerName, from, to)
|
||||
}
|
||||
}
|
||||
|
||||
// errorRateUnlocked computes error rate within window. Must be called with lock held.
|
||||
|
|
|
|||
|
|
@ -108,6 +108,48 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string,
|
|||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||
reqCopy := *req
|
||||
reqCopy.Model = model
|
||||
|
||||
body, err := json.Marshal(reqCopy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/embeddings", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
p.setHeaders(httpReq)
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sending request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, &ProviderError{
|
||||
StatusCode: resp.StatusCode,
|
||||
Body: string(respBody),
|
||||
Provider: p.name,
|
||||
}
|
||||
}
|
||||
|
||||
var embResp EmbeddingResponse
|
||||
if err := json.Unmarshal(respBody, &embResp); err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling response: %w", err)
|
||||
}
|
||||
|
||||
return &embResp, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) setHeaders(req *http.Request) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
|
|
|
|||
|
|
@ -52,9 +52,38 @@ type Usage struct {
|
|||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// EmbeddingRequest is the OpenAI-compatible embedding request.
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"` // string or []string
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse is the OpenAI-compatible embedding response.
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage *EmbeddingUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingData holds a single embedding vector.
|
||||
type EmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// EmbeddingUsage reports token usage for embeddings.
|
||||
type EmbeddingUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// Provider sends requests to an LLM API.
|
||||
type Provider interface {
|
||||
Name() string
|
||||
ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error)
|
||||
ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error)
|
||||
Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,17 @@ import (
|
|||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"llm-gateway/internal/config"
|
||||
)
|
||||
|
||||
// ModelTimeouts holds per-model timeout overrides.
|
||||
type ModelTimeouts struct {
|
||||
RequestTimeout time.Duration
|
||||
StreamingTimeout time.Duration
|
||||
}
|
||||
|
||||
// Route maps a model to a specific provider with pricing.
|
||||
type Route struct {
|
||||
Provider Provider
|
||||
|
|
@ -24,6 +31,7 @@ type Registry struct {
|
|||
balancers map[string]LoadBalancer
|
||||
aliases map[string]string // alias -> canonical name
|
||||
order []string // preserves config order (canonical names only)
|
||||
timeouts map[string]*ModelTimeouts
|
||||
}
|
||||
|
||||
func NewRegistry(cfg *config.Config) (*Registry, error) {
|
||||
|
|
@ -46,6 +54,7 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
|
|||
balancers := make(map[string]LoadBalancer)
|
||||
aliases := make(map[string]string)
|
||||
order := make([]string, 0, len(cfg.Models))
|
||||
timeouts := make(map[string]*ModelTimeouts)
|
||||
|
||||
for _, mc := range cfg.Models {
|
||||
var modelRoutes []Route
|
||||
|
|
@ -74,6 +83,14 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
|
|||
// Load balancer
|
||||
balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing)
|
||||
|
||||
// Per-model timeouts
|
||||
if mc.RequestTimeout > 0 || mc.StreamingTimeout > 0 {
|
||||
timeouts[mc.Name] = &ModelTimeouts{
|
||||
RequestTimeout: mc.RequestTimeout,
|
||||
StreamingTimeout: mc.StreamingTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Register aliases
|
||||
for _, alias := range mc.Aliases {
|
||||
aliases[alias] = mc.Name
|
||||
|
|
@ -85,6 +102,7 @@ func (r *Registry) buildFromConfig(cfg *config.Config) error {
|
|||
r.balancers = balancers
|
||||
r.aliases = aliases
|
||||
r.order = order
|
||||
r.timeouts = timeouts
|
||||
r.mu.Unlock()
|
||||
|
||||
return nil
|
||||
|
|
@ -135,6 +153,18 @@ func (r *Registry) ModelNames() []string {
|
|||
return names
|
||||
}
|
||||
|
||||
// ModelTimeoutsFor returns per-model timeout overrides, resolving aliases. Returns nil if none set.
|
||||
func (r *Registry) ModelTimeoutsFor(model string) *ModelTimeouts {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
canonical := model
|
||||
if alias, ok := r.aliases[model]; ok {
|
||||
canonical = alias
|
||||
}
|
||||
return r.timeouts[canonical]
|
||||
}
|
||||
|
||||
// RouteInfo exposes route details for dashboard display.
|
||||
type RouteInfo struct {
|
||||
ProviderName string `json:"provider_name"`
|
||||
|
|
|
|||
|
|
@ -23,6 +23,10 @@ func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *Chat
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) Embedding(_ context.Context, _ string, _ *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// newTestRegistry builds a Registry directly without going through config parsing.
|
||||
func newTestRegistry(models []testModel) *Registry {
|
||||
r := &Registry{
|
||||
|
|
|
|||
107
llm-gateway/internal/proxy/dedup.go
Normal file
107
llm-gateway/internal/proxy/dedup.go
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// inflight represents an in-progress deduplicated request.
|
||||
type inflight struct {
|
||||
done chan struct{}
|
||||
result []byte
|
||||
statusCode int
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// Deduplicator coalesces identical concurrent non-streaming requests.
|
||||
type Deduplicator struct {
|
||||
mu sync.Mutex
|
||||
flights map[string]*inflight
|
||||
window time.Duration
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewDeduplicator creates a new request deduplicator.
|
||||
func NewDeduplicator(window time.Duration) *Deduplicator {
|
||||
if window == 0 {
|
||||
window = 30 * time.Second
|
||||
}
|
||||
d := &Deduplicator{
|
||||
flights: make(map[string]*inflight),
|
||||
window: window,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go d.cleanup()
|
||||
return d
|
||||
}
|
||||
|
||||
// DedupKey computes a dedup key from model name and request body.
|
||||
func DedupKey(model string, body []byte) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(model))
|
||||
h.Write([]byte{0})
|
||||
h.Write(body)
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// TryJoin attempts to join an in-flight request. Returns the inflight entry and
|
||||
// whether this caller is the leader (true) or a follower (false).
|
||||
func (d *Deduplicator) TryJoin(key string) (*inflight, bool) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if f, ok := d.flights[key]; ok {
|
||||
return f, false // follower
|
||||
}
|
||||
|
||||
f := &inflight{
|
||||
done: make(chan struct{}),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
d.flights[key] = f
|
||||
return f, true // leader
|
||||
}
|
||||
|
||||
// Complete signals completion of a deduplicated request.
|
||||
func (d *Deduplicator) Complete(key string, result []byte, statusCode int) {
|
||||
d.mu.Lock()
|
||||
f, ok := d.flights[key]
|
||||
delete(d.flights, key)
|
||||
d.mu.Unlock()
|
||||
|
||||
if ok {
|
||||
f.result = result
|
||||
f.statusCode = statusCode
|
||||
close(f.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the background cleanup goroutine.
|
||||
func (d *Deduplicator) Close() {
|
||||
close(d.done)
|
||||
}
|
||||
|
||||
// cleanup periodically removes stale in-flight entries.
|
||||
func (d *Deduplicator) cleanup() {
|
||||
ticker := time.NewTicker(d.window)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-d.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
d.mu.Lock()
|
||||
now := time.Now()
|
||||
for key, f := range d.flights {
|
||||
if now.Sub(f.createdAt) > d.window*2 {
|
||||
delete(d.flights, key)
|
||||
close(f.done) // unblock any waiting followers
|
||||
}
|
||||
}
|
||||
d.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
74
llm-gateway/internal/proxy/dedup_test.go
Normal file
74
llm-gateway/internal/proxy/dedup_test.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDedupKey(t *testing.T) {
|
||||
k1 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
|
||||
k2 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
|
||||
k3 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hello"}]}`))
|
||||
|
||||
if k1 != k2 {
|
||||
t.Error("identical requests should produce the same key")
|
||||
}
|
||||
if k1 == k3 {
|
||||
t.Error("different requests should produce different keys")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeduplicator_LeaderFollower(t *testing.T) {
|
||||
d := NewDeduplicator(5 * time.Second)
|
||||
defer d.Close()
|
||||
|
||||
key := DedupKey("gpt-4", []byte(`test`))
|
||||
|
||||
// First call is leader
|
||||
f1, isLeader := d.TryJoin(key)
|
||||
if !isLeader {
|
||||
t.Fatal("first caller should be leader")
|
||||
}
|
||||
|
||||
// Second call with same key is follower
|
||||
f2, isLeader := d.TryJoin(key)
|
||||
if isLeader {
|
||||
t.Fatal("second caller should be follower")
|
||||
}
|
||||
if f1 != f2 {
|
||||
t.Fatal("follower should get same inflight entry")
|
||||
}
|
||||
|
||||
// Complete the request
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-f2.done
|
||||
if string(f2.result) != "response" {
|
||||
t.Error("follower should receive leader's result")
|
||||
}
|
||||
if f2.statusCode != 200 {
|
||||
t.Error("follower should receive leader's status code")
|
||||
}
|
||||
}()
|
||||
|
||||
d.Complete(key, []byte("response"), 200)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestDeduplicator_DifferentKeys(t *testing.T) {
|
||||
d := NewDeduplicator(5 * time.Second)
|
||||
defer d.Close()
|
||||
|
||||
_, isLeader1 := d.TryJoin("key1")
|
||||
_, isLeader2 := d.TryJoin("key2")
|
||||
|
||||
if !isLeader1 || !isLeader2 {
|
||||
t.Error("different keys should both be leaders")
|
||||
}
|
||||
|
||||
d.Complete("key1", []byte("r1"), 200)
|
||||
d.Complete("key2", []byte("r2"), 200)
|
||||
}
|
||||
|
|
@ -53,6 +53,7 @@ type Handler struct {
|
|||
cfg *config.Config
|
||||
healthTracker *provider.HealthTracker
|
||||
debugLogger *storage.DebugLogger
|
||||
dedup *Deduplicator
|
||||
}
|
||||
|
||||
func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler {
|
||||
|
|
@ -70,6 +71,10 @@ func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) {
|
|||
h.debugLogger = dl
|
||||
}
|
||||
|
||||
func (h *Handler) SetDeduplicator(d *Deduplicator) {
|
||||
h.dedup = d
|
||||
}
|
||||
|
||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
|
||||
if err != nil {
|
||||
|
|
@ -118,11 +123,47 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// Apply per-model timeout for non-streaming requests
|
||||
modelTimeouts := h.registry.ModelTimeoutsFor(req.Model)
|
||||
|
||||
if req.Stream {
|
||||
h.handleStream(w, r, &req, routes, tokenName, requestID)
|
||||
h.handleStream(w, r, &req, routes, tokenName, requestID, modelTimeouts)
|
||||
return
|
||||
}
|
||||
|
||||
// Request deduplication for non-streaming requests
|
||||
if h.dedup != nil {
|
||||
dedupKey := DedupKey(req.Model, body)
|
||||
flight, isLeader := h.dedup.TryJoin(dedupKey)
|
||||
if !isLeader {
|
||||
// Wait for the leader to complete
|
||||
select {
|
||||
case <-flight.done:
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
w.Header().Set("X-Dedup", "HIT")
|
||||
w.WriteHeader(flight.statusCode)
|
||||
w.Write(flight.result)
|
||||
return
|
||||
case <-r.Context().Done():
|
||||
writeError(w, http.StatusGatewayTimeout, "request cancelled while waiting for dedup")
|
||||
return
|
||||
}
|
||||
}
|
||||
// Leader: proceed normally, but capture response for followers
|
||||
defer func() {
|
||||
// If we haven't completed yet (e.g., panic), clean up
|
||||
}()
|
||||
h.handleNonStreamDedup(w, r, &req, routes, tokenName, body, requestID, modelTimeouts, dedupKey)
|
||||
return
|
||||
}
|
||||
|
||||
if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
h.handleNonStream(w, r, &req, routes, tokenName, body, requestID)
|
||||
}
|
||||
|
||||
|
|
@ -233,6 +274,153 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
|
|||
}
|
||||
}
|
||||
|
||||
// handleNonStreamDedup wraps handleNonStream to capture the response for dedup followers.
|
||||
func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
var req provider.EmbeddingRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model == "" {
|
||||
writeError(w, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
routes, ok := h.registry.Lookup(req.Model)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "model not found: "+req.Model)
|
||||
return
|
||||
}
|
||||
|
||||
routes = h.filterHealthyRoutes(routes)
|
||||
tokenName := getTokenName(r.Context())
|
||||
requestID := middleware.GetReqID(r.Context())
|
||||
|
||||
var lastErr error
|
||||
for i, route := range routes {
|
||||
if i > 0 {
|
||||
backoff := backoffDuration(i, h.cfg.Retry)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-r.Context().Done():
|
||||
writeError(w, http.StatusGatewayTimeout, "request cancelled")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := route.Provider.Embedding(r.Context(), route.ProviderModel, &req)
|
||||
latency := time.Since(start).Milliseconds()
|
||||
|
||||
if err != nil {
|
||||
var pe *provider.ProviderError
|
||||
if errors.As(err, &pe) && !pe.IsRetryable() {
|
||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
||||
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
|
||||
if h.healthTracker != nil {
|
||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
||||
}
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
writeErrorRaw(w, pe.StatusCode, pe.Body)
|
||||
return
|
||||
}
|
||||
lastErr = err
|
||||
log.Printf("Provider %s embedding failed for %s: %v", route.Provider.Name(), req.Model, err)
|
||||
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
|
||||
if h.healthTracker != nil {
|
||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if h.healthTracker != nil {
|
||||
h.healthTracker.Record(route.Provider.Name(), latency, nil)
|
||||
}
|
||||
|
||||
promptTokens := 0
|
||||
if resp.Usage != nil {
|
||||
promptTokens = resp.Usage.PromptTokens
|
||||
}
|
||||
cost := float64(promptTokens) / 1_000_000.0 * route.InputPrice
|
||||
|
||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, promptTokens, 0, cost)
|
||||
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, promptTokens, cost, latency, "success", "")
|
||||
|
||||
resp.Model = req.Model
|
||||
|
||||
respBytes, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to marshal response")
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
w.Write(respBytes)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
if lastErr != nil {
|
||||
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
|
||||
} else {
|
||||
writeError(w, http.StatusBadGateway, "all providers failed")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) logEmbeddingRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens int, cost float64, latencyMS int64, status, errMsg string) {
|
||||
h.logger.Log(storage.RequestLog{
|
||||
RequestID: requestID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
TokenName: tokenName,
|
||||
Model: model,
|
||||
Provider: providerName,
|
||||
ProviderModel: providerModel,
|
||||
InputTokens: inputTokens,
|
||||
CostUSD: cost,
|
||||
LatencyMS: latencyMS,
|
||||
Status: status,
|
||||
ErrorMessage: errMsg,
|
||||
RequestType: "embedding",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleNonStreamDedup(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string, modelTimeouts *provider.ModelTimeouts, dedupKey string) {
|
||||
if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
|
||||
defer cancel()
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
|
||||
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
h.handleNonStream(rec, r, req, routes, tokenName, rawBody, requestID)
|
||||
h.dedup.Complete(dedupKey, rec.body, rec.statusCode)
|
||||
}
|
||||
|
||||
// responseRecorder captures the response for dedup.
|
||||
type responseRecorder struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (r *responseRecorder) WriteHeader(code int) {
|
||||
r.statusCode = code
|
||||
r.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (r *responseRecorder) Write(b []byte) (int, error) {
|
||||
r.body = append(r.body, b...)
|
||||
return r.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// filterHealthyRoutes removes providers with open circuit breakers.
|
||||
// If all are filtered out, returns original routes as fallback.
|
||||
func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route {
|
||||
|
|
|
|||
|
|
@ -5,27 +5,66 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"llm-gateway/internal/config"
|
||||
"llm-gateway/internal/provider"
|
||||
)
|
||||
|
||||
type ModelsHandler struct {
|
||||
registry *provider.Registry
|
||||
healthTracker *provider.HealthTracker
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewModelsHandler(registry *provider.Registry) *ModelsHandler {
|
||||
return &ModelsHandler{registry: registry}
|
||||
func NewModelsHandler(registry *provider.Registry, healthTracker *provider.HealthTracker, cfg *config.Config) *ModelsHandler {
|
||||
return &ModelsHandler{
|
||||
registry: registry,
|
||||
healthTracker: healthTracker,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) {
|
||||
names := h.registry.ModelNames()
|
||||
models := make([]map[string]any, len(names))
|
||||
for i, name := range names {
|
||||
models[i] = map[string]any{
|
||||
"id": name,
|
||||
allRoutes := h.registry.AllRoutes()
|
||||
models := make([]map[string]any, 0, len(allRoutes))
|
||||
|
||||
for _, m := range allRoutes {
|
||||
providers := make([]map[string]any, 0, len(m.Routes))
|
||||
for _, rt := range m.Routes {
|
||||
healthy := true
|
||||
if h.healthTracker != nil {
|
||||
healthy = h.healthTracker.IsAvailable(rt.ProviderName)
|
||||
}
|
||||
providers = append(providers, map[string]any{
|
||||
"name": rt.ProviderName,
|
||||
"model": rt.ProviderModel,
|
||||
"input_price": rt.InputPrice,
|
||||
"output_price": rt.OutputPrice,
|
||||
"priority": rt.Priority,
|
||||
"healthy": healthy,
|
||||
})
|
||||
}
|
||||
|
||||
// Find load balancing strategy from config
|
||||
loadBalancing := "first"
|
||||
for _, mc := range h.cfg.Models {
|
||||
if mc.Name == m.Name {
|
||||
if mc.LoadBalancing != "" {
|
||||
loadBalancing = mc.LoadBalancing
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
models = append(models, map[string]any{
|
||||
"id": m.Name,
|
||||
"object": "model",
|
||||
"created": time.Now().Unix(),
|
||||
"owned_by": "llm-gateway",
|
||||
}
|
||||
"providers": providers,
|
||||
"provider_count": len(providers),
|
||||
"load_balancing": loadBalancing,
|
||||
"aliases": m.Aliases,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
|
|
|||
|
|
@ -8,12 +8,15 @@ import (
|
|||
"time"
|
||||
|
||||
"llm-gateway/internal/storage"
|
||||
"llm-gateway/internal/webhook"
|
||||
)
|
||||
|
||||
type RateLimiter struct {
|
||||
db *storage.DB
|
||||
mu sync.Mutex
|
||||
buckets map[string]*tokenBucket
|
||||
notifier *webhook.Notifier
|
||||
budgetNotified sync.Map // tracks which token+budget combos have been notified
|
||||
}
|
||||
|
||||
type tokenBucket struct {
|
||||
|
|
@ -30,6 +33,11 @@ func NewRateLimiter(db *storage.DB) *RateLimiter {
|
|||
}
|
||||
}
|
||||
|
||||
// SetNotifier sets the webhook notifier for budget threshold alerts.
|
||||
func (rl *RateLimiter) SetNotifier(n *webhook.Notifier) {
|
||||
rl.notifier = n
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Check(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
apiToken := getAPIToken(r.Context())
|
||||
|
|
@ -63,16 +71,55 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
|
|||
// Check daily budget
|
||||
if apiToken.DailyBudgetUSD > 0 {
|
||||
spent, err := rl.db.TodaySpend(tokenName)
|
||||
if err == nil && spent >= apiToken.DailyBudgetUSD {
|
||||
if err == nil {
|
||||
if spent >= apiToken.DailyBudgetUSD {
|
||||
writeError(w, http.StatusTooManyRequests, "daily budget exceeded")
|
||||
return
|
||||
}
|
||||
rl.checkBudgetThreshold(tokenName, "daily", spent, apiToken.DailyBudgetUSD)
|
||||
}
|
||||
}
|
||||
|
||||
// Check monthly budget
|
||||
if apiToken.MonthlyBudgetUSD > 0 {
|
||||
spent, err := rl.db.MonthSpend(tokenName)
|
||||
if err == nil {
|
||||
if spent >= apiToken.MonthlyBudgetUSD {
|
||||
writeError(w, http.StatusTooManyRequests, "monthly budget exceeded")
|
||||
return
|
||||
}
|
||||
rl.checkBudgetThreshold(tokenName, "monthly", spent, apiToken.MonthlyBudgetUSD)
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// checkBudgetThreshold fires a webhook notification when spend reaches 80% of budget.
|
||||
func (rl *RateLimiter) checkBudgetThreshold(tokenName, budgetType string, spent, budget float64) {
|
||||
if rl.notifier == nil || budget <= 0 {
|
||||
return
|
||||
}
|
||||
if spent/budget < 0.8 {
|
||||
return
|
||||
}
|
||||
key := tokenName + ":" + budgetType
|
||||
if _, loaded := rl.budgetNotified.LoadOrStore(key, true); loaded {
|
||||
return // already notified
|
||||
}
|
||||
rl.notifier.Notify(webhook.Event{
|
||||
Type: webhook.EventBudgetThreshold,
|
||||
Data: map[string]any{
|
||||
"token": tokenName,
|
||||
"budget_type": budgetType,
|
||||
"spent": spent,
|
||||
"budget": budget,
|
||||
"percent": spent / budget * 100,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
"llm-gateway/internal/provider"
|
||||
)
|
||||
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) {
|
||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string, modelTimeouts *provider.ModelTimeouts) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
writeError(w, http.StatusInternalServerError, "streaming not supported")
|
||||
|
|
@ -60,11 +60,15 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
|
|||
continue
|
||||
}
|
||||
|
||||
// Apply streaming timeout
|
||||
// Apply streaming timeout (per-model override takes precedence)
|
||||
streamingTimeout := h.cfg.Server.StreamingTimeout
|
||||
if modelTimeouts != nil && modelTimeouts.StreamingTimeout > 0 {
|
||||
streamingTimeout = modelTimeouts.StreamingTimeout
|
||||
}
|
||||
var streamCtx context.Context
|
||||
var streamCancel context.CancelFunc
|
||||
if h.cfg.Server.StreamingTimeout > 0 {
|
||||
streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout)
|
||||
if streamingTimeout > 0 {
|
||||
streamCtx, streamCancel = context.WithTimeout(r.Context(), streamingTimeout)
|
||||
} else {
|
||||
streamCtx, streamCancel = context.WithCancel(r.Context())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,6 +102,21 @@ func (db *DB) TodaySpend(tokenName string) (float64, error) {
|
|||
return total.Float64, nil
|
||||
}
|
||||
|
||||
// MonthSpend returns the total cost in USD for a given token this month.
|
||||
func (db *DB) MonthSpend(tokenName string) (float64, error) {
|
||||
now := time.Now()
|
||||
startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()).Unix()
|
||||
var total sql.NullFloat64
|
||||
err := db.QueryRow(
|
||||
"SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?",
|
||||
tokenName, startOfMonth,
|
||||
).Scan(&total)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return total.Float64, nil
|
||||
}
|
||||
|
||||
// TodaySpendAll returns today's spend for all tokens as a map.
|
||||
func (db *DB) TodaySpendAll() (map[string]float64, error) {
|
||||
startOfDay := time.Now().Truncate(24 * time.Hour).Unix()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ type RequestLog struct {
|
|||
ErrorMessage string
|
||||
Streaming bool
|
||||
Cached bool
|
||||
RequestType string // "chat" or "embedding"
|
||||
}
|
||||
|
||||
type AsyncLogger struct {
|
||||
|
|
@ -94,8 +95,8 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
|
|||
}
|
||||
|
||||
stmt, err := tx.Prepare(`INSERT INTO request_logs
|
||||
(request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
(request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached, request_type)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: preparing log statement: %v", err)
|
||||
tx.Rollback()
|
||||
|
|
@ -112,10 +113,14 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
|
|||
if r.Cached {
|
||||
cached = 1
|
||||
}
|
||||
reqType := r.RequestType
|
||||
if reqType == "" {
|
||||
reqType = "chat"
|
||||
}
|
||||
_, err := stmt.Exec(
|
||||
r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
|
||||
r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS,
|
||||
r.Status, r.ErrorMessage, streaming, cached,
|
||||
r.Status, r.ErrorMessage, streaming, cached, reqType,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: inserting log: %v", err)
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
ALTER TABLE api_tokens DROP COLUMN monthly_budget_usd;
|
||||
|
|
@ -0,0 +1 @@
|
|||
ALTER TABLE api_tokens ADD COLUMN monthly_budget_usd REAL NOT NULL DEFAULT 0;
|
||||
|
|
@ -0,0 +1 @@
|
|||
ALTER TABLE request_logs DROP COLUMN request_type;
|
||||
|
|
@ -0,0 +1 @@
|
|||
ALTER TABLE request_logs ADD COLUMN request_type TEXT NOT NULL DEFAULT 'chat';
|
||||
123
llm-gateway/internal/webhook/webhook.go
Normal file
123
llm-gateway/internal/webhook/webhook.go
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
package webhook
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"llm-gateway/internal/config"
|
||||
)
|
||||
|
||||
// Event types.
|
||||
const (
|
||||
EventCircuitBreakerOpen = "circuit_breaker.open"
|
||||
EventCircuitBreakerClosed = "circuit_breaker.closed"
|
||||
EventBudgetThreshold = "budget.threshold"
|
||||
)
|
||||
|
||||
// Event represents a webhook notification payload.
|
||||
type Event struct {
|
||||
Type string `json:"type"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data map[string]any `json:"data"`
|
||||
}
|
||||
|
||||
// Notifier sends webhook notifications.
|
||||
type Notifier struct {
|
||||
webhooks []config.WebhookConfig
|
||||
ch chan Event
|
||||
done chan struct{}
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewNotifier creates a webhook notifier from config.
|
||||
func NewNotifier(webhooks []config.WebhookConfig) *Notifier {
|
||||
n := &Notifier{
|
||||
webhooks: webhooks,
|
||||
ch: make(chan Event, 100),
|
||||
done: make(chan struct{}),
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
go n.run()
|
||||
return n
|
||||
}
|
||||
|
||||
// Notify queues an event for delivery (non-blocking).
|
||||
func (n *Notifier) Notify(evt Event) {
|
||||
if evt.Timestamp.IsZero() {
|
||||
evt.Timestamp = time.Now()
|
||||
}
|
||||
select {
|
||||
case n.ch <- evt:
|
||||
default:
|
||||
log.Printf("WARNING: webhook channel full, dropping event %s", evt.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Close drains pending events and shuts down.
|
||||
func (n *Notifier) Close() {
|
||||
close(n.ch)
|
||||
<-n.done
|
||||
}
|
||||
|
||||
func (n *Notifier) run() {
|
||||
defer close(n.done)
|
||||
for evt := range n.ch {
|
||||
for _, wh := range n.webhooks {
|
||||
if !n.shouldSend(wh, evt.Type) {
|
||||
continue
|
||||
}
|
||||
n.send(wh, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) shouldSend(wh config.WebhookConfig, eventType string) bool {
|
||||
if len(wh.Events) == 0 {
|
||||
return true // no filter = send all
|
||||
}
|
||||
for _, e := range wh.Events {
|
||||
if e == eventType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (n *Notifier) send(wh config.WebhookConfig, evt Event) {
|
||||
body, err := json.Marshal(evt)
|
||||
if err != nil {
|
||||
log.Printf("ERROR: webhook marshal: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, wh.URL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
log.Printf("ERROR: webhook request: %v", err)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
if wh.Secret != "" {
|
||||
mac := hmac.New(sha256.New, []byte(wh.Secret))
|
||||
mac.Write(body)
|
||||
sig := hex.EncodeToString(mac.Sum(nil))
|
||||
req.Header.Set("X-Webhook-Signature", "sha256="+sig)
|
||||
}
|
||||
|
||||
resp, err := n.client.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("WARNING: webhook delivery to %s failed: %v", wh.URL, err)
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("WARNING: webhook %s returned %d", wh.URL, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue