ai-servers/llm-gateway/internal/config/config_test.go
Ray Andrew 90adf6f3a8
feat(gateway): add circuit breaker, retry, and concurrency limit support
feat(gateway): add debug logging with file storage and retention

feat(gateway): add audit logging for user actions

feat(gateway): add request ID tracking and rate limit headers

feat(gateway): add model aliases and load balancing strategies

feat(gateway): add config hot-reload via SIGHUP

feat(gateway): add CORS support

feat(gateway): add data export API and dashboard endpoints

feat(gateway): add dashboard pages for audit and debug logs

feat(gateway): add concurrent request limiting per token

feat(gateway): add streaming timeout support

feat(gateway): add migration support for new schema fields
2026-02-15 04:21:40 -06:00

738 lines
18 KiB
Go

package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
// writeConfigFile creates a temporary YAML config file and returns its path.
func writeConfigFile(t *testing.T, content string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "config-*.yaml")
if err != nil {
t.Fatalf("creating temp file: %v", err)
}
if _, err := f.WriteString(content); err != nil {
f.Close()
t.Fatalf("writing temp file: %v", err)
}
f.Close()
return f.Name()
}
// minimalValidConfig returns a minimal valid YAML config string.
func minimalValidConfig() string {
return `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-test-key
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`
}
func TestLoad_ValidConfig(t *testing.T) {
path := writeConfigFile(t, `
server:
listen: "127.0.0.1:8080"
request_timeout: 60s
streaming_timeout: 120s
max_request_body_mb: 5
session_secret: "test-secret-1234567890abcdef1234567890abcdef"
database:
path: "/tmp/test.db"
retention_days: 30
pricing_lookup:
url: "https://pricing.example.com"
refresh_interval: 1h
circuit_breaker:
enabled: true
error_threshold: 0.3
min_requests: 10
cooldown_duration: 60s
retry:
initial_backoff: 200ms
max_backoff: 10s
multiplier: 3.0
debug:
enabled: true
max_body_bytes: 65536
retention_days: 60
cors:
enabled: true
allowed_origins:
- "https://example.com"
allowed_methods:
- GET
- POST
allowed_headers:
- Authorization
max_age: 600
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-test-key
priority: 2
timeout: 60s
- name: anthropic
base_url: https://api.anthropic.com/v1
api_key: sk-ant-test
priority: 1
timeout: 30s
models:
- name: gpt-4
aliases:
- gpt4
routes:
- provider: openai
model: gpt-4
pricing:
input: 30.0
output: 60.0
load_balancing: first
- name: claude-3
routes:
- provider: anthropic
model: claude-3-opus-20240229
tokens:
- name: test-token
key: tok-abc123
rate_limit_rpm: 100
daily_budget_usd: 10.0
max_concurrent: 5
`)
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() returned error: %v", err)
}
// Server
if cfg.Server.Listen != "127.0.0.1:8080" {
t.Errorf("Listen = %q, want %q", cfg.Server.Listen, "127.0.0.1:8080")
}
if cfg.Server.RequestTimeout != 60*time.Second {
t.Errorf("RequestTimeout = %v, want %v", cfg.Server.RequestTimeout, 60*time.Second)
}
if cfg.Server.StreamingTimeout != 120*time.Second {
t.Errorf("StreamingTimeout = %v, want %v", cfg.Server.StreamingTimeout, 120*time.Second)
}
if cfg.Server.MaxRequestBodyMB != 5 {
t.Errorf("MaxRequestBodyMB = %d, want %d", cfg.Server.MaxRequestBodyMB, 5)
}
if cfg.Server.SessionSecret != "test-secret-1234567890abcdef1234567890abcdef" {
t.Errorf("SessionSecret = %q, want %q", cfg.Server.SessionSecret, "test-secret-1234567890abcdef1234567890abcdef")
}
// Database
if cfg.Database.Path != "/tmp/test.db" {
t.Errorf("Database.Path = %q, want %q", cfg.Database.Path, "/tmp/test.db")
}
if cfg.Database.RetentionDays != 30 {
t.Errorf("Database.RetentionDays = %d, want %d", cfg.Database.RetentionDays, 30)
}
// Pricing
if cfg.Pricing.URL != "https://pricing.example.com" {
t.Errorf("Pricing.URL = %q, want %q", cfg.Pricing.URL, "https://pricing.example.com")
}
if cfg.Pricing.RefreshInterval != 1*time.Hour {
t.Errorf("Pricing.RefreshInterval = %v, want %v", cfg.Pricing.RefreshInterval, 1*time.Hour)
}
// Circuit breaker
if !cfg.CircuitBreaker.Enabled {
t.Error("CircuitBreaker.Enabled = false, want true")
}
if cfg.CircuitBreaker.ErrorThreshold != 0.3 {
t.Errorf("CircuitBreaker.ErrorThreshold = %v, want %v", cfg.CircuitBreaker.ErrorThreshold, 0.3)
}
if cfg.CircuitBreaker.MinRequests != 10 {
t.Errorf("CircuitBreaker.MinRequests = %d, want %d", cfg.CircuitBreaker.MinRequests, 10)
}
if cfg.CircuitBreaker.CooldownDuration != 60*time.Second {
t.Errorf("CircuitBreaker.CooldownDuration = %v, want %v", cfg.CircuitBreaker.CooldownDuration, 60*time.Second)
}
// Retry
if cfg.Retry.InitialBackoff != 200*time.Millisecond {
t.Errorf("Retry.InitialBackoff = %v, want %v", cfg.Retry.InitialBackoff, 200*time.Millisecond)
}
if cfg.Retry.MaxBackoff != 10*time.Second {
t.Errorf("Retry.MaxBackoff = %v, want %v", cfg.Retry.MaxBackoff, 10*time.Second)
}
if cfg.Retry.Multiplier != 3.0 {
t.Errorf("Retry.Multiplier = %v, want %v", cfg.Retry.Multiplier, 3.0)
}
// Debug
if !cfg.Debug.Enabled {
t.Error("Debug.Enabled = false, want true")
}
if cfg.Debug.MaxBodyBytes != 65536 {
t.Errorf("Debug.MaxBodyBytes = %d, want %d", cfg.Debug.MaxBodyBytes, 65536)
}
if cfg.Debug.RetentionDays != 60 {
t.Errorf("Debug.RetentionDays = %d, want %d", cfg.Debug.RetentionDays, 60)
}
// CORS
if !cfg.CORS.Enabled {
t.Error("CORS.Enabled = false, want true")
}
if cfg.CORS.MaxAge != 600 {
t.Errorf("CORS.MaxAge = %d, want %d", cfg.CORS.MaxAge, 600)
}
// Providers
if len(cfg.Providers) != 2 {
t.Fatalf("len(Providers) = %d, want 2", len(cfg.Providers))
}
if cfg.Providers[0].Name != "openai" {
t.Errorf("Providers[0].Name = %q, want %q", cfg.Providers[0].Name, "openai")
}
if cfg.Providers[0].Timeout != 60*time.Second {
t.Errorf("Providers[0].Timeout = %v, want %v", cfg.Providers[0].Timeout, 60*time.Second)
}
// Models
if len(cfg.Models) != 2 {
t.Fatalf("len(Models) = %d, want 2", len(cfg.Models))
}
if cfg.Models[0].LoadBalancing != "first" {
t.Errorf("Models[0].LoadBalancing = %q, want %q", cfg.Models[0].LoadBalancing, "first")
}
if len(cfg.Models[0].Aliases) != 1 || cfg.Models[0].Aliases[0] != "gpt4" {
t.Errorf("Models[0].Aliases = %v, want [gpt4]", cfg.Models[0].Aliases)
}
if cfg.Models[0].Routes[0].Pricing.Input != 30.0 {
t.Errorf("Models[0].Routes[0].Pricing.Input = %v, want 30.0", cfg.Models[0].Routes[0].Pricing.Input)
}
// Tokens
if len(cfg.Tokens) != 1 {
t.Fatalf("len(Tokens) = %d, want 1", len(cfg.Tokens))
}
if cfg.Tokens[0].Name != "test-token" {
t.Errorf("Tokens[0].Name = %q, want %q", cfg.Tokens[0].Name, "test-token")
}
if cfg.Tokens[0].RateLimitRPM != 100 {
t.Errorf("Tokens[0].RateLimitRPM = %d, want 100", cfg.Tokens[0].RateLimitRPM)
}
}
func TestValidate_Defaults(t *testing.T) {
path := writeConfigFile(t, minimalValidConfig())
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() returned error: %v", err)
}
tests := []struct {
name string
got any
want any
}{
// Server defaults
{"Server.Listen", cfg.Server.Listen, "0.0.0.0:3000"},
{"Server.RequestTimeout", cfg.Server.RequestTimeout, 300 * time.Second},
{"Server.StreamingTimeout", cfg.Server.StreamingTimeout, 5 * time.Minute},
{"Server.MaxRequestBodyMB", cfg.Server.MaxRequestBodyMB, 10},
// Database defaults
{"Database.Path", cfg.Database.Path, "gateway.db"},
{"Database.RetentionDays", cfg.Database.RetentionDays, 90},
// Pricing defaults
{"Pricing.RefreshInterval", cfg.Pricing.RefreshInterval, 6 * time.Hour},
// Circuit breaker defaults
{"CircuitBreaker.ErrorThreshold", cfg.CircuitBreaker.ErrorThreshold, 0.5},
{"CircuitBreaker.MinRequests", cfg.CircuitBreaker.MinRequests, 5},
{"CircuitBreaker.CooldownDuration", cfg.CircuitBreaker.CooldownDuration, 30 * time.Second},
// Retry defaults
{"Retry.InitialBackoff", cfg.Retry.InitialBackoff, 100 * time.Millisecond},
{"Retry.MaxBackoff", cfg.Retry.MaxBackoff, 5 * time.Second},
{"Retry.Multiplier", cfg.Retry.Multiplier, 2.0},
// Debug defaults
{"Debug.MaxBodyBytes", cfg.Debug.MaxBodyBytes, 0},
{"Debug.RetentionDays", cfg.Debug.RetentionDays, 90},
// CORS defaults
{"CORS.MaxAge", cfg.CORS.MaxAge, 300},
// Provider defaults
{"Providers[0].Timeout", cfg.Providers[0].Timeout, 120 * time.Second},
{"Providers[0].Priority", cfg.Providers[0].Priority, 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Compare using formatted strings to handle different numeric types
gotStr := formatValue(tt.got)
wantStr := formatValue(tt.want)
if gotStr != wantStr {
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
}
})
}
// SessionSecret should be auto-generated (non-empty, 64 hex chars)
if cfg.Server.SessionSecret == "" {
t.Error("SessionSecret should be auto-generated when empty")
}
if len(cfg.Server.SessionSecret) != 64 {
t.Errorf("SessionSecret length = %d, want 64 hex chars", len(cfg.Server.SessionSecret))
}
}
func formatValue(v any) string {
switch val := v.(type) {
case time.Duration:
return val.String()
case float64:
return fmt.Sprintf("%g", val)
case int:
return fmt.Sprintf("%d", val)
case string:
return val
default:
return fmt.Sprintf("%v", val)
}
}
func TestLoad_FileNotFound(t *testing.T) {
_, err := Load(filepath.Join(t.TempDir(), "nonexistent.yaml"))
if err == nil {
t.Fatal("Load() should return error for nonexistent file")
}
}
func TestLoad_InvalidYAML(t *testing.T) {
path := writeConfigFile(t, `{{{invalid yaml`)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error for invalid YAML")
}
}
func TestValidate_DuplicateProviderNames(t *testing.T) {
path := writeConfigFile(t, `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key1
- name: openai
base_url: https://api.openai.com/v2
api_key: sk-key2
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error for duplicate provider names")
}
wantSubstr := "duplicate provider name: openai"
if !strings.Contains(err.Error(), wantSubstr) {
t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
}
}
func TestValidate_DuplicateModelNames(t *testing.T) {
path := writeConfigFile(t, `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key1
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
- name: gpt-4
routes:
- provider: openai
model: gpt-4-turbo
`)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error for duplicate model names")
}
wantSubstr := "duplicate model name: gpt-4"
if !strings.Contains(err.Error(), wantSubstr) {
t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
}
}
func TestValidate_AliasConflicts(t *testing.T) {
tests := []struct {
name string
config string
wantErr string
}{
{
name: "alias conflicts with model name",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key1
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
- name: claude-3
aliases:
- gpt-4
routes:
- provider: openai
model: claude-3
`,
wantErr: "model alias gpt-4 conflicts with existing model or alias",
},
{
name: "alias conflicts with another alias",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key1
models:
- name: gpt-4
aliases:
- fast-model
routes:
- provider: openai
model: gpt-4
- name: claude-3
aliases:
- fast-model
routes:
- provider: openai
model: claude-3
`,
wantErr: "model alias fast-model conflicts with existing model or alias",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path := writeConfigFile(t, tt.config)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error for alias conflicts")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
}
})
}
}
func TestValidate_MissingRequiredFields(t *testing.T) {
tests := []struct {
name string
config string
wantErr string
}{
{
name: "no providers",
config: `models: [{name: test, routes: [{provider: x, model: y}]}]`,
wantErr: "at least one provider is required",
},
{
name: "no models",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
`,
wantErr: "at least one model is required",
},
{
name: "provider missing name",
config: `
providers:
- base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`,
wantErr: "provider 0: name, base_url, and api_key are required",
},
{
name: "provider missing base_url",
config: `
providers:
- name: openai
api_key: sk-key
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`,
wantErr: "provider 0: name, base_url, and api_key are required",
},
{
name: "provider missing api_key",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`,
wantErr: "provider 0: name, base_url, and api_key are required",
},
{
name: "model missing name",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- routes:
- provider: openai
model: gpt-4
`,
wantErr: "model 0: name is required",
},
{
name: "model missing routes",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
`,
wantErr: "model gpt-4: at least one route is required",
},
{
name: "route missing provider",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
routes:
- model: gpt-4
`,
wantErr: "model gpt-4 route 0: provider and model are required",
},
{
name: "route missing model",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
routes:
- provider: openai
`,
wantErr: "model gpt-4 route 0: provider and model are required",
},
{
name: "route references unknown provider",
config: `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-key
models:
- name: gpt-4
routes:
- provider: anthropic
model: gpt-4
`,
wantErr: "model gpt-4 route 0: unknown provider anthropic",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path := writeConfigFile(t, tt.config)
_, err := Load(path)
if err == nil {
t.Fatalf("Load() should return error, want %q", tt.wantErr)
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
}
})
}
}
func TestProviderByName(t *testing.T) {
path := writeConfigFile(t, `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: sk-openai
- name: anthropic
base_url: https://api.anthropic.com/v1
api_key: sk-anthropic
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`)
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() returned error: %v", err)
}
tests := []struct {
name string
lookup string
wantNil bool
wantName string
}{
{"existing provider openai", "openai", false, "openai"},
{"existing provider anthropic", "anthropic", false, "anthropic"},
{"nonexistent provider", "google", true, ""},
{"empty name", "", true, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := cfg.ProviderByName(tt.lookup)
if tt.wantNil {
if p != nil {
t.Errorf("ProviderByName(%q) = %v, want nil", tt.lookup, p)
}
} else {
if p == nil {
t.Fatalf("ProviderByName(%q) = nil, want provider", tt.lookup)
}
if p.Name != tt.wantName {
t.Errorf("ProviderByName(%q).Name = %q, want %q", tt.lookup, p.Name, tt.wantName)
}
}
})
}
// Verify returned pointer refers to the actual config entry
p := cfg.ProviderByName("openai")
if p.APIKey != "sk-openai" {
t.Errorf("ProviderByName(openai).APIKey = %q, want %q", p.APIKey, "sk-openai")
}
}
func TestLoad_EnvironmentVariableExpansion(t *testing.T) {
t.Setenv("TEST_API_KEY", "sk-from-env")
t.Setenv("TEST_BASE_URL", "https://env.example.com/v1")
t.Setenv("TEST_PROVIDER_NAME", "env-provider")
path := writeConfigFile(t, `
providers:
- name: $TEST_PROVIDER_NAME
base_url: ${TEST_BASE_URL}
api_key: ${TEST_API_KEY}
models:
- name: test-model
routes:
- provider: env-provider
model: gpt-4
`)
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() returned error: %v", err)
}
if cfg.Providers[0].Name != "env-provider" {
t.Errorf("Provider.Name = %q, want %q", cfg.Providers[0].Name, "env-provider")
}
if cfg.Providers[0].BaseURL != "https://env.example.com/v1" {
t.Errorf("Provider.BaseURL = %q, want %q", cfg.Providers[0].BaseURL, "https://env.example.com/v1")
}
if cfg.Providers[0].APIKey != "sk-from-env" {
t.Errorf("Provider.APIKey = %q, want %q", cfg.Providers[0].APIKey, "sk-from-env")
}
}
func TestLoad_UnsetEnvVarExpandsToEmpty(t *testing.T) {
// Ensure the variable is not set
t.Setenv("TEST_UNSET_VAR", "")
os.Unsetenv("TEST_UNSET_VAR")
path := writeConfigFile(t, `
providers:
- name: openai
base_url: https://api.openai.com/v1
api_key: ${TEST_UNSET_VAR}
models:
- name: gpt-4
routes:
- provider: openai
model: gpt-4
`)
_, err := Load(path)
if err == nil {
t.Fatal("Load() should return error when env var expands to empty required field")
}
// api_key will be empty, so validation should catch it
if !strings.Contains(err.Error(), "api_key are required") {
t.Errorf("error = %q, want to contain api_key validation message", err.Error())
}
}