feat(gateway): add debug logging with file storage and retention feat(gateway): add audit logging for user actions feat(gateway): add request ID tracking and rate limit headers feat(gateway): add model aliases and load balancing strategies feat(gateway): add config hot-reload via SIGHUP feat(gateway): add CORS support feat(gateway): add data export API and dashboard endpoints feat(gateway): add dashboard pages for audit and debug logs feat(gateway): add concurrent request limiting per token feat(gateway): add streaming timeout support feat(gateway): add migration support for new schema fields
156 lines
4.6 KiB
Go
156 lines
4.6 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"llm-gateway/internal/provider"
|
|
)
|
|
|
|
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) {
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
writeError(w, http.StatusInternalServerError, "streaming not supported")
|
|
return
|
|
}
|
|
|
|
var lastErr error
|
|
|
|
for i, route := range routes {
|
|
// Retry backoff between attempts
|
|
if i > 0 {
|
|
backoff := backoffDuration(i, h.cfg.Retry)
|
|
select {
|
|
case <-time.After(backoff):
|
|
case <-r.Context().Done():
|
|
writeError(w, http.StatusGatewayTimeout, "request cancelled")
|
|
return
|
|
}
|
|
}
|
|
|
|
start := time.Now()
|
|
body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req)
|
|
|
|
if err != nil {
|
|
var pe *provider.ProviderError
|
|
if errors.As(err, &pe) && !pe.IsRetryable() {
|
|
latency := time.Since(start).Milliseconds()
|
|
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
|
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
|
if h.healthTracker != nil {
|
|
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
}
|
|
w.Header().Set("X-Request-ID", requestID)
|
|
writeErrorRaw(w, pe.StatusCode, pe.Body)
|
|
return
|
|
}
|
|
lastErr = err
|
|
latency := time.Since(start).Milliseconds()
|
|
log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err)
|
|
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
|
if h.healthTracker != nil {
|
|
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Apply streaming timeout
|
|
var streamCtx context.Context
|
|
var streamCancel context.CancelFunc
|
|
if h.cfg.Server.StreamingTimeout > 0 {
|
|
streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout)
|
|
} else {
|
|
streamCtx, streamCancel = context.WithCancel(r.Context())
|
|
}
|
|
|
|
// Stream the response
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
w.Header().Set("X-Request-ID", requestID)
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
inputTokens, outputTokens := 0, 0
|
|
scanner := bufio.NewScanner(body)
|
|
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
|
|
|
scanDone := make(chan struct{})
|
|
go func() {
|
|
defer close(scanDone)
|
|
for scanner.Scan() {
|
|
select {
|
|
case <-streamCtx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
line := scanner.Text()
|
|
|
|
if strings.HasPrefix(line, "data: ") {
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data != "[DONE]" {
|
|
var chunk streamChunk
|
|
if json.Unmarshal([]byte(data), &chunk) == nil {
|
|
if chunk.Usage != nil {
|
|
inputTokens = chunk.Usage.PromptTokens
|
|
outputTokens = chunk.Usage.CompletionTokens
|
|
}
|
|
if chunk.Model != "" {
|
|
chunk.Model = req.Model
|
|
if rewritten, err := json.Marshal(chunk); err == nil {
|
|
line = "data: " + string(rewritten)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
w.Write([]byte(line + "\n"))
|
|
flusher.Flush()
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-scanDone:
|
|
// Normal completion
|
|
case <-streamCtx.Done():
|
|
log.Printf("Stream timeout for %s via %s", req.Model, route.Provider.Name())
|
|
}
|
|
|
|
body.Close()
|
|
streamCancel()
|
|
|
|
latency := time.Since(start).Milliseconds()
|
|
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
|
|
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
|
|
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
|
|
if h.healthTracker != nil {
|
|
h.healthTracker.Record(route.Provider.Name(), latency, nil)
|
|
}
|
|
return
|
|
}
|
|
|
|
// All providers failed
|
|
w.Header().Set("X-Request-ID", requestID)
|
|
if lastErr != nil {
|
|
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
|
|
} else {
|
|
writeError(w, http.StatusBadGateway, "all providers failed")
|
|
}
|
|
}
|
|
|
|
type streamChunk struct {
|
|
ID string `json:"id,omitempty"`
|
|
Object string `json:"object,omitempty"`
|
|
Created int64 `json:"created,omitempty"`
|
|
Model string `json:"model,omitempty"`
|
|
Choices []any `json:"choices,omitempty"`
|
|
Usage *provider.Usage `json:"usage,omitempty"`
|
|
}
|