ai-servers/llm-gateway/internal/config/config.go

198 lines
5.1 KiB
Go

package config
import (
"crypto/rand"
"encoding/hex"
"fmt"
"log"
"os"
"time"
"gopkg.in/yaml.v3"
)
type Config struct {
Server ServerConfig `yaml:"server"`
Database DatabaseConfig `yaml:"database"`
Cache CacheConfig `yaml:"cache"`
Pricing PricingLookupConfig `yaml:"pricing_lookup"`
Providers []ProviderConfig `yaml:"providers"`
Models []ModelConfig `yaml:"models"`
Tokens []TokenConfig `yaml:"tokens"`
}
type PricingLookupConfig struct {
URL string `yaml:"url"`
RefreshInterval time.Duration `yaml:"refresh_interval"`
}
type DefaultAdminConfig struct {
Username string `yaml:"username"`
Password string `yaml:"password"`
}
type TokenConfig struct {
Name string `yaml:"name"`
Key string `yaml:"key"`
RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited
DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited
}
type ServerConfig struct {
Listen string `yaml:"listen"`
RequestTimeout time.Duration `yaml:"request_timeout"`
MaxRequestBodyMB int `yaml:"max_request_body_mb"`
SessionSecret string `yaml:"session_secret"`
DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
}
type DatabaseConfig struct {
Path string `yaml:"path"`
RetentionDays int `yaml:"retention_days"`
}
type CacheConfig struct {
Enabled bool `yaml:"enabled"`
Address string `yaml:"address"`
TTL int `yaml:"ttl"` // seconds
}
type ProviderConfig struct {
Name string `yaml:"name"`
BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"`
Priority int `yaml:"priority"`
Timeout time.Duration `yaml:"timeout"`
}
type ModelConfig struct {
Name string `yaml:"name"`
Routes []RouteConfig `yaml:"routes"`
}
type RouteConfig struct {
Provider string `yaml:"provider"`
Model string `yaml:"model"`
Pricing PricingConfig `yaml:"pricing"`
}
type PricingConfig struct {
Input float64 `yaml:"input"` // cost per 1M tokens
Output float64 `yaml:"output"` // cost per 1M tokens
}
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading config: %w", err)
}
// Expand environment variables
expanded := os.ExpandEnv(string(data))
var cfg Config
if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil {
return nil, fmt.Errorf("parsing config: %w", err)
}
if err := cfg.validate(); err != nil {
return nil, fmt.Errorf("validating config: %w", err)
}
return &cfg, nil
}
func (c *Config) validate() error {
if c.Server.Listen == "" {
c.Server.Listen = "0.0.0.0:3000"
}
if c.Server.RequestTimeout == 0 {
c.Server.RequestTimeout = 300 * time.Second
}
if c.Server.MaxRequestBodyMB == 0 {
c.Server.MaxRequestBodyMB = 10
}
if c.Server.SessionSecret == "" {
b := make([]byte, 32)
rand.Read(b)
c.Server.SessionSecret = hex.EncodeToString(b)
log.Println("WARNING: no session_secret configured, generated random one (sessions won't survive restart)")
}
if c.Database.Path == "" {
c.Database.Path = "gateway.db"
}
if c.Database.RetentionDays == 0 {
c.Database.RetentionDays = 90
}
if c.Pricing.RefreshInterval == 0 {
c.Pricing.RefreshInterval = 6 * time.Hour
}
if len(c.Providers) == 0 {
return fmt.Errorf("at least one provider is required")
}
providerNames := make(map[string]bool)
for i, p := range c.Providers {
if p.Name == "" || p.BaseURL == "" || p.APIKey == "" {
return fmt.Errorf("provider %d: name, base_url, and api_key are required", i)
}
if providerNames[p.Name] {
return fmt.Errorf("duplicate provider name: %s", p.Name)
}
providerNames[p.Name] = true
if c.Providers[i].Timeout == 0 {
c.Providers[i].Timeout = 120 * time.Second
}
if c.Providers[i].Priority == 0 {
c.Providers[i].Priority = 1
}
}
if len(c.Models) == 0 {
return fmt.Errorf("at least one model is required")
}
modelNames := make(map[string]bool)
for i, m := range c.Models {
if m.Name == "" {
return fmt.Errorf("model %d: name is required", i)
}
if modelNames[m.Name] {
return fmt.Errorf("duplicate model name: %s", m.Name)
}
modelNames[m.Name] = true
if len(m.Routes) == 0 {
return fmt.Errorf("model %s: at least one route is required", m.Name)
}
for j, r := range m.Routes {
if r.Provider == "" || r.Model == "" {
return fmt.Errorf("model %s route %d: provider and model are required", m.Name, j)
}
if !providerNames[r.Provider] {
return fmt.Errorf("model %s route %d: unknown provider %s", m.Name, j, r.Provider)
}
}
}
// Validate tokens (optional section)
for i, t := range c.Tokens {
if t.Key == "" {
log.Printf("WARNING: token %d (%s) has empty key, skipping", i, t.Name)
continue
}
if t.Name == "" {
c.Tokens[i].Name = fmt.Sprintf("token-%d", i)
}
}
return nil
}
// ProviderByName returns the provider config by name.
func (c *Config) ProviderByName(name string) *ProviderConfig {
for i := range c.Providers {
if c.Providers[i].Name == name {
return &c.Providers[i]
}
}
return nil
}