ai-servers/llm-gateway/internal/proxy/stream.go

115 lines
3.6 KiB
Go

package proxy
import (
"bufio"
"encoding/json"
"errors"
"log"
"net/http"
"strings"
"time"
"llm-gateway/internal/provider"
)
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string) {
flusher, ok := w.(http.Flusher)
if !ok {
writeError(w, http.StatusInternalServerError, "streaming not supported")
return
}
var lastErr error
for _, route := range routes {
start := time.Now()
body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req)
if err != nil {
var pe *provider.ProviderError
if errors.As(err, &pe) && !pe.IsRetryable() {
latency := time.Since(start).Milliseconds()
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
writeErrorRaw(w, pe.StatusCode, pe.Body)
return
}
lastErr = err
latency := time.Since(start).Milliseconds()
log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err)
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, err)
}
continue
}
// Stream the response
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(http.StatusOK)
inputTokens, outputTokens := 0, 0
scanner := bufio.NewScanner(body)
scanner.Buffer(make([]byte, 64*1024), 256*1024)
for scanner.Scan() {
line := scanner.Text()
// Parse usage from the final chunk if available
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data != "[DONE]" {
var chunk streamChunk
if json.Unmarshal([]byte(data), &chunk) == nil {
if chunk.Usage != nil {
inputTokens = chunk.Usage.PromptTokens
outputTokens = chunk.Usage.CompletionTokens
}
// Override model name in chunk
if chunk.Model != "" {
chunk.Model = req.Model
if rewritten, err := json.Marshal(chunk); err == nil {
line = "data: " + string(rewritten)
}
}
}
}
}
w.Write([]byte(line + "\n"))
flusher.Flush()
}
body.Close()
latency := time.Since(start).Milliseconds()
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
if h.healthTracker != nil {
h.healthTracker.Record(route.Provider.Name(), latency, nil)
}
return
}
// All providers failed
if lastErr != nil {
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
} else {
writeError(w, http.StatusBadGateway, "all providers failed")
}
}
type streamChunk struct {
ID string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int64 `json:"created,omitempty"`
Model string `json:"model,omitempty"`
Choices []any `json:"choices,omitempty"`
Usage *provider.Usage `json:"usage,omitempty"`
}