diff --git a/llm-gateway/internal/dashboard/templates/partials/debug.html b/llm-gateway/internal/dashboard/templates/partials/debug.html
new file mode 100644
index 0000000..ffb419b
--- /dev/null
+++ b/llm-gateway/internal/dashboard/templates/partials/debug.html
@@ -0,0 +1,100 @@
+{{define "content"}}
+
+
+
+ Debug Mode
+
+ {{if .DebugEnabled}}Enabled — requests are being logged{{else}}Disabled{{end}}
+
+
+
+
+
+
+ |
+ Time |
+ Request ID |
+ Token |
+ Model |
+ Provider |
+ Status |
+
+
+
+ {{range $i, $entry := .DebugResult.Entries}}
+
+ | ▶ |
+ {{formatTimeDetail $entry.Timestamp}} |
+ {{$entry.RequestID}} |
+ {{$entry.TokenName}} |
+ {{$entry.Model}} |
+ {{$entry.Provider}} |
+
+ {{if and (ge $entry.ResponseStatus 200) (lt $entry.ResponseStatus 300)}}{{$entry.ResponseStatus}}
+ {{else if ge $entry.ResponseStatus 400}}{{$entry.ResponseStatus}}
+ {{else}}{{$entry.ResponseStatus}}{{end}}
+ |
+
+
+
+
+ Request Headers:
+ {{if $entry.RequestHeaders}}{{$entry.RequestHeaders}}{{else}}(none){{end}}
+ Request Body:
+ {{if $entry.RequestBody}}{{$entry.RequestBody}}{{else}}(none){{end}}
+ Response Body:
+ {{if $entry.ResponseBody}}{{$entry.ResponseBody}}{{else}}(none){{end}}
+
+ |
+
+ {{end}}
+ {{if not .DebugResult.Entries}}
+ | No debug log entries |
+ {{end}}
+
+
+
+ {{if gt .DebugResult.TotalPages 1}}
+
+ {{end}}
+
+
+
+{{end}}
diff --git a/llm-gateway/internal/dashboard/templates/partials/logs.html b/llm-gateway/internal/dashboard/templates/partials/logs.html
index 79dca12..bf8c640 100644
--- a/llm-gateway/internal/dashboard/templates/partials/logs.html
+++ b/llm-gateway/internal/dashboard/templates/partials/logs.html
@@ -20,6 +20,9 @@
+
+
+
@@ -116,5 +119,15 @@ function toggleExpand(id) {
var el = document.getElementById(id);
if (el) el.classList.toggle('show');
}
+function exportLogs(format) {
+ var params = ['format=' + format];
+ var model = document.getElementById('filter-model').value;
+ var token = document.getElementById('filter-token').value;
+ var status = document.getElementById('filter-status').value;
+ if (model) params.push('model=' + encodeURIComponent(model));
+ if (token) params.push('token=' + encodeURIComponent(token));
+ if (status) params.push('status=' + encodeURIComponent(status));
+ window.open('/api/export/logs?' + params.join('&'), '_blank');
+}
{{end}}
diff --git a/llm-gateway/internal/metrics/prometheus.go b/llm-gateway/internal/metrics/prometheus.go
index 7e8b635..620f51e 100644
--- a/llm-gateway/internal/metrics/prometheus.go
+++ b/llm-gateway/internal/metrics/prometheus.go
@@ -10,6 +10,8 @@ type Metrics struct {
requestDuration *prometheus.HistogramVec
tokensTotal *prometheus.CounterVec
costTotal *prometheus.CounterVec
+ cacheHits prometheus.Counter
+ cacheMisses prometheus.Counter
}
func New() *Metrics {
@@ -34,6 +36,16 @@ func New() *Metrics {
Name: "llm_gateway_cost_usd_total",
Help: "Total cost in USD",
}, []string{"model", "provider", "token_name"}),
+
+ cacheHits: promauto.NewCounter(prometheus.CounterOpts{
+ Name: "llm_gateway_cache_hits_total",
+ Help: "Total number of cache hits",
+ }),
+
+ cacheMisses: promauto.NewCounter(prometheus.CounterOpts{
+ Name: "llm_gateway_cache_misses_total",
+ Help: "Total number of cache misses",
+ }),
}
}
@@ -51,3 +63,11 @@ func (m *Metrics) RecordRequest(model, providerName, tokenName, status string, l
m.costTotal.WithLabelValues(model, providerName, tokenName).Add(cost)
}
}
+
+func (m *Metrics) RecordCacheHit() {
+ m.cacheHits.Inc()
+}
+
+func (m *Metrics) RecordCacheMiss() {
+ m.cacheMisses.Inc()
+}
diff --git a/llm-gateway/internal/provider/balancer.go b/llm-gateway/internal/provider/balancer.go
new file mode 100644
index 0000000..602d885
--- /dev/null
+++ b/llm-gateway/internal/provider/balancer.go
@@ -0,0 +1,144 @@
+package provider
+
+import (
+ "math/rand"
+ "sort"
+ "sync/atomic"
+)
+
+// LoadBalancer reorders routes for load distribution.
+type LoadBalancer interface {
+ Reorder(routes []Route) []Route
+}
+
+// NewLoadBalancer creates a load balancer by strategy name.
+func NewLoadBalancer(strategy string) LoadBalancer {
+ switch strategy {
+ case "round-robin":
+ return &RoundRobinBalancer{}
+ case "random":
+ return &RandomBalancer{}
+ case "least-cost":
+ return &LeastCostBalancer{}
+ default:
+ return &FirstBalancer{}
+ }
+}
+
+// FirstBalancer is a no-op that preserves original order.
+type FirstBalancer struct{}
+
+func (b *FirstBalancer) Reorder(routes []Route) []Route {
+ return routes
+}
+
+// RoundRobinBalancer rotates routes within same-priority groups.
+type RoundRobinBalancer struct {
+ counter atomic.Uint64
+}
+
+func (b *RoundRobinBalancer) Reorder(routes []Route) []Route {
+ if len(routes) <= 1 {
+ return routes
+ }
+
+ result := make([]Route, len(routes))
+ copy(result, routes)
+
+ // Group by priority and rotate within each group
+ groups := groupByPriority(result)
+ idx := 0
+ count := b.counter.Add(1)
+ for _, group := range groups {
+ if len(group) > 1 {
+ offset := int(count) % len(group)
+ for j := 0; j < len(group); j++ {
+ result[idx] = group[(j+offset)%len(group)]
+ idx++
+ }
+ } else {
+ result[idx] = group[0]
+ idx++
+ }
+ }
+
+ return result
+}
+
+// RandomBalancer shuffles routes within same-priority groups.
+type RandomBalancer struct{}
+
+func (b *RandomBalancer) Reorder(routes []Route) []Route {
+ if len(routes) <= 1 {
+ return routes
+ }
+
+ result := make([]Route, len(routes))
+ copy(result, routes)
+
+ groups := groupByPriority(result)
+ idx := 0
+ for _, group := range groups {
+ rand.Shuffle(len(group), func(i, j int) {
+ group[i], group[j] = group[j], group[i]
+ })
+ for _, r := range group {
+ result[idx] = r
+ idx++
+ }
+ }
+
+ return result
+}
+
+// LeastCostBalancer sorts by price within same-priority groups.
+type LeastCostBalancer struct{}
+
+func (b *LeastCostBalancer) Reorder(routes []Route) []Route {
+ if len(routes) <= 1 {
+ return routes
+ }
+
+ result := make([]Route, len(routes))
+ copy(result, routes)
+
+ groups := groupByPriority(result)
+ idx := 0
+ for _, group := range groups {
+ sort.Slice(group, func(i, j int) bool {
+ costI := group[i].InputPrice + group[i].OutputPrice
+ costJ := group[j].InputPrice + group[j].OutputPrice
+ return costI < costJ
+ })
+ for _, r := range group {
+ result[idx] = r
+ idx++
+ }
+ }
+
+ return result
+}
+
+// groupByPriority splits routes into groups of same priority, preserving order.
+func groupByPriority(routes []Route) [][]Route {
+ if len(routes) == 0 {
+ return nil
+ }
+
+ var groups [][]Route
+ currentPriority := routes[0].Priority
+ currentGroup := []Route{routes[0]}
+
+ for i := 1; i < len(routes); i++ {
+ if routes[i].Priority == currentPriority {
+ currentGroup = append(currentGroup, routes[i])
+ } else {
+ groups = append(groups, currentGroup)
+ currentPriority = routes[i].Priority
+ currentGroup = []Route{routes[i]}
+ }
+ }
+ groups = append(groups, currentGroup)
+
+ return groups
+}
diff --git a/llm-gateway/internal/provider/balancer_test.go b/llm-gateway/internal/provider/balancer_test.go
new file mode 100644
index 0000000..cc5378e
--- /dev/null
+++ b/llm-gateway/internal/provider/balancer_test.go
@@ -0,0 +1,294 @@
+package provider
+
+import (
+ "fmt"
+ "testing"
+)
+
+type routeSpec struct {
+ name string
+ priority int
+ input float64
+ output float64
+}
+
+func makeRoutes(specs ...routeSpec) []Route {
+ routes := make([]Route, len(specs))
+ for i, s := range specs {
+ routes[i] = Route{
+ Provider: &mockProvider{name: s.name},
+ ProviderModel: s.name + "-model",
+ Priority: s.priority,
+ InputPrice: s.input,
+ OutputPrice: s.output,
+ }
+ }
+ return routes
+}
+
+func routeNames(routes []Route) []string {
+ names := make([]string, len(routes))
+ for i, r := range routes {
+ names[i] = r.Provider.Name()
+ }
+ return names
+}
+
+func TestFirstBalancer_PreservesOrder(t *testing.T) {
+ routes := makeRoutes(
+ routeSpec{"a", 1, 1.0, 1.0},
+ routeSpec{"b", 1, 2.0, 2.0},
+ routeSpec{"c", 1, 3.0, 3.0},
+ )
+
+ b := &FirstBalancer{}
+ result := b.Reorder(routes)
+
+ names := routeNames(result)
+ if names[0] != "a" || names[1] != "b" || names[2] != "c" {
+ t.Fatalf("expected [a b c], got %v", names)
+ }
+}
+
+func TestRoundRobinBalancer_RotatesWithinPriorityGroup(t *testing.T) {
+ routes := makeRoutes(
+ routeSpec{"a", 1, 1.0, 1.0},
+ routeSpec{"b", 1, 1.0, 1.0},
+ routeSpec{"c", 1, 1.0, 1.0},
+ )
+
+ b := &RoundRobinBalancer{}
+
+ // Collect the first element from multiple calls
+ seen := make(map[string]bool)
+ for i := 0; i < 6; i++ {
+ result := b.Reorder(routes)
+ seen[result[0].Provider.Name()] = true
+ }
+
+ // All routes should have appeared as first at some point
+ for _, name := range []string{"a", "b", "c"} {
+ if !seen[name] {
+ t.Errorf("expected %q to appear as first element in rotation", name)
+ }
+ }
+}
+
+func TestRoundRobinBalancer_PreservesPriorityOrder(t *testing.T) {
+ routes := makeRoutes(
+ routeSpec{"a", 1, 1.0, 1.0},
+ routeSpec{"b", 1, 1.0, 1.0},
+ routeSpec{"c", 2, 1.0, 1.0},
+ )
+
+ b := &RoundRobinBalancer{}
+
+ // Priority 2 route should always be last
+ for i := 0; i < 5; i++ {
+ result := b.Reorder(routes)
+ if result[2].Provider.Name() != "c" {
+ t.Fatalf("expected priority-2 route 'c' at the end, got %q", result[2].Provider.Name())
+ }
+ }
+}
+
+func TestRandomBalancer_AllRoutesPresent(t *testing.T) {
+ routes := makeRoutes(
+ routeSpec{"a", 1, 1.0, 1.0},
+ routeSpec{"b", 1, 1.0, 1.0},
+ routeSpec{"c", 1, 1.0, 1.0},
+ )
+
+ b := &RandomBalancer{}
+
+ for i := 0; i < 10; i++ {
+ result := b.Reorder(routes)
+ if len(result) != 3 {
+ t.Fatalf("expected 3 routes, got %d", len(result))
+ }
+
+ names := make(map[string]bool)
+ for _, r := range result {
+ names[r.Provider.Name()] = true
+ }
+ for _, want := range []string{"a", "b", "c"} {
+ if !names[want] {
+ t.Errorf("missing route %q in result", want)
+ }
+ }
+ }
+}
+
+func TestRandomBalancer_PreservesPriorityOrder(t *testing.T) {
+ routes := makeRoutes(
+ routeSpec{"a", 1, 1.0, 1.0},
+ routeSpec{"b", 1, 1.0, 1.0},
+ routeSpec{"c", 2, 1.0, 1.0},
+ )
+
+ b := &RandomBalancer{}
+
+ for i := 0; i < 10; i++ {
+ result := b.Reorder(routes)
+ if result[2].Provider.Name() != "c" {
+ t.Fatalf("expected priority-2 route 'c' last, got %q", result[2].Provider.Name())
+ }
+ }
+}
+
+func TestLeastCostBalancer_SortsByCost(t *testing.T) {
+ routes := makeRoutes(
+ routeSpec{"expensive", 1, 10.0, 10.0},
+ routeSpec{"cheap", 1, 1.0, 1.0},
+ routeSpec{"medium", 1, 5.0, 5.0},
+ )
+
+ b := &LeastCostBalancer{}
+ result := b.Reorder(routes)
+
+ names := routeNames(result)
+ expected := []string{"cheap", "medium", "expensive"}
+ for i, want := range expected {
+ if names[i] != want {
+ t.Errorf("position %d: got %q, want %q", i, names[i], want)
+ }
+ }
+}
+
+func TestLeastCostBalancer_PreservesPriorityOrder(t *testing.T) {
+ routes := makeRoutes(
+ routeSpec{"expensive-p1", 1, 10.0, 10.0},
+ routeSpec{"cheap-p1", 1, 1.0, 1.0},
+ routeSpec{"cheap-p2", 2, 0.5, 0.5},
+ )
+
+ b := &LeastCostBalancer{}
+ result := b.Reorder(routes)
+
+ names := routeNames(result)
+ // Within priority 1, cheap should come first; priority 2 always last
+ if names[0] != "cheap-p1" {
+ t.Errorf("expected cheap-p1 first, got %q", names[0])
+ }
+ if names[1] != "expensive-p1" {
+ t.Errorf("expected expensive-p1 second, got %q", names[1])
+ }
+ if names[2] != "cheap-p2" {
+ t.Errorf("expected cheap-p2 last, got %q", names[2])
+ }
+}
+
+func TestGroupByPriority(t *testing.T) {
+ tests := []struct {
+ name string
+ priorities []int
+ wantGroups [][]int
+ }{
+ {
+ name: "empty",
+ priorities: nil,
+ wantGroups: nil,
+ },
+ {
+ name: "single",
+ priorities: []int{1},
+ wantGroups: [][]int{{1}},
+ },
+ {
+ name: "all same",
+ priorities: []int{1, 1, 1},
+ wantGroups: [][]int{{1, 1, 1}},
+ },
+ {
+ name: "two groups",
+ priorities: []int{1, 1, 2, 2},
+ wantGroups: [][]int{{1, 1}, {2, 2}},
+ },
+ {
+ name: "three groups",
+ priorities: []int{1, 2, 2, 3},
+ wantGroups: [][]int{{1}, {2, 2}, {3}},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var routes []Route
+ for _, p := range tt.priorities {
+ routes = append(routes, Route{Priority: p})
+ }
+
+ groups := groupByPriority(routes)
+
+ if tt.wantGroups == nil {
+ if groups != nil {
+ t.Fatalf("expected nil groups, got %v", groups)
+ }
+ return
+ }
+
+ if len(groups) != len(tt.wantGroups) {
+ t.Fatalf("expected %d groups, got %d", len(tt.wantGroups), len(groups))
+ }
+
+ for i, wg := range tt.wantGroups {
+ if len(groups[i]) != len(wg) {
+ t.Errorf("group %d: expected %d routes, got %d", i, len(wg), len(groups[i]))
+ continue
+ }
+ for j, wp := range wg {
+ if groups[i][j].Priority != wp {
+ t.Errorf("group %d, route %d: expected priority %d, got %d", i, j, wp, groups[i][j].Priority)
+ }
+ }
+ }
+ })
+ }
+}
+
+func TestBalancer_SingleRoute(t *testing.T) {
+ routes := makeRoutes(routeSpec{"only", 1, 1.0, 1.0})
+
+ balancers := []struct {
+ name string
+ balancer LoadBalancer
+ }{
+ {"first", &FirstBalancer{}},
+ {"round-robin", &RoundRobinBalancer{}},
+ {"random", &RandomBalancer{}},
+ {"least-cost", &LeastCostBalancer{}},
+ }
+
+ for _, bb := range balancers {
+ t.Run(bb.name, func(t *testing.T) {
+ result := bb.balancer.Reorder(routes)
+ if len(result) != 1 || result[0].Provider.Name() != "only" {
+ t.Fatalf("expected single route 'only', got %v", routeNames(result))
+ }
+ })
+ }
+}
+
+func TestNewLoadBalancer(t *testing.T) {
+ tests := []struct {
+ strategy string
+ wantType string
+ }{
+ {"round-robin", "*provider.RoundRobinBalancer"},
+ {"random", "*provider.RandomBalancer"},
+ {"least-cost", "*provider.LeastCostBalancer"},
+ {"first", "*provider.FirstBalancer"},
+ {"unknown", "*provider.FirstBalancer"},
+ {"", "*provider.FirstBalancer"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.strategy, func(t *testing.T) {
+ b := NewLoadBalancer(tt.strategy)
+ got := fmt.Sprintf("%T", b)
+ if got != tt.wantType {
+ t.Errorf("NewLoadBalancer(%q) = %s, want %s", tt.strategy, got, tt.wantType)
+ }
+ })
+ }
+}
diff --git a/llm-gateway/internal/provider/health.go b/llm-gateway/internal/provider/health.go
index ae6d97e..91cb965 100644
--- a/llm-gateway/internal/provider/health.go
+++ b/llm-gateway/internal/provider/health.go
@@ -3,8 +3,39 @@ package provider
import (
"sync"
"time"
+
+ "llm-gateway/internal/config"
)
+// CircuitState represents the state of a circuit breaker.
+type CircuitState int
+
+const (
+ CircuitClosed CircuitState = iota // normal operation
+ CircuitOpen // blocking requests
+ CircuitHalfOpen // testing with probe request
+)
+
+func (s CircuitState) String() string {
+ switch s {
+ case CircuitClosed:
+ return "closed"
+ case CircuitOpen:
+ return "open"
+ case CircuitHalfOpen:
+ return "half-open"
+ default:
+ return "unknown"
+ }
+}
+
+// ProviderCircuit tracks circuit breaker state for a single provider.
+type ProviderCircuit struct {
+ State CircuitState
+ OpenedAt time.Time
+ LastProbe time.Time
+}
+
// HealthEvent represents a single request outcome for a provider.
type HealthEvent struct {
Timestamp time.Time
@@ -15,12 +46,13 @@ type HealthEvent struct {
// ProviderHealth is the computed health status for a provider.
type ProviderHealth struct {
- Provider string `json:"provider"`
- Status string `json:"status"` // healthy, degraded, down
- ErrorRate float64 `json:"error_rate"`
- AvgLatency float64 `json:"avg_latency_ms"`
- Total int `json:"total"`
- Errors int `json:"errors"`
+ Provider string `json:"provider"`
+ Status string `json:"status"` // healthy, degraded, down
+ ErrorRate float64 `json:"error_rate"`
+ AvgLatency float64 `json:"avg_latency_ms"`
+ Total int `json:"total"`
+ Errors int `json:"errors"`
+ CircuitState string `json:"circuit_state"`
}
// HealthTracker tracks per-provider health using a sliding window.
@@ -28,20 +60,52 @@ type HealthTracker struct {
mu sync.RWMutex
windows map[string][]HealthEvent
windowDu time.Duration
+ circuits map[string]*ProviderCircuit
+ cbConfig config.CircuitBreakerConfig
}
// NewHealthTracker creates a health tracker with the given window duration.
-func NewHealthTracker(window time.Duration) *HealthTracker {
+func NewHealthTracker(window time.Duration, cbCfg config.CircuitBreakerConfig) *HealthTracker {
if window == 0 {
window = 5 * time.Minute
}
return &HealthTracker{
windows: make(map[string][]HealthEvent),
+ circuits: make(map[string]*ProviderCircuit),
windowDu: window,
+ cbConfig: cbCfg,
}
}
-// Record adds a health event for a provider.
+// IsAvailable returns true if the provider's circuit breaker allows requests.
+func (h *HealthTracker) IsAvailable(provider string) bool {
+ if !h.cbConfig.Enabled {
+ return true
+ }
+
+ h.mu.RLock()
+ defer h.mu.RUnlock()
+
+ circuit, ok := h.circuits[provider]
+ if !ok {
+ return true // no circuit = closed = available
+ }
+
+ switch circuit.State {
+ case CircuitOpen:
+ // Check if cooldown has elapsed -> transition to half-open
+ if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
+ return true // will transition to half-open on next record
+ }
+ return false
+ case CircuitHalfOpen:
+ return true // allow probe
+ default:
+ return true
+ }
+}
+
+// Record adds a health event for a provider and evaluates circuit transitions.
func (h *HealthTracker) Record(provider string, latencyMS int64, err error) {
event := HealthEvent{
Timestamp: time.Now(),
@@ -57,6 +121,69 @@ func (h *HealthTracker) Record(provider string, latencyMS int64, err error) {
h.windows[provider] = append(h.windows[provider], event)
h.prune(provider)
+
+ if h.cbConfig.Enabled {
+ h.evaluateCircuit(provider, err)
+ }
+}
+
+// evaluateCircuit transitions circuit breaker state. Must be called with lock held.
+func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
+ circuit, ok := h.circuits[providerName]
+ if !ok {
+ circuit = &ProviderCircuit{State: CircuitClosed}
+ h.circuits[providerName] = circuit
+ }
+
+ switch circuit.State {
+ case CircuitClosed:
+ // Check if error threshold exceeded
+ errorRate, total := h.errorRateUnlocked(providerName)
+ if total >= h.cbConfig.MinRequests && errorRate >= h.cbConfig.ErrorThreshold {
+ circuit.State = CircuitOpen
+ circuit.OpenedAt = time.Now()
+ }
+ case CircuitOpen:
+ // Check if cooldown elapsed -> half-open
+ if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
+ circuit.State = CircuitHalfOpen
+ circuit.LastProbe = time.Now()
+ // Evaluate the probe result immediately
+ if lastErr == nil {
+ circuit.State = CircuitClosed
+ } else {
+ circuit.State = CircuitOpen
+ circuit.OpenedAt = time.Now()
+ }
+ }
+ case CircuitHalfOpen:
+ if lastErr == nil {
+ circuit.State = CircuitClosed
+ } else {
+ circuit.State = CircuitOpen
+ circuit.OpenedAt = time.Now()
+ }
+ }
+}
+
+// errorRateUnlocked computes error rate within window. Must be called with lock held.
+func (h *HealthTracker) errorRateUnlocked(provider string) (float64, int) {
+ cutoff := time.Now().Add(-h.windowDu)
+ events := h.windows[provider]
+ var total, errors int
+ for _, e := range events {
+ if e.Timestamp.Before(cutoff) {
+ continue
+ }
+ total++
+ if e.IsError {
+ errors++
+ }
+ }
+ if total == 0 {
+ return 0, 0
+ }
+ return float64(errors) / float64(total), total
}
// Status returns computed health for all tracked providers.
@@ -94,13 +221,19 @@ func (h *HealthTracker) Status() []ProviderHealth {
status = "degraded"
}
+ circuitState := "closed"
+ if circuit, ok := h.circuits[provider]; ok {
+ circuitState = circuit.State.String()
+ }
+
results = append(results, ProviderHealth{
- Provider: provider,
- Status: status,
- ErrorRate: errorRate,
- AvgLatency: float64(totalLatency) / float64(total),
- Total: total,
- Errors: errors,
+ Provider: provider,
+ Status: status,
+ ErrorRate: errorRate,
+ AvgLatency: float64(totalLatency) / float64(total),
+ Total: total,
+ Errors: errors,
+ CircuitState: circuitState,
})
}
diff --git a/llm-gateway/internal/provider/health_test.go b/llm-gateway/internal/provider/health_test.go
new file mode 100644
index 0000000..6d021b1
--- /dev/null
+++ b/llm-gateway/internal/provider/health_test.go
@@ -0,0 +1,345 @@
+package provider
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "llm-gateway/internal/config"
+)
+
+func newTestTracker(window time.Duration, cb config.CircuitBreakerConfig) *HealthTracker {
+ return NewHealthTracker(window, cb)
+}
+
+func defaultCBConfig() config.CircuitBreakerConfig {
+ return config.CircuitBreakerConfig{
+ Enabled: true,
+ ErrorThreshold: 0.5,
+ MinRequests: 3,
+ CooldownDuration: 100 * time.Millisecond,
+ }
+}
+
+func TestHealthTracker_Record(t *testing.T) {
+ ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
+
+ ht.Record("provA", 100, nil)
+ ht.Record("provA", 200, errors.New("fail"))
+ ht.Record("provB", 50, nil)
+
+ ht.mu.RLock()
+ defer ht.mu.RUnlock()
+
+ if len(ht.windows["provA"]) != 2 {
+ t.Fatalf("expected 2 events for provA, got %d", len(ht.windows["provA"]))
+ }
+ if len(ht.windows["provB"]) != 1 {
+ t.Fatalf("expected 1 event for provB, got %d", len(ht.windows["provB"]))
+ }
+
+ // Verify event fields
+ ev := ht.windows["provA"][1]
+ if !ev.IsError || ev.ErrorMsg != "fail" || ev.LatencyMS != 200 {
+ t.Fatalf("unexpected event fields: %+v", ev)
+ }
+}
+
+func TestHealthTracker_Status(t *testing.T) {
+ tests := []struct {
+ name string
+ successCount int
+ errorCount int
+ wantStatus string
+ wantErrorRate float64
+ wantTotal int
+ wantErrors int
+ }{
+ {
+ name: "healthy - no errors",
+ successCount: 10,
+ errorCount: 0,
+ wantStatus: "healthy",
+ wantErrorRate: 0.0,
+ wantTotal: 10,
+ wantErrors: 0,
+ },
+ {
+ name: "healthy - below 10% errors",
+ successCount: 19,
+ errorCount: 1,
+ wantStatus: "healthy",
+ wantErrorRate: 0.05,
+ wantTotal: 20,
+ wantErrors: 1,
+ },
+ {
+ name: "degraded - 20% errors",
+ successCount: 8,
+ errorCount: 2,
+ wantStatus: "degraded",
+ wantErrorRate: 0.2,
+ wantTotal: 10,
+ wantErrors: 2,
+ },
+ {
+ name: "degraded - exactly 10% errors",
+ successCount: 9,
+ errorCount: 1,
+ wantStatus: "degraded",
+ wantErrorRate: 0.1,
+ wantTotal: 10,
+ wantErrors: 1,
+ },
+ {
+ name: "down - 50% errors",
+ successCount: 5,
+ errorCount: 5,
+ wantStatus: "down",
+ wantErrorRate: 0.5,
+ wantTotal: 10,
+ wantErrors: 5,
+ },
+ {
+ name: "down - all errors",
+ successCount: 0,
+ errorCount: 5,
+ wantStatus: "down",
+ wantErrorRate: 1.0,
+ wantTotal: 5,
+ wantErrors: 5,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
+
+ for i := 0; i < tt.successCount; i++ {
+ ht.Record("prov", 100, nil)
+ }
+ for i := 0; i < tt.errorCount; i++ {
+ ht.Record("prov", 100, errors.New("err"))
+ }
+
+ statuses := ht.Status()
+ if len(statuses) != 1 {
+ t.Fatalf("expected 1 status, got %d", len(statuses))
+ }
+
+ s := statuses[0]
+ if s.Status != tt.wantStatus {
+ t.Errorf("status = %q, want %q", s.Status, tt.wantStatus)
+ }
+ if s.Total != tt.wantTotal {
+ t.Errorf("total = %d, want %d", s.Total, tt.wantTotal)
+ }
+ if s.Errors != tt.wantErrors {
+ t.Errorf("errors = %d, want %d", s.Errors, tt.wantErrors)
+ }
+ // Allow small float tolerance
+ if diff := s.ErrorRate - tt.wantErrorRate; diff > 0.001 || diff < -0.001 {
+ t.Errorf("error_rate = %f, want %f", s.ErrorRate, tt.wantErrorRate)
+ }
+ })
+ }
+}
+
+func TestHealthTracker_CircuitBreaker_ClosedToOpen(t *testing.T) {
+ cb := defaultCBConfig()
+ cb.MinRequests = 3
+ cb.ErrorThreshold = 0.5
+
+ ht := newTestTracker(5*time.Minute, cb)
+
+ // Record errors to exceed threshold (3 errors out of 3 = 100% > 50%)
+ ht.Record("prov", 100, errors.New("err"))
+ ht.Record("prov", 100, errors.New("err"))
+ ht.Record("prov", 100, errors.New("err"))
+
+ ht.mu.RLock()
+ state := ht.circuits["prov"].State
+ ht.mu.RUnlock()
+
+ if state != CircuitOpen {
+ t.Fatalf("expected CircuitOpen, got %s", state)
+ }
+
+ if ht.IsAvailable("prov") {
+ t.Fatal("expected IsAvailable=false when circuit is open")
+ }
+}
+
+func TestHealthTracker_CircuitBreaker_OpenToHalfOpenOnCooldown(t *testing.T) {
+ cb := defaultCBConfig()
+ cb.CooldownDuration = 50 * time.Millisecond
+
+ ht := newTestTracker(5*time.Minute, cb)
+
+ // Trip the circuit
+ for i := 0; i < 5; i++ {
+ ht.Record("prov", 100, errors.New("err"))
+ }
+
+ if ht.IsAvailable("prov") {
+ t.Fatal("expected circuit open, IsAvailable should be false")
+ }
+
+ // Wait for cooldown
+ time.Sleep(60 * time.Millisecond)
+
+ // After cooldown, IsAvailable should return true (will transition to half-open)
+ if !ht.IsAvailable("prov") {
+ t.Fatal("expected IsAvailable=true after cooldown")
+ }
+}
+
+func TestHealthTracker_CircuitBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
+ cb := defaultCBConfig()
+ cb.CooldownDuration = 10 * time.Millisecond
+
+ ht := newTestTracker(5*time.Minute, cb)
+
+ // Trip the circuit
+ for i := 0; i < 5; i++ {
+ ht.Record("prov", 100, errors.New("err"))
+ }
+
+ // Wait for cooldown so next Record transitions through Open->HalfOpen
+ time.Sleep(20 * time.Millisecond)
+
+ // A successful record should transition: Open -> HalfOpen -> Closed
+ ht.Record("prov", 100, nil)
+
+ ht.mu.RLock()
+ state := ht.circuits["prov"].State
+ ht.mu.RUnlock()
+
+ if state != CircuitClosed {
+ t.Fatalf("expected CircuitClosed after success in half-open, got %s", state)
+ }
+
+ if !ht.IsAvailable("prov") {
+ t.Fatal("expected IsAvailable=true after circuit closed")
+ }
+}
+
+func TestHealthTracker_CircuitBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
+ cb := defaultCBConfig()
+ cb.CooldownDuration = 10 * time.Millisecond
+
+ ht := newTestTracker(5*time.Minute, cb)
+
+ // Trip the circuit
+ for i := 0; i < 5; i++ {
+ ht.Record("prov", 100, errors.New("err"))
+ }
+
+ // Wait for cooldown
+ time.Sleep(20 * time.Millisecond)
+
+ // A failed record should transition: Open -> HalfOpen -> Open
+ ht.Record("prov", 100, errors.New("still failing"))
+
+ ht.mu.RLock()
+ state := ht.circuits["prov"].State
+ ht.mu.RUnlock()
+
+ if state != CircuitOpen {
+ t.Fatalf("expected CircuitOpen after failure in half-open, got %s", state)
+ }
+}
+
+func TestHealthTracker_IsAvailable_NoCircuitBreaker(t *testing.T) {
+ ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{Enabled: false})
+
+ // Even with errors, IsAvailable should return true when CB is disabled
+ for i := 0; i < 10; i++ {
+ ht.Record("prov", 100, errors.New("err"))
+ }
+
+ if !ht.IsAvailable("prov") {
+ t.Fatal("expected IsAvailable=true when circuit breaker disabled")
+ }
+}
+
+func TestHealthTracker_IsAvailable_UnknownProvider(t *testing.T) {
+ ht := newTestTracker(5*time.Minute, defaultCBConfig())
+
+ if !ht.IsAvailable("unknown") {
+ t.Fatal("expected IsAvailable=true for unknown provider (no circuit)")
+ }
+}
+
+func TestHealthTracker_WindowPruning(t *testing.T) {
+ // Use a tiny window so events expire quickly
+ ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
+
+ ht.Record("prov", 100, nil)
+ ht.Record("prov", 200, nil)
+
+ // Wait for events to expire
+ time.Sleep(60 * time.Millisecond)
+
+ // Record a new event to trigger pruning
+ ht.Record("prov", 300, nil)
+
+ ht.mu.RLock()
+ count := len(ht.windows["prov"])
+ ht.mu.RUnlock()
+
+ if count != 1 {
+ t.Fatalf("expected 1 event after pruning, got %d", count)
+ }
+}
+
+func TestHealthTracker_Status_EmptyAfterPruning(t *testing.T) {
+ ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
+
+ ht.Record("prov", 100, nil)
+
+ // Wait for events to expire
+ time.Sleep(60 * time.Millisecond)
+
+ statuses := ht.Status()
+ if len(statuses) != 0 {
+ t.Fatalf("expected 0 statuses after window expiry, got %d", len(statuses))
+ }
+}
+
+func TestHealthTracker_Status_AvgLatency(t *testing.T) {
+ ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
+
+ ht.Record("prov", 100, nil)
+ ht.Record("prov", 200, nil)
+ ht.Record("prov", 300, nil)
+
+ statuses := ht.Status()
+ if len(statuses) != 1 {
+ t.Fatalf("expected 1 status, got %d", len(statuses))
+ }
+
+ want := 200.0
+ if diff := statuses[0].AvgLatency - want; diff > 0.001 || diff < -0.001 {
+ t.Errorf("avg_latency = %f, want %f", statuses[0].AvgLatency, want)
+ }
+}
+
+func TestHealthTracker_Status_CircuitStateReported(t *testing.T) {
+ cb := defaultCBConfig()
+ ht := newTestTracker(5*time.Minute, cb)
+
+ // Trip the circuit
+ for i := 0; i < 5; i++ {
+ ht.Record("prov", 100, errors.New("err"))
+ }
+
+ statuses := ht.Status()
+ if len(statuses) != 1 {
+ t.Fatalf("expected 1 status, got %d", len(statuses))
+ }
+
+ if statuses[0].CircuitState != "open" {
+ t.Errorf("circuit_state = %q, want %q", statuses[0].CircuitState, "open")
+ }
+}
diff --git a/llm-gateway/internal/provider/openai.go b/llm-gateway/internal/provider/openai.go
index 1278eea..0e434f3 100644
--- a/llm-gateway/internal/provider/openai.go
+++ b/llm-gateway/internal/provider/openai.go
@@ -111,6 +111,12 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string,
func (p *OpenAIProvider) setHeaders(req *http.Request) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+p.apiKey)
+ // Forward request ID if present in context
+ if reqID := req.Context().Value("requestID"); reqID != nil {
+ if id, ok := reqID.(string); ok && id != "" {
+ req.Header.Set("X-Request-ID", id)
+ }
+ }
}
// ProviderError represents a non-200 response from a provider.
diff --git a/llm-gateway/internal/provider/registry.go b/llm-gateway/internal/provider/registry.go
index 1a2302b..5a4dced 100644
--- a/llm-gateway/internal/provider/registry.go
+++ b/llm-gateway/internal/provider/registry.go
@@ -3,6 +3,7 @@ package provider
import (
"fmt"
"sort"
+ "sync"
"llm-gateway/internal/config"
)
@@ -18,26 +19,40 @@ type Route struct {
// Registry maps model names to provider routes.
type Registry struct {
- routes map[string][]Route
- order []string // preserves config order
+ mu sync.RWMutex
+ routes map[string][]Route
+ balancers map[string]LoadBalancer
+ aliases map[string]string // alias -> canonical name
+ order []string // preserves config order (canonical names only)
}
func NewRegistry(cfg *config.Config) (*Registry, error) {
+ r := &Registry{}
+ if err := r.buildFromConfig(cfg); err != nil {
+ return nil, err
+ }
+ return r, nil
+}
+
+func (r *Registry) buildFromConfig(cfg *config.Config) error {
// Build providers
providers := make(map[string]Provider)
for _, pc := range cfg.Providers {
providers[pc.Name] = NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKey, pc.Timeout)
}
- // Build routes (preserving config order)
+ // Build routes
routes := make(map[string][]Route)
+ balancers := make(map[string]LoadBalancer)
+ aliases := make(map[string]string)
order := make([]string, 0, len(cfg.Models))
+
for _, mc := range cfg.Models {
var modelRoutes []Route
for _, rc := range mc.Routes {
p, ok := providers[rc.Provider]
if !ok {
- return nil, fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider)
+ return fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider)
}
pc := cfg.ProviderByName(rc.Provider)
priority := pc.Priority
@@ -55,20 +70,69 @@ func NewRegistry(cfg *config.Config) (*Registry, error) {
})
routes[mc.Name] = modelRoutes
order = append(order, mc.Name)
+
+ // Load balancer
+ balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing)
+
+ // Register aliases
+ for _, alias := range mc.Aliases {
+ aliases[alias] = mc.Name
+ }
}
- return &Registry{routes: routes, order: order}, nil
+ r.mu.Lock()
+ r.routes = routes
+ r.balancers = balancers
+ r.aliases = aliases
+ r.order = order
+ r.mu.Unlock()
+
+ return nil
}
-// Lookup returns the routes for a model name.
+// Reload rebuilds routes from new config. Used for hot-reload.
+func (r *Registry) Reload(cfg *config.Config) error {
+ return r.buildFromConfig(cfg)
+}
+
+// Lookup returns the routes for a model name (resolving aliases).
func (r *Registry) Lookup(model string) ([]Route, bool) {
- routes, ok := r.routes[model]
- return routes, ok
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ // Resolve alias
+ canonical := model
+ if alias, ok := r.aliases[model]; ok {
+ canonical = alias
+ }
+
+ routes, ok := r.routes[canonical]
+ if !ok {
+ return nil, false
+ }
+
+ // Apply load balancer
+ if balancer, ok := r.balancers[canonical]; ok {
+ routes = balancer.Reorder(routes)
+ }
+
+ return routes, true
}
-// ModelNames returns all registered model names in config order.
+// ModelNames returns all registered model names in config order (including aliases).
func (r *Registry) ModelNames() []string {
- return r.order
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ var names []string
+ for _, name := range r.order {
+ names = append(names, name)
+ }
+ // Add aliases
+ for alias := range r.aliases {
+ names = append(names, alias)
+ }
+ return names
}
// RouteInfo exposes route details for dashboard display.
@@ -82,16 +146,29 @@ type RouteInfo struct {
// ModelRouteInfo exposes a model and its routes for dashboard display.
type ModelRouteInfo struct {
- Name string `json:"name"`
- Routes []RouteInfo `json:"routes"`
+ Name string `json:"name"`
+ Aliases []string `json:"aliases,omitempty"`
+ Routes []RouteInfo `json:"routes"`
}
// AllRoutes returns all models and their routes in config order.
func (r *Registry) AllRoutes() []ModelRouteInfo {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ // Build reverse alias map
+ modelAliases := make(map[string][]string)
+ for alias, canonical := range r.aliases {
+ modelAliases[canonical] = append(modelAliases[canonical], alias)
+ }
+
results := make([]ModelRouteInfo, 0, len(r.order))
for _, name := range r.order {
routes := r.routes[name]
- info := ModelRouteInfo{Name: name}
+ info := ModelRouteInfo{
+ Name: name,
+ Aliases: modelAliases[name],
+ }
for _, rt := range routes {
info.Routes = append(info.Routes, RouteInfo{
ProviderName: rt.Provider.Name(),
diff --git a/llm-gateway/internal/provider/registry_test.go b/llm-gateway/internal/provider/registry_test.go
new file mode 100644
index 0000000..b04c45a
--- /dev/null
+++ b/llm-gateway/internal/provider/registry_test.go
@@ -0,0 +1,282 @@
+package provider
+
+import (
+ "context"
+ "io"
+ "testing"
+
+ "llm-gateway/internal/config"
+)
+
+// mockProvider implements the Provider interface for testing.
+type mockProvider struct {
+ name string
+}
+
+func (m *mockProvider) Name() string { return m.name }
+
+func (m *mockProvider) ChatCompletion(_ context.Context, _ string, _ *ChatRequest) (*ChatResponse, error) {
+ return nil, nil
+}
+
+func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *ChatRequest) (io.ReadCloser, error) {
+ return nil, nil
+}
+
+// newTestRegistry builds a Registry directly without going through config parsing.
+func newTestRegistry(models []testModel) *Registry {
+ r := &Registry{
+ routes: make(map[string][]Route),
+ balancers: make(map[string]LoadBalancer),
+ aliases: make(map[string]string),
+ }
+
+ for _, m := range models {
+ r.routes[m.name] = m.routes
+ r.balancers[m.name] = &FirstBalancer{}
+ r.order = append(r.order, m.name)
+ for _, alias := range m.aliases {
+ r.aliases[alias] = m.name
+ }
+ }
+
+ return r
+}
+
+type testModel struct {
+ name string
+ aliases []string
+ routes []Route
+}
+
+func TestRegistry_Lookup_Canonical(t *testing.T) {
+ reg := newTestRegistry([]testModel{
+ {
+ name: "gpt-4",
+ routes: []Route{
+ {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
+ },
+ },
+ })
+
+ routes, ok := reg.Lookup("gpt-4")
+ if !ok {
+ t.Fatal("expected Lookup to find gpt-4")
+ }
+ if len(routes) != 1 {
+ t.Fatalf("expected 1 route, got %d", len(routes))
+ }
+ if routes[0].Provider.Name() != "openai" {
+ t.Errorf("expected provider 'openai', got %q", routes[0].Provider.Name())
+ }
+}
+
+func TestRegistry_Lookup_Alias(t *testing.T) {
+ reg := newTestRegistry([]testModel{
+ {
+ name: "gpt-4",
+ aliases: []string{"gpt4", "gpt-4-latest"},
+ routes: []Route{
+ {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
+ },
+ },
+ })
+
+ tests := []struct {
+ name string
+ model string
+ found bool
+ }{
+ {"canonical", "gpt-4", true},
+ {"alias1", "gpt4", true},
+ {"alias2", "gpt-4-latest", true},
+ {"unknown", "gpt-5", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ routes, ok := reg.Lookup(tt.model)
+ if ok != tt.found {
+ t.Fatalf("Lookup(%q) found=%v, want %v", tt.model, ok, tt.found)
+ }
+ if tt.found && len(routes) != 1 {
+ t.Fatalf("expected 1 route, got %d", len(routes))
+ }
+ })
+ }
+}
+
+func TestRegistry_ModelNames_IncludesAliases(t *testing.T) {
+ reg := newTestRegistry([]testModel{
+ {
+ name: "gpt-4",
+ aliases: []string{"gpt4"},
+ routes: []Route{
+ {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
+ },
+ },
+ {
+ name: "claude-3",
+ routes: []Route{
+ {Provider: &mockProvider{name: "anthropic"}, ProviderModel: "claude-3", Priority: 1},
+ },
+ },
+ })
+
+ names := reg.ModelNames()
+
+ want := map[string]bool{"gpt-4": true, "gpt4": true, "claude-3": true}
+ got := make(map[string]bool)
+ for _, n := range names {
+ got[n] = true
+ }
+
+ for name := range want {
+ if !got[name] {
+ t.Errorf("expected %q in ModelNames, not found", name)
+ }
+ }
+
+ if len(names) != len(want) {
+ t.Errorf("expected %d names, got %d: %v", len(want), len(names), names)
+ }
+}
+
+func TestRegistry_AllRoutes_ShowsAliases(t *testing.T) {
+ reg := newTestRegistry([]testModel{
+ {
+ name: "gpt-4",
+ aliases: []string{"gpt4", "gpt-4-latest"},
+ routes: []Route{
+ {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
+ {Provider: &mockProvider{name: "azure"}, ProviderModel: "gpt-4", Priority: 2},
+ },
+ },
+ })
+
+ allRoutes := reg.AllRoutes()
+ if len(allRoutes) != 1 {
+ t.Fatalf("expected 1 model, got %d", len(allRoutes))
+ }
+
+ m := allRoutes[0]
+ if m.Name != "gpt-4" {
+ t.Errorf("expected name 'gpt-4', got %q", m.Name)
+ }
+
+ aliasSet := make(map[string]bool)
+ for _, a := range m.Aliases {
+ aliasSet[a] = true
+ }
+ if !aliasSet["gpt4"] || !aliasSet["gpt-4-latest"] {
+ t.Errorf("expected aliases [gpt4, gpt-4-latest], got %v", m.Aliases)
+ }
+
+ if len(m.Routes) != 2 {
+ t.Fatalf("expected 2 routes, got %d", len(m.Routes))
+ }
+ if m.Routes[0].ProviderName != "openai" {
+ t.Errorf("expected first route provider 'openai', got %q", m.Routes[0].ProviderName)
+ }
+ if m.Routes[1].ProviderName != "azure" {
+ t.Errorf("expected second route provider 'azure', got %q", m.Routes[1].ProviderName)
+ }
+}
+
+func TestRegistry_AllRoutes_ConfigOrder(t *testing.T) {
+ reg := newTestRegistry([]testModel{
+ {
+ name: "model-b",
+ routes: []Route{
+ {Provider: &mockProvider{name: "prov"}, ProviderModel: "b", Priority: 1},
+ },
+ },
+ {
+ name: "model-a",
+ routes: []Route{
+ {Provider: &mockProvider{name: "prov"}, ProviderModel: "a", Priority: 1},
+ },
+ },
+ })
+
+ allRoutes := reg.AllRoutes()
+ if len(allRoutes) != 2 {
+ t.Fatalf("expected 2 models, got %d", len(allRoutes))
+ }
+ if allRoutes[0].Name != "model-b" {
+ t.Errorf("expected first model 'model-b', got %q", allRoutes[0].Name)
+ }
+ if allRoutes[1].Name != "model-a" {
+ t.Errorf("expected second model 'model-a', got %q", allRoutes[1].Name)
+ }
+}
+
+func TestRegistry_PrioritySorting(t *testing.T) {
+ reg := newTestRegistry([]testModel{
+ {
+ name: "multi-provider",
+ routes: []Route{
+ {Provider: &mockProvider{name: "low-priority"}, ProviderModel: "m", Priority: 3},
+ {Provider: &mockProvider{name: "high-priority"}, ProviderModel: "m", Priority: 1},
+ {Provider: &mockProvider{name: "mid-priority"}, ProviderModel: "m", Priority: 2},
+ },
+ },
+ })
+
+ // Note: routes are stored as given (sorting happens during buildFromConfig).
+ // For this test we verify AllRoutes returns them in stored order.
+ allRoutes := reg.AllRoutes()
+ if len(allRoutes) != 1 {
+ t.Fatalf("expected 1 model, got %d", len(allRoutes))
+ }
+
+ routes := allRoutes[0].Routes
+ if len(routes) != 3 {
+ t.Fatalf("expected 3 routes, got %d", len(routes))
+ }
+
+ // Verify the priorities are present
+ priorities := make(map[int]bool)
+ for _, r := range routes {
+ priorities[r.Priority] = true
+ }
+ for _, p := range []int{1, 2, 3} {
+ if !priorities[p] {
+ t.Errorf("expected priority %d in routes", p)
+ }
+ }
+}
+
+func TestRegistry_NewRegistry_UnknownProvider(t *testing.T) {
+ cfg := &config.Config{
+ Models: []config.ModelConfig{
+ {
+ Name: "test-model",
+ Routes: []config.RouteConfig{
+ {Provider: "nonexistent", Model: "m"},
+ },
+ },
+ },
+ }
+
+ _, err := NewRegistry(cfg)
+ if err == nil {
+ t.Fatal("expected error for unknown provider, got nil")
+ }
+}
+
+func TestRegistry_Lookup_NotFound(t *testing.T) {
+ reg := newTestRegistry([]testModel{
+ {
+ name: "gpt-4",
+ routes: []Route{
+ {Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
+ },
+ },
+ })
+
+ _, ok := reg.Lookup("nonexistent")
+ if ok {
+ t.Fatal("expected Lookup to return false for nonexistent model")
+ }
+}
diff --git a/llm-gateway/internal/proxy/concurrency.go b/llm-gateway/internal/proxy/concurrency.go
new file mode 100644
index 0000000..4f28262
--- /dev/null
+++ b/llm-gateway/internal/proxy/concurrency.go
@@ -0,0 +1,51 @@
+package proxy
+
+import (
+ "net/http"
+ "sync"
+ "sync/atomic"
+)
+
+// ConcurrencyLimiter enforces per-token concurrent request limits.
+type ConcurrencyLimiter struct {
+ mu sync.Mutex
+ counters map[string]*atomic.Int64
+}
+
+func NewConcurrencyLimiter() *ConcurrencyLimiter {
+ return &ConcurrencyLimiter{
+ counters: make(map[string]*atomic.Int64),
+ }
+}
+
+func (cl *ConcurrencyLimiter) getCounter(tokenName string) *atomic.Int64 {
+ cl.mu.Lock()
+ defer cl.mu.Unlock()
+ c, ok := cl.counters[tokenName]
+ if !ok {
+ c = &atomic.Int64{}
+ cl.counters[tokenName] = c
+ }
+ return c
+}
+
+func (cl *ConcurrencyLimiter) Check(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ apiToken := getAPIToken(r.Context())
+ if apiToken == nil || apiToken.MaxConcurrent <= 0 {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ counter := cl.getCounter(apiToken.Name)
+ current := counter.Add(1)
+ defer counter.Add(-1)
+
+ if current > int64(apiToken.MaxConcurrent) {
+ writeError(w, http.StatusTooManyRequests, "concurrent request limit exceeded")
+ return
+ }
+
+ next.ServeHTTP(w, r)
+ })
+}
diff --git a/llm-gateway/internal/proxy/concurrency_test.go b/llm-gateway/internal/proxy/concurrency_test.go
new file mode 100644
index 0000000..fa6ccbf
--- /dev/null
+++ b/llm-gateway/internal/proxy/concurrency_test.go
@@ -0,0 +1,317 @@
+package proxy
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "llm-gateway/internal/auth"
+)
+
+func TestConcurrencyLimiter_AllowsWithinLimit(t *testing.T) {
+ tests := []struct {
+ name string
+ maxConcurrent int
+ numRequests int
+ wantAllowed int
+ }{
+ {
+ name: "single request within limit",
+ maxConcurrent: 5,
+ numRequests: 1,
+ wantAllowed: 1,
+ },
+ {
+ name: "all requests within limit",
+ maxConcurrent: 5,
+ numRequests: 5,
+ wantAllowed: 5,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cl := NewConcurrencyLimiter()
+
+ token := &auth.APIToken{
+ Name: "conc-token",
+ MaxConcurrent: tt.maxConcurrent,
+ }
+
+ var allowed atomic.Int64
+ var wg sync.WaitGroup
+ // Use a channel to hold all goroutines inside the handler simultaneously.
+ gate := make(chan struct{})
+
+ handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ allowed.Add(1)
+ <-gate // Block until released.
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ for i := 0; i < tt.numRequests; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx := withAPIToken(req.Context(), token)
+ req = req.WithContext(ctx)
+ handler.ServeHTTP(rec, req)
+ }()
+ }
+
+ // Wait for goroutines to enter the handler.
+ time.Sleep(50 * time.Millisecond)
+ close(gate)
+ wg.Wait()
+
+ if int(allowed.Load()) != tt.wantAllowed {
+ t.Errorf("allowed = %d, want %d", allowed.Load(), tt.wantAllowed)
+ }
+ })
+ }
+}
+
+func TestConcurrencyLimiter_DeniesOverLimit(t *testing.T) {
+ tests := []struct {
+ name string
+ maxConcurrent int
+ numRequests int
+ wantDenied int
+ }{
+ {
+ name: "one over limit",
+ maxConcurrent: 2,
+ numRequests: 3,
+ wantDenied: 1,
+ },
+ {
+ name: "many over limit",
+ maxConcurrent: 1,
+ numRequests: 5,
+ wantDenied: 4,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cl := NewConcurrencyLimiter()
+
+ token := &auth.APIToken{
+ Name: "conc-token",
+ MaxConcurrent: tt.maxConcurrent,
+ }
+
+ var denied atomic.Int64
+ var wg sync.WaitGroup
+ gate := make(chan struct{})
+
+ handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ <-gate
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ results := make([]int, tt.numRequests)
+ for i := 0; i < tt.numRequests; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx := withAPIToken(req.Context(), token)
+ req = req.WithContext(ctx)
+ handler.ServeHTTP(rec, req)
+ results[idx] = rec.Code
+ if rec.Code == http.StatusTooManyRequests {
+ denied.Add(1)
+ }
+ }(i)
+ }
+
+ // Wait for goroutines to reach the handler or be rejected.
+ time.Sleep(50 * time.Millisecond)
+ close(gate)
+ wg.Wait()
+
+ if int(denied.Load()) != tt.wantDenied {
+ t.Errorf("denied = %d, want %d", denied.Load(), tt.wantDenied)
+ }
+ })
+ }
+}
+
+func TestConcurrencyLimiter_CounterDecrementsAfterCompletion(t *testing.T) {
+ cl := NewConcurrencyLimiter()
+
+ token := &auth.APIToken{
+ Name: "decrement-token",
+ MaxConcurrent: 1,
+ }
+
+ handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ // First request should succeed and complete, decrementing the counter.
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx := withAPIToken(req.Context(), token)
+ req = req.WithContext(ctx)
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("first request: status = %d, want %d", rec.Code, http.StatusOK)
+ }
+
+ // Counter should have decremented. A second request should also succeed.
+ rec2 := httptest.NewRecorder()
+ req2 := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx2 := withAPIToken(req2.Context(), token)
+ req2 = req2.WithContext(ctx2)
+ handler.ServeHTTP(rec2, req2)
+
+ if rec2.Code != http.StatusOK {
+ t.Errorf("second request after first completed: status = %d, want %d", rec2.Code, http.StatusOK)
+ }
+
+ // Verify the internal counter is back to 0.
+ counter := cl.getCounter(token.Name)
+ val := counter.Load()
+ if val != 0 {
+ t.Errorf("counter = %d, want 0 after all requests completed", val)
+ }
+}
+
+func TestConcurrencyLimiter_ZeroMaxConcurrentMeansUnlimited(t *testing.T) {
+ tests := []struct {
+ name string
+ maxConcurrent int
+ numRequests int
+ }{
+ {
+ name: "zero allows unlimited concurrent requests",
+ maxConcurrent: 0,
+ numRequests: 50,
+ },
+ {
+ name: "negative allows unlimited concurrent requests",
+ maxConcurrent: -1,
+ numRequests: 50,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cl := NewConcurrencyLimiter()
+
+ token := &auth.APIToken{
+ Name: "unlimited-token",
+ MaxConcurrent: tt.maxConcurrent,
+ }
+
+ var allowed atomic.Int64
+ var wg sync.WaitGroup
+ gate := make(chan struct{})
+
+ handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ allowed.Add(1)
+ <-gate
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ for i := 0; i < tt.numRequests; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx := withAPIToken(req.Context(), token)
+ req = req.WithContext(ctx)
+ handler.ServeHTTP(rec, req)
+ }()
+ }
+
+ // Give goroutines time to enter the handler.
+ time.Sleep(100 * time.Millisecond)
+ close(gate)
+ wg.Wait()
+
+ if int(allowed.Load()) != tt.numRequests {
+ t.Errorf("allowed = %d, want %d (zero/negative maxConcurrent should be unlimited)", allowed.Load(), tt.numRequests)
+ }
+ })
+ }
+}
+
+func TestConcurrencyLimiter_NoToken(t *testing.T) {
+ cl := NewConcurrencyLimiter()
+
+ handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ // No API token in context.
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Errorf("status = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
+ }
+}
+
+func TestConcurrencyLimiter_PerTokenIsolation(t *testing.T) {
+ cl := NewConcurrencyLimiter()
+
+ tokenA := &auth.APIToken{
+ Name: "token-a",
+ MaxConcurrent: 1,
+ }
+ tokenB := &auth.APIToken{
+ Name: "token-b",
+ MaxConcurrent: 1,
+ }
+
+ gateA := make(chan struct{})
+ var wg sync.WaitGroup
+
+ handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ tok := getAPIToken(r.Context())
+ if tok.Name == "token-a" {
+ <-gateA // Block token A's request.
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+
+ // Start a request for token A that blocks.
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx := withAPIToken(req.Context(), tokenA)
+ req = req.WithContext(ctx)
+ handler.ServeHTTP(rec, req)
+ }()
+
+ // Give token A's goroutine time to enter handler.
+ time.Sleep(50 * time.Millisecond)
+
+ // Token B should not be affected by token A's in-flight request.
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx := withAPIToken(req.Context(), tokenB)
+ req = req.WithContext(ctx)
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Errorf("token-b status = %d, want %d (should not be affected by token-a)", rec.Code, http.StatusOK)
+ }
+
+ close(gateA)
+ wg.Wait()
+}
diff --git a/llm-gateway/internal/proxy/handler.go b/llm-gateway/internal/proxy/handler.go
index ba4f36a..cb6c619 100644
--- a/llm-gateway/internal/proxy/handler.go
+++ b/llm-gateway/internal/proxy/handler.go
@@ -4,11 +4,16 @@ import (
"context"
"encoding/json"
"errors"
+ "fmt"
"io"
"log"
"net/http"
+ "sort"
+ "strings"
"time"
+ "github.com/go-chi/chi/v5/middleware"
+
"llm-gateway/internal/auth"
"llm-gateway/internal/cache"
"llm-gateway/internal/config"
@@ -47,6 +52,7 @@ type Handler struct {
metrics *metrics.Metrics
cfg *config.Config
healthTracker *provider.HealthTracker
+ debugLogger *storage.DebugLogger
}
func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler {
@@ -60,6 +66,10 @@ func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cac
}
}
+func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) {
+ h.debugLogger = dl
+}
+
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
if err != nil {
@@ -84,31 +94,53 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
return
}
+ // Filter healthy routes (circuit breaker)
+ routes = h.filterHealthyRoutes(routes)
+
tokenName := getTokenName(r.Context())
+ requestID := middleware.GetReqID(r.Context())
// Check cache for non-streaming requests
if !req.Stream && h.cache != nil {
if cached, err := h.cache.Get(r.Context(), req.Model, body); err == nil && cached != nil {
- h.logRequest(tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true)
+ h.logRequest(requestID, tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true)
+ if h.metrics != nil {
+ h.metrics.RecordCacheHit()
+ }
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Cache", "HIT")
+ w.Header().Set("X-Request-ID", requestID)
w.Write(cached)
return
}
+ if h.metrics != nil {
+ h.metrics.RecordCacheMiss()
+ }
}
if req.Stream {
- h.handleStream(w, r, &req, routes, tokenName)
+ h.handleStream(w, r, &req, routes, tokenName, requestID)
return
}
- h.handleNonStream(w, r, &req, routes, tokenName, body)
+ h.handleNonStream(w, r, &req, routes, tokenName, body, requestID)
}
-func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte) {
+func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string) {
var lastErr error
- for _, route := range routes {
+ for i, route := range routes {
+ // Retry backoff between attempts (not before first attempt)
+ if i > 0 {
+ backoff := backoffDuration(i, h.cfg.Retry)
+ select {
+ case <-time.After(backoff):
+ case <-r.Context().Done():
+ writeError(w, http.StatusGatewayTimeout, "request cancelled")
+ return
+ }
+ }
+
start := time.Now()
resp, err := route.Provider.ChatCompletion(r.Context(), route.ProviderModel, req)
latency := time.Since(start).Milliseconds()
@@ -116,19 +148,19 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
if err != nil {
var pe *provider.ProviderError
if errors.As(err, &pe) && !pe.IsRetryable() {
- // Client error — don't retry
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
- h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
+ h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
+ w.Header().Set("X-Request-ID", requestID)
writeErrorRaw(w, pe.StatusCode, pe.Body)
return
}
lastErr = err
log.Printf("Provider %s failed for %s: %v", route.Provider.Name(), req.Model, err)
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
- h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
+ h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
@@ -139,7 +171,6 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
h.healthTracker.Record(route.Provider.Name(), latency, nil)
}
- // Compute cost
inputTokens, outputTokens := 0, 0
if resp.Usage != nil {
inputTokens = resp.Usage.PromptTokens
@@ -148,9 +179,8 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
- h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false)
+ h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false)
- // Override model name in response to match the requested model
resp.Model = req.Model
respBytes, err := json.Marshal(resp)
@@ -159,27 +189,84 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *p
return
}
- // Cache the response
if h.cache != nil {
h.cache.Set(r.Context(), req.Model, rawBody, respBytes)
}
+ // Debug logging
+ if h.debugLogger != nil && h.debugLogger.IsEnabled() {
+ reqBody := string(rawBody)
+ respBody := string(respBytes)
+ if h.cfg.Debug.MaxBodyBytes > 0 {
+ if len(reqBody) > h.cfg.Debug.MaxBodyBytes {
+ reqBody = reqBody[:h.cfg.Debug.MaxBodyBytes]
+ }
+ if len(respBody) > h.cfg.Debug.MaxBodyBytes {
+ respBody = respBody[:h.cfg.Debug.MaxBodyBytes]
+ }
+ }
+ h.debugLogger.Log(storage.DebugLogEntry{
+ RequestID: requestID,
+ TokenName: tokenName,
+ Model: req.Model,
+ Provider: route.Provider.Name(),
+ RequestBody: reqBody,
+ ResponseBody: respBody,
+ RequestHeaders: formatHeaders(r.Header),
+ ResponseStatus: http.StatusOK,
+ })
+ }
+
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Cache", "MISS")
+ w.Header().Set("X-Request-ID", requestID)
w.Write(respBytes)
return
}
- // All providers failed
if lastErr != nil {
+ w.Header().Set("X-Request-ID", requestID)
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
} else {
+ w.Header().Set("X-Request-ID", requestID)
writeError(w, http.StatusBadGateway, "all providers failed")
}
}
-func (h *Handler) logRequest(tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) {
+// filterHealthyRoutes removes providers with open circuit breakers.
+// If all are filtered out, returns original routes as fallback.
+func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route {
+ if h.healthTracker == nil {
+ return routes
+ }
+ var healthy []provider.Route
+ for _, r := range routes {
+ if h.healthTracker.IsAvailable(r.Provider.Name()) {
+ healthy = append(healthy, r)
+ }
+ }
+ if len(healthy) == 0 {
+ return routes // all-down fallback
+ }
+ return healthy
+}
+
+// backoffDuration computes exponential backoff for the given attempt.
+func backoffDuration(attempt int, cfg config.RetryConfig) time.Duration {
+ d := cfg.InitialBackoff
+ for i := 1; i < attempt; i++ {
+ d = time.Duration(float64(d) * cfg.Multiplier)
+ if d > cfg.MaxBackoff {
+ d = cfg.MaxBackoff
+ break
+ }
+ }
+ return d
+}
+
+func (h *Handler) logRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) {
h.logger.Log(storage.RequestLog{
+ RequestID: requestID,
Timestamp: time.Now().Unix(),
TokenName: tokenName,
Model: model,
@@ -217,3 +304,23 @@ func writeErrorRaw(w http.ResponseWriter, code int, body string) {
w.WriteHeader(code)
w.Write([]byte(body))
}
+
+// formatHeaders serializes HTTP headers to a readable string, sorted by key.
+// Sensitive headers (Authorization) are redacted.
+func formatHeaders(h http.Header) string {
+ keys := make([]string, 0, len(h))
+ for k := range h {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+
+ var b strings.Builder
+ for _, k := range keys {
+ val := strings.Join(h[k], ", ")
+ if strings.EqualFold(k, "Authorization") {
+ val = "[REDACTED]"
+ }
+ fmt.Fprintf(&b, "%s: %s\n", k, val)
+ }
+ return b.String()
+}
diff --git a/llm-gateway/internal/proxy/ratelimit.go b/llm-gateway/internal/proxy/ratelimit.go
index 2278ea1..240bbe2 100644
--- a/llm-gateway/internal/proxy/ratelimit.go
+++ b/llm-gateway/internal/proxy/ratelimit.go
@@ -1,6 +1,8 @@
package proxy
import (
+ "fmt"
+ "math"
"net/http"
"sync"
"time"
@@ -40,7 +42,19 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
// Check rate limit
if apiToken.RateLimitRPM > 0 {
- if !rl.allow(tokenName, apiToken.RateLimitRPM) {
+ allowed, remaining, resetAt := rl.allow(tokenName, apiToken.RateLimitRPM)
+
+ // Set rate limit headers on all responses
+ w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", apiToken.RateLimitRPM))
+ w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
+ w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetAt))
+
+ if !allowed {
+ retryAfter := resetAt - time.Now().Unix()
+ if retryAfter < 1 {
+ retryAfter = 1
+ }
+ w.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter))
writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
return
}
@@ -59,7 +73,7 @@ func (rl *RateLimiter) Check(next http.Handler) http.Handler {
})
}
-func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) bool {
+func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) {
rl.mu.Lock()
defer rl.mu.Unlock()
@@ -82,9 +96,27 @@ func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) bool {
}
bucket.lastRefill = now
+ remaining := int(math.Floor(bucket.tokens))
+ if remaining < 0 {
+ remaining = 0
+ }
+
+ // Compute reset time: when bucket would be full again
+ deficit := bucket.maxTokens - bucket.tokens
+ var resetAt int64
+ if deficit > 0 && bucket.refillRate > 0 {
+ resetAt = now.Add(time.Duration(deficit/bucket.refillRate) * time.Second).Unix()
+ } else {
+ resetAt = now.Unix()
+ }
+
if bucket.tokens < 1 {
- return false
+ return false, 0, resetAt
}
bucket.tokens--
- return true
+ remaining = int(math.Floor(bucket.tokens))
+ if remaining < 0 {
+ remaining = 0
+ }
+ return true, remaining, resetAt
}
diff --git a/llm-gateway/internal/proxy/ratelimit_test.go b/llm-gateway/internal/proxy/ratelimit_test.go
new file mode 100644
index 0000000..8d1afb7
--- /dev/null
+++ b/llm-gateway/internal/proxy/ratelimit_test.go
@@ -0,0 +1,374 @@
+package proxy
+
+import (
+ "context"
+ "database/sql"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+ "time"
+
+ _ "modernc.org/sqlite"
+
+ "llm-gateway/internal/auth"
+ "llm-gateway/internal/storage"
+)
+
+// newTestDB creates an in-memory SQLite database wrapped in storage.DB.
+// It creates the request_logs table needed by TodaySpend.
+func newTestDB(t *testing.T) *storage.DB {
+ t.Helper()
+ sqlDB, err := sql.Open("sqlite", ":memory:")
+ if err != nil {
+ t.Fatalf("opening in-memory sqlite: %v", err)
+ }
+ t.Cleanup(func() { sqlDB.Close() })
+
+ // Create the minimal table needed for TodaySpend queries.
+ _, err = sqlDB.Exec(`CREATE TABLE IF NOT EXISTS request_logs (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ token_name TEXT,
+ cost_usd REAL,
+ timestamp INTEGER
+ )`)
+ if err != nil {
+ t.Fatalf("creating request_logs table: %v", err)
+ }
+ return &storage.DB{DB: sqlDB}
+}
+
+// okHandler is a simple handler that writes 200 OK.
+var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+})
+
+func TestRateLimiter_Allow(t *testing.T) {
+ tests := []struct {
+ name string
+ rateLimitRPM int
+ numRequests int
+ wantAllowed int
+ wantDenied int
+ }{
+ {
+ name: "allows requests within limit",
+ rateLimitRPM: 10,
+ numRequests: 5,
+ wantAllowed: 5,
+ wantDenied: 0,
+ },
+ {
+ name: "denies requests over limit",
+ rateLimitRPM: 3,
+ numRequests: 6,
+ wantAllowed: 3,
+ wantDenied: 3,
+ },
+ {
+ name: "allows exactly up to limit",
+ rateLimitRPM: 5,
+ numRequests: 5,
+ wantAllowed: 5,
+ wantDenied: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ db := newTestDB(t)
+ rl := NewRateLimiter(db)
+
+ allowed := 0
+ denied := 0
+ for i := 0; i < tt.numRequests; i++ {
+ ok, _, _ := rl.allow("test-token", tt.rateLimitRPM)
+ if ok {
+ allowed++
+ } else {
+ denied++
+ }
+ }
+
+ if allowed != tt.wantAllowed {
+ t.Errorf("allowed = %d, want %d", allowed, tt.wantAllowed)
+ }
+ if denied != tt.wantDenied {
+ t.Errorf("denied = %d, want %d", denied, tt.wantDenied)
+ }
+ })
+ }
+}
+
+func TestRateLimiter_TokenRefillsOverTime(t *testing.T) {
+ db := newTestDB(t)
+ rl := NewRateLimiter(db)
+
+ rpm := 60 // 1 token per second refill rate
+
+ // Exhaust all tokens.
+ for i := 0; i < rpm; i++ {
+ ok, _, _ := rl.allow("refill-token", rpm)
+ if !ok {
+ t.Fatalf("request %d should have been allowed", i)
+ }
+ }
+
+ // Next request should be denied.
+ ok, _, _ := rl.allow("refill-token", rpm)
+ if ok {
+ t.Fatal("request should have been denied after exhausting tokens")
+ }
+
+ // Manually advance the bucket's lastRefill to simulate time passing.
+ rl.mu.Lock()
+ bucket := rl.buckets["refill-token"]
+ bucket.lastRefill = bucket.lastRefill.Add(-2 * time.Second)
+ rl.mu.Unlock()
+
+ // After 2 seconds at 1 token/sec, we should have ~2 tokens refilled.
+ ok, remaining, _ := rl.allow("refill-token", rpm)
+ if !ok {
+ t.Fatal("request should have been allowed after token refill")
+ }
+ // We consumed 1 of the ~2 refilled tokens, so remaining should be >= 0.
+ if remaining < 0 {
+ t.Errorf("remaining = %d, want >= 0", remaining)
+ }
+}
+
+func TestRateLimiter_AllowReturnValues(t *testing.T) {
+ tests := []struct {
+ name string
+ rateLimitRPM int
+ numRequests int
+ wantLastAllowed bool
+ wantLastRemaining int
+ }{
+ {
+ name: "remaining decrements correctly",
+ rateLimitRPM: 5,
+ numRequests: 1,
+ wantLastAllowed: true,
+ wantLastRemaining: 4,
+ },
+ {
+ name: "remaining is zero at limit",
+ rateLimitRPM: 3,
+ numRequests: 3,
+ wantLastAllowed: true,
+ wantLastRemaining: 0,
+ },
+ {
+ name: "denied returns zero remaining",
+ rateLimitRPM: 2,
+ numRequests: 3,
+ wantLastAllowed: false,
+ wantLastRemaining: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ db := newTestDB(t)
+ rl := NewRateLimiter(db)
+
+ var allowed bool
+ var remaining int
+ for i := 0; i < tt.numRequests; i++ {
+ allowed, remaining, _ = rl.allow("test-token", tt.rateLimitRPM)
+ }
+
+ if allowed != tt.wantLastAllowed {
+ t.Errorf("allowed = %v, want %v", allowed, tt.wantLastAllowed)
+ }
+ if remaining != tt.wantLastRemaining {
+ t.Errorf("remaining = %d, want %d", remaining, tt.wantLastRemaining)
+ }
+ })
+ }
+}
+
+func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) {
+ tests := []struct {
+ name string
+ rateLimitRPM int
+ numRequests int
+ wantStatusCode int
+ wantLimitHeader string
+ wantRetryAfter bool
+ }{
+ {
+ name: "sets rate limit headers on allowed request",
+ rateLimitRPM: 10,
+ numRequests: 1,
+ wantStatusCode: http.StatusOK,
+ wantLimitHeader: "10",
+ wantRetryAfter: false,
+ },
+ {
+ name: "sets Retry-After header on 429",
+ rateLimitRPM: 2,
+ numRequests: 3,
+ wantStatusCode: http.StatusTooManyRequests,
+ wantLimitHeader: "2",
+ wantRetryAfter: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ db := newTestDB(t)
+ rl := NewRateLimiter(db)
+
+ token := &auth.APIToken{
+ Name: "header-test-token",
+ RateLimitRPM: tt.rateLimitRPM,
+ }
+
+ handler := rl.Check(okHandler)
+
+ var rec *httptest.ResponseRecorder
+ for i := 0; i < tt.numRequests; i++ {
+ rec = httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx := withAPIToken(req.Context(), token)
+ req = req.WithContext(ctx)
+ handler.ServeHTTP(rec, req)
+ }
+
+ // Check the last response.
+ if rec.Code != tt.wantStatusCode {
+ t.Errorf("status code = %d, want %d", rec.Code, tt.wantStatusCode)
+ }
+
+ // X-RateLimit-Limit header.
+ limitHeader := rec.Header().Get("X-RateLimit-Limit")
+ if limitHeader != tt.wantLimitHeader {
+ t.Errorf("X-RateLimit-Limit = %q, want %q", limitHeader, tt.wantLimitHeader)
+ }
+
+ // X-RateLimit-Remaining header must be present and numeric.
+ remainingHeader := rec.Header().Get("X-RateLimit-Remaining")
+ if remainingHeader == "" {
+ t.Error("X-RateLimit-Remaining header is missing")
+ } else if _, err := strconv.Atoi(remainingHeader); err != nil {
+ t.Errorf("X-RateLimit-Remaining = %q, not a valid integer", remainingHeader)
+ }
+
+ // X-RateLimit-Reset header must be present and numeric.
+ resetHeader := rec.Header().Get("X-RateLimit-Reset")
+ if resetHeader == "" {
+ t.Error("X-RateLimit-Reset header is missing")
+ } else if _, err := strconv.ParseInt(resetHeader, 10, 64); err != nil {
+ t.Errorf("X-RateLimit-Reset = %q, not a valid integer", resetHeader)
+ }
+
+ // Retry-After header.
+ retryAfter := rec.Header().Get("Retry-After")
+ if tt.wantRetryAfter && retryAfter == "" {
+ t.Error("Retry-After header is missing on 429 response")
+ }
+ if !tt.wantRetryAfter && retryAfter != "" {
+ t.Errorf("Retry-After header should not be present, got %q", retryAfter)
+ }
+ })
+ }
+}
+
+func TestRateLimiter_CheckMiddleware_NoToken(t *testing.T) {
+ db := newTestDB(t)
+ rl := NewRateLimiter(db)
+
+ handler := rl.Check(okHandler)
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ // No API token in context.
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
+ }
+}
+
+func TestRateLimiter_CheckMiddleware_ZeroRPM(t *testing.T) {
+ db := newTestDB(t)
+ rl := NewRateLimiter(db)
+
+ token := &auth.APIToken{
+ Name: "unlimited-token",
+ RateLimitRPM: 0, // zero means unlimited
+ }
+
+ handler := rl.Check(okHandler)
+
+ for i := 0; i < 100; i++ {
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx := withAPIToken(req.Context(), token)
+ req = req.WithContext(ctx)
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("request %d: status code = %d, want %d (zero RPM should be unlimited)", i, rec.Code, http.StatusOK)
+ }
+ }
+}
+
+func TestRateLimiter_PerTokenIsolation(t *testing.T) {
+ db := newTestDB(t)
+ rl := NewRateLimiter(db)
+
+ rpm := 2
+
+ // Exhaust token A.
+ for i := 0; i < rpm; i++ {
+ rl.allow("token-a", rpm)
+ }
+ ok, _, _ := rl.allow("token-a", rpm)
+ if ok {
+ t.Fatal("token-a should be rate limited")
+ }
+
+ // Token B should still have its own bucket.
+ ok, _, _ = rl.allow("token-b", rpm)
+ if !ok {
+ t.Fatal("token-b should not be affected by token-a's rate limit")
+ }
+}
+
+func TestRateLimiter_ResetAtIsFuture(t *testing.T) {
+ db := newTestDB(t)
+ rl := NewRateLimiter(db)
+
+ // Consume one token so there's a deficit.
+ _, _, resetAt := rl.allow("reset-token", 10)
+ now := time.Now().Unix()
+
+ if resetAt < now {
+ t.Errorf("resetAt = %d, want >= %d (should be now or in the future)", resetAt, now)
+ }
+}
+
+func TestRateLimiter_CheckMiddleware_ContextCancelled(t *testing.T) {
+ db := newTestDB(t)
+ rl := NewRateLimiter(db)
+
+ token := &auth.APIToken{
+ Name: "ctx-token",
+ RateLimitRPM: 10,
+ }
+
+ handler := rl.Check(okHandler)
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/", nil)
+ ctx, cancel := context.WithCancel(req.Context())
+ ctx = withAPIToken(ctx, token)
+ cancel() // Cancel immediately.
+ req = req.WithContext(ctx)
+
+ // Should still process (rate limiter does not check context cancellation).
+ handler.ServeHTTP(rec, req)
+ // The handler itself may or may not respect cancelled context;
+ // the key point is no panic occurs.
+}
diff --git a/llm-gateway/internal/proxy/stream.go b/llm-gateway/internal/proxy/stream.go
index eb9e225..c1304d5 100644
--- a/llm-gateway/internal/proxy/stream.go
+++ b/llm-gateway/internal/proxy/stream.go
@@ -2,6 +2,7 @@ package proxy
import (
"bufio"
+ "context"
"encoding/json"
"errors"
"log"
@@ -12,7 +13,7 @@ import (
"llm-gateway/internal/provider"
)
-func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string) {
+func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) {
flusher, ok := w.(http.Flusher)
if !ok {
writeError(w, http.StatusInternalServerError, "streaming not supported")
@@ -21,7 +22,18 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
var lastErr error
- for _, route := range routes {
+ for i, route := range routes {
+ // Retry backoff between attempts
+ if i > 0 {
+ backoff := backoffDuration(i, h.cfg.Retry)
+ select {
+ case <-time.After(backoff):
+ case <-r.Context().Done():
+ writeError(w, http.StatusGatewayTimeout, "request cancelled")
+ return
+ }
+ }
+
start := time.Now()
body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req)
@@ -30,67 +42,95 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
if errors.As(err, &pe) && !pe.IsRetryable() {
latency := time.Since(start).Milliseconds()
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
- h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
+ h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
+ w.Header().Set("X-Request-ID", requestID)
writeErrorRaw(w, pe.StatusCode, pe.Body)
return
}
lastErr = err
latency := time.Since(start).Milliseconds()
log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err)
- h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
+ h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
continue
}
+ // Apply streaming timeout
+ var streamCtx context.Context
+ var streamCancel context.CancelFunc
+ if h.cfg.Server.StreamingTimeout > 0 {
+ streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout)
+ } else {
+ streamCtx, streamCancel = context.WithCancel(r.Context())
+ }
+
// Stream the response
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
+ w.Header().Set("X-Request-ID", requestID)
w.WriteHeader(http.StatusOK)
inputTokens, outputTokens := 0, 0
scanner := bufio.NewScanner(body)
scanner.Buffer(make([]byte, 64*1024), 256*1024)
- for scanner.Scan() {
- line := scanner.Text()
+ scanDone := make(chan struct{})
+ go func() {
+ defer close(scanDone)
+ for scanner.Scan() {
+ select {
+ case <-streamCtx.Done():
+ return
+ default:
+ }
- // Parse usage from the final chunk if available
- if strings.HasPrefix(line, "data: ") {
- data := strings.TrimPrefix(line, "data: ")
- if data != "[DONE]" {
- var chunk streamChunk
- if json.Unmarshal([]byte(data), &chunk) == nil {
- if chunk.Usage != nil {
- inputTokens = chunk.Usage.PromptTokens
- outputTokens = chunk.Usage.CompletionTokens
- }
- // Override model name in chunk
- if chunk.Model != "" {
- chunk.Model = req.Model
- if rewritten, err := json.Marshal(chunk); err == nil {
- line = "data: " + string(rewritten)
+ line := scanner.Text()
+
+ if strings.HasPrefix(line, "data: ") {
+ data := strings.TrimPrefix(line, "data: ")
+ if data != "[DONE]" {
+ var chunk streamChunk
+ if json.Unmarshal([]byte(data), &chunk) == nil {
+ if chunk.Usage != nil {
+ inputTokens = chunk.Usage.PromptTokens
+ outputTokens = chunk.Usage.CompletionTokens
+ }
+ if chunk.Model != "" {
+ chunk.Model = req.Model
+ if rewritten, err := json.Marshal(chunk); err == nil {
+ line = "data: " + string(rewritten)
+ }
}
}
}
}
- }
- w.Write([]byte(line + "\n"))
- flusher.Flush()
+ w.Write([]byte(line + "\n"))
+ flusher.Flush()
+ }
+ }()
+
+ select {
+ case <-scanDone:
+ // Normal completion
+ case <-streamCtx.Done():
+ log.Printf("Stream timeout for %s via %s", req.Model, route.Provider.Name())
}
+
body.Close()
+ streamCancel()
latency := time.Since(start).Milliseconds()
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
- h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
+ h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, nil)
}
@@ -98,6 +138,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *prov
}
// All providers failed
+ w.Header().Set("X-Request-ID", requestID)
if lastErr != nil {
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
} else {
diff --git a/llm-gateway/internal/storage/audit.go b/llm-gateway/internal/storage/audit.go
new file mode 100644
index 0000000..d821f99
--- /dev/null
+++ b/llm-gateway/internal/storage/audit.go
@@ -0,0 +1,102 @@
+package storage
+
+import (
+ "log"
+ "time"
+)
+
+type AuditEntry struct {
+ ID int64 `json:"id"`
+ Timestamp int64 `json:"timestamp"`
+ UserID int64 `json:"user_id"`
+ Username string `json:"username"`
+ Action string `json:"action"`
+ TargetType string `json:"target_type"`
+ TargetID string `json:"target_id"`
+ Details string `json:"details"`
+ IPAddress string `json:"ip_address"`
+ RequestID string `json:"request_id"`
+}
+
+type AuditLogger struct {
+ db *DB
+}
+
+func NewAuditLogger(db *DB) *AuditLogger {
+ return &AuditLogger{db: db}
+}
+
+func (a *AuditLogger) Log(entry AuditEntry) {
+ if entry.Timestamp == 0 {
+ entry.Timestamp = time.Now().Unix()
+ }
+ _, err := a.db.Exec(`INSERT INTO audit_log
+ (timestamp, user_id, username, action, target_type, target_id, details, ip_address, request_id)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
+ entry.Timestamp, entry.UserID, entry.Username, entry.Action,
+ entry.TargetType, entry.TargetID, entry.Details, entry.IPAddress, entry.RequestID,
+ )
+ if err != nil {
+ log.Printf("ERROR: audit log: %v", err)
+ }
+}
+
+type AuditQueryResult struct {
+ Entries []AuditEntry `json:"entries"`
+ Page int `json:"page"`
+ TotalPages int `json:"total_pages"`
+ Total int `json:"total"`
+}
+
+func (a *AuditLogger) Query(since int64, action string, page, limit int) *AuditQueryResult {
+ if page < 1 {
+ page = 1
+ }
+ if limit <= 0 {
+ limit = 50
+ }
+ offset := (page - 1) * limit
+
+ where := "WHERE timestamp >= ?"
+ args := []any{since}
+
+ if action != "" {
+ where += " AND action = ?"
+ args = append(args, action)
+ }
+
+ var total int
+ countArgs := make([]any, len(args))
+ copy(countArgs, args)
+ a.db.QueryRow("SELECT COUNT(*) FROM audit_log "+where, countArgs...).Scan(&total)
+
+ totalPages := (total + limit - 1) / limit
+ if totalPages < 1 {
+ totalPages = 1
+ }
+
+ query := `SELECT id, timestamp, COALESCE(user_id, 0), username, action,
+ COALESCE(target_type, ''), COALESCE(target_id, ''), COALESCE(details, ''),
+ COALESCE(ip_address, ''), COALESCE(request_id, '')
+ FROM audit_log ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?`
+ args = append(args, limit, offset)
+
+ rows, err := a.db.Query(query, args...)
+ if err != nil {
+ return &AuditQueryResult{Entries: []AuditEntry{}, Page: page, TotalPages: totalPages, Total: total}
+ }
+ defer rows.Close()
+
+ var entries []AuditEntry
+ for rows.Next() {
+ var e AuditEntry
+ rows.Scan(&e.ID, &e.Timestamp, &e.UserID, &e.Username, &e.Action,
+ &e.TargetType, &e.TargetID, &e.Details, &e.IPAddress, &e.RequestID)
+ entries = append(entries, e)
+ }
+ if entries == nil {
+ entries = []AuditEntry{}
+ }
+
+ return &AuditQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
+}
diff --git a/llm-gateway/internal/storage/debuglog.go b/llm-gateway/internal/storage/debuglog.go
new file mode 100644
index 0000000..be333dc
--- /dev/null
+++ b/llm-gateway/internal/storage/debuglog.go
@@ -0,0 +1,250 @@
+package storage
+
+import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "sync/atomic"
+ "time"
+)
+
+type DebugLogEntry struct {
+ ID int64 `json:"id"`
+ RequestID string `json:"request_id"`
+ Timestamp int64 `json:"timestamp"`
+ TokenName string `json:"token_name"`
+ Model string `json:"model"`
+ Provider string `json:"provider"`
+ RequestBody string `json:"request_body"`
+ ResponseBody string `json:"response_body"`
+ RequestHeaders string `json:"request_headers"`
+ ResponseStatus int `json:"response_status"`
+ FilePath string `json:"-"`
+}
+
+// debugFile is the JSON structure written to disk.
+type debugFile struct {
+ RequestHeaders string `json:"request_headers"`
+ RequestBody string `json:"request_body"`
+ ResponseBody string `json:"response_body"`
+}
+
+type DebugLogger struct {
+ db *DB
+ enabled atomic.Bool
+ dataDir string
+}
+
+func NewDebugLogger(db *DB, enabled bool, dataDir string) *DebugLogger {
+ dl := &DebugLogger{db: db, dataDir: dataDir}
+ dl.enabled.Store(enabled)
+ return dl
+}
+
+func (d *DebugLogger) SetEnabled(v bool) {
+ d.enabled.Store(v)
+}
+
+func (d *DebugLogger) IsEnabled() bool {
+ return d.enabled.Load()
+}
+
+// debugLogDir returns the base directory for debug log files.
+func (d *DebugLogger) debugLogDir() string {
+ return filepath.Join(d.dataDir, "debug-logs")
+}
+
+// debugFilePath builds the file path for a debug log entry.
+func (d *DebugLogger) debugFilePath(requestID string, ts time.Time) string {
+ date := ts.Format("2006-01-02")
+ return filepath.Join(d.debugLogDir(), date, requestID+".json")
+}
+
+func (d *DebugLogger) Log(entry DebugLogEntry) {
+ if !d.IsEnabled() {
+ return
+ }
+ if entry.Timestamp == 0 {
+ entry.Timestamp = time.Now().Unix()
+ }
+
+ ts := time.Unix(entry.Timestamp, 0)
+ fp := d.debugFilePath(entry.RequestID, ts)
+
+ // Write body file
+ if err := os.MkdirAll(filepath.Dir(fp), 0755); err != nil {
+ log.Printf("ERROR: debug log mkdir: %v", err)
+ return
+ }
+
+ df := debugFile{
+ RequestHeaders: entry.RequestHeaders,
+ RequestBody: entry.RequestBody,
+ ResponseBody: entry.ResponseBody,
+ }
+ data, err := json.Marshal(df)
+ if err != nil {
+ log.Printf("ERROR: debug log marshal: %v", err)
+ return
+ }
+ if err := os.WriteFile(fp, data, 0644); err != nil {
+ log.Printf("ERROR: debug log write: %v", err)
+ return
+ }
+
+ // Insert metadata into DB (no bodies)
+ _, err = d.db.Exec(`INSERT INTO debug_log
+ (request_id, timestamp, token_name, model, provider, response_status, file_path)
+ VALUES (?, ?, ?, ?, ?, ?, ?)`,
+ entry.RequestID, entry.Timestamp, entry.TokenName, entry.Model,
+ entry.Provider, entry.ResponseStatus, fp,
+ )
+ if err != nil {
+ log.Printf("ERROR: debug log db insert: %v", err)
+ }
+}
+
+type DebugLogQueryResult struct {
+ Entries []DebugLogEntry `json:"entries"`
+ Page int `json:"page"`
+ TotalPages int `json:"total_pages"`
+ Total int `json:"total"`
+}
+
+// Query returns paginated debug log metadata (no bodies — fast).
+func (d *DebugLogger) Query(page, limit int) *DebugLogQueryResult {
+ if page < 1 {
+ page = 1
+ }
+ if limit <= 0 {
+ limit = 50
+ }
+ offset := (page - 1) * limit
+
+ var total int
+ d.db.QueryRow("SELECT COUNT(*) FROM debug_log").Scan(&total)
+
+ totalPages := (total + limit - 1) / limit
+ if totalPages < 1 {
+ totalPages = 1
+ }
+
+ rows, err := d.db.Query(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
+ COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
+ FROM debug_log ORDER BY timestamp DESC LIMIT ? OFFSET ?`, limit, offset)
+ if err != nil {
+ return &DebugLogQueryResult{Entries: []DebugLogEntry{}, Page: page, TotalPages: totalPages, Total: total}
+ }
+ defer rows.Close()
+
+ var entries []DebugLogEntry
+ for rows.Next() {
+ var e DebugLogEntry
+ rows.Scan(&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
+ &e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
+ entries = append(entries, e)
+ }
+ if entries == nil {
+ entries = []DebugLogEntry{}
+ }
+
+ return &DebugLogQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
+}
+
+// QueryFull returns paginated debug log entries including request/response bodies read from files.
+func (d *DebugLogger) QueryFull(page, limit int) *DebugLogQueryResult {
+ result := d.Query(page, limit)
+ for i := range result.Entries {
+ d.populateFromFile(&result.Entries[i])
+ }
+ return result
+}
+
+// GetByRequestID returns a single debug log entry with bodies read from file.
+func (d *DebugLogger) GetByRequestID(requestID string) *DebugLogEntry {
+ var e DebugLogEntry
+ err := d.db.QueryRow(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
+ COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
+ FROM debug_log WHERE request_id = ?`, requestID).Scan(
+ &e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
+ &e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
+ if err != nil {
+ return nil
+ }
+ d.populateFromFile(&e)
+ return &e
+}
+
+// populateFromFile reads body data from the debug file on disk.
+// Falls back to DB columns for pre-migration entries that have no file_path.
+func (d *DebugLogger) populateFromFile(e *DebugLogEntry) {
+ if e.FilePath == "" {
+ // Legacy entry: try reading bodies from DB columns
+ d.db.QueryRow(`SELECT COALESCE(request_body, ''), COALESCE(response_body, ''), COALESCE(request_headers, '')
+ FROM debug_log WHERE id = ?`, e.ID).Scan(&e.RequestBody, &e.ResponseBody, &e.RequestHeaders)
+ return
+ }
+ data, err := os.ReadFile(e.FilePath)
+ if err != nil {
+ log.Printf("WARN: debug log read file %s: %v", e.FilePath, err)
+ return
+ }
+ var df debugFile
+ if err := json.Unmarshal(data, &df); err != nil {
+ log.Printf("WARN: debug log parse file %s: %v", e.FilePath, err)
+ return
+ }
+ e.RequestHeaders = df.RequestHeaders
+ e.RequestBody = df.RequestBody
+ e.ResponseBody = df.ResponseBody
+}
+
+// Cleanup removes debug log entries and files older than retentionDays.
+func (d *DebugLogger) Cleanup(retentionDays int) error {
+ cutoff := time.Now().AddDate(0, 0, -retentionDays)
+ cutoffUnix := cutoff.Unix()
+
+ // Delete old DB rows
+ result, err := d.db.Exec("DELETE FROM debug_log WHERE timestamp < ?", cutoffUnix)
+ if err != nil {
+ return fmt.Errorf("delete old debug rows: %w", err)
+ }
+ affected, _ := result.RowsAffected()
+ if affected > 0 {
+ log.Printf("Cleaned up %d old debug log entries", affected)
+ }
+
+ // Remove old date directories
+ baseDir := d.debugLogDir()
+ dirs, err := os.ReadDir(baseDir)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil
+ }
+ return fmt.Errorf("read debug log dir: %w", err)
+ }
+
+ cutoffDate := cutoff.Format("2006-01-02")
+ sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() })
+
+ for _, dir := range dirs {
+ if !dir.IsDir() {
+ continue
+ }
+ // Date directories are named YYYY-MM-DD; string comparison works
+ if strings.Compare(dir.Name(), cutoffDate) < 0 {
+ dirPath := filepath.Join(baseDir, dir.Name())
+ if err := os.RemoveAll(dirPath); err != nil {
+ log.Printf("WARN: failed to remove debug log dir %s: %v", dirPath, err)
+ } else {
+ log.Printf("Removed old debug log directory: %s", dir.Name())
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/llm-gateway/internal/storage/logger.go b/llm-gateway/internal/storage/logger.go
index ad2e829..d832dd4 100644
--- a/llm-gateway/internal/storage/logger.go
+++ b/llm-gateway/internal/storage/logger.go
@@ -6,6 +6,7 @@ import (
)
type RequestLog struct {
+ RequestID string
Timestamp int64
TokenName string
Model string
@@ -93,8 +94,8 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
}
stmt, err := tx.Prepare(`INSERT INTO request_logs
- (timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
+ (request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
if err != nil {
log.Printf("ERROR: preparing log statement: %v", err)
tx.Rollback()
@@ -112,7 +113,7 @@ func (l *AsyncLogger) flush(batch []RequestLog) {
cached = 1
}
_, err := stmt.Exec(
- r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
+ r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS,
r.Status, r.ErrorMessage, streaming, cached,
)
diff --git a/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql b/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql
new file mode 100644
index 0000000..e11bb40
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/004_token_concurrency.down.sql
@@ -0,0 +1,4 @@
+-- SQLite doesn't support DROP COLUMN in older versions, so we recreate the table
+CREATE TABLE api_tokens_backup AS SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens;
+DROP TABLE api_tokens;
+ALTER TABLE api_tokens_backup RENAME TO api_tokens;
diff --git a/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql b/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql
new file mode 100644
index 0000000..ccf0549
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/004_token_concurrency.up.sql
@@ -0,0 +1 @@
+ALTER TABLE api_tokens ADD COLUMN max_concurrent INTEGER DEFAULT 0;
diff --git a/llm-gateway/internal/storage/migrations/005_request_id.down.sql b/llm-gateway/internal/storage/migrations/005_request_id.down.sql
new file mode 100644
index 0000000..819b90b
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/005_request_id.down.sql
@@ -0,0 +1 @@
+DROP INDEX IF EXISTS idx_request_logs_request_id;
diff --git a/llm-gateway/internal/storage/migrations/005_request_id.up.sql b/llm-gateway/internal/storage/migrations/005_request_id.up.sql
new file mode 100644
index 0000000..ff54384
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/005_request_id.up.sql
@@ -0,0 +1,2 @@
+ALTER TABLE request_logs ADD COLUMN request_id TEXT DEFAULT '';
+CREATE INDEX idx_request_logs_request_id ON request_logs(request_id);
diff --git a/llm-gateway/internal/storage/migrations/006_audit_log.down.sql b/llm-gateway/internal/storage/migrations/006_audit_log.down.sql
new file mode 100644
index 0000000..b750c3b
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/006_audit_log.down.sql
@@ -0,0 +1 @@
+DROP TABLE IF EXISTS audit_log;
diff --git a/llm-gateway/internal/storage/migrations/006_audit_log.up.sql b/llm-gateway/internal/storage/migrations/006_audit_log.up.sql
new file mode 100644
index 0000000..c2fc48a
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/006_audit_log.up.sql
@@ -0,0 +1,14 @@
+CREATE TABLE audit_log (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ timestamp INTEGER NOT NULL,
+ user_id INTEGER,
+ username TEXT NOT NULL DEFAULT '',
+ action TEXT NOT NULL,
+ target_type TEXT DEFAULT '',
+ target_id TEXT DEFAULT '',
+ details TEXT DEFAULT '',
+ ip_address TEXT DEFAULT '',
+ request_id TEXT DEFAULT ''
+);
+CREATE INDEX idx_audit_timestamp ON audit_log(timestamp);
+CREATE INDEX idx_audit_action ON audit_log(action);
diff --git a/llm-gateway/internal/storage/migrations/007_debug_log.down.sql b/llm-gateway/internal/storage/migrations/007_debug_log.down.sql
new file mode 100644
index 0000000..41353f5
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/007_debug_log.down.sql
@@ -0,0 +1 @@
+DROP TABLE IF EXISTS debug_log;
diff --git a/llm-gateway/internal/storage/migrations/007_debug_log.up.sql b/llm-gateway/internal/storage/migrations/007_debug_log.up.sql
new file mode 100644
index 0000000..9a8441a
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/007_debug_log.up.sql
@@ -0,0 +1,14 @@
+CREATE TABLE debug_log (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ request_id TEXT NOT NULL,
+ timestamp INTEGER NOT NULL,
+ token_name TEXT DEFAULT '',
+ model TEXT DEFAULT '',
+ provider TEXT DEFAULT '',
+ request_body TEXT DEFAULT '',
+ response_body TEXT DEFAULT '',
+ request_headers TEXT DEFAULT '',
+ response_status INTEGER DEFAULT 0
+);
+CREATE INDEX idx_debug_request_id ON debug_log(request_id);
+CREATE INDEX idx_debug_timestamp ON debug_log(timestamp);
diff --git a/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql b/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql
new file mode 100644
index 0000000..032a37d
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/008_debug_log_files.down.sql
@@ -0,0 +1 @@
+-- no-op: file_path column is harmless to keep
diff --git a/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql b/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql
new file mode 100644
index 0000000..7a5bf8b
--- /dev/null
+++ b/llm-gateway/internal/storage/migrations/008_debug_log_files.up.sql
@@ -0,0 +1 @@
+ALTER TABLE debug_log ADD COLUMN file_path TEXT DEFAULT '';