ai-servers/llm-gateway/internal/proxy/stream.go
Ray Andrew 90adf6f3a8
feat(gateway): add circuit breaker, retry, and concurrency limit support
feat(gateway): add debug logging with file storage and retention

feat(gateway): add audit logging for user actions

feat(gateway): add request ID tracking and rate limit headers

feat(gateway): add model aliases and load balancing strategies

feat(gateway): add config hot-reload via SIGHUP

feat(gateway): add CORS support

feat(gateway): add data export API and dashboard endpoints

feat(gateway): add dashboard pages for audit and debug logs

feat(gateway): add concurrent request limiting per token

feat(gateway): add streaming timeout support

feat(gateway): add migration support for new schema fields
2026-02-15 04:21:40 -06:00

156 lines
4.6 KiB
Go

package proxy
import (
"bufio"
"context"
"encoding/json"
"errors"
"log"
"net/http"
"strings"
"time"
"llm-gateway/internal/provider"
)
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string) {
flusher, ok := w.(http.Flusher)
if !ok {
writeError(w, http.StatusInternalServerError, "streaming not supported")
return
}
var lastErr error
for i, route := range routes {
// Retry backoff between attempts
if i > 0 {
backoff := backoffDuration(i, h.cfg.Retry)
select {
case <-time.After(backoff):
case <-r.Context().Done():
writeError(w, http.StatusGatewayTimeout, "request cancelled")
return
}
}
start := time.Now()
body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req)
if err != nil {
var pe *provider.ProviderError
if errors.As(err, &pe) && !pe.IsRetryable() {
latency := time.Since(start).Milliseconds()
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
w.Header().Set("X-Request-ID", requestID)
writeErrorRaw(w, pe.StatusCode, pe.Body)
return
}
lastErr = err
latency := time.Since(start).Milliseconds()
log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err)
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
continue
}
// Apply streaming timeout
var streamCtx context.Context
var streamCancel context.CancelFunc
if h.cfg.Server.StreamingTimeout > 0 {
streamCtx, streamCancel = context.WithTimeout(r.Context(), h.cfg.Server.StreamingTimeout)
} else {
streamCtx, streamCancel = context.WithCancel(r.Context())
}
// Stream the response
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
w.Header().Set("X-Request-ID", requestID)
w.WriteHeader(http.StatusOK)
inputTokens, outputTokens := 0, 0
scanner := bufio.NewScanner(body)
scanner.Buffer(make([]byte, 64*1024), 256*1024)
scanDone := make(chan struct{})
go func() {
defer close(scanDone)
for scanner.Scan() {
select {
case <-streamCtx.Done():
return
default:
}
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data != "[DONE]" {
var chunk streamChunk
if json.Unmarshal([]byte(data), &chunk) == nil {
if chunk.Usage != nil {
inputTokens = chunk.Usage.PromptTokens
outputTokens = chunk.Usage.CompletionTokens
}
if chunk.Model != "" {
chunk.Model = req.Model
if rewritten, err := json.Marshal(chunk); err == nil {
line = "data: " + string(rewritten)
}
}
}
}
}
w.Write([]byte(line + "\n"))
flusher.Flush()
}
}()
select {
case <-scanDone:
// Normal completion
case <-streamCtx.Done():
log.Printf("Stream timeout for %s via %s", req.Model, route.Provider.Name())
}
body.Close()
streamCancel()
latency := time.Since(start).Milliseconds()
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, nil)
}
return
}
// All providers failed
w.Header().Set("X-Request-ID", requestID)
if lastErr != nil {
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
} else {
writeError(w, http.StatusBadGateway, "all providers failed")
}
}
type streamChunk struct {
ID string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
Choices []any `json:"choices,omitempty"`
Usage *provider.Usage `json:"usage,omitempty"`
}