115 lines
3.6 KiB
Go
115 lines
3.6 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"errors"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"llm-gateway/internal/provider"
|
|
)
|
|
|
|
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string) {
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
writeError(w, http.StatusInternalServerError, "streaming not supported")
|
|
return
|
|
}
|
|
|
|
var lastErr error
|
|
|
|
for _, route := range routes {
|
|
start := time.Now()
|
|
body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req)
|
|
|
|
if err != nil {
|
|
var pe *provider.ProviderError
|
|
if errors.As(err, &pe) && !pe.IsRetryable() {
|
|
latency := time.Since(start).Milliseconds()
|
|
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
|
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
|
if h.healthTracker != nil {
|
|
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
}
|
|
writeErrorRaw(w, pe.StatusCode, pe.Body)
|
|
return
|
|
}
|
|
lastErr = err
|
|
latency := time.Since(start).Milliseconds()
|
|
log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err)
|
|
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
|
if h.healthTracker != nil {
|
|
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Stream the response
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
inputTokens, outputTokens := 0, 0
|
|
scanner := bufio.NewScanner(body)
|
|
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
|
|
// Parse usage from the final chunk if available
|
|
if strings.HasPrefix(line, "data: ") {
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data != "[DONE]" {
|
|
var chunk streamChunk
|
|
if json.Unmarshal([]byte(data), &chunk) == nil {
|
|
if chunk.Usage != nil {
|
|
inputTokens = chunk.Usage.PromptTokens
|
|
outputTokens = chunk.Usage.CompletionTokens
|
|
}
|
|
// Override model name in chunk
|
|
if chunk.Model != "" {
|
|
chunk.Model = req.Model
|
|
if rewritten, err := json.Marshal(chunk); err == nil {
|
|
line = "data: " + string(rewritten)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
w.Write([]byte(line + "\n"))
|
|
flusher.Flush()
|
|
}
|
|
body.Close()
|
|
|
|
latency := time.Since(start).Milliseconds()
|
|
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
|
|
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
|
|
h.logRequest(tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
|
|
if h.healthTracker != nil {
|
|
h.healthTracker.Record(route.Provider.Name(), latency, nil)
|
|
}
|
|
return
|
|
}
|
|
|
|
// All providers failed
|
|
if lastErr != nil {
|
|
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
|
|
} else {
|
|
writeError(w, http.StatusBadGateway, "all providers failed")
|
|
}
|
|
}
|
|
|
|
type streamChunk struct {
|
|
ID string `json:"id,omitempty"`
|
|
Object string `json:"object,omitempty"`
|
|
Created int64 `json:"created,omitempty"`
|
|
Model string `json:"model,omitempty"`
|
|
Choices []any `json:"choices,omitempty"`
|
|
Usage *provider.Usage `json:"usage,omitempty"`
|
|
}
|