ai-servers/llm-gateway/internal/provider/registry.go

74 lines
1.8 KiB
Go

package provider
import (
"fmt"
"sort"
"llm-gateway/internal/config"
)
// Route maps a model to a specific provider with pricing.
type Route struct {
Provider Provider
ProviderModel string
Priority int
InputPrice float64 // per 1M tokens
OutputPrice float64 // per 1M tokens
}
// Registry maps model names to provider routes.
type Registry struct {
routes map[string][]Route
}
func NewRegistry(cfg *config.Config) (*Registry, error) {
// Build providers
providers := make(map[string]Provider)
for _, pc := range cfg.Providers {
providers[pc.Name] = NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKey, pc.Timeout)
}
// Build routes
routes := make(map[string][]Route)
for _, mc := range cfg.Models {
var modelRoutes []Route
for _, rc := range mc.Routes {
p, ok := providers[rc.Provider]
if !ok {
return nil, fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider)
}
pc := cfg.ProviderByName(rc.Provider)
priority := pc.Priority
modelRoutes = append(modelRoutes, Route{
Provider: p,
ProviderModel: rc.Model,
Priority: priority,
InputPrice: rc.Pricing.Input,
OutputPrice: rc.Pricing.Output,
})
}
// Sort by priority (lower = higher priority)
sort.Slice(modelRoutes, func(i, j int) bool {
return modelRoutes[i].Priority < modelRoutes[j].Priority
})
routes[mc.Name] = modelRoutes
}
return &Registry{routes: routes}, nil
}
// Lookup returns the routes for a model name.
func (r *Registry) Lookup(model string) ([]Route, bool) {
routes, ok := r.routes[model]
return routes, ok
}
// ModelNames returns all registered model names.
func (r *Registry) ModelNames() []string {
names := make([]string, 0, len(r.routes))
for name := range r.routes {
names = append(names, name)
}
sort.Strings(names)
return names
}