ai-servers/llm-gateway/internal/proxy/ratelimit.go

90 lines
1.8 KiB
Go

package proxy
import (
"net/http"
"sync"
"time"
"llm-gateway/internal/storage"
)
type RateLimiter struct {
db *storage.DB
mu sync.Mutex
buckets map[string]*tokenBucket
}
type tokenBucket struct {
tokens float64
maxTokens float64
refillRate float64 // tokens per second
lastRefill time.Time
}
func NewRateLimiter(db *storage.DB) *RateLimiter {
return &RateLimiter{
db: db,
buckets: make(map[string]*tokenBucket),
}
}
func (rl *RateLimiter) Check(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
apiToken := getAPIToken(r.Context())
if apiToken == nil {
next.ServeHTTP(w, r)
return
}
tokenName := apiToken.Name
// Check rate limit
if apiToken.RateLimitRPM > 0 {
if !rl.allow(tokenName, apiToken.RateLimitRPM) {
writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
return
}
}
// Check daily budget
if apiToken.DailyBudgetUSD > 0 {
spent, err := rl.db.TodaySpend(tokenName)
if err == nil && spent >= apiToken.DailyBudgetUSD {
writeError(w, http.StatusTooManyRequests, "daily budget exceeded")
return
}
}
next.ServeHTTP(w, r)
})
}
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
bucket, ok := rl.buckets[tokenName]
if !ok {
bucket = &tokenBucket{
tokens: float64(rateLimitRPM),
maxTokens: float64(rateLimitRPM),
refillRate: float64(rateLimitRPM) / 60.0,
lastRefill: time.Now(),
}
rl.buckets[tokenName] = bucket
}
now := time.Now()
elapsed := now.Sub(bucket.lastRefill).Seconds()
bucket.tokens += elapsed * bucket.refillRate
if bucket.tokens > bucket.maxTokens {
bucket.tokens = bucket.maxTokens
}
bucket.lastRefill = now
if bucket.tokens < 1 {
return false
}
bucket.tokens--
return true
}