package proxy import ( "context" "encoding/json" "errors" "io" "log" "net/http" "time" "llm-gateway/internal/auth" "llm-gateway/internal/cache" "llm-gateway/internal/config" "llm-gateway/internal/metrics" "llm-gateway/internal/provider" "llm-gateway/internal/storage" ) type contextKey string const tokenNameKey contextKey = "token_name" const apiTokenKey contextKey = "api_token" func withTokenName(ctx context.Context, name string) context.Context { return context.WithValue(ctx, tokenNameKey, name) } func getTokenName(ctx context.Context) string { name, _ := ctx.Value(tokenNameKey).(string) return name } func withAPIToken(ctx context.Context, token *auth.APIToken) context.Context { return context.WithValue(ctx, apiTokenKey, token) } func getAPIToken(ctx context.Context) *auth.APIToken { t, _ := ctx.Value(apiTokenKey).(*auth.APIToken) return t } type Handler struct { registry *provider.Registry logger *storage.AsyncLogger cache *cache.Cache metrics *metrics.Metrics cfg *config.Config healthTracker *provider.HealthTracker } func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler { return &Handler{ registry: registry, logger: logger, cache: c, metrics: m, cfg: cfg, healthTracker: ht, } } func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20)) if err != nil { writeError(w, http.StatusBadRequest, "failed to read request body") return } var req provider.ChatRequest if err := json.Unmarshal(body, &req); err != nil { writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error()) return } if req.Model == "" { writeError(w, http.StatusBadRequest, "model is required") return } routes, ok := h.registry.Lookup(req.Model) if !ok { writeError(w, http.StatusNotFound, "model not found: "+req.Model) return } tokenName := getTokenName(r.Context()) // Check cache for non-streaming requests if !req.Stream && h.cache != nil { if cached, err := h.cache.Get(r.Context(), req.Model, body); err == nil && cached != nil { h.logRequest(tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true) w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Cache", "HIT") w.Write(cached) return } } if req.Stream { h.handleStream(w, r, &req, routes, tokenName) return } h.handleNonStream(w, r, &req, routes, tokenName, body) } func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte) { var lastErr error for _, route := range routes { start := time.Now() resp, err := route.Provider.ChatCompletion(r.Context(), route.ProviderModel, req) latency := time.Since(start).Milliseconds() if err != nil { var pe *provider.ProviderError if errors.As(err, &pe) && !pe.IsRetryable() { // Client error — don't retry h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) if h.healthTracker != nil { h.healthTracker.Record(route.Provider.Name(), latency, err) } writeErrorRaw(w, pe.StatusCode, pe.Body) return } lastErr = err log.Printf("Provider %s failed for %s: %v", route.Provider.Name(), req.Model, err) h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0) h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false) if h.healthTracker != nil { h.healthTracker.Record(route.Provider.Name(), latency, err) } continue } if h.healthTracker != nil { h.healthTracker.Record(route.Provider.Name(), latency, nil) } // Compute cost inputTokens, outputTokens := 0, 0 if resp.Usage != nil { inputTokens = resp.Usage.PromptTokens outputTokens = resp.Usage.CompletionTokens } cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice) h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost) h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false) // Override model name in response to match the requested model resp.Model = req.Model respBytes, err := json.Marshal(resp) if err != nil { writeError(w, http.StatusInternalServerError, "failed to marshal response") return } // Cache the response if h.cache != nil { h.cache.Set(r.Context(), req.Model, rawBody, respBytes) } w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Cache", "MISS") w.Write(respBytes) return } // All providers failed if lastErr != nil { writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error()) } else { writeError(w, http.StatusBadGateway, "all providers failed") } } func (h *Handler) logRequest(tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) { h.logger.Log(storage.RequestLog{ Timestamp: time.Now().Unix(), TokenName: tokenName, Model: model, Provider: providerName, ProviderModel: providerModel, InputTokens: inputTokens, OutputTokens: outputTokens, CostUSD: cost, LatencyMS: latencyMS, Status: status, ErrorMessage: errMsg, Streaming: streaming, Cached: cached, }) } func computeCost(inputTokens, outputTokens int, inputPrice, outputPrice float64) float64 { return (float64(inputTokens) / 1_000_000.0 * inputPrice) + (float64(outputTokens) / 1_000_000.0 * outputPrice) } func writeError(w http.ResponseWriter, code int, msg string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) json.NewEncoder(w).Encode(map[string]any{ "error": map[string]any{ "message": msg, "type": "error", "code": code, }, }) } func writeErrorRaw(w http.ResponseWriter, code int, body string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) w.Write([]byte(body)) }