chore: remove all llm-gateway service files and configurations
This commit is contained in:
parent
e459a5f58e
commit
92b6ec72b1
80 changed files with 0 additions and 12258 deletions
372
llm-gateway.yaml
372
llm-gateway.yaml
|
|
@ -1,372 +0,0 @@
|
||||||
server:
|
|
||||||
listen: "0.0.0.0:3000"
|
|
||||||
request_timeout: 300s
|
|
||||||
max_request_body_mb: 10
|
|
||||||
session_secret: "${SESSION_SECRET}"
|
|
||||||
default_admin:
|
|
||||||
username: "${ADMIN_USERNAME}"
|
|
||||||
password: "${ADMIN_PASSWORD}"
|
|
||||||
|
|
||||||
tokens:
|
|
||||||
- name: "open-webui"
|
|
||||||
key: "${OPENWEBUI_API_KEY}"
|
|
||||||
rate_limit_rpm: 0 # unlimited
|
|
||||||
daily_budget_usd: 0
|
|
||||||
- name: "opencode"
|
|
||||||
key: "${OPENCODE_API_KEY}"
|
|
||||||
rate_limit_rpm: 0 # unlimited
|
|
||||||
daily_budget_usd: 0
|
|
||||||
|
|
||||||
pricing_lookup:
|
|
||||||
# url: "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json" # default
|
|
||||||
refresh_interval: 6h
|
|
||||||
|
|
||||||
database:
|
|
||||||
path: "/data/gateway.db"
|
|
||||||
retention_days: 90
|
|
||||||
|
|
||||||
debug:
|
|
||||||
enabled: true
|
|
||||||
retention_days: 90
|
|
||||||
# data_dir: "/data" # defaults to directory of database.path
|
|
||||||
# max_body_bytes: 0 # 0 = unlimited (save full bodies)
|
|
||||||
|
|
||||||
cache:
|
|
||||||
enabled: true
|
|
||||||
address: "valkey:6379"
|
|
||||||
ttl: 3600
|
|
||||||
|
|
||||||
providers:
|
|
||||||
- name: deepinfra
|
|
||||||
base_url: "https://api.deepinfra.com/v1/openai"
|
|
||||||
api_key: "${DEEPINFRA_API_KEY}"
|
|
||||||
priority: 1
|
|
||||||
timeout: 120s
|
|
||||||
- name: siliconflow
|
|
||||||
base_url: "https://api.siliconflow.com/v1"
|
|
||||||
api_key: "${SILICONFLOW_API_KEY}"
|
|
||||||
priority: 2
|
|
||||||
timeout: 120s
|
|
||||||
- name: openrouter
|
|
||||||
base_url: "https://openrouter.ai/api/v1"
|
|
||||||
api_key: "${OPENROUTER_API_KEY}"
|
|
||||||
priority: 3
|
|
||||||
timeout: 120s
|
|
||||||
- name: groq
|
|
||||||
base_url: "https://api.groq.com/openai/v1"
|
|
||||||
api_key: "${GROQ_API_KEY}"
|
|
||||||
priority: 1
|
|
||||||
timeout: 120s
|
|
||||||
- name: cerebras
|
|
||||||
base_url: "https://api.cerebras.ai/v1"
|
|
||||||
api_key: "${CEREBRAS_API_KEY}"
|
|
||||||
priority: 1
|
|
||||||
timeout: 120s
|
|
||||||
|
|
||||||
models:
|
|
||||||
# ═══ TIER 1: Free (OpenRouter free models, $0) ═══
|
|
||||||
# NOTE: Commented out — free models are heavily rate-limited upstream.
|
|
||||||
# Uncomment if you want best-effort free access.
|
|
||||||
# - name: "llama-3.3-70b-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "meta-llama/llama-3.3-70b-instruct:free"
|
|
||||||
# - name: "deepseek-r1-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "deepseek/deepseek-r1-0528:free"
|
|
||||||
# - name: "gpt-oss-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "openai/gpt-oss-120b:free"
|
|
||||||
# - name: "gpt-oss-20b-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "openai/gpt-oss-20b:free"
|
|
||||||
# - name: "qwen3-coder-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "qwen/qwen3-coder:free"
|
|
||||||
# - name: "qwen3-235b-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "qwen/qwen3-235b-a22b-thinking-2507"
|
|
||||||
# - name: "glm-4.5-air-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "z-ai/glm-4.5-air:free"
|
|
||||||
# - name: "nemotron-nano-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "nvidia/nemotron-nano-9b-v2:free"
|
|
||||||
# - name: "trinity-large-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "arcee-ai/trinity-large-preview:free"
|
|
||||||
# - name: "mistral-small-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "mistralai/mistral-small-3.1-24b-instruct:free"
|
|
||||||
# - name: "gemma-3-27b-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "google/gemma-3-27b-it:free"
|
|
||||||
# - name: "step-3.5-flash-free"
|
|
||||||
# routes:
|
|
||||||
# - provider: openrouter
|
|
||||||
# model: "stepfun/step-3.5-flash:free"
|
|
||||||
|
|
||||||
# ═══ TIER 2: Low cost (Groq, Cerebras — free tier with rate limits) ═══
|
|
||||||
- name: "llama-3.1-8b"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "llama-3.1-8b-instant"
|
|
||||||
pricing: { input: 0.05, output: 0.08 }
|
|
||||||
- provider: cerebras
|
|
||||||
model: "llama3.1-8b"
|
|
||||||
pricing: { input: 0.10, output: 0.10 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
|
||||||
pricing: { input: 0.03, output: 0.05 }
|
|
||||||
|
|
||||||
- name: "llama-3.3-70b"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
|
||||||
pricing: { input: 0.23, output: 0.40 }
|
|
||||||
- provider: groq
|
|
||||||
model: "llama-3.3-70b-versatile"
|
|
||||||
pricing: { input: 0.59, output: 0.79 }
|
|
||||||
- provider: cerebras
|
|
||||||
model: "llama-3.3-70b"
|
|
||||||
pricing: { input: 0.85, output: 1.20 }
|
|
||||||
|
|
||||||
- name: "gpt-oss"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "openai/gpt-oss-120b"
|
|
||||||
pricing: { input: 0.15, output: 0.60 }
|
|
||||||
- provider: cerebras
|
|
||||||
model: "gpt-oss-120b"
|
|
||||||
pricing: { input: 0.35, output: 0.75 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "openai/gpt-oss-120b"
|
|
||||||
pricing: { input: 0.05, output: 0.24 }
|
|
||||||
|
|
||||||
- name: "gpt-oss-20b"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "openai/gpt-oss-20b"
|
|
||||||
pricing: { input: 0.075, output: 0.30 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "openai/gpt-oss-20b"
|
|
||||||
pricing: { input: 0.04, output: 0.16 }
|
|
||||||
|
|
||||||
- name: "llama-4-scout"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "meta-llama/llama-4-scout-17b-16e-instruct"
|
|
||||||
pricing: { input: 0.11, output: 0.34 }
|
|
||||||
|
|
||||||
- name: "llama-4-maverick"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "meta-llama/llama-4-maverick-17b-128e-instruct"
|
|
||||||
pricing: { input: 0.20, output: 0.60 }
|
|
||||||
|
|
||||||
- name: "qwen3-32b"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "qwen/qwen3-32b"
|
|
||||||
pricing: { input: 0.29, output: 0.59 }
|
|
||||||
- provider: cerebras
|
|
||||||
model: "qwen-3-32b"
|
|
||||||
|
|
||||||
# ═══ TIER 3: DeepSeek V3.2 (cheapest flagship) ═══
|
|
||||||
- name: "deepseek-v3.2"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "deepseek-ai/DeepSeek-V3.2"
|
|
||||||
pricing: { input: 0.26, output: 0.38 }
|
|
||||||
- provider: siliconflow
|
|
||||||
model: "deepseek-ai/DeepSeek-V3.2"
|
|
||||||
pricing: { input: 0.27, output: 0.42 }
|
|
||||||
- provider: openrouter
|
|
||||||
model: "deepseek/deepseek-chat-v3-0324"
|
|
||||||
pricing: { input: 0.30, output: 0.88 }
|
|
||||||
|
|
||||||
# ═══ TIER 4: Ultra-cheap DeepInfra ═══
|
|
||||||
- name: "nemotron-super"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "nvidia/Llama-3.3-Nemotron-Super-49B-v1.5"
|
|
||||||
pricing: { input: 0.10, output: 0.40 }
|
|
||||||
|
|
||||||
- name: "nemotron-nano"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
|
|
||||||
pricing: { input: 0.04, output: 0.16 }
|
|
||||||
|
|
||||||
# ═══ TIER 5: DeepSeek R1 & reasoning ═══
|
|
||||||
- name: "deepseek-r1"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "deepseek-ai/DeepSeek-R1-0528"
|
|
||||||
- provider: openrouter
|
|
||||||
model: "deepseek/deepseek-r1"
|
|
||||||
|
|
||||||
- name: "deepseek-r1-distill-llama-70b"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
|
|
||||||
|
|
||||||
- name: "devstral-small"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "mistralai/devstral-small"
|
|
||||||
|
|
||||||
- name: "devstral-medium"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "mistralai/devstral-medium"
|
|
||||||
|
|
||||||
# ═══ TIER 6: GLM ═══
|
|
||||||
- name: "glm-4.6"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "zai-org/GLM-4.6"
|
|
||||||
pricing: { input: 0.60, output: 1.90 }
|
|
||||||
|
|
||||||
- name: "glm-4.7"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "zai-org/GLM-4.7"
|
|
||||||
pricing: { input: 0.40, output: 1.75 }
|
|
||||||
- provider: cerebras
|
|
||||||
model: "zai-glm-4.7"
|
|
||||||
pricing: { input: 2.25, output: 2.75 }
|
|
||||||
- provider: siliconflow
|
|
||||||
model: "THUDM/GLM-4-32B-0414"
|
|
||||||
|
|
||||||
- name: "glm-5"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "zai-org/GLM-5"
|
|
||||||
pricing: { input: 0.80, output: 2.56 }
|
|
||||||
|
|
||||||
# ═══ TIER 7: Kimi ═══
|
|
||||||
- name: "kimi-k2"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "moonshotai/kimi-k2-instruct-0905"
|
|
||||||
pricing: { input: 1.00, output: 3.00 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "moonshotai/Kimi-K2-Instruct-0905"
|
|
||||||
pricing: { input: 0.50, output: 2.00 }
|
|
||||||
- provider: siliconflow
|
|
||||||
model: "moonshotai/Kimi-K2-Instruct-0905"
|
|
||||||
pricing: { input: 0.58, output: 2.29 }
|
|
||||||
|
|
||||||
- name: "kimi-k2.5"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "moonshotai/Kimi-K2.5"
|
|
||||||
pricing: { input: 0.45, output: 2.25 }
|
|
||||||
- provider: openrouter
|
|
||||||
model: "moonshotai/kimi-k2.5"
|
|
||||||
|
|
||||||
# ═══ TIER 8: SiliconFlow (Qwen) ═══
|
|
||||||
- name: "qwen3-coder"
|
|
||||||
routes:
|
|
||||||
- provider: siliconflow
|
|
||||||
model: "Qwen/Qwen3-Coder-480B-A35B-Instruct"
|
|
||||||
pricing: { input: 1.14, output: 2.28 }
|
|
||||||
|
|
||||||
- name: "qwen3-coder-30b"
|
|
||||||
routes:
|
|
||||||
- provider: siliconflow
|
|
||||||
model: "Qwen/Qwen3-Coder-30B-A3B-Instruct"
|
|
||||||
|
|
||||||
# ═══ TIER 9: OpenRouter premium (paid) ═══
|
|
||||||
- name: "minimax-m2.5"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "minimax/minimax-m2.5"
|
|
||||||
|
|
||||||
- name: "gpt-4.1-mini"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "openai/gpt-4.1-mini"
|
|
||||||
|
|
||||||
- name: "gpt-4.1"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "openai/gpt-4.1"
|
|
||||||
|
|
||||||
- name: "gemini-3-flash-preview"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "google/gemini-3-flash-preview"
|
|
||||||
|
|
||||||
- name: "gemini-2.5-pro"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "google/gemini-2.5-pro-preview"
|
|
||||||
|
|
||||||
# ═══ TIER 10: Vision / Multimodal ═══
|
|
||||||
- name: "gemma-3-4b"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "google/gemma-3-4b-it"
|
|
||||||
pricing: { input: 0.017, output: 0.068 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "google/gemma-3-4b-it"
|
|
||||||
pricing: { input: 0.04, output: 0.08 }
|
|
||||||
|
|
||||||
- name: "gemma-3-12b"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "google/gemma-3-12b-it"
|
|
||||||
pricing: { input: 0.03, output: 0.10 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "google/gemma-3-12b-it"
|
|
||||||
pricing: { input: 0.04, output: 0.13 }
|
|
||||||
|
|
||||||
- name: "gemma-3-27b"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "google/gemma-3-27b-it"
|
|
||||||
pricing: { input: 0.04, output: 0.15 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "google/gemma-3-27b-it"
|
|
||||||
pricing: { input: 0.08, output: 0.16 }
|
|
||||||
|
|
||||||
- name: "qwen3-vl-8b"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "qwen/qwen3-vl-8b-instruct"
|
|
||||||
pricing: { input: 0.08, output: 0.50 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "Qwen/Qwen3-VL-8B-Instruct"
|
|
||||||
pricing: { input: 0.18, output: 0.69 }
|
|
||||||
|
|
||||||
- name: "qwen3-vl-32b"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "qwen/qwen3-vl-32b-instruct"
|
|
||||||
pricing: { input: 0.104, output: 0.416 }
|
|
||||||
|
|
||||||
- name: "qwen2.5-vl-32b"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "qwen/qwen2.5-vl-32b-instruct"
|
|
||||||
pricing: { input: 0.05, output: 0.22 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "Qwen/Qwen2.5-VL-32B-Instruct"
|
|
||||||
pricing: { input: 0.20, output: 0.60 }
|
|
||||||
|
|
||||||
- name: "claude-sonnet"
|
|
||||||
routes:
|
|
||||||
- provider: openrouter
|
|
||||||
model: "anthropic/claude-sonnet-4"
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
# LLM Gateway Environment Variables
|
|
||||||
|
|
||||||
# Session secret (required for persistent sessions)
|
|
||||||
SESSION_SECRET=change-me-to-a-random-string
|
|
||||||
|
|
||||||
# Default admin (created on first run if no users exist)
|
|
||||||
ADMIN_USERNAME=admin
|
|
||||||
ADMIN_PASSWORD=change-me-min-8-chars
|
|
||||||
|
|
||||||
# Static API tokens (seeded on startup)
|
|
||||||
OPENWEBUI_API_KEY=sk-your-openwebui-key
|
|
||||||
PERSONAL_API_KEY=sk-your-personal-key
|
|
||||||
|
|
||||||
# Provider API keys
|
|
||||||
DEEPINFRA_API_KEY=
|
|
||||||
SILICONFLOW_API_KEY=
|
|
||||||
OPENROUTER_API_KEY=
|
|
||||||
GROQ_API_KEY=
|
|
||||||
CEREBRAS_API_KEY=
|
|
||||||
18
llm-gateway/.gitignore
vendored
18
llm-gateway/.gitignore
vendored
|
|
@ -1,18 +0,0 @@
|
||||||
# Binaries
|
|
||||||
gateway
|
|
||||||
llm-gateway
|
|
||||||
|
|
||||||
# Database
|
|
||||||
*.db
|
|
||||||
*.db-journal
|
|
||||||
*.db-wal
|
|
||||||
*.db-shm
|
|
||||||
|
|
||||||
# Debug log files
|
|
||||||
debug-logs/
|
|
||||||
|
|
||||||
# Local config
|
|
||||||
configs/config.local.yaml
|
|
||||||
|
|
||||||
# Environment
|
|
||||||
.env
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
FROM golang:1.24-alpine AS builder
|
|
||||||
WORKDIR /src
|
|
||||||
COPY go.mod go.sum ./
|
|
||||||
RUN go mod download
|
|
||||||
COPY . .
|
|
||||||
RUN CGO_ENABLED=0 go build -ldflags="-s -w" -o /llm-gateway ./cmd/gateway
|
|
||||||
|
|
||||||
FROM alpine:3.19
|
|
||||||
RUN apk add --no-cache ca-certificates tzdata
|
|
||||||
COPY --from=builder /llm-gateway /usr/local/bin/llm-gateway
|
|
||||||
RUN mkdir -p /data
|
|
||||||
VOLUME /data
|
|
||||||
EXPOSE 3000
|
|
||||||
ENTRYPOINT ["llm-gateway"]
|
|
||||||
CMD ["-config", "/etc/llm-gateway/config.yaml"]
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
||||||
.PHONY: build run clean docker
|
|
||||||
|
|
||||||
BINARY=llm-gateway
|
|
||||||
VERSION=$(shell git describe --tags --always --dirty 2>/dev/null || echo dev)
|
|
||||||
|
|
||||||
build:
|
|
||||||
go build -ldflags="-s -w -X main.version=$(VERSION)" -o $(BINARY) ./cmd/gateway
|
|
||||||
|
|
||||||
run: build
|
|
||||||
./$(BINARY) -config configs/config.yaml
|
|
||||||
|
|
||||||
clean:
|
|
||||||
rm -f $(BINARY)
|
|
||||||
|
|
||||||
docker:
|
|
||||||
docker build -t llm-gateway:latest .
|
|
||||||
|
|
@ -1,434 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"flag"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"path/filepath"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
|
||||||
gocors "github.com/go-chi/cors"
|
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
||||||
|
|
||||||
"llm-gateway/internal/auth"
|
|
||||||
"llm-gateway/internal/cache"
|
|
||||||
"llm-gateway/internal/config"
|
|
||||||
"llm-gateway/internal/dashboard"
|
|
||||||
"llm-gateway/internal/metrics"
|
|
||||||
"llm-gateway/internal/pricing"
|
|
||||||
"llm-gateway/internal/provider"
|
|
||||||
"llm-gateway/internal/proxy"
|
|
||||||
"llm-gateway/internal/storage"
|
|
||||||
"llm-gateway/internal/webhook"
|
|
||||||
)
|
|
||||||
|
|
||||||
var version = "dev"
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
configPath := flag.String("config", "configs/config.yaml", "path to config file")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
log.Printf("llm-gateway %s starting", version)
|
|
||||||
|
|
||||||
cfg, err := config.Load(*configPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to load config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pricing lookup (fetches from URL, refreshes periodically)
|
|
||||||
pricingLookup := pricing.NewLookup(cfg.Pricing.URL, cfg.Pricing.RefreshInterval)
|
|
||||||
defer pricingLookup.Close()
|
|
||||||
|
|
||||||
// Auto-fill missing pricing from fetched data
|
|
||||||
for i, m := range cfg.Models {
|
|
||||||
for j, r := range m.Routes {
|
|
||||||
if r.Pricing.Input == 0 && r.Pricing.Output == 0 {
|
|
||||||
if pricingLookup.FillMissing(r.Provider, r.Model, &cfg.Models[i].Routes[j].Pricing.Input, &cfg.Models[i].Routes[j].Pricing.Output) {
|
|
||||||
log.Printf("Auto-filled pricing for %s via %s: $%.2f/$%.2f per 1M tokens",
|
|
||||||
m.Name, r.Provider, cfg.Models[i].Routes[j].Pricing.Input, cfg.Models[i].Routes[j].Pricing.Output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Database
|
|
||||||
db, err := storage.Open(cfg.Database.Path)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
if err := db.CleanupOldRecords(cfg.Database.RetentionDays); err != nil {
|
|
||||||
log.Printf("WARNING: retention cleanup failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
asyncLogger := storage.NewAsyncLogger(db, 1000)
|
|
||||||
defer asyncLogger.Close()
|
|
||||||
|
|
||||||
// SSE broker for real-time dashboard updates
|
|
||||||
sseBroker := dashboard.NewSSEBroker()
|
|
||||||
asyncLogger.OnFlush = sseBroker.Notify
|
|
||||||
|
|
||||||
// Cache (optional)
|
|
||||||
var c *cache.Cache
|
|
||||||
if cfg.Cache.Enabled {
|
|
||||||
c, err = cache.New(cfg.Cache.Address, cfg.Cache.TTL)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("WARNING: cache disabled: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Printf("Cache connected to %s", cfg.Cache.Address)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Provider registry
|
|
||||||
registry, err := provider.NewRegistry(cfg)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to build provider registry: %v", err)
|
|
||||||
}
|
|
||||||
log.Printf("Registered %d models", len(cfg.Models))
|
|
||||||
|
|
||||||
// Provider health tracker
|
|
||||||
healthTracker := provider.NewHealthTracker(5*time.Minute, cfg.CircuitBreaker)
|
|
||||||
|
|
||||||
// Webhook notifier
|
|
||||||
var notifier *webhook.Notifier
|
|
||||||
if len(cfg.Webhooks) > 0 {
|
|
||||||
notifier = webhook.NewNotifier(cfg.Webhooks)
|
|
||||||
defer notifier.Close()
|
|
||||||
log.Printf("Webhooks configured: %d endpoints", len(cfg.Webhooks))
|
|
||||||
|
|
||||||
// Wire health tracker state changes to webhook
|
|
||||||
healthTracker.OnStateChange = func(providerName string, from, to provider.CircuitState) {
|
|
||||||
eventType := webhook.EventCircuitBreakerOpen
|
|
||||||
if to == provider.CircuitClosed {
|
|
||||||
eventType = webhook.EventCircuitBreakerClosed
|
|
||||||
}
|
|
||||||
notifier.Notify(webhook.Event{
|
|
||||||
Type: eventType,
|
|
||||||
Data: map[string]any{
|
|
||||||
"provider": providerName,
|
|
||||||
"from": from.String(),
|
|
||||||
"to": to.String(),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auth store (static tokens checked in-memory, not seeded to DB)
|
|
||||||
var staticTokens []auth.StaticToken
|
|
||||||
for _, t := range cfg.Tokens {
|
|
||||||
if t.Key != "" {
|
|
||||||
staticTokens = append(staticTokens, auth.StaticToken{
|
|
||||||
Name: t.Name,
|
|
||||||
Key: t.Key,
|
|
||||||
RateLimitRPM: t.RateLimitRPM,
|
|
||||||
DailyBudgetUSD: t.DailyBudgetUSD,
|
|
||||||
MonthlyBudgetUSD: t.MonthlyBudgetUSD,
|
|
||||||
MaxConcurrent: t.MaxConcurrent,
|
|
||||||
})
|
|
||||||
log.Printf("Loaded static token: %s", t.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
authStore := auth.NewStore(db.DB, staticTokens)
|
|
||||||
authMiddleware := auth.NewMiddleware(authStore)
|
|
||||||
authHandlers := auth.NewHandlers(authStore, cfg.Server.SessionSecret)
|
|
||||||
|
|
||||||
// Audit logger
|
|
||||||
auditLogger := storage.NewAuditLogger(db)
|
|
||||||
auditLogger.OnWrite = sseBroker.Notify
|
|
||||||
authHandlers.SetAuditLogger(auditLogger)
|
|
||||||
|
|
||||||
// Debug logger
|
|
||||||
debugDataDir := cfg.Debug.DataDir
|
|
||||||
if debugDataDir == "" {
|
|
||||||
debugDataDir = filepath.Dir(cfg.Database.Path)
|
|
||||||
}
|
|
||||||
debugLogger := storage.NewDebugLogger(db, cfg.Debug.Enabled, debugDataDir)
|
|
||||||
debugLogger.OnWrite = sseBroker.Notify
|
|
||||||
|
|
||||||
// Seed default admin
|
|
||||||
seedDefaultAdmin(cfg, authStore)
|
|
||||||
|
|
||||||
// Metrics
|
|
||||||
m := metrics.New()
|
|
||||||
|
|
||||||
// Handlers
|
|
||||||
proxyHandler := proxy.NewHandler(registry, asyncLogger, c, m, cfg, healthTracker)
|
|
||||||
proxyHandler.SetDebugLogger(debugLogger)
|
|
||||||
|
|
||||||
// Request deduplication
|
|
||||||
if cfg.Dedup.Enabled {
|
|
||||||
dedup := proxy.NewDeduplicator(cfg.Dedup.Window)
|
|
||||||
defer dedup.Close()
|
|
||||||
proxyHandler.SetDeduplicator(dedup)
|
|
||||||
log.Printf("Request deduplication enabled (window: %v)", cfg.Dedup.Window)
|
|
||||||
}
|
|
||||||
|
|
||||||
modelsHandler := proxy.NewModelsHandler(registry, healthTracker, cfg)
|
|
||||||
proxyAuth := proxy.NewAuthMiddleware(authStore)
|
|
||||||
rateLimiter := proxy.NewRateLimiter(db)
|
|
||||||
if notifier != nil {
|
|
||||||
rateLimiter.SetNotifier(notifier)
|
|
||||||
}
|
|
||||||
concurrencyLimiter := proxy.NewConcurrencyLimiter()
|
|
||||||
statsAPI := dashboard.NewStatsAPI(db, authStore)
|
|
||||||
statsAPI.SetHealthTracker(healthTracker)
|
|
||||||
statsAPI.SetAuditLogger(auditLogger)
|
|
||||||
statsAPI.SetDebugLogger(debugLogger)
|
|
||||||
statsAPI.SetConfigPath(*configPath)
|
|
||||||
if c != nil {
|
|
||||||
statsAPI.SetCache(c)
|
|
||||||
}
|
|
||||||
dash := dashboard.NewDashboard(authStore, statsAPI)
|
|
||||||
dash.SetRegistry(registry)
|
|
||||||
dash.SetAuditLogger(auditLogger)
|
|
||||||
dash.SetDebugLogger(debugLogger)
|
|
||||||
if c != nil {
|
|
||||||
dash.SetCache(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Export handler
|
|
||||||
exportHandler := dashboard.NewExportHandler(db, authStore)
|
|
||||||
|
|
||||||
// Router
|
|
||||||
r := chi.NewRouter()
|
|
||||||
|
|
||||||
// CORS (before other middleware)
|
|
||||||
if cfg.CORS.Enabled {
|
|
||||||
r.Use(gocors.Handler(gocors.Options{
|
|
||||||
AllowedOrigins: cfg.CORS.AllowedOrigins,
|
|
||||||
AllowedMethods: cfg.CORS.AllowedMethods,
|
|
||||||
AllowedHeaders: cfg.CORS.AllowedHeaders,
|
|
||||||
MaxAge: cfg.CORS.MaxAge,
|
|
||||||
AllowCredentials: true,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
r.Use(middleware.RealIP)
|
|
||||||
r.Use(middleware.Recoverer)
|
|
||||||
r.Use(middleware.RequestID)
|
|
||||||
|
|
||||||
// Health & metrics (public)
|
|
||||||
r.Get("/health", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if err := db.Ping(); err != nil {
|
|
||||||
http.Error(w, "database unhealthy", http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if c != nil {
|
|
||||||
if err := c.Ping(r.Context()); err != nil {
|
|
||||||
http.Error(w, "cache unhealthy", http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte("OK"))
|
|
||||||
})
|
|
||||||
r.Handle("/metrics", promhttp.Handler())
|
|
||||||
|
|
||||||
// OpenAI-compatible API (API token auth via Bearer header)
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(proxyAuth.Authenticate)
|
|
||||||
r.Use(rateLimiter.Check)
|
|
||||||
r.Use(concurrencyLimiter.Check)
|
|
||||||
r.Post("/v1/chat/completions", proxyHandler.ChatCompletions)
|
|
||||||
r.Post("/v1/embeddings", proxyHandler.Embeddings)
|
|
||||||
r.Get("/v1/models", modelsHandler.ListModels)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Auth pages (public)
|
|
||||||
r.Get("/login", dash.LoginPage)
|
|
||||||
r.Get("/setup", dash.SetupPage)
|
|
||||||
|
|
||||||
// Auth API endpoints (public)
|
|
||||||
r.Post("/api/auth/login", authHandlers.Login)
|
|
||||||
r.Post("/api/auth/setup", authHandlers.Setup)
|
|
||||||
r.Post("/api/auth/login/totp", authHandlers.LoginTOTP)
|
|
||||||
|
|
||||||
// Favicon (prevent 401 noise in browser console)
|
|
||||||
r.Get("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusNoContent)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Root redirect
|
|
||||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Authenticated pages and API
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(authMiddleware.RequireAuth)
|
|
||||||
|
|
||||||
// Dashboard pages (HTMX)
|
|
||||||
r.Get("/dashboard", dash.DashboardPage)
|
|
||||||
r.Get("/logs", dash.LogsPage)
|
|
||||||
r.Get("/models", dash.ModelsPage)
|
|
||||||
r.Get("/tokens", dash.TokensPage)
|
|
||||||
r.Get("/settings", dash.SettingsPage)
|
|
||||||
|
|
||||||
// Admin-only pages
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(authMiddleware.RequireAdmin)
|
|
||||||
r.Get("/users", dash.UsersPage)
|
|
||||||
r.Get("/audit", dash.AuditPage)
|
|
||||||
r.Get("/debug", dash.DebugPage)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Auth API
|
|
||||||
r.Post("/api/auth/logout", authHandlers.Logout)
|
|
||||||
r.Get("/api/auth/me", authHandlers.Me)
|
|
||||||
r.Put("/api/auth/me/password", authHandlers.ChangePassword)
|
|
||||||
r.Put("/api/auth/me/username", authHandlers.ChangeUsername)
|
|
||||||
r.Put("/api/auth/me/email", authHandlers.ChangeEmail)
|
|
||||||
r.Post("/api/auth/totp/setup", authHandlers.TOTPSetup)
|
|
||||||
r.Post("/api/auth/totp/verify", authHandlers.TOTPVerify)
|
|
||||||
r.Delete("/api/auth/totp", authHandlers.TOTPDisable)
|
|
||||||
|
|
||||||
// API token management
|
|
||||||
r.Get("/api/tokens", authHandlers.ListTokens)
|
|
||||||
r.Post("/api/tokens", authHandlers.CreateToken)
|
|
||||||
r.Delete("/api/tokens/{id}", authHandlers.DeleteToken)
|
|
||||||
|
|
||||||
// SSE events
|
|
||||||
r.Get("/api/events", sseBroker.ServeHTTP)
|
|
||||||
|
|
||||||
// Dashboard stats
|
|
||||||
r.Get("/api/stats/summary", statsAPI.Summary)
|
|
||||||
r.Get("/api/stats/models", statsAPI.Models)
|
|
||||||
r.Get("/api/stats/providers", statsAPI.Providers)
|
|
||||||
r.Get("/api/stats/tokens", statsAPI.Tokens)
|
|
||||||
r.Get("/api/stats/timeseries", statsAPI.Timeseries)
|
|
||||||
r.Get("/api/stats/logs", statsAPI.Logs)
|
|
||||||
r.Get("/api/stats/latency", statsAPI.Latency)
|
|
||||||
r.Get("/api/stats/cost-breakdown", statsAPI.CostBreakdown)
|
|
||||||
r.Get("/api/stats/provider-health", statsAPI.ProviderHealthHandler)
|
|
||||||
r.Get("/api/stats/cache", statsAPI.CacheStats)
|
|
||||||
|
|
||||||
// Data export
|
|
||||||
r.Get("/api/export/logs", exportHandler.ExportLogs)
|
|
||||||
r.Get("/api/export/stats", exportHandler.ExportStats)
|
|
||||||
|
|
||||||
// Admin-only: user management, audit, debug, config validation
|
|
||||||
r.Group(func(r chi.Router) {
|
|
||||||
r.Use(authMiddleware.RequireAdmin)
|
|
||||||
r.Get("/api/auth/users", authHandlers.ListUsers)
|
|
||||||
r.Post("/api/auth/users", authHandlers.CreateUser)
|
|
||||||
r.Delete("/api/auth/users/{id}", authHandlers.DeleteUser)
|
|
||||||
|
|
||||||
// Audit log
|
|
||||||
r.Get("/api/stats/audit", statsAPI.AuditLogs)
|
|
||||||
|
|
||||||
// Config validation
|
|
||||||
r.Get("/api/config/validate", statsAPI.ValidateConfig)
|
|
||||||
|
|
||||||
// Debug logging
|
|
||||||
r.Post("/api/debug/toggle", statsAPI.DebugToggle)
|
|
||||||
r.Get("/api/debug/status", statsAPI.DebugStatus)
|
|
||||||
r.Get("/api/debug/logs", statsAPI.DebugLogs)
|
|
||||||
r.Get("/api/debug/logs/{requestID}", statsAPI.DebugLogByRequestID)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
// Periodic session cleanup and debug log cleanup
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(1 * time.Hour)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for range ticker.C {
|
|
||||||
if err := authStore.CleanExpiredSessions(); err != nil {
|
|
||||||
log.Printf("WARNING: session cleanup failed: %v", err)
|
|
||||||
}
|
|
||||||
if err := debugLogger.Cleanup(cfg.Debug.RetentionDays); err != nil {
|
|
||||||
log.Printf("WARNING: debug log cleanup failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Server
|
|
||||||
srv := &http.Server{
|
|
||||||
Addr: cfg.Server.Listen,
|
|
||||||
Handler: r,
|
|
||||||
ReadTimeout: 30 * time.Second,
|
|
||||||
WriteTimeout: cfg.Server.RequestTimeout + 10*time.Second,
|
|
||||||
IdleTimeout: 120 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config hot-reload via SIGHUP
|
|
||||||
config.WatchReload(*configPath, func(newCfg *config.Config) {
|
|
||||||
// Reload registry (models, providers, routes)
|
|
||||||
if err := registry.Reload(newCfg); err != nil {
|
|
||||||
log.Printf("ERROR: registry reload failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("Reloaded %d models", len(newCfg.Models))
|
|
||||||
|
|
||||||
// Reload pricing
|
|
||||||
for i, m := range newCfg.Models {
|
|
||||||
for j, rt := range m.Routes {
|
|
||||||
if rt.Pricing.Input == 0 && rt.Pricing.Output == 0 {
|
|
||||||
pricingLookup.FillMissing(rt.Provider, rt.Model,
|
|
||||||
&newCfg.Models[i].Routes[j].Pricing.Input,
|
|
||||||
&newCfg.Models[i].Routes[j].Pricing.Output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload static tokens
|
|
||||||
var newStaticTokens []auth.StaticToken
|
|
||||||
for _, t := range newCfg.Tokens {
|
|
||||||
if t.Key != "" {
|
|
||||||
newStaticTokens = append(newStaticTokens, auth.StaticToken{
|
|
||||||
Name: t.Name,
|
|
||||||
Key: t.Key,
|
|
||||||
RateLimitRPM: t.RateLimitRPM,
|
|
||||||
DailyBudgetUSD: t.DailyBudgetUSD,
|
|
||||||
MonthlyBudgetUSD: t.MonthlyBudgetUSD,
|
|
||||||
MaxConcurrent: t.MaxConcurrent,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
authStore.SetStaticTokens(newStaticTokens)
|
|
||||||
|
|
||||||
// Update config pointer for retry/debug/etc
|
|
||||||
cfg = newCfg
|
|
||||||
})
|
|
||||||
|
|
||||||
// Graceful shutdown
|
|
||||||
done := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(done, os.Interrupt, syscall.SIGTERM)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
log.Printf("Listening on %s", cfg.Server.Listen)
|
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
||||||
log.Fatalf("Server failed: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-done
|
|
||||||
log.Println("Shutting down...")
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
srv.Shutdown(ctx)
|
|
||||||
|
|
||||||
log.Println("Stopped")
|
|
||||||
}
|
|
||||||
|
|
||||||
// seedDefaultAdmin creates the default admin user if no users exist.
|
|
||||||
func seedDefaultAdmin(cfg *config.Config, authStore *auth.Store) {
|
|
||||||
if !authStore.HasAnyUser() {
|
|
||||||
da := cfg.Server.DefaultAdmin
|
|
||||||
if da.Username != "" && da.Password != "" {
|
|
||||||
user, err := authStore.CreateUser(da.Username, da.Password, true)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("WARNING: failed to create default admin: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Printf("Created default admin user: %s (id=%d)", user.Username, user.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,140 +0,0 @@
|
||||||
server:
|
|
||||||
listen: "0.0.0.0:3000"
|
|
||||||
request_timeout: 300s
|
|
||||||
max_request_body_mb: 10
|
|
||||||
session_secret: "${SESSION_SECRET}"
|
|
||||||
default_admin:
|
|
||||||
username: "${ADMIN_USERNAME}"
|
|
||||||
password: "${ADMIN_PASSWORD}"
|
|
||||||
|
|
||||||
tokens:
|
|
||||||
- name: "open-webui"
|
|
||||||
key: "${OPENWEBUI_API_KEY}"
|
|
||||||
rate_limit_rpm: 0 # unlimited
|
|
||||||
daily_budget_usd: 5.0
|
|
||||||
- name: "rayandrew"
|
|
||||||
key: "${PERSONAL_API_KEY}"
|
|
||||||
rate_limit_rpm: 0 # unlimited
|
|
||||||
daily_budget_usd: 10.0
|
|
||||||
|
|
||||||
pricing_lookup:
|
|
||||||
# url: "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json" # default
|
|
||||||
refresh_interval: 6h
|
|
||||||
|
|
||||||
database:
|
|
||||||
path: "/data/gateway.db"
|
|
||||||
retention_days: 90
|
|
||||||
|
|
||||||
cache:
|
|
||||||
enabled: true
|
|
||||||
address: "valkey:6379"
|
|
||||||
ttl: 3600
|
|
||||||
|
|
||||||
providers:
|
|
||||||
- name: deepinfra
|
|
||||||
base_url: "https://api.deepinfra.com/v1/openai"
|
|
||||||
api_key: "${DEEPINFRA_API_KEY}"
|
|
||||||
priority: 1
|
|
||||||
timeout: 120s
|
|
||||||
- name: siliconflow
|
|
||||||
base_url: "https://api.siliconflow.com/v1"
|
|
||||||
api_key: "${SILICONFLOW_API_KEY}"
|
|
||||||
priority: 2
|
|
||||||
timeout: 120s
|
|
||||||
- name: openrouter
|
|
||||||
base_url: "https://openrouter.ai/api/v1"
|
|
||||||
api_key: "${OPENROUTER_API_KEY}"
|
|
||||||
priority: 3
|
|
||||||
timeout: 120s
|
|
||||||
- name: groq
|
|
||||||
base_url: "https://api.groq.com/openai/v1"
|
|
||||||
api_key: "${GROQ_API_KEY}"
|
|
||||||
priority: 1
|
|
||||||
timeout: 120s
|
|
||||||
- name: cerebras
|
|
||||||
base_url: "https://api.cerebras.ai/v1"
|
|
||||||
api_key: "${CEREBRAS_API_KEY}"
|
|
||||||
priority: 1
|
|
||||||
timeout: 120s
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: "deepseek-v3.2"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "deepseek-ai/DeepSeek-V3.2"
|
|
||||||
pricing: { input: 0.26, output: 0.38 }
|
|
||||||
- provider: siliconflow
|
|
||||||
model: "deepseek-ai/DeepSeek-V3.2"
|
|
||||||
pricing: { input: 0.27, output: 0.42 }
|
|
||||||
- provider: openrouter
|
|
||||||
model: "deepseek/deepseek-chat-v3-0324"
|
|
||||||
pricing: { input: 0.30, output: 0.88 }
|
|
||||||
|
|
||||||
- name: "llama-3.3-70b"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "llama-3.3-70b-versatile"
|
|
||||||
pricing: { input: 0, output: 0 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "meta-llama/Llama-3.3-70B-Instruct"
|
|
||||||
pricing: { input: 0.23, output: 0.40 }
|
|
||||||
|
|
||||||
- name: "llama-3.1-8b"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "llama-3.1-8b-instant"
|
|
||||||
pricing: { input: 0, output: 0 }
|
|
||||||
- provider: cerebras
|
|
||||||
model: "llama-3.1-8b"
|
|
||||||
pricing: { input: 0, output: 0 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
|
||||||
pricing: { input: 0.03, output: 0.05 }
|
|
||||||
|
|
||||||
- name: "qwen-2.5-72b"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "qwen-2.5-72b"
|
|
||||||
pricing: { input: 0, output: 0 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "Qwen/Qwen2.5-72B-Instruct"
|
|
||||||
pricing: { input: 0.23, output: 0.40 }
|
|
||||||
|
|
||||||
- name: "qwen-2.5-coder-32b"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "qwen-2.5-coder-32b"
|
|
||||||
pricing: { input: 0, output: 0 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "Qwen/Qwen2.5-Coder-32B-Instruct"
|
|
||||||
pricing: { input: 0.07, output: 0.16 }
|
|
||||||
|
|
||||||
- name: "gemma-2-9b"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "gemma2-9b-it"
|
|
||||||
pricing: { input: 0, output: 0 }
|
|
||||||
|
|
||||||
- name: "deepseek-r1"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "deepseek-ai/DeepSeek-R1"
|
|
||||||
pricing: { input: 0.40, output: 1.60 }
|
|
||||||
- provider: openrouter
|
|
||||||
model: "deepseek/deepseek-r1"
|
|
||||||
pricing: { input: 0.55, output: 2.19 }
|
|
||||||
|
|
||||||
- name: "deepseek-r1-distill-llama-70b"
|
|
||||||
routes:
|
|
||||||
- provider: groq
|
|
||||||
model: "deepseek-r1-distill-llama-70b"
|
|
||||||
pricing: { input: 0, output: 0 }
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
|
|
||||||
pricing: { input: 0.23, output: 0.69 }
|
|
||||||
|
|
||||||
- name: "deepseek-r1-distill-qwen-32b"
|
|
||||||
routes:
|
|
||||||
- provider: deepinfra
|
|
||||||
model: "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
|
||||||
pricing: { input: 0.07, output: 0.16 }
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
||||||
module llm-gateway
|
|
||||||
|
|
||||||
go 1.24.0
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/go-chi/chi/v5 v5.2.5
|
|
||||||
github.com/go-chi/cors v1.2.2
|
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1
|
|
||||||
github.com/pquerna/otp v1.5.0
|
|
||||||
github.com/prometheus/client_golang v1.23.2
|
|
||||||
github.com/redis/go-redis/v9 v9.17.3
|
|
||||||
golang.org/x/crypto v0.48.0
|
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
|
||||||
modernc.org/sqlite v1.45.0
|
|
||||||
)
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
|
||||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
|
||||||
github.com/kr/text v0.2.0 // indirect
|
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
|
||||||
github.com/prometheus/client_model v0.6.2 // indirect
|
|
||||||
github.com/prometheus/common v0.66.1 // indirect
|
|
||||||
github.com/prometheus/procfs v0.16.1 // indirect
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
|
||||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
|
||||||
golang.org/x/sys v0.41.0 // indirect
|
|
||||||
google.golang.org/protobuf v1.36.8 // indirect
|
|
||||||
modernc.org/libc v1.67.6 // indirect
|
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
|
||||||
modernc.org/memory v1.11.0 // indirect
|
|
||||||
)
|
|
||||||
|
|
@ -1,123 +0,0 @@
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
|
||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
|
||||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
|
|
||||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
|
||||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
|
||||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
|
||||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
|
||||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
|
||||||
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
|
|
||||||
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
|
|
||||||
github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE=
|
|
||||||
github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA=
|
|
||||||
github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE=
|
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
|
||||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
|
||||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
|
||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
|
||||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
|
||||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
|
||||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
|
||||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
|
||||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
|
||||||
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
|
|
||||||
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
|
|
||||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
|
||||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
|
||||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
|
||||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
|
||||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
|
||||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
|
||||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
|
||||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
|
||||||
github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4=
|
|
||||||
github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
|
||||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
|
||||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
|
||||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
|
||||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
|
||||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
|
||||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
|
||||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
|
||||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
|
||||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
|
||||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
|
||||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
|
||||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
|
||||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
|
||||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
|
||||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
|
||||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
|
||||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
||||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
|
||||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
|
||||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
|
||||||
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
|
||||||
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
|
||||||
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
|
||||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
|
||||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
|
||||||
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
|
||||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
|
||||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
|
||||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
|
||||||
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
|
|
||||||
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
|
|
||||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
|
||||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
|
||||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
|
||||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
|
||||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
|
||||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
|
||||||
modernc.org/sqlite v1.45.0 h1:r51cSGzKpbptxnby+EIIz5fop4VuE4qFoVEjNvWoObs=
|
|
||||||
modernc.org/sqlite v1.45.0/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
|
||||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,83 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type contextKey string
|
|
||||||
|
|
||||||
const userContextKey contextKey = "auth_user"
|
|
||||||
|
|
||||||
const (
|
|
||||||
sessionCookieName = "llmgw_session"
|
|
||||||
sessionTTLDays = 7
|
|
||||||
)
|
|
||||||
|
|
||||||
type Middleware struct {
|
|
||||||
store *Store
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMiddleware(store *Store) *Middleware {
|
|
||||||
return &Middleware{store: store}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UserFromContext(ctx context.Context) *User {
|
|
||||||
u, _ := ctx.Value(userContextKey).(*User)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
cookie, err := r.Cookie(sessionCookieName)
|
|
||||||
if err != nil || cookie.Value == "" {
|
|
||||||
m.unauthorized(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sess, err := m.store.GetSession(cookie.Value)
|
|
||||||
if err != nil {
|
|
||||||
m.unauthorized(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := m.store.GetUserByID(sess.UserID)
|
|
||||||
if err != nil {
|
|
||||||
m.unauthorized(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.WithValue(r.Context(), userContextKey, user)
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Middleware) RequireAdmin(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := UserFromContext(r.Context())
|
|
||||||
if user == nil || !user.IsAdmin {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusForbidden)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": "admin access required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Middleware) unauthorized(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Header.Get("HX-Request") == "true" {
|
|
||||||
w.Header().Set("HX-Redirect", "/login")
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(r.URL.Path, "/api/") {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
json.NewEncoder(w).Encode(map[string]string{"error": "authentication required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
http.Redirect(w, r, "/login", http.StatusFound)
|
|
||||||
}
|
|
||||||
|
|
@ -1,391 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
PasswordHash string `json:"-"`
|
|
||||||
IsAdmin bool `json:"is_admin"`
|
|
||||||
TOTPSecret string `json:"-"`
|
|
||||||
TOTPEnabled bool `json:"totp_enabled"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
|
||||||
UpdatedAt int64 `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Session struct {
|
|
||||||
ID string
|
|
||||||
UserID int64
|
|
||||||
CreatedAt int64
|
|
||||||
ExpiresAt int64
|
|
||||||
}
|
|
||||||
|
|
||||||
type APIToken struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
KeyPrefix string `json:"key_prefix"`
|
|
||||||
KeyHash string `json:"-"`
|
|
||||||
UserID int64 `json:"user_id"`
|
|
||||||
RateLimitRPM int `json:"rate_limit_rpm"`
|
|
||||||
DailyBudgetUSD float64 `json:"daily_budget_usd"`
|
|
||||||
MonthlyBudgetUSD float64 `json:"monthly_budget_usd"`
|
|
||||||
MaxConcurrent int `json:"max_concurrent"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
|
||||||
LastUsedAt int64 `json:"last_used_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// StaticToken represents a token defined in config (checked in-memory, never stored in DB).
|
|
||||||
type StaticToken struct {
|
|
||||||
Name string
|
|
||||||
Key string
|
|
||||||
RateLimitRPM int
|
|
||||||
DailyBudgetUSD float64
|
|
||||||
MonthlyBudgetUSD float64
|
|
||||||
MaxConcurrent int
|
|
||||||
}
|
|
||||||
|
|
||||||
type Store struct {
|
|
||||||
db *sql.DB
|
|
||||||
staticTokens []StaticToken
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStore(db *sql.DB, staticTokens []StaticToken) *Store {
|
|
||||||
return &Store{db: db, staticTokens: staticTokens}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetStaticTokens updates the static tokens list (used for config hot-reload).
|
|
||||||
func (s *Store) SetStaticTokens(tokens []StaticToken) {
|
|
||||||
s.staticTokens = tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) HasAnyUser() bool {
|
|
||||||
var count int
|
|
||||||
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) CreateUser(username, password string, isAdmin bool) (*User, error) {
|
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), 12)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("hashing password: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().Unix()
|
|
||||||
adminInt := 0
|
|
||||||
if isAdmin {
|
|
||||||
adminInt = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := s.db.Exec(
|
|
||||||
"INSERT INTO users (username, password_hash, is_admin, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
username, string(hash), adminInt, now, now,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("creating user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
id, _ := result.LastInsertId()
|
|
||||||
return &User{
|
|
||||||
ID: id,
|
|
||||||
Username: username,
|
|
||||||
IsAdmin: isAdmin,
|
|
||||||
CreatedAt: now,
|
|
||||||
UpdatedAt: now,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetUserByUsername(username string) (*User, error) {
|
|
||||||
return s.scanUser(s.db.QueryRow(
|
|
||||||
"SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users WHERE username = ?",
|
|
||||||
username,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetUserByID(id int64) (*User, error) {
|
|
||||||
return s.scanUser(s.db.QueryRow(
|
|
||||||
"SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users WHERE id = ?",
|
|
||||||
id,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) scanUser(row *sql.Row) (*User, error) {
|
|
||||||
var u User
|
|
||||||
var isAdmin, totpEnabled int
|
|
||||||
var totpSecret sql.NullString
|
|
||||||
var email sql.NullString
|
|
||||||
err := row.Scan(&u.ID, &u.Username, &email, &u.PasswordHash, &isAdmin, &totpSecret, &totpEnabled, &u.CreatedAt, &u.UpdatedAt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
u.Email = email.String
|
|
||||||
u.IsAdmin = isAdmin == 1
|
|
||||||
u.TOTPEnabled = totpEnabled == 1
|
|
||||||
u.TOTPSecret = totpSecret.String
|
|
||||||
return &u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) ListUsers() ([]User, error) {
|
|
||||||
rows, err := s.db.Query("SELECT id, username, email, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users ORDER BY id")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var users []User
|
|
||||||
for rows.Next() {
|
|
||||||
var u User
|
|
||||||
var isAdmin, totpEnabled int
|
|
||||||
var totpSecret sql.NullString
|
|
||||||
var email sql.NullString
|
|
||||||
if err := rows.Scan(&u.ID, &u.Username, &email, &u.PasswordHash, &isAdmin, &totpSecret, &totpEnabled, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
u.Email = email.String
|
|
||||||
u.IsAdmin = isAdmin == 1
|
|
||||||
u.TOTPEnabled = totpEnabled == 1
|
|
||||||
u.TOTPSecret = totpSecret.String
|
|
||||||
users = append(users, u)
|
|
||||||
}
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteUser(id int64) error {
|
|
||||||
// Prevent deleting the last admin
|
|
||||||
var adminCount int
|
|
||||||
s.db.QueryRow("SELECT COUNT(*) FROM users WHERE is_admin = 1").Scan(&adminCount)
|
|
||||||
|
|
||||||
var isAdmin int
|
|
||||||
s.db.QueryRow("SELECT is_admin FROM users WHERE id = ?", id).Scan(&isAdmin)
|
|
||||||
if isAdmin == 1 && adminCount <= 1 {
|
|
||||||
return fmt.Errorf("cannot delete the last admin user")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := s.db.Exec("DELETE FROM users WHERE id = ?", id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) UpdatePassword(userID int64, newPassword string) error {
|
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("hashing password: %w", err)
|
|
||||||
}
|
|
||||||
_, err = s.db.Exec("UPDATE users SET password_hash = ?, updated_at = ? WHERE id = ?", string(hash), time.Now().Unix(), userID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) CheckPassword(user *User, password string) bool {
|
|
||||||
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) SetTOTPSecret(userID int64, secret string) error {
|
|
||||||
_, err := s.db.Exec("UPDATE users SET totp_secret = ?, updated_at = ? WHERE id = ?", secret, time.Now().Unix(), userID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) EnableTOTP(userID int64) error {
|
|
||||||
_, err := s.db.Exec("UPDATE users SET totp_enabled = 1, updated_at = ? WHERE id = ?", time.Now().Unix(), userID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DisableTOTP(userID int64) error {
|
|
||||||
_, err := s.db.Exec("UPDATE users SET totp_enabled = 0, totp_secret = '', updated_at = ? WHERE id = ?", time.Now().Unix(), userID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Session management
|
|
||||||
|
|
||||||
func (s *Store) CreateSession(userID int64, ttl time.Duration) (string, error) {
|
|
||||||
b := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(b); err != nil {
|
|
||||||
return "", fmt.Errorf("generating session ID: %w", err)
|
|
||||||
}
|
|
||||||
id := hex.EncodeToString(b)
|
|
||||||
now := time.Now().Unix()
|
|
||||||
expiresAt := time.Now().Add(ttl).Unix()
|
|
||||||
|
|
||||||
_, err := s.db.Exec(
|
|
||||||
"INSERT INTO sessions (id, user_id, created_at, expires_at) VALUES (?, ?, ?, ?)",
|
|
||||||
id, userID, now, expiresAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("creating session: %w", err)
|
|
||||||
}
|
|
||||||
return id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetSession(sessionID string) (*Session, error) {
|
|
||||||
var sess Session
|
|
||||||
err := s.db.QueryRow(
|
|
||||||
"SELECT id, user_id, created_at, expires_at FROM sessions WHERE id = ? AND expires_at > ?",
|
|
||||||
sessionID, time.Now().Unix(),
|
|
||||||
).Scan(&sess.ID, &sess.UserID, &sess.CreatedAt, &sess.ExpiresAt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &sess, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteSession(id string) error {
|
|
||||||
_, err := s.db.Exec("DELETE FROM sessions WHERE id = ?", id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) CleanExpiredSessions() error {
|
|
||||||
_, err := s.db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now().Unix())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// API Token management
|
|
||||||
|
|
||||||
func (s *Store) CreateAPIToken(userID int64, name string, rateLimitRPM int, dailyBudgetUSD float64) (string, *APIToken, error) {
|
|
||||||
// Generate sk- prefixed random key
|
|
||||||
b := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(b); err != nil {
|
|
||||||
return "", nil, fmt.Errorf("generating token: %w", err)
|
|
||||||
}
|
|
||||||
plainKey := "sk-" + hex.EncodeToString(b)
|
|
||||||
keyPrefix := plainKey[:11] // "sk-" + first 8 hex chars
|
|
||||||
|
|
||||||
hash := sha256.Sum256([]byte(plainKey))
|
|
||||||
keyHash := hex.EncodeToString(hash[:])
|
|
||||||
|
|
||||||
now := time.Now().Unix()
|
|
||||||
result, err := s.db.Exec(
|
|
||||||
"INSERT INTO api_tokens (name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
|
||||||
name, keyHash, keyPrefix, userID, rateLimitRPM, dailyBudgetUSD, now,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("creating API token: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
id, _ := result.LastInsertId()
|
|
||||||
token := &APIToken{
|
|
||||||
ID: id,
|
|
||||||
Name: name,
|
|
||||||
KeyPrefix: keyPrefix,
|
|
||||||
KeyHash: keyHash,
|
|
||||||
UserID: userID,
|
|
||||||
RateLimitRPM: rateLimitRPM,
|
|
||||||
DailyBudgetUSD: dailyBudgetUSD,
|
|
||||||
CreatedAt: now,
|
|
||||||
}
|
|
||||||
return plainKey, token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) LookupAPIToken(key string) (*APIToken, error) {
|
|
||||||
// Check static tokens first (from config, never stored in DB)
|
|
||||||
for _, st := range s.staticTokens {
|
|
||||||
if st.Key == key {
|
|
||||||
prefix := st.Key
|
|
||||||
if len(prefix) > 11 {
|
|
||||||
prefix = prefix[:11]
|
|
||||||
}
|
|
||||||
return &APIToken{
|
|
||||||
ID: -1, // sentinel: static token
|
|
||||||
Name: st.Name,
|
|
||||||
KeyPrefix: prefix,
|
|
||||||
RateLimitRPM: st.RateLimitRPM,
|
|
||||||
DailyBudgetUSD: st.DailyBudgetUSD,
|
|
||||||
MonthlyBudgetUSD: st.MonthlyBudgetUSD,
|
|
||||||
MaxConcurrent: st.MaxConcurrent,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fall back to DB tokens
|
|
||||||
hash := sha256.Sum256([]byte(key))
|
|
||||||
keyHash := hex.EncodeToString(hash[:])
|
|
||||||
|
|
||||||
var t APIToken
|
|
||||||
err := s.db.QueryRow(
|
|
||||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE key_hash = ?",
|
|
||||||
keyHash,
|
|
||||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) ListAPITokens(userID int64) ([]APIToken, error) {
|
|
||||||
// Include static tokens (shown for all users, not deletable)
|
|
||||||
var tokens []APIToken
|
|
||||||
for _, st := range s.staticTokens {
|
|
||||||
prefix := st.Key
|
|
||||||
if len(prefix) > 11 {
|
|
||||||
prefix = prefix[:11]
|
|
||||||
}
|
|
||||||
tokens = append(tokens, APIToken{
|
|
||||||
ID: -1, // sentinel: static token
|
|
||||||
Name: st.Name,
|
|
||||||
KeyPrefix: prefix,
|
|
||||||
RateLimitRPM: st.RateLimitRPM,
|
|
||||||
DailyBudgetUSD: st.DailyBudgetUSD,
|
|
||||||
MonthlyBudgetUSD: st.MonthlyBudgetUSD,
|
|
||||||
MaxConcurrent: st.MaxConcurrent,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// DB tokens
|
|
||||||
var rows *sql.Rows
|
|
||||||
var err error
|
|
||||||
if userID == 0 {
|
|
||||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens ORDER BY id")
|
|
||||||
} else {
|
|
||||||
rows, err = s.db.Query("SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE user_id = ? ORDER BY id", userID)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return tokens, nil
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
var t APIToken
|
|
||||||
if err := rows.Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt); err != nil {
|
|
||||||
return tokens, nil
|
|
||||||
}
|
|
||||||
tokens = append(tokens, t)
|
|
||||||
}
|
|
||||||
return tokens, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) DeleteAPIToken(id int64) error {
|
|
||||||
_, err := s.db.Exec("DELETE FROM api_tokens WHERE id = ?", id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) GetAPIToken(id int64) (*APIToken, error) {
|
|
||||||
var t APIToken
|
|
||||||
err := s.db.QueryRow(
|
|
||||||
"SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, COALESCE(monthly_budget_usd, 0), COALESCE(max_concurrent, 0), created_at, last_used_at FROM api_tokens WHERE id = ?",
|
|
||||||
id,
|
|
||||||
).Scan(&t.ID, &t.Name, &t.KeyHash, &t.KeyPrefix, &t.UserID, &t.RateLimitRPM, &t.DailyBudgetUSD, &t.MonthlyBudgetUSD, &t.MaxConcurrent, &t.CreatedAt, &t.LastUsedAt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) UpdateAPITokenLastUsed(id int64) {
|
|
||||||
s.db.Exec("UPDATE api_tokens SET last_used_at = ? WHERE id = ?", time.Now().Unix(), id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) UpdateUsername(userID int64, newUsername string) error {
|
|
||||||
_, err := s.db.Exec("UPDATE users SET username = ?, updated_at = ? WHERE id = ?", newUsername, time.Now().Unix(), userID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Store) UpdateEmail(userID int64, email string) error {
|
|
||||||
_, err := s.db.Exec("UPDATE users SET email = ?, updated_at = ? WHERE id = ?", email, time.Now().Unix(), userID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
@ -1,301 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setupTestDB(t *testing.T) *sql.DB {
|
|
||||||
t.Helper()
|
|
||||||
db, err := sql.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("opening test db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create tables
|
|
||||||
_, err = db.Exec(`
|
|
||||||
CREATE TABLE users (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
username TEXT UNIQUE NOT NULL,
|
|
||||||
email TEXT DEFAULT '',
|
|
||||||
password_hash TEXT NOT NULL,
|
|
||||||
is_admin INTEGER DEFAULT 0,
|
|
||||||
totp_secret TEXT DEFAULT '',
|
|
||||||
totp_enabled INTEGER DEFAULT 0,
|
|
||||||
created_at INTEGER NOT NULL,
|
|
||||||
updated_at INTEGER NOT NULL
|
|
||||||
);
|
|
||||||
CREATE TABLE sessions (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
user_id INTEGER NOT NULL,
|
|
||||||
created_at INTEGER NOT NULL,
|
|
||||||
expires_at INTEGER NOT NULL
|
|
||||||
);
|
|
||||||
CREATE TABLE api_tokens (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
key_hash TEXT NOT NULL,
|
|
||||||
key_prefix TEXT NOT NULL,
|
|
||||||
user_id INTEGER NOT NULL,
|
|
||||||
rate_limit_rpm INTEGER DEFAULT 0,
|
|
||||||
daily_budget_usd REAL DEFAULT 0,
|
|
||||||
monthly_budget_usd REAL DEFAULT 0,
|
|
||||||
max_concurrent INTEGER DEFAULT 0,
|
|
||||||
created_at INTEGER NOT NULL,
|
|
||||||
last_used_at INTEGER DEFAULT 0
|
|
||||||
);
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("creating tables: %v", err)
|
|
||||||
}
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateUser(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
store := NewStore(db, nil)
|
|
||||||
|
|
||||||
user, err := store.CreateUser("alice", "password123", true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("CreateUser: %v", err)
|
|
||||||
}
|
|
||||||
if user.Username != "alice" {
|
|
||||||
t.Errorf("expected username 'alice', got '%s'", user.Username)
|
|
||||||
}
|
|
||||||
if !user.IsAdmin {
|
|
||||||
t.Error("expected admin user")
|
|
||||||
}
|
|
||||||
if user.ID == 0 {
|
|
||||||
t.Error("expected non-zero ID")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetUserByUsername(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
store := NewStore(db, nil)
|
|
||||||
|
|
||||||
store.CreateUser("bob", "password123", false)
|
|
||||||
|
|
||||||
user, err := store.GetUserByUsername("bob")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetUserByUsername: %v", err)
|
|
||||||
}
|
|
||||||
if user.Username != "bob" {
|
|
||||||
t.Errorf("expected 'bob', got '%s'", user.Username)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = store.GetUserByUsername("nonexistent")
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error for nonexistent user")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckPassword(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
store := NewStore(db, nil)
|
|
||||||
|
|
||||||
store.CreateUser("charlie", "correctpassword", false)
|
|
||||||
user, _ := store.GetUserByUsername("charlie")
|
|
||||||
|
|
||||||
if !store.CheckPassword(user, "correctpassword") {
|
|
||||||
t.Error("correct password should match")
|
|
||||||
}
|
|
||||||
if store.CheckPassword(user, "wrongpassword") {
|
|
||||||
t.Error("wrong password should not match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdatePassword(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
store := NewStore(db, nil)
|
|
||||||
|
|
||||||
user, _ := store.CreateUser("dave", "oldpass12", false)
|
|
||||||
|
|
||||||
if err := store.UpdatePassword(user.ID, "newpass12"); err != nil {
|
|
||||||
t.Fatalf("UpdatePassword: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user, _ = store.GetUserByUsername("dave")
|
|
||||||
if store.CheckPassword(user, "oldpass12") {
|
|
||||||
t.Error("old password should not work")
|
|
||||||
}
|
|
||||||
if !store.CheckPassword(user, "newpass12") {
|
|
||||||
t.Error("new password should work")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeleteUser(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
store := NewStore(db, nil)
|
|
||||||
|
|
||||||
user1, _ := store.CreateUser("admin1", "password1234", true)
|
|
||||||
user2, _ := store.CreateUser("user2", "password1234", false)
|
|
||||||
|
|
||||||
// Can delete non-admin
|
|
||||||
if err := store.DeleteUser(user2.ID); err != nil {
|
|
||||||
t.Fatalf("DeleteUser: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cannot delete last admin
|
|
||||||
if err := store.DeleteUser(user1.ID); err == nil {
|
|
||||||
t.Error("should not be able to delete last admin")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHasAnyUser(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
store := NewStore(db, nil)
|
|
||||||
|
|
||||||
if store.HasAnyUser() {
|
|
||||||
t.Error("should have no users initially")
|
|
||||||
}
|
|
||||||
|
|
||||||
store.CreateUser("first", "password1234", true)
|
|
||||||
|
|
||||||
if !store.HasAnyUser() {
|
|
||||||
t.Error("should have users after creation")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSessionCRUD(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
store := NewStore(db, nil)
|
|
||||||
|
|
||||||
user, _ := store.CreateUser("sessuser", "password1234", false)
|
|
||||||
|
|
||||||
sessionID, err := store.CreateSession(user.ID, 1*time.Hour)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("CreateSession: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sess, err := store.GetSession(sessionID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetSession: %v", err)
|
|
||||||
}
|
|
||||||
if sess.UserID != user.ID {
|
|
||||||
t.Errorf("expected user ID %d, got %d", user.ID, sess.UserID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := store.DeleteSession(sessionID); err != nil {
|
|
||||||
t.Fatalf("DeleteSession: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = store.GetSession(sessionID)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("session should be deleted")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStaticTokenLookup(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
staticTokens := []StaticToken{
|
|
||||||
{Name: "test-token", Key: "sk-test-key-12345678", RateLimitRPM: 60, DailyBudgetUSD: 10.0, MaxConcurrent: 5},
|
|
||||||
}
|
|
||||||
store := NewStore(db, staticTokens)
|
|
||||||
|
|
||||||
token, err := store.LookupAPIToken("sk-test-key-12345678")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("LookupAPIToken: %v", err)
|
|
||||||
}
|
|
||||||
if token.Name != "test-token" {
|
|
||||||
t.Errorf("expected 'test-token', got '%s'", token.Name)
|
|
||||||
}
|
|
||||||
if token.ID != -1 {
|
|
||||||
t.Errorf("static token should have ID -1, got %d", token.ID)
|
|
||||||
}
|
|
||||||
if token.RateLimitRPM != 60 {
|
|
||||||
t.Errorf("expected RPM 60, got %d", token.RateLimitRPM)
|
|
||||||
}
|
|
||||||
if token.MaxConcurrent != 5 {
|
|
||||||
t.Errorf("expected max_concurrent 5, got %d", token.MaxConcurrent)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non-existent token
|
|
||||||
_, err = store.LookupAPIToken("nonexistent")
|
|
||||||
if err == nil {
|
|
||||||
t.Error("should error on nonexistent token")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDBTokenCRUD(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
store := NewStore(db, nil)
|
|
||||||
|
|
||||||
user, _ := store.CreateUser("tokenuser", "password1234", false)
|
|
||||||
|
|
||||||
plainKey, token, err := store.CreateAPIToken(user.ID, "my-token", 100, 5.0)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("CreateAPIToken: %v", err)
|
|
||||||
}
|
|
||||||
if plainKey == "" {
|
|
||||||
t.Error("plain key should not be empty")
|
|
||||||
}
|
|
||||||
if token.Name != "my-token" {
|
|
||||||
t.Errorf("expected 'my-token', got '%s'", token.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup by key
|
|
||||||
found, err := store.LookupAPIToken(plainKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("LookupAPIToken: %v", err)
|
|
||||||
}
|
|
||||||
if found.Name != "my-token" {
|
|
||||||
t.Errorf("expected 'my-token', got '%s'", found.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// List tokens
|
|
||||||
tokens, err := store.ListAPITokens(user.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ListAPITokens: %v", err)
|
|
||||||
}
|
|
||||||
if len(tokens) != 1 {
|
|
||||||
t.Errorf("expected 1 token, got %d", len(tokens))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete
|
|
||||||
if err := store.DeleteAPIToken(token.ID); err != nil {
|
|
||||||
t.Fatalf("DeleteAPIToken: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = store.LookupAPIToken(plainKey)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("token should be deleted")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetStaticTokens(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
store := NewStore(db, nil)
|
|
||||||
|
|
||||||
_, err := store.LookupAPIToken("key1")
|
|
||||||
if err == nil {
|
|
||||||
t.Error("should not find token before setting")
|
|
||||||
}
|
|
||||||
|
|
||||||
store.SetStaticTokens([]StaticToken{
|
|
||||||
{Name: "new-token", Key: "key1"},
|
|
||||||
})
|
|
||||||
|
|
||||||
token, err := store.LookupAPIToken("key1")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("after SetStaticTokens: %v", err)
|
|
||||||
}
|
|
||||||
if token.Name != "new-token" {
|
|
||||||
t.Errorf("expected 'new-token', got '%s'", token.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/pquerna/otp"
|
|
||||||
"github.com/pquerna/otp/totp"
|
|
||||||
)
|
|
||||||
|
|
||||||
func GenerateTOTPKey(username string) (*otp.Key, error) {
|
|
||||||
return totp.Generate(totp.GenerateOpts{
|
|
||||||
Issuer: "LLM Gateway",
|
|
||||||
AccountName: username,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func ValidateTOTPCode(secret, code string) bool {
|
|
||||||
return totp.Validate(code, secret)
|
|
||||||
}
|
|
||||||
178
llm-gateway/internal/cache/cache.go
vendored
178
llm-gateway/internal/cache/cache.go
vendored
|
|
@ -1,178 +0,0 @@
|
||||||
package cache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Cache struct {
|
|
||||||
client *redis.Client
|
|
||||||
ttl time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(addr string, ttlSeconds int) (*Cache, error) {
|
|
||||||
client := redis.NewClient(&redis.Options{
|
|
||||||
Addr: addr,
|
|
||||||
})
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := client.Ping(ctx).Err(); err != nil {
|
|
||||||
return nil, fmt.Errorf("connecting to Valkey: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ttl := time.Duration(ttlSeconds) * time.Second
|
|
||||||
if ttl == 0 {
|
|
||||||
ttl = 1 * time.Hour
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Cache{client: client, ttl: ttl}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cache) Get(ctx context.Context, model string, requestBody []byte) ([]byte, error) {
|
|
||||||
key := c.cacheKey(model, requestBody)
|
|
||||||
data, err := c.client.Get(ctx, key).Bytes()
|
|
||||||
if err == redis.Nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return data, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cache) Set(ctx context.Context, model string, requestBody, responseBody []byte) error {
|
|
||||||
key := c.cacheKey(model, requestBody)
|
|
||||||
return c.client.Set(ctx, key, responseBody, c.ttl).Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cache) Ping(ctx context.Context) error {
|
|
||||||
return c.client.Ping(ctx).Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cache) Close() error {
|
|
||||||
return c.client.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheStats holds cache statistics from the Valkey/Redis server.
|
|
||||||
type CacheStats struct {
|
|
||||||
Hits int64 `json:"hits"`
|
|
||||||
Misses int64 `json:"misses"`
|
|
||||||
HitRate float64 `json:"hit_rate"`
|
|
||||||
MemoryUsed string `json:"memory_used"`
|
|
||||||
Keys int64 `json:"keys"`
|
|
||||||
Connected bool `json:"connected"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stats returns cache statistics by querying Valkey/Redis INFO.
|
|
||||||
func (c *Cache) Stats(ctx context.Context) *CacheStats {
|
|
||||||
stats := &CacheStats{}
|
|
||||||
|
|
||||||
// Check connectivity
|
|
||||||
if err := c.client.Ping(ctx).Err(); err != nil {
|
|
||||||
return stats
|
|
||||||
}
|
|
||||||
stats.Connected = true
|
|
||||||
|
|
||||||
// Parse INFO stats for hits/misses
|
|
||||||
info, err := c.client.Info(ctx, "stats").Result()
|
|
||||||
if err == nil {
|
|
||||||
stats.Hits = parseInfoInt(info, "keyspace_hits")
|
|
||||||
stats.Misses = parseInfoInt(info, "keyspace_misses")
|
|
||||||
total := stats.Hits + stats.Misses
|
|
||||||
if total > 0 {
|
|
||||||
stats.HitRate = float64(stats.Hits) / float64(total)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse INFO memory
|
|
||||||
memInfo, err := c.client.Info(ctx, "memory").Result()
|
|
||||||
if err == nil {
|
|
||||||
stats.MemoryUsed = parseInfoString(memInfo, "used_memory_human")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse INFO keyspace
|
|
||||||
ksInfo, err := c.client.Info(ctx, "keyspace").Result()
|
|
||||||
if err == nil {
|
|
||||||
stats.Keys = parseKeyspaceKeys(ksInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
return stats
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseInfoInt(info, key string) int64 {
|
|
||||||
prefix := key + ":"
|
|
||||||
for _, line := range splitLines(info) {
|
|
||||||
if len(line) > len(prefix) && line[:len(prefix)] == prefix {
|
|
||||||
var v int64
|
|
||||||
fmt.Sscanf(line[len(prefix):], "%d", &v)
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseInfoString(info, key string) string {
|
|
||||||
prefix := key + ":"
|
|
||||||
for _, line := range splitLines(info) {
|
|
||||||
if len(line) > len(prefix) && line[:len(prefix)] == prefix {
|
|
||||||
val := line[len(prefix):]
|
|
||||||
// Trim trailing \r
|
|
||||||
if len(val) > 0 && val[len(val)-1] == '\r' {
|
|
||||||
val = val[:len(val)-1]
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseKeyspaceKeys(info string) int64 {
|
|
||||||
// Format: db0:keys=123,expires=45,avg_ttl=6789
|
|
||||||
for _, line := range splitLines(info) {
|
|
||||||
if len(line) > 3 && line[:2] == "db" {
|
|
||||||
prefix := "keys="
|
|
||||||
idx := -1
|
|
||||||
for i := 0; i <= len(line)-len(prefix); i++ {
|
|
||||||
if line[i:i+len(prefix)] == prefix {
|
|
||||||
idx = i + len(prefix)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if idx >= 0 {
|
|
||||||
end := idx
|
|
||||||
for end < len(line) && line[end] >= '0' && line[end] <= '9' {
|
|
||||||
end++
|
|
||||||
}
|
|
||||||
var v int64
|
|
||||||
fmt.Sscanf(line[idx:end], "%d", &v)
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitLines(s string) []string {
|
|
||||||
var lines []string
|
|
||||||
start := 0
|
|
||||||
for i := 0; i < len(s); i++ {
|
|
||||||
if s[i] == '\n' {
|
|
||||||
lines = append(lines, s[start:i])
|
|
||||||
start = i + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if start < len(s) {
|
|
||||||
lines = append(lines, s[start:])
|
|
||||||
}
|
|
||||||
return lines
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cache) cacheKey(model string, requestBody []byte) string {
|
|
||||||
h := sha256.New()
|
|
||||||
h.Write([]byte(model))
|
|
||||||
h.Write(requestBody)
|
|
||||||
return fmt.Sprintf("llm-gw:%x", h.Sum(nil))
|
|
||||||
}
|
|
||||||
112
llm-gateway/internal/cache/cache_test.go
vendored
112
llm-gateway/internal/cache/cache_test.go
vendored
|
|
@ -1,112 +0,0 @@
|
||||||
package cache
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCacheKey_Deterministic(t *testing.T) {
|
|
||||||
c := &Cache{}
|
|
||||||
|
|
||||||
model := "gpt-4"
|
|
||||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
|
||||||
|
|
||||||
key1 := c.cacheKey(model, body)
|
|
||||||
key2 := c.cacheKey(model, body)
|
|
||||||
|
|
||||||
if key1 != key2 {
|
|
||||||
t.Errorf("cache key not deterministic: %s != %s", key1, key2)
|
|
||||||
}
|
|
||||||
|
|
||||||
if key1 == "" {
|
|
||||||
t.Error("cache key is empty")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheKey_DifferentInputs(t *testing.T) {
|
|
||||||
c := &Cache{}
|
|
||||||
|
|
||||||
body := []byte(`{"messages":[{"role":"user","content":"hello"}]}`)
|
|
||||||
|
|
||||||
key1 := c.cacheKey("gpt-4", body)
|
|
||||||
key2 := c.cacheKey("gpt-3.5", body)
|
|
||||||
|
|
||||||
if key1 == key2 {
|
|
||||||
t.Error("different models should produce different cache keys")
|
|
||||||
}
|
|
||||||
|
|
||||||
key3 := c.cacheKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"world"}]}`))
|
|
||||||
if key1 == key3 {
|
|
||||||
t.Error("different bodies should produce different cache keys")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheKey_HasPrefix(t *testing.T) {
|
|
||||||
c := &Cache{}
|
|
||||||
key := c.cacheKey("gpt-4", []byte("test"))
|
|
||||||
|
|
||||||
if len(key) < 7 || key[:7] != "llm-gw:" {
|
|
||||||
t.Errorf("cache key should start with 'llm-gw:', got: %s", key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseInfoInt(t *testing.T) {
|
|
||||||
info := "keyspace_hits:42\nkeyspace_misses:10\n"
|
|
||||||
|
|
||||||
hits := parseInfoInt(info, "keyspace_hits")
|
|
||||||
if hits != 42 {
|
|
||||||
t.Errorf("expected 42, got %d", hits)
|
|
||||||
}
|
|
||||||
|
|
||||||
misses := parseInfoInt(info, "keyspace_misses")
|
|
||||||
if misses != 10 {
|
|
||||||
t.Errorf("expected 10, got %d", misses)
|
|
||||||
}
|
|
||||||
|
|
||||||
unknown := parseInfoInt(info, "nonexistent")
|
|
||||||
if unknown != 0 {
|
|
||||||
t.Errorf("expected 0 for unknown key, got %d", unknown)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseInfoString(t *testing.T) {
|
|
||||||
info := "used_memory_human:1.5M\r\nother:value\r\n"
|
|
||||||
|
|
||||||
mem := parseInfoString(info, "used_memory_human")
|
|
||||||
if mem != "1.5M" {
|
|
||||||
t.Errorf("expected '1.5M', got '%s'", mem)
|
|
||||||
}
|
|
||||||
|
|
||||||
unknown := parseInfoString(info, "nonexistent")
|
|
||||||
if unknown != "" {
|
|
||||||
t.Errorf("expected empty for unknown key, got '%s'", unknown)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseKeyspaceKeys(t *testing.T) {
|
|
||||||
info := "# Keyspace\ndb0:keys=123,expires=45,avg_ttl=6789\n"
|
|
||||||
|
|
||||||
keys := parseKeyspaceKeys(info)
|
|
||||||
if keys != 123 {
|
|
||||||
t.Errorf("expected 123, got %d", keys)
|
|
||||||
}
|
|
||||||
|
|
||||||
empty := parseKeyspaceKeys("# Keyspace\n")
|
|
||||||
if empty != 0 {
|
|
||||||
t.Errorf("expected 0 for empty keyspace, got %d", empty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSplitLines(t *testing.T) {
|
|
||||||
lines := splitLines("a\nb\nc")
|
|
||||||
if len(lines) != 3 {
|
|
||||||
t.Errorf("expected 3 lines, got %d", len(lines))
|
|
||||||
}
|
|
||||||
if lines[0] != "a" || lines[1] != "b" || lines[2] != "c" {
|
|
||||||
t.Errorf("unexpected lines: %v", lines)
|
|
||||||
}
|
|
||||||
|
|
||||||
single := splitLines("hello")
|
|
||||||
if len(single) != 1 || single[0] != "hello" {
|
|
||||||
t.Errorf("single line: %v", single)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,312 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
Server ServerConfig `yaml:"server"`
|
|
||||||
Database DatabaseConfig `yaml:"database"`
|
|
||||||
Cache CacheConfig `yaml:"cache"`
|
|
||||||
Pricing PricingLookupConfig `yaml:"pricing_lookup"`
|
|
||||||
CircuitBreaker CircuitBreakerConfig `yaml:"circuit_breaker"`
|
|
||||||
Retry RetryConfig `yaml:"retry"`
|
|
||||||
Debug DebugConfig `yaml:"debug"`
|
|
||||||
CORS CORSConfig `yaml:"cors"`
|
|
||||||
Dedup DedupConfig `yaml:"dedup"`
|
|
||||||
Webhooks []WebhookConfig `yaml:"webhooks"`
|
|
||||||
Providers []ProviderConfig `yaml:"providers"`
|
|
||||||
Models []ModelConfig `yaml:"models"`
|
|
||||||
Tokens []TokenConfig `yaml:"tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DedupConfig struct {
|
|
||||||
Enabled bool `yaml:"enabled"`
|
|
||||||
Window time.Duration `yaml:"window"` // max time to wait for dedup result
|
|
||||||
}
|
|
||||||
|
|
||||||
type WebhookConfig struct {
|
|
||||||
URL string `yaml:"url"`
|
|
||||||
Events []string `yaml:"events"` // event types to send
|
|
||||||
Secret string `yaml:"secret"` // optional HMAC secret
|
|
||||||
}
|
|
||||||
|
|
||||||
type PricingLookupConfig struct {
|
|
||||||
URL string `yaml:"url"`
|
|
||||||
RefreshInterval time.Duration `yaml:"refresh_interval"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DefaultAdminConfig struct {
|
|
||||||
Username string `yaml:"username"`
|
|
||||||
Password string `yaml:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenConfig struct {
|
|
||||||
Name string `yaml:"name"`
|
|
||||||
Key string `yaml:"key"`
|
|
||||||
RateLimitRPM int `yaml:"rate_limit_rpm"` // 0 = unlimited
|
|
||||||
DailyBudgetUSD float64 `yaml:"daily_budget_usd"` // 0 = unlimited
|
|
||||||
MonthlyBudgetUSD float64 `yaml:"monthly_budget_usd"` // 0 = unlimited
|
|
||||||
MaxConcurrent int `yaml:"max_concurrent"` // 0 = unlimited
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServerConfig struct {
|
|
||||||
Listen string `yaml:"listen"`
|
|
||||||
RequestTimeout time.Duration `yaml:"request_timeout"`
|
|
||||||
StreamingTimeout time.Duration `yaml:"streaming_timeout"`
|
|
||||||
MaxRequestBodyMB int `yaml:"max_request_body_mb"`
|
|
||||||
SessionSecret string `yaml:"session_secret"`
|
|
||||||
DefaultAdmin DefaultAdminConfig `yaml:"default_admin"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CircuitBreakerConfig struct {
|
|
||||||
Enabled bool `yaml:"enabled"`
|
|
||||||
ErrorThreshold float64 `yaml:"error_threshold"`
|
|
||||||
MinRequests int `yaml:"min_requests"`
|
|
||||||
CooldownDuration time.Duration `yaml:"cooldown_duration"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type RetryConfig struct {
|
|
||||||
InitialBackoff time.Duration `yaml:"initial_backoff"`
|
|
||||||
MaxBackoff time.Duration `yaml:"max_backoff"`
|
|
||||||
Multiplier float64 `yaml:"multiplier"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DebugConfig struct {
|
|
||||||
Enabled bool `yaml:"enabled"`
|
|
||||||
MaxBodyBytes int `yaml:"max_body_bytes"` // 0 = unlimited (save full bodies)
|
|
||||||
RetentionDays int `yaml:"retention_days"`
|
|
||||||
DataDir string `yaml:"data_dir"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CORSConfig struct {
|
|
||||||
Enabled bool `yaml:"enabled"`
|
|
||||||
AllowedOrigins []string `yaml:"allowed_origins"`
|
|
||||||
AllowedMethods []string `yaml:"allowed_methods"`
|
|
||||||
AllowedHeaders []string `yaml:"allowed_headers"`
|
|
||||||
MaxAge int `yaml:"max_age"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DatabaseConfig struct {
|
|
||||||
Path string `yaml:"path"`
|
|
||||||
RetentionDays int `yaml:"retention_days"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CacheConfig struct {
|
|
||||||
Enabled bool `yaml:"enabled"`
|
|
||||||
Address string `yaml:"address"`
|
|
||||||
TTL int `yaml:"ttl"` // seconds
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProviderConfig struct {
|
|
||||||
Name string `yaml:"name"`
|
|
||||||
BaseURL string `yaml:"base_url"`
|
|
||||||
APIKey string `yaml:"api_key"`
|
|
||||||
Priority int `yaml:"priority"`
|
|
||||||
Timeout time.Duration `yaml:"timeout"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ModelConfig struct {
|
|
||||||
Name string `yaml:"name"`
|
|
||||||
Aliases []string `yaml:"aliases"`
|
|
||||||
Routes []RouteConfig `yaml:"routes"`
|
|
||||||
LoadBalancing string `yaml:"load_balancing"` // first, round-robin, random, least-cost
|
|
||||||
RequestTimeout time.Duration `yaml:"request_timeout"` // per-model override; 0 = use server default
|
|
||||||
StreamingTimeout time.Duration `yaml:"streaming_timeout"` // per-model override; 0 = use server default
|
|
||||||
}
|
|
||||||
|
|
||||||
type RouteConfig struct {
|
|
||||||
Provider string `yaml:"provider"`
|
|
||||||
Model string `yaml:"model"`
|
|
||||||
Pricing PricingConfig `yaml:"pricing"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type PricingConfig struct {
|
|
||||||
Input float64 `yaml:"input"` // cost per 1M tokens
|
|
||||||
Output float64 `yaml:"output"` // cost per 1M tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
func Load(path string) (*Config, error) {
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("reading config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Expand environment variables
|
|
||||||
expanded := os.ExpandEnv(string(data))
|
|
||||||
|
|
||||||
var cfg Config
|
|
||||||
if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil {
|
|
||||||
return nil, fmt.Errorf("parsing config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cfg.Validate(); err != nil {
|
|
||||||
return nil, fmt.Errorf("validating config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &cfg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate checks the config for correctness and applies defaults.
|
|
||||||
func (c *Config) Validate() error {
|
|
||||||
if c.Server.Listen == "" {
|
|
||||||
c.Server.Listen = "0.0.0.0:3000"
|
|
||||||
}
|
|
||||||
if c.Server.RequestTimeout == 0 {
|
|
||||||
c.Server.RequestTimeout = 300 * time.Second
|
|
||||||
}
|
|
||||||
if c.Server.MaxRequestBodyMB == 0 {
|
|
||||||
c.Server.MaxRequestBodyMB = 10
|
|
||||||
}
|
|
||||||
if c.Server.SessionSecret == "" {
|
|
||||||
b := make([]byte, 32)
|
|
||||||
rand.Read(b)
|
|
||||||
c.Server.SessionSecret = hex.EncodeToString(b)
|
|
||||||
log.Println("WARNING: no session_secret configured, generated random one (sessions won't survive restart)")
|
|
||||||
}
|
|
||||||
if c.Database.Path == "" {
|
|
||||||
c.Database.Path = "gateway.db"
|
|
||||||
}
|
|
||||||
if c.Database.RetentionDays == 0 {
|
|
||||||
c.Database.RetentionDays = 90
|
|
||||||
}
|
|
||||||
if c.Pricing.RefreshInterval == 0 {
|
|
||||||
c.Pricing.RefreshInterval = 6 * time.Hour
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server defaults
|
|
||||||
if c.Server.StreamingTimeout == 0 {
|
|
||||||
c.Server.StreamingTimeout = 5 * time.Minute
|
|
||||||
}
|
|
||||||
|
|
||||||
// Circuit breaker defaults
|
|
||||||
if c.CircuitBreaker.ErrorThreshold == 0 {
|
|
||||||
c.CircuitBreaker.ErrorThreshold = 0.5
|
|
||||||
}
|
|
||||||
if c.CircuitBreaker.MinRequests == 0 {
|
|
||||||
c.CircuitBreaker.MinRequests = 5
|
|
||||||
}
|
|
||||||
if c.CircuitBreaker.CooldownDuration == 0 {
|
|
||||||
c.CircuitBreaker.CooldownDuration = 30 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retry defaults
|
|
||||||
if c.Retry.InitialBackoff == 0 {
|
|
||||||
c.Retry.InitialBackoff = 100 * time.Millisecond
|
|
||||||
}
|
|
||||||
if c.Retry.MaxBackoff == 0 {
|
|
||||||
c.Retry.MaxBackoff = 5 * time.Second
|
|
||||||
}
|
|
||||||
if c.Retry.Multiplier == 0 {
|
|
||||||
c.Retry.Multiplier = 2.0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug defaults
|
|
||||||
if c.Debug.RetentionDays == 0 {
|
|
||||||
c.Debug.RetentionDays = 90
|
|
||||||
}
|
|
||||||
|
|
||||||
// CORS defaults
|
|
||||||
if c.CORS.MaxAge == 0 {
|
|
||||||
c.CORS.MaxAge = 300
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dedup defaults
|
|
||||||
if c.Dedup.Window == 0 {
|
|
||||||
c.Dedup.Window = 30 * time.Second
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(c.Providers) == 0 {
|
|
||||||
return fmt.Errorf("at least one provider is required")
|
|
||||||
}
|
|
||||||
providerNames := make(map[string]bool)
|
|
||||||
for i, p := range c.Providers {
|
|
||||||
if p.Name == "" || p.BaseURL == "" || p.APIKey == "" {
|
|
||||||
return fmt.Errorf("provider %d: name, base_url, and api_key are required", i)
|
|
||||||
}
|
|
||||||
if providerNames[p.Name] {
|
|
||||||
return fmt.Errorf("duplicate provider name: %s", p.Name)
|
|
||||||
}
|
|
||||||
providerNames[p.Name] = true
|
|
||||||
if c.Providers[i].Timeout == 0 {
|
|
||||||
c.Providers[i].Timeout = 120 * time.Second
|
|
||||||
}
|
|
||||||
if c.Providers[i].Priority == 0 {
|
|
||||||
c.Providers[i].Priority = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(c.Models) == 0 {
|
|
||||||
return fmt.Errorf("at least one model is required")
|
|
||||||
}
|
|
||||||
modelNames := make(map[string]bool)
|
|
||||||
for i, m := range c.Models {
|
|
||||||
if m.Name == "" {
|
|
||||||
return fmt.Errorf("model %d: name is required", i)
|
|
||||||
}
|
|
||||||
if modelNames[m.Name] {
|
|
||||||
return fmt.Errorf("duplicate model name: %s", m.Name)
|
|
||||||
}
|
|
||||||
modelNames[m.Name] = true
|
|
||||||
for _, alias := range m.Aliases {
|
|
||||||
if modelNames[alias] {
|
|
||||||
return fmt.Errorf("model alias %s conflicts with existing model or alias", alias)
|
|
||||||
}
|
|
||||||
modelNames[alias] = true
|
|
||||||
}
|
|
||||||
if len(m.Routes) == 0 {
|
|
||||||
return fmt.Errorf("model %s: at least one route is required", m.Name)
|
|
||||||
}
|
|
||||||
for j, r := range m.Routes {
|
|
||||||
if r.Provider == "" || r.Model == "" {
|
|
||||||
return fmt.Errorf("model %s route %d: provider and model are required", m.Name, j)
|
|
||||||
}
|
|
||||||
if !providerNames[r.Provider] {
|
|
||||||
return fmt.Errorf("model %s route %d: unknown provider %s", m.Name, j, r.Provider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate tokens (optional section)
|
|
||||||
for i, t := range c.Tokens {
|
|
||||||
if t.Key == "" {
|
|
||||||
log.Printf("WARNING: token %d (%s) has empty key, skipping", i, t.Name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if t.Name == "" {
|
|
||||||
c.Tokens[i].Name = fmt.Sprintf("token-%d", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateBytes parses raw YAML and returns a list of validation errors.
|
|
||||||
func ValidateBytes(data []byte) []string {
|
|
||||||
expanded := os.ExpandEnv(string(data))
|
|
||||||
var cfg Config
|
|
||||||
if err := yaml.Unmarshal([]byte(expanded), &cfg); err != nil {
|
|
||||||
return []string{"parse error: " + err.Error()}
|
|
||||||
}
|
|
||||||
if err := cfg.Validate(); err != nil {
|
|
||||||
return []string{err.Error()}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProviderByName returns the provider config by name.
|
|
||||||
func (c *Config) ProviderByName(name string) *ProviderConfig {
|
|
||||||
for i := range c.Providers {
|
|
||||||
if c.Providers[i].Name == name {
|
|
||||||
return &c.Providers[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,737 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// writeConfigFile creates a temporary YAML config file and returns its path.
|
|
||||||
func writeConfigFile(t *testing.T, content string) string {
|
|
||||||
t.Helper()
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), "config-*.yaml")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("creating temp file: %v", err)
|
|
||||||
}
|
|
||||||
if _, err := f.WriteString(content); err != nil {
|
|
||||||
f.Close()
|
|
||||||
t.Fatalf("writing temp file: %v", err)
|
|
||||||
}
|
|
||||||
f.Close()
|
|
||||||
return f.Name()
|
|
||||||
}
|
|
||||||
|
|
||||||
// minimalValidConfig returns a minimal valid YAML config string.
|
|
||||||
func minimalValidConfig() string {
|
|
||||||
return `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-test-key
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
`
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoad_ValidConfig(t *testing.T) {
|
|
||||||
path := writeConfigFile(t, `
|
|
||||||
server:
|
|
||||||
listen: "127.0.0.1:8080"
|
|
||||||
request_timeout: 60s
|
|
||||||
streaming_timeout: 120s
|
|
||||||
max_request_body_mb: 5
|
|
||||||
session_secret: "test-secret-1234567890abcdef1234567890abcdef"
|
|
||||||
|
|
||||||
database:
|
|
||||||
path: "/tmp/test.db"
|
|
||||||
retention_days: 30
|
|
||||||
|
|
||||||
pricing_lookup:
|
|
||||||
url: "https://pricing.example.com"
|
|
||||||
refresh_interval: 1h
|
|
||||||
|
|
||||||
circuit_breaker:
|
|
||||||
enabled: true
|
|
||||||
error_threshold: 0.3
|
|
||||||
min_requests: 10
|
|
||||||
cooldown_duration: 60s
|
|
||||||
|
|
||||||
retry:
|
|
||||||
initial_backoff: 200ms
|
|
||||||
max_backoff: 10s
|
|
||||||
multiplier: 3.0
|
|
||||||
|
|
||||||
debug:
|
|
||||||
enabled: true
|
|
||||||
max_body_bytes: 65536
|
|
||||||
retention_days: 60
|
|
||||||
|
|
||||||
cors:
|
|
||||||
enabled: true
|
|
||||||
allowed_origins:
|
|
||||||
- "https://example.com"
|
|
||||||
allowed_methods:
|
|
||||||
- GET
|
|
||||||
- POST
|
|
||||||
allowed_headers:
|
|
||||||
- Authorization
|
|
||||||
max_age: 600
|
|
||||||
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-test-key
|
|
||||||
priority: 2
|
|
||||||
timeout: 60s
|
|
||||||
- name: anthropic
|
|
||||||
base_url: https://api.anthropic.com/v1
|
|
||||||
api_key: sk-ant-test
|
|
||||||
priority: 1
|
|
||||||
timeout: 30s
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
aliases:
|
|
||||||
- gpt4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
pricing:
|
|
||||||
input: 30.0
|
|
||||||
output: 60.0
|
|
||||||
load_balancing: first
|
|
||||||
- name: claude-3
|
|
||||||
routes:
|
|
||||||
- provider: anthropic
|
|
||||||
model: claude-3-opus-20240229
|
|
||||||
|
|
||||||
tokens:
|
|
||||||
- name: test-token
|
|
||||||
key: tok-abc123
|
|
||||||
rate_limit_rpm: 100
|
|
||||||
daily_budget_usd: 10.0
|
|
||||||
max_concurrent: 5
|
|
||||||
`)
|
|
||||||
|
|
||||||
cfg, err := Load(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Load() returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Server
|
|
||||||
if cfg.Server.Listen != "127.0.0.1:8080" {
|
|
||||||
t.Errorf("Listen = %q, want %q", cfg.Server.Listen, "127.0.0.1:8080")
|
|
||||||
}
|
|
||||||
if cfg.Server.RequestTimeout != 60*time.Second {
|
|
||||||
t.Errorf("RequestTimeout = %v, want %v", cfg.Server.RequestTimeout, 60*time.Second)
|
|
||||||
}
|
|
||||||
if cfg.Server.StreamingTimeout != 120*time.Second {
|
|
||||||
t.Errorf("StreamingTimeout = %v, want %v", cfg.Server.StreamingTimeout, 120*time.Second)
|
|
||||||
}
|
|
||||||
if cfg.Server.MaxRequestBodyMB != 5 {
|
|
||||||
t.Errorf("MaxRequestBodyMB = %d, want %d", cfg.Server.MaxRequestBodyMB, 5)
|
|
||||||
}
|
|
||||||
if cfg.Server.SessionSecret != "test-secret-1234567890abcdef1234567890abcdef" {
|
|
||||||
t.Errorf("SessionSecret = %q, want %q", cfg.Server.SessionSecret, "test-secret-1234567890abcdef1234567890abcdef")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Database
|
|
||||||
if cfg.Database.Path != "/tmp/test.db" {
|
|
||||||
t.Errorf("Database.Path = %q, want %q", cfg.Database.Path, "/tmp/test.db")
|
|
||||||
}
|
|
||||||
if cfg.Database.RetentionDays != 30 {
|
|
||||||
t.Errorf("Database.RetentionDays = %d, want %d", cfg.Database.RetentionDays, 30)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pricing
|
|
||||||
if cfg.Pricing.URL != "https://pricing.example.com" {
|
|
||||||
t.Errorf("Pricing.URL = %q, want %q", cfg.Pricing.URL, "https://pricing.example.com")
|
|
||||||
}
|
|
||||||
if cfg.Pricing.RefreshInterval != 1*time.Hour {
|
|
||||||
t.Errorf("Pricing.RefreshInterval = %v, want %v", cfg.Pricing.RefreshInterval, 1*time.Hour)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Circuit breaker
|
|
||||||
if !cfg.CircuitBreaker.Enabled {
|
|
||||||
t.Error("CircuitBreaker.Enabled = false, want true")
|
|
||||||
}
|
|
||||||
if cfg.CircuitBreaker.ErrorThreshold != 0.3 {
|
|
||||||
t.Errorf("CircuitBreaker.ErrorThreshold = %v, want %v", cfg.CircuitBreaker.ErrorThreshold, 0.3)
|
|
||||||
}
|
|
||||||
if cfg.CircuitBreaker.MinRequests != 10 {
|
|
||||||
t.Errorf("CircuitBreaker.MinRequests = %d, want %d", cfg.CircuitBreaker.MinRequests, 10)
|
|
||||||
}
|
|
||||||
if cfg.CircuitBreaker.CooldownDuration != 60*time.Second {
|
|
||||||
t.Errorf("CircuitBreaker.CooldownDuration = %v, want %v", cfg.CircuitBreaker.CooldownDuration, 60*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retry
|
|
||||||
if cfg.Retry.InitialBackoff != 200*time.Millisecond {
|
|
||||||
t.Errorf("Retry.InitialBackoff = %v, want %v", cfg.Retry.InitialBackoff, 200*time.Millisecond)
|
|
||||||
}
|
|
||||||
if cfg.Retry.MaxBackoff != 10*time.Second {
|
|
||||||
t.Errorf("Retry.MaxBackoff = %v, want %v", cfg.Retry.MaxBackoff, 10*time.Second)
|
|
||||||
}
|
|
||||||
if cfg.Retry.Multiplier != 3.0 {
|
|
||||||
t.Errorf("Retry.Multiplier = %v, want %v", cfg.Retry.Multiplier, 3.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug
|
|
||||||
if !cfg.Debug.Enabled {
|
|
||||||
t.Error("Debug.Enabled = false, want true")
|
|
||||||
}
|
|
||||||
if cfg.Debug.MaxBodyBytes != 65536 {
|
|
||||||
t.Errorf("Debug.MaxBodyBytes = %d, want %d", cfg.Debug.MaxBodyBytes, 65536)
|
|
||||||
}
|
|
||||||
if cfg.Debug.RetentionDays != 60 {
|
|
||||||
t.Errorf("Debug.RetentionDays = %d, want %d", cfg.Debug.RetentionDays, 60)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CORS
|
|
||||||
if !cfg.CORS.Enabled {
|
|
||||||
t.Error("CORS.Enabled = false, want true")
|
|
||||||
}
|
|
||||||
if cfg.CORS.MaxAge != 600 {
|
|
||||||
t.Errorf("CORS.MaxAge = %d, want %d", cfg.CORS.MaxAge, 600)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Providers
|
|
||||||
if len(cfg.Providers) != 2 {
|
|
||||||
t.Fatalf("len(Providers) = %d, want 2", len(cfg.Providers))
|
|
||||||
}
|
|
||||||
if cfg.Providers[0].Name != "openai" {
|
|
||||||
t.Errorf("Providers[0].Name = %q, want %q", cfg.Providers[0].Name, "openai")
|
|
||||||
}
|
|
||||||
if cfg.Providers[0].Timeout != 60*time.Second {
|
|
||||||
t.Errorf("Providers[0].Timeout = %v, want %v", cfg.Providers[0].Timeout, 60*time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Models
|
|
||||||
if len(cfg.Models) != 2 {
|
|
||||||
t.Fatalf("len(Models) = %d, want 2", len(cfg.Models))
|
|
||||||
}
|
|
||||||
if cfg.Models[0].LoadBalancing != "first" {
|
|
||||||
t.Errorf("Models[0].LoadBalancing = %q, want %q", cfg.Models[0].LoadBalancing, "first")
|
|
||||||
}
|
|
||||||
if len(cfg.Models[0].Aliases) != 1 || cfg.Models[0].Aliases[0] != "gpt4" {
|
|
||||||
t.Errorf("Models[0].Aliases = %v, want [gpt4]", cfg.Models[0].Aliases)
|
|
||||||
}
|
|
||||||
if cfg.Models[0].Routes[0].Pricing.Input != 30.0 {
|
|
||||||
t.Errorf("Models[0].Routes[0].Pricing.Input = %v, want 30.0", cfg.Models[0].Routes[0].Pricing.Input)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tokens
|
|
||||||
if len(cfg.Tokens) != 1 {
|
|
||||||
t.Fatalf("len(Tokens) = %d, want 1", len(cfg.Tokens))
|
|
||||||
}
|
|
||||||
if cfg.Tokens[0].Name != "test-token" {
|
|
||||||
t.Errorf("Tokens[0].Name = %q, want %q", cfg.Tokens[0].Name, "test-token")
|
|
||||||
}
|
|
||||||
if cfg.Tokens[0].RateLimitRPM != 100 {
|
|
||||||
t.Errorf("Tokens[0].RateLimitRPM = %d, want 100", cfg.Tokens[0].RateLimitRPM)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_Defaults(t *testing.T) {
|
|
||||||
path := writeConfigFile(t, minimalValidConfig())
|
|
||||||
cfg, err := Load(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Load() returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
got any
|
|
||||||
want any
|
|
||||||
}{
|
|
||||||
// Server defaults
|
|
||||||
{"Server.Listen", cfg.Server.Listen, "0.0.0.0:3000"},
|
|
||||||
{"Server.RequestTimeout", cfg.Server.RequestTimeout, 300 * time.Second},
|
|
||||||
{"Server.StreamingTimeout", cfg.Server.StreamingTimeout, 5 * time.Minute},
|
|
||||||
{"Server.MaxRequestBodyMB", cfg.Server.MaxRequestBodyMB, 10},
|
|
||||||
|
|
||||||
// Database defaults
|
|
||||||
{"Database.Path", cfg.Database.Path, "gateway.db"},
|
|
||||||
{"Database.RetentionDays", cfg.Database.RetentionDays, 90},
|
|
||||||
|
|
||||||
// Pricing defaults
|
|
||||||
{"Pricing.RefreshInterval", cfg.Pricing.RefreshInterval, 6 * time.Hour},
|
|
||||||
|
|
||||||
// Circuit breaker defaults
|
|
||||||
{"CircuitBreaker.ErrorThreshold", cfg.CircuitBreaker.ErrorThreshold, 0.5},
|
|
||||||
{"CircuitBreaker.MinRequests", cfg.CircuitBreaker.MinRequests, 5},
|
|
||||||
{"CircuitBreaker.CooldownDuration", cfg.CircuitBreaker.CooldownDuration, 30 * time.Second},
|
|
||||||
|
|
||||||
// Retry defaults
|
|
||||||
{"Retry.InitialBackoff", cfg.Retry.InitialBackoff, 100 * time.Millisecond},
|
|
||||||
{"Retry.MaxBackoff", cfg.Retry.MaxBackoff, 5 * time.Second},
|
|
||||||
{"Retry.Multiplier", cfg.Retry.Multiplier, 2.0},
|
|
||||||
|
|
||||||
// Debug defaults
|
|
||||||
{"Debug.MaxBodyBytes", cfg.Debug.MaxBodyBytes, 0},
|
|
||||||
{"Debug.RetentionDays", cfg.Debug.RetentionDays, 90},
|
|
||||||
|
|
||||||
// CORS defaults
|
|
||||||
{"CORS.MaxAge", cfg.CORS.MaxAge, 300},
|
|
||||||
|
|
||||||
// Provider defaults
|
|
||||||
{"Providers[0].Timeout", cfg.Providers[0].Timeout, 120 * time.Second},
|
|
||||||
{"Providers[0].Priority", cfg.Providers[0].Priority, 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Compare using formatted strings to handle different numeric types
|
|
||||||
gotStr := formatValue(tt.got)
|
|
||||||
wantStr := formatValue(tt.want)
|
|
||||||
if gotStr != wantStr {
|
|
||||||
t.Errorf("%s = %v, want %v", tt.name, tt.got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SessionSecret should be auto-generated (non-empty, 64 hex chars)
|
|
||||||
if cfg.Server.SessionSecret == "" {
|
|
||||||
t.Error("SessionSecret should be auto-generated when empty")
|
|
||||||
}
|
|
||||||
if len(cfg.Server.SessionSecret) != 64 {
|
|
||||||
t.Errorf("SessionSecret length = %d, want 64 hex chars", len(cfg.Server.SessionSecret))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatValue(v any) string {
|
|
||||||
switch val := v.(type) {
|
|
||||||
case time.Duration:
|
|
||||||
return val.String()
|
|
||||||
case float64:
|
|
||||||
return fmt.Sprintf("%g", val)
|
|
||||||
case int:
|
|
||||||
return fmt.Sprintf("%d", val)
|
|
||||||
case string:
|
|
||||||
return val
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("%v", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoad_FileNotFound(t *testing.T) {
|
|
||||||
_, err := Load(filepath.Join(t.TempDir(), "nonexistent.yaml"))
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Load() should return error for nonexistent file")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoad_InvalidYAML(t *testing.T) {
|
|
||||||
path := writeConfigFile(t, `{{{invalid yaml`)
|
|
||||||
_, err := Load(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Load() should return error for invalid YAML")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_DuplicateProviderNames(t *testing.T) {
|
|
||||||
path := writeConfigFile(t, `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key1
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v2
|
|
||||||
api_key: sk-key2
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
`)
|
|
||||||
|
|
||||||
_, err := Load(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Load() should return error for duplicate provider names")
|
|
||||||
}
|
|
||||||
wantSubstr := "duplicate provider name: openai"
|
|
||||||
if !strings.Contains(err.Error(), wantSubstr) {
|
|
||||||
t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_DuplicateModelNames(t *testing.T) {
|
|
||||||
path := writeConfigFile(t, `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key1
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4-turbo
|
|
||||||
`)
|
|
||||||
|
|
||||||
_, err := Load(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Load() should return error for duplicate model names")
|
|
||||||
}
|
|
||||||
wantSubstr := "duplicate model name: gpt-4"
|
|
||||||
if !strings.Contains(err.Error(), wantSubstr) {
|
|
||||||
t.Errorf("error = %q, want to contain %q", err.Error(), wantSubstr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_AliasConflicts(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
config string
|
|
||||||
wantErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "alias conflicts with model name",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key1
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
- name: claude-3
|
|
||||||
aliases:
|
|
||||||
- gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: claude-3
|
|
||||||
`,
|
|
||||||
wantErr: "model alias gpt-4 conflicts with existing model or alias",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "alias conflicts with another alias",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key1
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
aliases:
|
|
||||||
- fast-model
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
- name: claude-3
|
|
||||||
aliases:
|
|
||||||
- fast-model
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: claude-3
|
|
||||||
`,
|
|
||||||
wantErr: "model alias fast-model conflicts with existing model or alias",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
path := writeConfigFile(t, tt.config)
|
|
||||||
_, err := Load(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Load() should return error for alias conflicts")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
|
||||||
t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidate_MissingRequiredFields(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
config string
|
|
||||||
wantErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no providers",
|
|
||||||
config: `models: [{name: test, routes: [{provider: x, model: y}]}]`,
|
|
||||||
wantErr: "at least one provider is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no models",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key
|
|
||||||
`,
|
|
||||||
wantErr: "at least one model is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "provider missing name",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
`,
|
|
||||||
wantErr: "provider 0: name, base_url, and api_key are required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "provider missing base_url",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
api_key: sk-key
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
`,
|
|
||||||
wantErr: "provider 0: name, base_url, and api_key are required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "provider missing api_key",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
`,
|
|
||||||
wantErr: "provider 0: name, base_url, and api_key are required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model missing name",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key
|
|
||||||
|
|
||||||
models:
|
|
||||||
- routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
`,
|
|
||||||
wantErr: "model 0: name is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model missing routes",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
`,
|
|
||||||
wantErr: "model gpt-4: at least one route is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "route missing provider",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- model: gpt-4
|
|
||||||
`,
|
|
||||||
wantErr: "model gpt-4 route 0: provider and model are required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "route missing model",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
`,
|
|
||||||
wantErr: "model gpt-4 route 0: provider and model are required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "route references unknown provider",
|
|
||||||
config: `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-key
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: anthropic
|
|
||||||
model: gpt-4
|
|
||||||
`,
|
|
||||||
wantErr: "model gpt-4 route 0: unknown provider anthropic",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
path := writeConfigFile(t, tt.config)
|
|
||||||
_, err := Load(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("Load() should return error, want %q", tt.wantErr)
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
|
||||||
t.Errorf("error = %q, want to contain %q", err.Error(), tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProviderByName(t *testing.T) {
|
|
||||||
path := writeConfigFile(t, `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: sk-openai
|
|
||||||
- name: anthropic
|
|
||||||
base_url: https://api.anthropic.com/v1
|
|
||||||
api_key: sk-anthropic
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
`)
|
|
||||||
|
|
||||||
cfg, err := Load(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Load() returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
lookup string
|
|
||||||
wantNil bool
|
|
||||||
wantName string
|
|
||||||
}{
|
|
||||||
{"existing provider openai", "openai", false, "openai"},
|
|
||||||
{"existing provider anthropic", "anthropic", false, "anthropic"},
|
|
||||||
{"nonexistent provider", "google", true, ""},
|
|
||||||
{"empty name", "", true, ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
p := cfg.ProviderByName(tt.lookup)
|
|
||||||
if tt.wantNil {
|
|
||||||
if p != nil {
|
|
||||||
t.Errorf("ProviderByName(%q) = %v, want nil", tt.lookup, p)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if p == nil {
|
|
||||||
t.Fatalf("ProviderByName(%q) = nil, want provider", tt.lookup)
|
|
||||||
}
|
|
||||||
if p.Name != tt.wantName {
|
|
||||||
t.Errorf("ProviderByName(%q).Name = %q, want %q", tt.lookup, p.Name, tt.wantName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify returned pointer refers to the actual config entry
|
|
||||||
p := cfg.ProviderByName("openai")
|
|
||||||
if p.APIKey != "sk-openai" {
|
|
||||||
t.Errorf("ProviderByName(openai).APIKey = %q, want %q", p.APIKey, "sk-openai")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoad_EnvironmentVariableExpansion(t *testing.T) {
|
|
||||||
t.Setenv("TEST_API_KEY", "sk-from-env")
|
|
||||||
t.Setenv("TEST_BASE_URL", "https://env.example.com/v1")
|
|
||||||
t.Setenv("TEST_PROVIDER_NAME", "env-provider")
|
|
||||||
|
|
||||||
path := writeConfigFile(t, `
|
|
||||||
providers:
|
|
||||||
- name: $TEST_PROVIDER_NAME
|
|
||||||
base_url: ${TEST_BASE_URL}
|
|
||||||
api_key: ${TEST_API_KEY}
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: test-model
|
|
||||||
routes:
|
|
||||||
- provider: env-provider
|
|
||||||
model: gpt-4
|
|
||||||
`)
|
|
||||||
|
|
||||||
cfg, err := Load(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Load() returned error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Providers[0].Name != "env-provider" {
|
|
||||||
t.Errorf("Provider.Name = %q, want %q", cfg.Providers[0].Name, "env-provider")
|
|
||||||
}
|
|
||||||
if cfg.Providers[0].BaseURL != "https://env.example.com/v1" {
|
|
||||||
t.Errorf("Provider.BaseURL = %q, want %q", cfg.Providers[0].BaseURL, "https://env.example.com/v1")
|
|
||||||
}
|
|
||||||
if cfg.Providers[0].APIKey != "sk-from-env" {
|
|
||||||
t.Errorf("Provider.APIKey = %q, want %q", cfg.Providers[0].APIKey, "sk-from-env")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoad_UnsetEnvVarExpandsToEmpty(t *testing.T) {
|
|
||||||
// Ensure the variable is not set
|
|
||||||
t.Setenv("TEST_UNSET_VAR", "")
|
|
||||||
os.Unsetenv("TEST_UNSET_VAR")
|
|
||||||
|
|
||||||
path := writeConfigFile(t, `
|
|
||||||
providers:
|
|
||||||
- name: openai
|
|
||||||
base_url: https://api.openai.com/v1
|
|
||||||
api_key: ${TEST_UNSET_VAR}
|
|
||||||
|
|
||||||
models:
|
|
||||||
- name: gpt-4
|
|
||||||
routes:
|
|
||||||
- provider: openai
|
|
||||||
model: gpt-4
|
|
||||||
`)
|
|
||||||
|
|
||||||
_, err := Load(path)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Load() should return error when env var expands to empty required field")
|
|
||||||
}
|
|
||||||
// api_key will be empty, so validation should catch it
|
|
||||||
if !strings.Contains(err.Error(), "api_key are required") {
|
|
||||||
t.Errorf("error = %q, want to contain api_key validation message", err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WatchReload listens for SIGHUP and calls the callback with the new config.
|
|
||||||
func WatchReload(configPath string, callback func(*Config)) {
|
|
||||||
sighup := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sighup, syscall.SIGHUP)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for range sighup {
|
|
||||||
log.Println("SIGHUP received, reloading config...")
|
|
||||||
newCfg, err := Load(configPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ERROR: config reload failed: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
callback(newCfg)
|
|
||||||
log.Println("Config reloaded successfully")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
@ -1,775 +0,0 @@
|
||||||
package dashboard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"sort"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
|
||||||
|
|
||||||
"llm-gateway/internal/auth"
|
|
||||||
"llm-gateway/internal/cache"
|
|
||||||
"llm-gateway/internal/config"
|
|
||||||
"llm-gateway/internal/provider"
|
|
||||||
"llm-gateway/internal/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Exported types for template rendering and JSON API.
|
|
||||||
|
|
||||||
type Period struct {
|
|
||||||
Requests int `json:"requests"`
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
Errors int `json:"errors"`
|
|
||||||
CachedHits int `json:"cached_hits"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SummaryResult struct {
|
|
||||||
Today *Period `json:"today"`
|
|
||||||
Week *Period `json:"week"`
|
|
||||||
Month *Period `json:"month"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ModelStats struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Requests int `json:"requests"`
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
AvgLatencyMS float64 `json:"avg_latency_ms"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProviderStats struct {
|
|
||||||
Provider string `json:"provider"`
|
|
||||||
Requests int `json:"requests"`
|
|
||||||
Successes int `json:"successes"`
|
|
||||||
Errors int `json:"errors"`
|
|
||||||
AvgLatencyMS float64 `json:"avg_latency_ms"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenUsageStats struct {
|
|
||||||
TokenName string `json:"token_name"`
|
|
||||||
Requests int `json:"requests"`
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestLogEntry represents a single request log row.
|
|
||||||
type RequestLogEntry struct {
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
Timestamp int64 `json:"timestamp"`
|
|
||||||
TokenName string `json:"token_name"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Provider string `json:"provider"`
|
|
||||||
ProviderModel string `json:"provider_model"`
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
LatencyMS int64 `json:"latency_ms"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
ErrorMessage string `json:"error_message"`
|
|
||||||
Streaming bool `json:"streaming"`
|
|
||||||
Cached bool `json:"cached"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogsResult holds paginated logs.
|
|
||||||
type LogsResult struct {
|
|
||||||
Logs []RequestLogEntry `json:"logs"`
|
|
||||||
Page int `json:"page"`
|
|
||||||
TotalPages int `json:"total_pages"`
|
|
||||||
Total int `json:"total"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// LatencyResult holds latency percentiles.
|
|
||||||
type LatencyResult struct {
|
|
||||||
P50 float64 `json:"p50"`
|
|
||||||
P95 float64 `json:"p95"`
|
|
||||||
P99 float64 `json:"p99"`
|
|
||||||
Avg float64 `json:"avg"`
|
|
||||||
Min float64 `json:"min"`
|
|
||||||
Max float64 `json:"max"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CostBreakdownEntry holds cost data grouped by day and dimension.
|
|
||||||
type CostBreakdownEntry struct {
|
|
||||||
Day string `json:"day"`
|
|
||||||
GroupBy string `json:"group_by"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
Requests int `json:"requests"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type StatsAPI struct {
|
|
||||||
db *storage.DB
|
|
||||||
authStore *auth.Store
|
|
||||||
healthTracker *provider.HealthTracker
|
|
||||||
cache *cache.Cache
|
|
||||||
auditLogger *storage.AuditLogger
|
|
||||||
debugLogger *storage.DebugLogger
|
|
||||||
configPath string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStatsAPI(db *storage.DB, authStore *auth.Store) *StatsAPI {
|
|
||||||
return &StatsAPI{db: db, authStore: authStore}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHealthTracker sets the provider health tracker.
|
|
||||||
func (s *StatsAPI) SetHealthTracker(ht *provider.HealthTracker) {
|
|
||||||
s.healthTracker = ht
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCache sets the cache for stats.
|
|
||||||
func (s *StatsAPI) SetCache(c *cache.Cache) {
|
|
||||||
s.cache = c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAuditLogger sets the audit logger.
|
|
||||||
func (s *StatsAPI) SetAuditLogger(al *storage.AuditLogger) {
|
|
||||||
s.auditLogger = al
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDebugLogger sets the debug logger.
|
|
||||||
func (s *StatsAPI) SetDebugLogger(dl *storage.DebugLogger) {
|
|
||||||
s.debugLogger = dl
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetConfigPath sets the config file path for validation.
|
|
||||||
func (s *StatsAPI) SetConfigPath(path string) {
|
|
||||||
s.configPath = path
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenNamesForUser returns the token names that belong to the user.
|
|
||||||
// Admins get nil (no filter), non-admins get their token names.
|
|
||||||
func (s *StatsAPI) TokenNamesForUser(user *auth.User) []string {
|
|
||||||
if user == nil || user.IsAdmin {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
tokens, err := s.authStore.ListAPITokens(user.ID)
|
|
||||||
if err != nil {
|
|
||||||
return []string{"__none__"}
|
|
||||||
}
|
|
||||||
names := make([]string, len(tokens))
|
|
||||||
for i, t := range tokens {
|
|
||||||
names[i] = t.Name
|
|
||||||
}
|
|
||||||
if len(names) == 0 {
|
|
||||||
return []string{"__none__"}
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
// tokenNamesForUser returns token names from request context (for HTTP handlers).
|
|
||||||
func (s *StatsAPI) tokenNamesForUser(r *http.Request) []string {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
return s.TokenNamesForUser(user)
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildTokenFilter(tokenNames []string) (string, []any) {
|
|
||||||
if tokenNames == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
placeholders := ""
|
|
||||||
args := make([]any, len(tokenNames))
|
|
||||||
for i, n := range tokenNames {
|
|
||||||
if i > 0 {
|
|
||||||
placeholders += ","
|
|
||||||
}
|
|
||||||
placeholders += "?"
|
|
||||||
args[i] = n
|
|
||||||
}
|
|
||||||
return " AND token_name IN (" + placeholders + ")", args
|
|
||||||
}
|
|
||||||
|
|
||||||
// Data-fetching methods (used by both JSON handlers and template handlers).
|
|
||||||
|
|
||||||
func (s *StatsAPI) GetSummary(tokenNames []string) *SummaryResult {
|
|
||||||
now := time.Now()
|
|
||||||
todayStart := now.Truncate(24 * time.Hour).Unix()
|
|
||||||
weekStart := now.AddDate(0, 0, -7).Unix()
|
|
||||||
monthStart := now.AddDate(0, -1, 0).Unix()
|
|
||||||
|
|
||||||
tokenFilter, filterArgs := buildTokenFilter(tokenNames)
|
|
||||||
|
|
||||||
result := &SummaryResult{
|
|
||||||
Today: &Period{},
|
|
||||||
Week: &Period{},
|
|
||||||
Month: &Period{},
|
|
||||||
}
|
|
||||||
|
|
||||||
periods := map[string]struct {
|
|
||||||
since int64
|
|
||||||
period *Period
|
|
||||||
}{
|
|
||||||
"today": {todayStart, result.Today},
|
|
||||||
"week": {weekStart, result.Week},
|
|
||||||
"month": {monthStart, result.Month},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range periods {
|
|
||||||
args := append([]any{p.since}, filterArgs...)
|
|
||||||
row := s.db.QueryRow(`SELECT
|
|
||||||
COUNT(*),
|
|
||||||
COALESCE(SUM(input_tokens), 0),
|
|
||||||
COALESCE(SUM(output_tokens), 0),
|
|
||||||
COALESCE(SUM(cost_usd), 0),
|
|
||||||
COALESCE(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END), 0),
|
|
||||||
COALESCE(SUM(CASE WHEN cached = 1 THEN 1 ELSE 0 END), 0)
|
|
||||||
FROM request_logs WHERE timestamp >= ?`+tokenFilter, args...)
|
|
||||||
row.Scan(&p.period.Requests, &p.period.InputTokens, &p.period.OutputTokens, &p.period.CostUSD, &p.period.Errors, &p.period.CachedHits)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StatsAPI) GetModels(tokenNames []string) []ModelStats {
|
|
||||||
since := time.Now().AddDate(0, 0, -30).Unix()
|
|
||||||
tokenFilter, filterArgs := buildTokenFilter(tokenNames)
|
|
||||||
|
|
||||||
args := append([]any{since}, filterArgs...)
|
|
||||||
rows, err := s.db.Query(`SELECT
|
|
||||||
model,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
|
||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
|
||||||
COALESCE(SUM(cost_usd), 0) as cost,
|
|
||||||
COALESCE(AVG(latency_ms), 0) as avg_latency
|
|
||||||
FROM request_logs WHERE timestamp >= ?`+tokenFilter+`
|
|
||||||
GROUP BY model ORDER BY requests DESC`, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var results []ModelStats
|
|
||||||
for rows.Next() {
|
|
||||||
var m ModelStats
|
|
||||||
rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS)
|
|
||||||
results = append(results, m)
|
|
||||||
}
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StatsAPI) GetProviders(tokenNames []string) []ProviderStats {
|
|
||||||
since := time.Now().AddDate(0, 0, -30).Unix()
|
|
||||||
tokenFilter, filterArgs := buildTokenFilter(tokenNames)
|
|
||||||
|
|
||||||
args := append([]any{since}, filterArgs...)
|
|
||||||
rows, err := s.db.Query(`SELECT
|
|
||||||
provider,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(CASE WHEN status = 'success' THEN 1 ELSE 0 END), 0) as successes,
|
|
||||||
COALESCE(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END), 0) as errors,
|
|
||||||
COALESCE(AVG(latency_ms), 0) as avg_latency,
|
|
||||||
COALESCE(SUM(cost_usd), 0) as cost
|
|
||||||
FROM request_logs WHERE timestamp >= ?`+tokenFilter+`
|
|
||||||
GROUP BY provider ORDER BY requests DESC`, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var results []ProviderStats
|
|
||||||
for rows.Next() {
|
|
||||||
var p ProviderStats
|
|
||||||
rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD)
|
|
||||||
results = append(results, p)
|
|
||||||
}
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StatsAPI) GetTokenUsage(tokenNames []string) []TokenUsageStats {
|
|
||||||
since := time.Now().AddDate(0, 0, -30).Unix()
|
|
||||||
tokenFilter, filterArgs := buildTokenFilter(tokenNames)
|
|
||||||
|
|
||||||
args := append([]any{since}, filterArgs...)
|
|
||||||
rows, err := s.db.Query(`SELECT
|
|
||||||
token_name,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(input_tokens), 0) as input_tokens,
|
|
||||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
|
||||||
COALESCE(SUM(cost_usd), 0) as cost
|
|
||||||
FROM request_logs WHERE timestamp >= ?`+tokenFilter+`
|
|
||||||
GROUP BY token_name ORDER BY requests DESC`, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var results []TokenUsageStats
|
|
||||||
for rows.Next() {
|
|
||||||
var t TokenUsageStats
|
|
||||||
rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD)
|
|
||||||
results = append(results, t)
|
|
||||||
}
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLogs returns paginated request logs with filters.
|
|
||||||
func (s *StatsAPI) GetLogs(tokenNames []string, page int, model, token, status string) *LogsResult {
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
limit := 50
|
|
||||||
offset := (page - 1) * limit
|
|
||||||
|
|
||||||
tokenFilter, filterArgs := buildTokenFilter(tokenNames)
|
|
||||||
|
|
||||||
where := "WHERE 1=1" + tokenFilter
|
|
||||||
args := make([]any, 0)
|
|
||||||
args = append(args, filterArgs...)
|
|
||||||
|
|
||||||
if model != "" {
|
|
||||||
where += " AND model = ?"
|
|
||||||
args = append(args, model)
|
|
||||||
}
|
|
||||||
if token != "" {
|
|
||||||
where += " AND token_name = ?"
|
|
||||||
args = append(args, token)
|
|
||||||
}
|
|
||||||
if status != "" {
|
|
||||||
where += " AND status = ?"
|
|
||||||
args = append(args, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get total count
|
|
||||||
var total int
|
|
||||||
countArgs := make([]any, len(args))
|
|
||||||
copy(countArgs, args)
|
|
||||||
s.db.QueryRow("SELECT COUNT(*) FROM request_logs "+where, countArgs...).Scan(&total)
|
|
||||||
|
|
||||||
totalPages := (total + limit - 1) / limit
|
|
||||||
if totalPages < 1 {
|
|
||||||
totalPages = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get page
|
|
||||||
query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model,
|
|
||||||
input_tokens, output_tokens, cost_usd, latency_ms, status,
|
|
||||||
COALESCE(error_message, ''), streaming, cached
|
|
||||||
FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?`
|
|
||||||
args = append(args, limit, offset)
|
|
||||||
|
|
||||||
rows, err := s.db.Query(query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return &LogsResult{Logs: []RequestLogEntry{}, Page: page, TotalPages: totalPages, Total: total}
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var logs []RequestLogEntry
|
|
||||||
for rows.Next() {
|
|
||||||
var l RequestLogEntry
|
|
||||||
var streaming, cached int
|
|
||||||
rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
|
|
||||||
&l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status,
|
|
||||||
&l.ErrorMessage, &streaming, &cached)
|
|
||||||
l.Streaming = streaming == 1
|
|
||||||
l.Cached = cached == 1
|
|
||||||
logs = append(logs, l)
|
|
||||||
}
|
|
||||||
if logs == nil {
|
|
||||||
logs = []RequestLogEntry{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &LogsResult{
|
|
||||||
Logs: logs,
|
|
||||||
Page: page,
|
|
||||||
TotalPages: totalPages,
|
|
||||||
Total: total,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDistinctModels returns distinct model names from logs.
|
|
||||||
func (s *StatsAPI) GetDistinctModels() []string {
|
|
||||||
rows, err := s.db.Query("SELECT DISTINCT model FROM request_logs ORDER BY model")
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
var models []string
|
|
||||||
for rows.Next() {
|
|
||||||
var m string
|
|
||||||
rows.Scan(&m)
|
|
||||||
models = append(models, m)
|
|
||||||
}
|
|
||||||
return models
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDistinctTokens returns distinct token names from logs.
|
|
||||||
func (s *StatsAPI) GetDistinctTokens() []string {
|
|
||||||
rows, err := s.db.Query("SELECT DISTINCT token_name FROM request_logs ORDER BY token_name")
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
var tokens []string
|
|
||||||
for rows.Next() {
|
|
||||||
var t string
|
|
||||||
rows.Scan(&t)
|
|
||||||
tokens = append(tokens, t)
|
|
||||||
}
|
|
||||||
return tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLatency computes latency percentiles from request_logs.
|
|
||||||
func (s *StatsAPI) GetLatency(tokenNames []string, period, model, providerName string) *LatencyResult {
|
|
||||||
var since int64
|
|
||||||
switch period {
|
|
||||||
case "7d":
|
|
||||||
since = time.Now().AddDate(0, 0, -7).Unix()
|
|
||||||
case "30d":
|
|
||||||
since = time.Now().AddDate(0, -1, 0).Unix()
|
|
||||||
default:
|
|
||||||
since = time.Now().Add(-24 * time.Hour).Unix()
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenFilter, filterArgs := buildTokenFilter(tokenNames)
|
|
||||||
|
|
||||||
where := "WHERE timestamp >= ? AND status = 'success'" + tokenFilter
|
|
||||||
args := []any{since}
|
|
||||||
args = append(args, filterArgs...)
|
|
||||||
|
|
||||||
if model != "" {
|
|
||||||
where += " AND model = ?"
|
|
||||||
args = append(args, model)
|
|
||||||
}
|
|
||||||
if providerName != "" {
|
|
||||||
where += " AND provider = ?"
|
|
||||||
args = append(args, providerName)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := s.db.Query("SELECT latency_ms FROM request_logs "+where+" ORDER BY latency_ms", args...)
|
|
||||||
if err != nil {
|
|
||||||
return &LatencyResult{}
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var latencies []float64
|
|
||||||
for rows.Next() {
|
|
||||||
var l float64
|
|
||||||
rows.Scan(&l)
|
|
||||||
latencies = append(latencies, l)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(latencies) == 0 {
|
|
||||||
return &LatencyResult{}
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Float64s(latencies)
|
|
||||||
n := len(latencies)
|
|
||||||
var sum float64
|
|
||||||
for _, l := range latencies {
|
|
||||||
sum += l
|
|
||||||
}
|
|
||||||
|
|
||||||
return &LatencyResult{
|
|
||||||
P50: latencies[n*50/100],
|
|
||||||
P95: latencies[n*95/100],
|
|
||||||
P99: latencies[min(n*99/100, n-1)],
|
|
||||||
Avg: sum / float64(n),
|
|
||||||
Min: latencies[0],
|
|
||||||
Max: latencies[n-1],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCostBreakdown returns cost data grouped by day and dimension.
|
|
||||||
func (s *StatsAPI) GetCostBreakdown(tokenNames []string, period, groupBy string) []CostBreakdownEntry {
|
|
||||||
var since int64
|
|
||||||
switch period {
|
|
||||||
case "30d":
|
|
||||||
since = time.Now().AddDate(0, -1, 0).Unix()
|
|
||||||
case "7d":
|
|
||||||
since = time.Now().AddDate(0, 0, -7).Unix()
|
|
||||||
default:
|
|
||||||
since = time.Now().Add(-24 * time.Hour).Unix()
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenFilter, filterArgs := buildTokenFilter(tokenNames)
|
|
||||||
|
|
||||||
groupCol := "model"
|
|
||||||
if groupBy == "token" {
|
|
||||||
groupCol = "token_name"
|
|
||||||
} else if groupBy == "provider" {
|
|
||||||
groupCol = "provider"
|
|
||||||
}
|
|
||||||
|
|
||||||
args := []any{since}
|
|
||||||
args = append(args, filterArgs...)
|
|
||||||
|
|
||||||
query := `SELECT date(timestamp, 'unixepoch') as day, ` + groupCol + `,
|
|
||||||
COALESCE(SUM(cost_usd), 0), COUNT(*)
|
|
||||||
FROM request_logs WHERE timestamp >= ?` + tokenFilter + `
|
|
||||||
GROUP BY day, ` + groupCol + ` ORDER BY day, ` + groupCol
|
|
||||||
|
|
||||||
rows, err := s.db.Query(query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var results []CostBreakdownEntry
|
|
||||||
for rows.Next() {
|
|
||||||
var e CostBreakdownEntry
|
|
||||||
rows.Scan(&e.Day, &e.GroupBy, &e.CostUSD, &e.Requests)
|
|
||||||
results = append(results, e)
|
|
||||||
}
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
// JSON HTTP handlers (thin wrappers).
|
|
||||||
|
|
||||||
func (s *StatsAPI) Summary(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tokenNames := s.tokenNamesForUser(r)
|
|
||||||
result := s.GetSummary(tokenNames)
|
|
||||||
writeJSON(w, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StatsAPI) Models(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tokenNames := s.tokenNamesForUser(r)
|
|
||||||
results := s.GetModels(tokenNames)
|
|
||||||
writeJSON(w, results)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StatsAPI) Providers(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tokenNames := s.tokenNamesForUser(r)
|
|
||||||
results := s.GetProviders(tokenNames)
|
|
||||||
writeJSON(w, results)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StatsAPI) Tokens(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tokenNames := s.tokenNamesForUser(r)
|
|
||||||
results := s.GetTokenUsage(tokenNames)
|
|
||||||
writeJSON(w, results)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StatsAPI) Timeseries(w http.ResponseWriter, r *http.Request) {
|
|
||||||
period := r.URL.Query().Get("period")
|
|
||||||
var since int64
|
|
||||||
var groupFmt string
|
|
||||||
switch period {
|
|
||||||
case "7d":
|
|
||||||
since = time.Now().AddDate(0, 0, -7).Unix()
|
|
||||||
groupFmt = "%Y-%m-%d"
|
|
||||||
case "30d":
|
|
||||||
since = time.Now().AddDate(0, -1, 0).Unix()
|
|
||||||
groupFmt = "%Y-%m-%d"
|
|
||||||
default:
|
|
||||||
since = time.Now().Add(-24 * time.Hour).Unix()
|
|
||||||
groupFmt = "%Y-%m-%d %H:00"
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenNames := s.tokenNamesForUser(r)
|
|
||||||
tokenFilter, filterArgs := buildTokenFilter(tokenNames)
|
|
||||||
|
|
||||||
args := append([]any{since}, filterArgs...)
|
|
||||||
rows, err := s.db.Query(`SELECT
|
|
||||||
strftime('`+groupFmt+`', timestamp, 'unixepoch') as bucket,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(cost_usd), 0) as cost,
|
|
||||||
COALESCE(SUM(input_tokens + output_tokens), 0) as total_tokens
|
|
||||||
FROM request_logs WHERE timestamp >= ?`+tokenFilter+`
|
|
||||||
GROUP BY bucket ORDER BY bucket`, args...)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
type point struct {
|
|
||||||
Bucket string `json:"bucket"`
|
|
||||||
Requests int `json:"requests"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var results []point
|
|
||||||
for rows.Next() {
|
|
||||||
var p point
|
|
||||||
rows.Scan(&p.Bucket, &p.Requests, &p.CostUSD, &p.TotalTokens)
|
|
||||||
results = append(results, p)
|
|
||||||
}
|
|
||||||
writeJSON(w, results)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logs serves the paginated logs API.
|
|
||||||
func (s *StatsAPI) Logs(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tokenNames := s.tokenNamesForUser(r)
|
|
||||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
|
||||||
model := r.URL.Query().Get("model")
|
|
||||||
token := r.URL.Query().Get("token")
|
|
||||||
status := r.URL.Query().Get("status")
|
|
||||||
result := s.GetLogs(tokenNames, page, model, token, status)
|
|
||||||
writeJSON(w, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Latency serves latency percentiles API.
|
|
||||||
func (s *StatsAPI) Latency(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tokenNames := s.tokenNamesForUser(r)
|
|
||||||
period := r.URL.Query().Get("period")
|
|
||||||
model := r.URL.Query().Get("model")
|
|
||||||
providerName := r.URL.Query().Get("provider")
|
|
||||||
result := s.GetLatency(tokenNames, period, model, providerName)
|
|
||||||
writeJSON(w, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CostBreakdown serves cost breakdown API.
|
|
||||||
func (s *StatsAPI) CostBreakdown(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tokenNames := s.tokenNamesForUser(r)
|
|
||||||
period := r.URL.Query().Get("period")
|
|
||||||
groupBy := r.URL.Query().Get("group_by")
|
|
||||||
if groupBy == "" {
|
|
||||||
groupBy = "model"
|
|
||||||
}
|
|
||||||
result := s.GetCostBreakdown(tokenNames, period, groupBy)
|
|
||||||
writeJSON(w, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProviderHealthHandler serves provider health status API.
|
|
||||||
func (s *StatsAPI) ProviderHealthHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if s.healthTracker == nil {
|
|
||||||
writeJSON(w, []provider.ProviderHealth{})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, s.healthTracker.Status())
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheStats serves cache statistics API.
|
|
||||||
func (s *StatsAPI) CacheStats(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if s.cache == nil {
|
|
||||||
writeJSON(w, map[string]any{"enabled": false})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
stats := s.cache.Stats(r.Context())
|
|
||||||
writeJSON(w, stats)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuditLogs serves the audit log API (admin-only).
|
|
||||||
func (s *StatsAPI) AuditLogs(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if s.auditLogger == nil {
|
|
||||||
writeJSON(w, map[string]any{"entries": []any{}, "total": 0})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
|
||||||
action := r.URL.Query().Get("action")
|
|
||||||
since := time.Now().AddDate(0, 0, -30).Unix()
|
|
||||||
if sinceStr := r.URL.Query().Get("since"); sinceStr != "" {
|
|
||||||
if s, err := strconv.ParseInt(sinceStr, 10, 64); err == nil {
|
|
||||||
since = s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result := s.auditLogger.Query(since, action, page, 50)
|
|
||||||
writeJSON(w, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DebugToggle enables/disables debug logging at runtime.
|
|
||||||
func (s *StatsAPI) DebugToggle(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if s.debugLogger == nil {
|
|
||||||
writeJSON(w, map[string]any{"error": "debug logger not configured"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var req struct {
|
|
||||||
Enabled bool `json:"enabled"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
writeJSON(w, map[string]string{"error": "invalid JSON"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.debugLogger.SetEnabled(req.Enabled)
|
|
||||||
writeJSON(w, map[string]any{"enabled": s.debugLogger.IsEnabled()})
|
|
||||||
}
|
|
||||||
|
|
||||||
// DebugStatus returns whether debug logging is enabled.
|
|
||||||
func (s *StatsAPI) DebugStatus(w http.ResponseWriter, r *http.Request) {
|
|
||||||
enabled := false
|
|
||||||
if s.debugLogger != nil {
|
|
||||||
enabled = s.debugLogger.IsEnabled()
|
|
||||||
}
|
|
||||||
writeJSON(w, map[string]any{"enabled": enabled})
|
|
||||||
}
|
|
||||||
|
|
||||||
// DebugLogs serves paginated debug log entries.
|
|
||||||
func (s *StatsAPI) DebugLogs(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if s.debugLogger == nil {
|
|
||||||
writeJSON(w, map[string]any{"entries": []any{}, "total": 0})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
|
||||||
result := s.debugLogger.Query(page, 50)
|
|
||||||
writeJSON(w, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DebugLogByRequestID serves a single debug log entry by request ID.
|
|
||||||
func (s *StatsAPI) DebugLogByRequestID(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if s.debugLogger == nil {
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
writeJSON(w, map[string]string{"error": "debug logger not configured"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
requestID := chi.URLParam(r, "requestID")
|
|
||||||
entry := s.debugLogger.GetByRequestID(requestID)
|
|
||||||
if entry == nil {
|
|
||||||
w.WriteHeader(http.StatusNotFound)
|
|
||||||
writeJSON(w, map[string]string{"error": "not found"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, entry)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateConfig validates the config file at the stored path.
|
|
||||||
// Returns HTML for HTMX requests, JSON otherwise.
|
|
||||||
func (s *StatsAPI) ValidateConfig(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if s.configPath == "" {
|
|
||||||
if r.Header.Get("HX-Request") == "true" {
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
w.Write([]byte(`<div class="error-msg">Config path not set</div>`))
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
writeJSON(w, map[string]any{"valid": false, "errors": []string{"config path not set"}})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data, err := os.ReadFile(s.configPath)
|
|
||||||
if err != nil {
|
|
||||||
msg := "failed to read config: " + err.Error()
|
|
||||||
if r.Header.Get("HX-Request") == "true" {
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
w.Write([]byte(`<div class="error-msg">` + msg + `</div>`))
|
|
||||||
} else {
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
writeJSON(w, map[string]any{"valid": false, "errors": []string{msg}})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
errs := config.ValidateBytes(data)
|
|
||||||
|
|
||||||
if r.Header.Get("HX-Request") == "true" {
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
if len(errs) > 0 {
|
|
||||||
html := `<div class="error-msg">Configuration errors:<ul style="margin:4px 0 0 16px;">`
|
|
||||||
for _, e := range errs {
|
|
||||||
html += "<li>" + e + "</li>"
|
|
||||||
}
|
|
||||||
html += "</ul></div>"
|
|
||||||
w.Write([]byte(html))
|
|
||||||
} else {
|
|
||||||
w.Write([]byte(`<div class="success-msg">Configuration is valid.</div>`))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errs) > 0 {
|
|
||||||
writeJSON(w, map[string]any{"valid": false, "errors": errs})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
writeJSON(w, map[string]any{"valid": true, "errors": []string{}})
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeJSON(w http.ResponseWriter, v any) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(v)
|
|
||||||
}
|
|
||||||
|
|
@ -1,297 +0,0 @@
|
||||||
package dashboard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/csv"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/auth"
|
|
||||||
"llm-gateway/internal/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ExportHandler struct {
|
|
||||||
db *storage.DB
|
|
||||||
authStore *auth.Store
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewExportHandler(db *storage.DB, authStore *auth.Store) *ExportHandler {
|
|
||||||
return &ExportHandler{db: db, authStore: authStore}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExportLogs exports request logs as CSV or JSON.
|
|
||||||
func (e *ExportHandler) ExportLogs(w http.ResponseWriter, r *http.Request) {
|
|
||||||
format := r.URL.Query().Get("format")
|
|
||||||
if format == "" {
|
|
||||||
format = "json"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build query
|
|
||||||
where := "WHERE 1=1"
|
|
||||||
var args []any
|
|
||||||
|
|
||||||
if from := r.URL.Query().Get("from"); from != "" {
|
|
||||||
if ts, err := strconv.ParseInt(from, 10, 64); err == nil {
|
|
||||||
where += " AND timestamp >= ?"
|
|
||||||
args = append(args, ts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if to := r.URL.Query().Get("to"); to != "" {
|
|
||||||
if ts, err := strconv.ParseInt(to, 10, 64); err == nil {
|
|
||||||
where += " AND timestamp <= ?"
|
|
||||||
args = append(args, ts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if model := r.URL.Query().Get("model"); model != "" {
|
|
||||||
where += " AND model = ?"
|
|
||||||
args = append(args, model)
|
|
||||||
}
|
|
||||||
if token := r.URL.Query().Get("token"); token != "" {
|
|
||||||
where += " AND token_name = ?"
|
|
||||||
args = append(args, token)
|
|
||||||
}
|
|
||||||
if status := r.URL.Query().Get("status"); status != "" {
|
|
||||||
where += " AND status = ?"
|
|
||||||
args = append(args, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Token filtering for non-admins
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
if user != nil && !user.IsAdmin {
|
|
||||||
tokens, err := e.authStore.ListAPITokens(user.ID)
|
|
||||||
if err != nil || len(tokens) == 0 {
|
|
||||||
where += " AND 1=0"
|
|
||||||
} else {
|
|
||||||
where += " AND token_name IN ("
|
|
||||||
for i, t := range tokens {
|
|
||||||
if i > 0 {
|
|
||||||
where += ","
|
|
||||||
}
|
|
||||||
where += "?"
|
|
||||||
args = append(args, t.Name)
|
|
||||||
}
|
|
||||||
where += ")"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
query := `SELECT COALESCE(request_id, ''), timestamp, token_name, model, provider, provider_model,
|
|
||||||
input_tokens, output_tokens, cost_usd, latency_ms, status,
|
|
||||||
COALESCE(error_message, ''), streaming, cached
|
|
||||||
FROM request_logs ` + where + ` ORDER BY timestamp DESC LIMIT 100000`
|
|
||||||
|
|
||||||
rows, err := e.db.Query(query, args...)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "query failed", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
type logRow struct {
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
Timestamp int64 `json:"timestamp"`
|
|
||||||
TokenName string `json:"token_name"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Provider string `json:"provider"`
|
|
||||||
ProviderModel string `json:"provider_model"`
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
LatencyMS int64 `json:"latency_ms"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
ErrorMessage string `json:"error_message"`
|
|
||||||
Streaming bool `json:"streaming"`
|
|
||||||
Cached bool `json:"cached"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var results []logRow
|
|
||||||
for rows.Next() {
|
|
||||||
var l logRow
|
|
||||||
var streaming, cached int
|
|
||||||
rows.Scan(&l.RequestID, &l.Timestamp, &l.TokenName, &l.Model, &l.Provider, &l.ProviderModel,
|
|
||||||
&l.InputTokens, &l.OutputTokens, &l.CostUSD, &l.LatencyMS, &l.Status,
|
|
||||||
&l.ErrorMessage, &streaming, &cached)
|
|
||||||
l.Streaming = streaming == 1
|
|
||||||
l.Cached = cached == 1
|
|
||||||
results = append(results, l)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().Format("20060102-150405")
|
|
||||||
|
|
||||||
switch format {
|
|
||||||
case "csv":
|
|
||||||
w.Header().Set("Content-Type", "text/csv")
|
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.csv", now))
|
|
||||||
writer := csv.NewWriter(w)
|
|
||||||
writer.Write([]string{"request_id", "timestamp", "token_name", "model", "provider", "provider_model",
|
|
||||||
"input_tokens", "output_tokens", "cost_usd", "latency_ms", "status", "error_message", "streaming", "cached"})
|
|
||||||
for _, l := range results {
|
|
||||||
writer.Write([]string{
|
|
||||||
l.RequestID,
|
|
||||||
strconv.FormatInt(l.Timestamp, 10),
|
|
||||||
l.TokenName, l.Model, l.Provider, l.ProviderModel,
|
|
||||||
strconv.Itoa(l.InputTokens), strconv.Itoa(l.OutputTokens),
|
|
||||||
fmt.Sprintf("%.8f", l.CostUSD),
|
|
||||||
strconv.FormatInt(l.LatencyMS, 10),
|
|
||||||
l.Status, l.ErrorMessage,
|
|
||||||
strconv.FormatBool(l.Streaming), strconv.FormatBool(l.Cached),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
writer.Flush()
|
|
||||||
default:
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=logs-%s.json", now))
|
|
||||||
json.NewEncoder(w).Encode(results)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExportStats exports aggregated stats as CSV or JSON.
|
|
||||||
func (e *ExportHandler) ExportStats(w http.ResponseWriter, r *http.Request) {
|
|
||||||
format := r.URL.Query().Get("format")
|
|
||||||
if format == "" {
|
|
||||||
format = "json"
|
|
||||||
}
|
|
||||||
statsType := r.URL.Query().Get("type")
|
|
||||||
if statsType == "" {
|
|
||||||
statsType = "summary"
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().Format("20060102-150405")
|
|
||||||
since := time.Now().AddDate(0, -1, 0).Unix()
|
|
||||||
|
|
||||||
switch statsType {
|
|
||||||
case "models":
|
|
||||||
rows, err := e.db.Query(`SELECT model, COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
|
|
||||||
COALESCE(SUM(cost_usd), 0), COALESCE(AVG(latency_ms), 0)
|
|
||||||
FROM request_logs WHERE timestamp >= ? GROUP BY model ORDER BY requests DESC`, since)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "query failed", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
type modelRow struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Requests int `json:"requests"`
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
AvgLatencyMS float64 `json:"avg_latency_ms"`
|
|
||||||
}
|
|
||||||
var results []modelRow
|
|
||||||
for rows.Next() {
|
|
||||||
var m modelRow
|
|
||||||
rows.Scan(&m.Model, &m.Requests, &m.InputTokens, &m.OutputTokens, &m.CostUSD, &m.AvgLatencyMS)
|
|
||||||
results = append(results, m)
|
|
||||||
}
|
|
||||||
|
|
||||||
if format == "csv" {
|
|
||||||
w.Header().Set("Content-Type", "text/csv")
|
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.csv", now))
|
|
||||||
writer := csv.NewWriter(w)
|
|
||||||
writer.Write([]string{"model", "requests", "input_tokens", "output_tokens", "cost_usd", "avg_latency_ms"})
|
|
||||||
for _, m := range results {
|
|
||||||
writer.Write([]string{m.Model, strconv.Itoa(m.Requests), strconv.Itoa(m.InputTokens),
|
|
||||||
strconv.Itoa(m.OutputTokens), fmt.Sprintf("%.8f", m.CostUSD), fmt.Sprintf("%.2f", m.AvgLatencyMS)})
|
|
||||||
}
|
|
||||||
writer.Flush()
|
|
||||||
} else {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-models-%s.json", now))
|
|
||||||
json.NewEncoder(w).Encode(results)
|
|
||||||
}
|
|
||||||
|
|
||||||
case "providers":
|
|
||||||
rows, err := e.db.Query(`SELECT provider, COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(CASE WHEN status='success' THEN 1 ELSE 0 END), 0),
|
|
||||||
COALESCE(SUM(CASE WHEN status='error' THEN 1 ELSE 0 END), 0),
|
|
||||||
COALESCE(AVG(latency_ms), 0), COALESCE(SUM(cost_usd), 0)
|
|
||||||
FROM request_logs WHERE timestamp >= ? GROUP BY provider ORDER BY requests DESC`, since)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "query failed", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
type providerRow struct {
|
|
||||||
Provider string `json:"provider"`
|
|
||||||
Requests int `json:"requests"`
|
|
||||||
Successes int `json:"successes"`
|
|
||||||
Errors int `json:"errors"`
|
|
||||||
AvgLatencyMS float64 `json:"avg_latency_ms"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
}
|
|
||||||
var results []providerRow
|
|
||||||
for rows.Next() {
|
|
||||||
var p providerRow
|
|
||||||
rows.Scan(&p.Provider, &p.Requests, &p.Successes, &p.Errors, &p.AvgLatencyMS, &p.CostUSD)
|
|
||||||
results = append(results, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
if format == "csv" {
|
|
||||||
w.Header().Set("Content-Type", "text/csv")
|
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.csv", now))
|
|
||||||
writer := csv.NewWriter(w)
|
|
||||||
writer.Write([]string{"provider", "requests", "successes", "errors", "avg_latency_ms", "cost_usd"})
|
|
||||||
for _, p := range results {
|
|
||||||
writer.Write([]string{p.Provider, strconv.Itoa(p.Requests), strconv.Itoa(p.Successes),
|
|
||||||
strconv.Itoa(p.Errors), fmt.Sprintf("%.2f", p.AvgLatencyMS), fmt.Sprintf("%.8f", p.CostUSD)})
|
|
||||||
}
|
|
||||||
writer.Flush()
|
|
||||||
} else {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-providers-%s.json", now))
|
|
||||||
json.NewEncoder(w).Encode(results)
|
|
||||||
}
|
|
||||||
|
|
||||||
case "tokens":
|
|
||||||
rows, err := e.db.Query(`SELECT token_name, COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0),
|
|
||||||
COALESCE(SUM(cost_usd), 0)
|
|
||||||
FROM request_logs WHERE timestamp >= ? GROUP BY token_name ORDER BY requests DESC`, since)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "query failed", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
type tokenRow struct {
|
|
||||||
TokenName string `json:"token_name"`
|
|
||||||
Requests int `json:"requests"`
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
CostUSD float64 `json:"cost_usd"`
|
|
||||||
}
|
|
||||||
var results []tokenRow
|
|
||||||
for rows.Next() {
|
|
||||||
var t tokenRow
|
|
||||||
rows.Scan(&t.TokenName, &t.Requests, &t.InputTokens, &t.OutputTokens, &t.CostUSD)
|
|
||||||
results = append(results, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
if format == "csv" {
|
|
||||||
w.Header().Set("Content-Type", "text/csv")
|
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.csv", now))
|
|
||||||
writer := csv.NewWriter(w)
|
|
||||||
writer.Write([]string{"token_name", "requests", "input_tokens", "output_tokens", "cost_usd"})
|
|
||||||
for _, t := range results {
|
|
||||||
writer.Write([]string{t.TokenName, strconv.Itoa(t.Requests), strconv.Itoa(t.InputTokens),
|
|
||||||
strconv.Itoa(t.OutputTokens), fmt.Sprintf("%.8f", t.CostUSD)})
|
|
||||||
}
|
|
||||||
writer.Flush()
|
|
||||||
} else {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-tokens-%s.json", now))
|
|
||||||
json.NewEncoder(w).Encode(results)
|
|
||||||
}
|
|
||||||
|
|
||||||
default: // summary
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=stats-summary-%s.json", now))
|
|
||||||
statsAPI := NewStatsAPI(e.db, e.authStore)
|
|
||||||
result := statsAPI.GetSummary(nil)
|
|
||||||
json.NewEncoder(w).Encode(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,426 +0,0 @@
|
||||||
package dashboard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"embed"
|
|
||||||
"fmt"
|
|
||||||
"html/template"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/auth"
|
|
||||||
"llm-gateway/internal/cache"
|
|
||||||
"llm-gateway/internal/provider"
|
|
||||||
"llm-gateway/internal/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed templates/*.html templates/partials/*.html
|
|
||||||
var templateFiles embed.FS
|
|
||||||
|
|
||||||
var templateFuncs = template.FuncMap{
|
|
||||||
"formatTime": func(ts int64) string {
|
|
||||||
if ts == 0 {
|
|
||||||
return "never"
|
|
||||||
}
|
|
||||||
return time.Unix(ts, 0).Format("2006-01-02")
|
|
||||||
},
|
|
||||||
"formatTimeDetail": func(ts int64) string {
|
|
||||||
if ts == 0 {
|
|
||||||
return "never"
|
|
||||||
}
|
|
||||||
return time.Unix(ts, 0).Format("2006-01-02 15:04:05")
|
|
||||||
},
|
|
||||||
"addInt": func(a, b int) int {
|
|
||||||
return a + b
|
|
||||||
},
|
|
||||||
"subInt": func(a, b int) int {
|
|
||||||
return a - b
|
|
||||||
},
|
|
||||||
"formatCost": func(v float64) string {
|
|
||||||
if v == 0 {
|
|
||||||
return "$0.00"
|
|
||||||
}
|
|
||||||
if v < 0.01 {
|
|
||||||
return fmt.Sprintf("$%.6f", v)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("$%.4f", v)
|
|
||||||
},
|
|
||||||
"formatPrice": func(v float64) string {
|
|
||||||
if v == 0 {
|
|
||||||
return "-"
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("$%.2f", v)
|
|
||||||
},
|
|
||||||
"formatPct": func(v float64) string {
|
|
||||||
return fmt.Sprintf("%.1f%%", v*100)
|
|
||||||
},
|
|
||||||
"budgetPct": func(spend, budget float64) float64 {
|
|
||||||
if budget <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return spend / budget * 100
|
|
||||||
},
|
|
||||||
"budgetColor": func(pct float64) string {
|
|
||||||
if pct >= 80 {
|
|
||||||
return "#f87171"
|
|
||||||
}
|
|
||||||
if pct >= 50 {
|
|
||||||
return "#fbbf24"
|
|
||||||
}
|
|
||||||
return "#4ade80"
|
|
||||||
},
|
|
||||||
"seq": func(start, end int) []int {
|
|
||||||
var s []int
|
|
||||||
for i := start; i <= end; i++ {
|
|
||||||
s = append(s, i)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
},
|
|
||||||
"paginationStart": func(page, totalPages int) int {
|
|
||||||
start := page - 2
|
|
||||||
if start < 1 {
|
|
||||||
start = 1
|
|
||||||
}
|
|
||||||
if totalPages-start < 4 && totalPages > 4 {
|
|
||||||
start = totalPages - 4
|
|
||||||
}
|
|
||||||
return start
|
|
||||||
},
|
|
||||||
"paginationEnd": func(page, totalPages int) int {
|
|
||||||
start := page - 2
|
|
||||||
if start < 1 {
|
|
||||||
start = 1
|
|
||||||
}
|
|
||||||
end := start + 4
|
|
||||||
if end > totalPages {
|
|
||||||
end = totalPages
|
|
||||||
}
|
|
||||||
return end
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// PageData is the common data passed to all templates.
|
|
||||||
type PageData struct {
|
|
||||||
ActivePage string
|
|
||||||
User *auth.User
|
|
||||||
// Dashboard data
|
|
||||||
Summary *SummaryResult
|
|
||||||
Models []ModelStats
|
|
||||||
Providers []ProviderStats
|
|
||||||
TokenStats []TokenUsageStats
|
|
||||||
ProviderHealth []provider.ProviderHealth
|
|
||||||
Latency *LatencyResult
|
|
||||||
CacheEnabled bool
|
|
||||||
CacheInfo *cache.CacheStats
|
|
||||||
// Tokens page data
|
|
||||||
Tokens []auth.APIToken
|
|
||||||
TokenSpend map[string]float64
|
|
||||||
// Users page data
|
|
||||||
Users []auth.User
|
|
||||||
// Logs page data
|
|
||||||
LogsResult *LogsResult
|
|
||||||
LogModels []string
|
|
||||||
LogTokens []string
|
|
||||||
FilterModel string
|
|
||||||
FilterToken string
|
|
||||||
FilterStatus string
|
|
||||||
// Models routing page data
|
|
||||||
ModelRoutes []provider.ModelRouteInfo
|
|
||||||
// Audit page data
|
|
||||||
AuditResult *storage.AuditQueryResult
|
|
||||||
AuditFilterActions []string
|
|
||||||
FilterAction string
|
|
||||||
// Debug page data
|
|
||||||
DebugResult *storage.DebugLogQueryResult
|
|
||||||
DebugEnabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dashboard serves the HTMX-based dashboard pages.
|
|
||||||
type Dashboard struct {
|
|
||||||
templates *template.Template
|
|
||||||
authStore *auth.Store
|
|
||||||
statsAPI *StatsAPI
|
|
||||||
registry *provider.Registry
|
|
||||||
cache *cache.Cache
|
|
||||||
auditLogger *storage.AuditLogger
|
|
||||||
debugLogger *storage.DebugLogger
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDashboard creates a new Dashboard handler.
|
|
||||||
func NewDashboard(authStore *auth.Store, statsAPI *StatsAPI) *Dashboard {
|
|
||||||
tmpl := template.Must(
|
|
||||||
template.New("").Funcs(templateFuncs).ParseFS(templateFiles,
|
|
||||||
"templates/*.html",
|
|
||||||
"templates/partials/*.html",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return &Dashboard{
|
|
||||||
templates: tmpl,
|
|
||||||
authStore: authStore,
|
|
||||||
statsAPI: statsAPI,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRegistry sets the provider registry for model routing display.
|
|
||||||
func (d *Dashboard) SetRegistry(r *provider.Registry) {
|
|
||||||
d.registry = r
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCache sets the cache reference for cache stats display.
|
|
||||||
func (d *Dashboard) SetCache(c *cache.Cache) {
|
|
||||||
d.cache = c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAuditLogger sets the audit logger for the audit page.
|
|
||||||
func (d *Dashboard) SetAuditLogger(al *storage.AuditLogger) {
|
|
||||||
d.auditLogger = al
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDebugLogger sets the debug logger for the debug page.
|
|
||||||
func (d *Dashboard) SetDebugLogger(dl *storage.DebugLogger) {
|
|
||||||
d.debugLogger = dl
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoginPage serves the login page.
|
|
||||||
func (d *Dashboard) LoginPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !d.authStore.HasAnyUser() {
|
|
||||||
http.Redirect(w, r, "/setup", http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if user := d.getSessionUser(r); user != nil {
|
|
||||||
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
d.templates.ExecuteTemplate(w, "login", nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetupPage serves the initial setup page.
|
|
||||||
func (d *Dashboard) SetupPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if d.authStore.HasAnyUser() {
|
|
||||||
http.Redirect(w, r, "/login", http.StatusFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
d.templates.ExecuteTemplate(w, "setup", nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DashboardPage serves the main dashboard view.
|
|
||||||
func (d *Dashboard) DashboardPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
tokenNames := d.statsAPI.TokenNamesForUser(user)
|
|
||||||
|
|
||||||
data := PageData{
|
|
||||||
ActivePage: "dashboard",
|
|
||||||
User: user,
|
|
||||||
Summary: d.statsAPI.GetSummary(tokenNames),
|
|
||||||
Models: d.statsAPI.GetModels(tokenNames),
|
|
||||||
Providers: d.statsAPI.GetProviders(tokenNames),
|
|
||||||
TokenStats: d.statsAPI.GetTokenUsage(tokenNames),
|
|
||||||
Latency: d.statsAPI.GetLatency(tokenNames, "24h", "", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Provider health
|
|
||||||
if d.statsAPI.healthTracker != nil {
|
|
||||||
data.ProviderHealth = d.statsAPI.healthTracker.Status()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache stats
|
|
||||||
if d.cache != nil {
|
|
||||||
data.CacheEnabled = true
|
|
||||||
data.CacheInfo = d.cache.Stats(r.Context())
|
|
||||||
}
|
|
||||||
|
|
||||||
d.renderDashboardPage(w, r, "partials/dashboard.html", data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogsPage serves the request logs view.
|
|
||||||
func (d *Dashboard) LogsPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
tokenNames := d.statsAPI.TokenNamesForUser(user)
|
|
||||||
|
|
||||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
model := r.URL.Query().Get("model")
|
|
||||||
token := r.URL.Query().Get("token")
|
|
||||||
status := r.URL.Query().Get("status")
|
|
||||||
|
|
||||||
data := PageData{
|
|
||||||
ActivePage: "logs",
|
|
||||||
User: user,
|
|
||||||
LogsResult: d.statsAPI.GetLogs(tokenNames, page, model, token, status),
|
|
||||||
LogModels: d.statsAPI.GetDistinctModels(),
|
|
||||||
LogTokens: d.statsAPI.GetDistinctTokens(),
|
|
||||||
FilterModel: model,
|
|
||||||
FilterToken: token,
|
|
||||||
FilterStatus: status,
|
|
||||||
}
|
|
||||||
|
|
||||||
d.renderDashboardPage(w, r, "partials/logs.html", data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModelsPage serves the model routing table view.
|
|
||||||
func (d *Dashboard) ModelsPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
|
|
||||||
data := PageData{
|
|
||||||
ActivePage: "models",
|
|
||||||
User: user,
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.registry != nil {
|
|
||||||
data.ModelRoutes = d.registry.AllRoutes()
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.statsAPI.healthTracker != nil {
|
|
||||||
data.ProviderHealth = d.statsAPI.healthTracker.Status()
|
|
||||||
}
|
|
||||||
|
|
||||||
d.renderDashboardPage(w, r, "partials/models-page.html", data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokensPage serves the tokens management view.
|
|
||||||
func (d *Dashboard) TokensPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
|
|
||||||
var userID int64
|
|
||||||
if !user.IsAdmin {
|
|
||||||
userID = user.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens, _ := d.authStore.ListAPITokens(userID)
|
|
||||||
if tokens == nil {
|
|
||||||
tokens = []auth.APIToken{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get today's spend for budget display
|
|
||||||
spend, _ := d.statsAPI.db.TodaySpendAll()
|
|
||||||
if spend == nil {
|
|
||||||
spend = make(map[string]float64)
|
|
||||||
}
|
|
||||||
|
|
||||||
d.renderDashboardPage(w, r, "partials/tokens.html", PageData{
|
|
||||||
ActivePage: "tokens",
|
|
||||||
User: user,
|
|
||||||
Tokens: tokens,
|
|
||||||
TokenSpend: spend,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsersPage serves the user management view (admin only).
|
|
||||||
func (d *Dashboard) UsersPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
users, _ := d.authStore.ListUsers()
|
|
||||||
|
|
||||||
d.renderDashboardPage(w, r, "partials/users.html", PageData{
|
|
||||||
ActivePage: "users",
|
|
||||||
User: user,
|
|
||||||
Users: users,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuditPage serves the audit log view (admin only).
|
|
||||||
func (d *Dashboard) AuditPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
|
|
||||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
action := r.URL.Query().Get("action")
|
|
||||||
since := time.Now().AddDate(0, 0, -30).Unix()
|
|
||||||
|
|
||||||
var auditResult *storage.AuditQueryResult
|
|
||||||
if d.auditLogger != nil {
|
|
||||||
auditResult = d.auditLogger.Query(since, action, page, 50)
|
|
||||||
} else {
|
|
||||||
auditResult = &storage.AuditQueryResult{Entries: []storage.AuditEntry{}, Page: 1, TotalPages: 1}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Common audit action types for the filter dropdown
|
|
||||||
actions := []string{"login", "logout", "create_user", "delete_user", "create_token", "delete_token", "change_password", "setup_totp", "disable_totp"}
|
|
||||||
|
|
||||||
d.renderDashboardPage(w, r, "partials/audit.html", PageData{
|
|
||||||
ActivePage: "audit",
|
|
||||||
User: user,
|
|
||||||
AuditResult: auditResult,
|
|
||||||
AuditFilterActions: actions,
|
|
||||||
FilterAction: action,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// DebugPage serves the debug logging view (admin only).
|
|
||||||
func (d *Dashboard) DebugPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
|
|
||||||
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
var debugResult *storage.DebugLogQueryResult
|
|
||||||
debugEnabled := false
|
|
||||||
if d.debugLogger != nil {
|
|
||||||
debugResult = d.debugLogger.QueryFull(page, 50)
|
|
||||||
debugEnabled = d.debugLogger.IsEnabled()
|
|
||||||
} else {
|
|
||||||
debugResult = &storage.DebugLogQueryResult{Entries: []storage.DebugLogEntry{}, Page: 1, TotalPages: 1}
|
|
||||||
}
|
|
||||||
|
|
||||||
d.renderDashboardPage(w, r, "partials/debug.html", PageData{
|
|
||||||
ActivePage: "debug",
|
|
||||||
User: user,
|
|
||||||
DebugResult: debugResult,
|
|
||||||
DebugEnabled: debugEnabled,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SettingsPage serves the settings view.
|
|
||||||
func (d *Dashboard) SettingsPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
user := auth.UserFromContext(r.Context())
|
|
||||||
user, _ = d.authStore.GetUserByID(user.ID)
|
|
||||||
|
|
||||||
d.renderDashboardPage(w, r, "partials/settings.html", PageData{
|
|
||||||
ActivePage: "settings",
|
|
||||||
User: user,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// renderDashboardPage renders either the full layout or just the content partial.
|
|
||||||
func (d *Dashboard) renderDashboardPage(w http.ResponseWriter, r *http.Request, partialFile string, data PageData) {
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
|
|
||||||
if r.Header.Get("HX-Request") == "true" {
|
|
||||||
tmpl := template.Must(
|
|
||||||
template.New("").Funcs(templateFuncs).ParseFS(templateFiles, "templates/"+partialFile),
|
|
||||||
)
|
|
||||||
tmpl.ExecuteTemplate(w, "content", data)
|
|
||||||
} else {
|
|
||||||
tmpl := template.Must(
|
|
||||||
template.New("").Funcs(templateFuncs).ParseFS(templateFiles,
|
|
||||||
"templates/layout.html",
|
|
||||||
"templates/"+partialFile,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
tmpl.ExecuteTemplate(w, "layout", data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Dashboard) getSessionUser(r *http.Request) *auth.User {
|
|
||||||
cookie, err := r.Cookie("llmgw_session")
|
|
||||||
if err != nil || cookie.Value == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
sess, err := d.authStore.GetSession(cookie.Value)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
user, err := d.authStore.GetUserByID(sess.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return user
|
|
||||||
}
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
package dashboard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SSEBroker manages Server-Sent Events connections.
|
|
||||||
type SSEBroker struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
clients map[chan struct{}]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSSEBroker creates a new SSE broker.
|
|
||||||
func NewSSEBroker() *SSEBroker {
|
|
||||||
return &SSEBroker{
|
|
||||||
clients: make(map[chan struct{}]struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notify sends a refresh signal to all connected SSE clients.
|
|
||||||
func (b *SSEBroker) Notify() {
|
|
||||||
b.mu.RLock()
|
|
||||||
defer b.mu.RUnlock()
|
|
||||||
for ch := range b.clients {
|
|
||||||
select {
|
|
||||||
case ch <- struct{}{}:
|
|
||||||
default:
|
|
||||||
// Client not ready, skip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTP handles SSE connections.
|
|
||||||
func (b *SSEBroker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
|
|
||||||
ch := make(chan struct{}, 1)
|
|
||||||
b.mu.Lock()
|
|
||||||
b.clients[ch] = struct{}{}
|
|
||||||
b.mu.Unlock()
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
b.mu.Lock()
|
|
||||||
delete(b.clients, ch)
|
|
||||||
b.mu.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Send initial connection event
|
|
||||||
fmt.Fprintf(w, "event: connected\ndata: ok\n\n")
|
|
||||||
flusher.Flush()
|
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ch:
|
|
||||||
fmt.Fprintf(w, "event: refresh\ndata: updated\n\n")
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,366 +0,0 @@
|
||||||
{{define "layout"}}
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>LLM Gateway</title>
|
|
||||||
<script>
|
|
||||||
// Prevent flash of wrong theme
|
|
||||||
(function() {
|
|
||||||
var pref = localStorage.getItem('theme') || 'auto';
|
|
||||||
var effective = pref;
|
|
||||||
if (pref === 'auto') effective = window.matchMedia('(prefers-color-scheme: light)').matches ? 'light' : 'dark';
|
|
||||||
if (effective === 'light') document.documentElement.setAttribute('data-theme', 'light');
|
|
||||||
})();
|
|
||||||
</script>
|
|
||||||
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
|
||||||
<script src="https://unpkg.com/htmx-ext-json-enc@2.0.3/json-enc.js"></script>
|
|
||||||
<script src="https://cdn.jsdelivr.net/npm/chart.js@4"></script>
|
|
||||||
<style>
|
|
||||||
:root {
|
|
||||||
--bg-primary: #0f172a;
|
|
||||||
--bg-secondary: #1e293b;
|
|
||||||
--bg-tertiary: #334155;
|
|
||||||
--border-color: #334155;
|
|
||||||
--text-primary: #e2e8f0;
|
|
||||||
--text-secondary: #94a3b8;
|
|
||||||
--text-muted: #64748b;
|
|
||||||
--text-heading: #f8fafc;
|
|
||||||
--text-subheading: #cbd5e1;
|
|
||||||
--accent-blue: #3b82f6;
|
|
||||||
--accent-blue-hover: #2563eb;
|
|
||||||
--accent-blue-bg: #3b82f620;
|
|
||||||
--accent-green: #4ade80;
|
|
||||||
--accent-green-bg: #4ade8020;
|
|
||||||
--accent-red: #f87171;
|
|
||||||
--accent-red-bg: #7f1d1d40;
|
|
||||||
--accent-red-border: #991b1b;
|
|
||||||
--accent-red-text: #fca5a5;
|
|
||||||
--accent-yellow: #fbbf24;
|
|
||||||
--accent-yellow-bg: #92400e40;
|
|
||||||
--accent-purple: #a78bfa;
|
|
||||||
--accent-purple-bg: #a78bfa20;
|
|
||||||
--success-bg: #14532d40;
|
|
||||||
--success-border: #166534;
|
|
||||||
--success-text: #86efac;
|
|
||||||
--modal-overlay: #00000080;
|
|
||||||
--chart-grid: #1e293b;
|
|
||||||
}
|
|
||||||
[data-theme="light"] {
|
|
||||||
--bg-primary: #f8fafc;
|
|
||||||
--bg-secondary: #ffffff;
|
|
||||||
--bg-tertiary: #e2e8f0;
|
|
||||||
--border-color: #cbd5e1;
|
|
||||||
--text-primary: #1e293b;
|
|
||||||
--text-secondary: #475569;
|
|
||||||
--text-muted: #94a3b8;
|
|
||||||
--text-heading: #0f172a;
|
|
||||||
--text-subheading: #334155;
|
|
||||||
--accent-blue: #2563eb;
|
|
||||||
--accent-blue-hover: #1d4ed8;
|
|
||||||
--accent-blue-bg: #dbeafe;
|
|
||||||
--accent-green: #16a34a;
|
|
||||||
--accent-green-bg: #dcfce7;
|
|
||||||
--accent-red: #dc2626;
|
|
||||||
--accent-red-bg: #fef2f2;
|
|
||||||
--accent-red-border: #fca5a5;
|
|
||||||
--accent-red-text: #991b1b;
|
|
||||||
--accent-yellow: #d97706;
|
|
||||||
--accent-yellow-bg: #fef3c7;
|
|
||||||
--accent-purple: #7c3aed;
|
|
||||||
--accent-purple-bg: #ede9fe;
|
|
||||||
--success-bg: #f0fdf4;
|
|
||||||
--success-border: #86efac;
|
|
||||||
--success-text: #166534;
|
|
||||||
--modal-overlay: #00000040;
|
|
||||||
--chart-grid: #e2e8f0;
|
|
||||||
}
|
|
||||||
|
|
||||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
|
||||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; background: var(--bg-primary); color: var(--text-primary); min-height: 100vh; display: flex; }
|
|
||||||
|
|
||||||
/* Sidebar */
|
|
||||||
.sidebar { width: 220px; background: var(--bg-secondary); border-right: 1px solid var(--border-color); min-height: 100vh; display: flex; flex-direction: column; position: fixed; top: 0; left: 0; }
|
|
||||||
.sidebar-brand { padding: 20px 16px; font-size: 1.1rem; font-weight: 700; color: var(--text-heading); border-bottom: 1px solid var(--border-color); }
|
|
||||||
.sidebar-nav { flex: 1; padding: 12px 0; }
|
|
||||||
.sidebar-nav a { display: block; padding: 10px 20px; color: var(--text-secondary); text-decoration: none; font-size: 0.9rem; transition: all 0.15s; }
|
|
||||||
.sidebar-nav a:hover { background: var(--bg-tertiary); color: var(--text-primary); }
|
|
||||||
.sidebar-nav a.active { background: var(--accent-blue-bg); color: var(--accent-blue); border-right: 3px solid var(--accent-blue); }
|
|
||||||
.sidebar-footer { padding: 16px; border-top: 1px solid var(--border-color); }
|
|
||||||
.sidebar-footer .user-info { font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 8px; }
|
|
||||||
.sidebar-footer a { display: block; padding: 6px 0; color: var(--text-secondary); text-decoration: none; font-size: 0.85rem; }
|
|
||||||
.sidebar-footer a:hover { color: var(--accent-red); }
|
|
||||||
.theme-toggle { cursor: pointer; background: var(--bg-tertiary); border: 1px solid var(--border-color); color: var(--text-secondary); padding: 6px 12px; border-radius: 6px; font-size: 0.8rem; width: 100%; margin-bottom: 8px; }
|
|
||||||
.theme-toggle:hover { color: var(--text-primary); }
|
|
||||||
|
|
||||||
/* Main content */
|
|
||||||
.main { flex: 1; margin-left: 220px; padding: 24px; min-height: 100vh; max-width: calc(100vw - 220px); overflow-x: hidden; }
|
|
||||||
|
|
||||||
/* Cards & tables */
|
|
||||||
.cards { display: grid; grid-template-columns: repeat(auto-fit, minmax(160px, 1fr)); gap: 12px; margin-bottom: 24px; max-width: 100%; }
|
|
||||||
.card { background: var(--bg-secondary); border-radius: 8px; padding: 16px; }
|
|
||||||
.card .label { font-size: 0.75rem; color: var(--text-secondary); text-transform: uppercase; letter-spacing: 0.05em; }
|
|
||||||
.card .value { font-size: 1.5rem; font-weight: 700; margin-top: 4px; }
|
|
||||||
.card .sub { font-size: 0.75rem; color: var(--text-muted); margin-top: 2px; }
|
|
||||||
.section { background: var(--bg-secondary); border-radius: 8px; padding: 16px; margin-bottom: 16px; overflow-x: auto; }
|
|
||||||
.section h2 { font-size: 1.1rem; margin-bottom: 12px; color: var(--text-subheading); }
|
|
||||||
.tabs { display: flex; gap: 8px; margin-bottom: 16px; }
|
|
||||||
.tabs button { background: var(--bg-secondary); border: 1px solid var(--border-color); color: var(--text-secondary); padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 0.8rem; }
|
|
||||||
.tabs button.active { background: var(--accent-blue); border-color: var(--accent-blue); color: #fff; }
|
|
||||||
table { width: 100%; border-collapse: collapse; font-size: 0.85rem; }
|
|
||||||
th { text-align: left; padding: 8px; color: var(--text-secondary); border-bottom: 1px solid var(--border-color); font-weight: 500; }
|
|
||||||
td { padding: 8px; border-bottom: 1px solid var(--border-color); }
|
|
||||||
.green { color: var(--accent-green); }
|
|
||||||
.red { color: var(--accent-red); }
|
|
||||||
.blue { color: var(--accent-blue); }
|
|
||||||
.yellow { color: var(--accent-yellow); }
|
|
||||||
|
|
||||||
/* Buttons */
|
|
||||||
.btn { display: inline-block; padding: 10px 20px; border-radius: 6px; border: none; cursor: pointer; font-size: 0.9rem; font-weight: 500; text-decoration: none; }
|
|
||||||
.btn-primary { background: var(--accent-blue); color: #fff; }
|
|
||||||
.btn-primary:hover { background: var(--accent-blue-hover); }
|
|
||||||
.btn-danger { background: #ef4444; color: #fff; }
|
|
||||||
.btn-danger:hover { background: #dc2626; }
|
|
||||||
.btn-sm { padding: 6px 12px; font-size: 0.8rem; }
|
|
||||||
.btn-outline { background: transparent; border: 1px solid var(--border-color); color: var(--text-secondary); }
|
|
||||||
.btn-outline:hover { border-color: var(--text-muted); color: var(--text-primary); }
|
|
||||||
|
|
||||||
/* Forms */
|
|
||||||
.form-group { margin-bottom: 16px; }
|
|
||||||
.form-group label { display: block; font-size: 0.85rem; color: var(--text-secondary); margin-bottom: 4px; }
|
|
||||||
.form-group input, .form-group select { width: 100%; padding: 10px 12px; background: var(--bg-primary); border: 1px solid var(--border-color); border-radius: 6px; color: var(--text-primary); font-size: 0.95rem; }
|
|
||||||
.form-group input:focus, .form-group select:focus { outline: none; border-color: var(--accent-blue); }
|
|
||||||
.error-msg { background: var(--accent-red-bg); border: 1px solid var(--accent-red-border); color: var(--accent-red-text); padding: 10px; border-radius: 6px; margin-bottom: 16px; font-size: 0.85rem; }
|
|
||||||
.success-msg { background: var(--success-bg); border: 1px solid var(--success-border); color: var(--success-text); padding: 10px; border-radius: 6px; margin-bottom: 16px; font-size: 0.85rem; }
|
|
||||||
|
|
||||||
/* Modal */
|
|
||||||
.modal-overlay { position: fixed; top: 0; left: 0; right: 0; bottom: 0; background: var(--modal-overlay); display: none; align-items: center; justify-content: center; z-index: 100; }
|
|
||||||
.modal-overlay.show { display: flex; }
|
|
||||||
.modal { background: var(--bg-secondary); border-radius: 12px; padding: 24px; width: 100%; max-width: 440px; }
|
|
||||||
.modal h2 { margin-bottom: 16px; color: var(--text-subheading); }
|
|
||||||
.modal-actions { display: flex; gap: 8px; justify-content: flex-end; margin-top: 16px; }
|
|
||||||
|
|
||||||
/* Token display */
|
|
||||||
.token-key { background: var(--bg-primary); padding: 8px 12px; border-radius: 6px; font-family: monospace; font-size: 0.85rem; word-break: break-all; margin: 8px 0; display: flex; align-items: center; gap: 8px; }
|
|
||||||
.token-key code { flex: 1; }
|
|
||||||
.copy-btn { background: var(--bg-tertiary); border: none; color: var(--text-secondary); padding: 4px 8px; border-radius: 4px; cursor: pointer; font-size: 0.75rem; }
|
|
||||||
.copy-btn:hover { color: var(--text-primary); }
|
|
||||||
|
|
||||||
/* Badge */
|
|
||||||
.badge { display: inline-block; padding: 2px 8px; border-radius: 12px; font-size: 0.7rem; font-weight: 600; }
|
|
||||||
.badge-admin { background: var(--accent-blue-bg); color: var(--accent-blue); }
|
|
||||||
.badge-user { background: var(--accent-green-bg); color: var(--accent-green); }
|
|
||||||
.badge-totp { background: var(--accent-purple-bg); color: var(--accent-purple); }
|
|
||||||
.badge-healthy { background: var(--accent-green-bg); color: var(--accent-green); }
|
|
||||||
.badge-degraded { background: var(--accent-yellow-bg); color: var(--accent-yellow); }
|
|
||||||
.badge-down { background: var(--accent-red-bg); color: var(--accent-red); }
|
|
||||||
.badge-success { background: var(--accent-green-bg); color: var(--accent-green); }
|
|
||||||
.badge-error { background: var(--accent-red-bg); color: var(--accent-red); }
|
|
||||||
.badge-cached { background: var(--accent-blue-bg); color: var(--accent-blue); }
|
|
||||||
.badge-priority { background: var(--bg-tertiary); color: var(--text-secondary); }
|
|
||||||
.badge-open { background: var(--accent-red-bg); color: var(--accent-red); }
|
|
||||||
.badge-half-open { background: var(--accent-yellow-bg); color: var(--accent-yellow); }
|
|
||||||
|
|
||||||
/* Toggle switch */
|
|
||||||
.toggle-switch { position: relative; display: inline-block; width: 44px; height: 24px; }
|
|
||||||
.toggle-switch input { opacity: 0; width: 0; height: 0; }
|
|
||||||
.toggle-slider { position: absolute; cursor: pointer; top: 0; left: 0; right: 0; bottom: 0; background: var(--bg-tertiary); border-radius: 24px; transition: 0.2s; }
|
|
||||||
.toggle-slider:before { content: ""; position: absolute; height: 18px; width: 18px; left: 3px; bottom: 3px; background: var(--text-secondary); border-radius: 50%; transition: 0.2s; }
|
|
||||||
.toggle-switch input:checked + .toggle-slider { background: var(--accent-blue); }
|
|
||||||
.toggle-switch input:checked + .toggle-slider:before { transform: translateX(20px); background: #fff; }
|
|
||||||
|
|
||||||
/* Code block for debug bodies */
|
|
||||||
.code-block { background: var(--bg-primary); border: 1px solid var(--border-color); border-radius: 6px; padding: 12px; font-family: monospace; font-size: 0.8rem; white-space: pre-wrap; word-break: break-all; max-height: 300px; overflow-y: auto; }
|
|
||||||
|
|
||||||
/* Export button inline */
|
|
||||||
.export-links { display: inline-flex; gap: 6px; margin-left: 12px; }
|
|
||||||
.export-links a { font-size: 0.7rem; color: var(--text-muted); text-decoration: none; padding: 2px 6px; border: 1px solid var(--border-color); border-radius: 4px; }
|
|
||||||
.export-links a:hover { color: var(--text-primary); border-color: var(--text-muted); }
|
|
||||||
|
|
||||||
.page-header { display: flex; align-items: center; gap: 12px; margin-bottom: 20px; }
|
|
||||||
.page-header h1 { font-size: 1.3rem; color: var(--text-heading); }
|
|
||||||
|
|
||||||
/* Filter bar */
|
|
||||||
.filter-bar { display: flex; gap: 12px; align-items: center; margin-bottom: 16px; flex-wrap: wrap; }
|
|
||||||
.filter-bar select { padding: 6px 10px; background: var(--bg-primary); border: 1px solid var(--border-color); border-radius: 6px; color: var(--text-primary); font-size: 0.85rem; }
|
|
||||||
.filter-bar select:focus { outline: none; border-color: var(--accent-blue); }
|
|
||||||
|
|
||||||
/* Pagination */
|
|
||||||
.pagination { display: flex; gap: 4px; align-items: center; margin-top: 16px; justify-content: center; }
|
|
||||||
.pagination button, .pagination span { padding: 6px 12px; border-radius: 6px; font-size: 0.8rem; border: 1px solid var(--border-color); background: var(--bg-secondary); color: var(--text-secondary); cursor: pointer; }
|
|
||||||
.pagination button:hover { background: var(--bg-tertiary); color: var(--text-primary); }
|
|
||||||
.pagination button.active { background: var(--accent-blue); border-color: var(--accent-blue); color: #fff; }
|
|
||||||
.pagination button:disabled { opacity: 0.5; cursor: not-allowed; }
|
|
||||||
.pagination .page-info { border: none; background: none; cursor: default; color: var(--text-muted); font-size: 0.8rem; }
|
|
||||||
|
|
||||||
/* Progress bar */
|
|
||||||
.progress-bar { width: 100%; height: 8px; background: var(--bg-tertiary); border-radius: 4px; overflow: hidden; }
|
|
||||||
.progress-bar-fill { height: 100%; border-radius: 4px; transition: width 0.3s; }
|
|
||||||
.budget-info { font-size: 0.75rem; color: var(--text-muted); margin-top: 2px; }
|
|
||||||
|
|
||||||
/* Expandable row */
|
|
||||||
.expandable { cursor: pointer; }
|
|
||||||
.expand-content { display: none; padding: 8px 12px; background: var(--bg-primary); border-radius: 6px; margin: 4px 0; font-size: 0.8rem; font-family: monospace; white-space: pre-wrap; word-break: break-all; }
|
|
||||||
.expand-content.show { display: block; }
|
|
||||||
|
|
||||||
/* Health status row */
|
|
||||||
.health-row { display: flex; gap: 12px; flex-wrap: wrap; margin-bottom: 16px; }
|
|
||||||
.health-item { display: flex; align-items: center; gap: 8px; padding: 8px 16px; background: var(--bg-secondary); border-radius: 8px; }
|
|
||||||
.health-item .provider-name { font-weight: 600; font-size: 0.85rem; }
|
|
||||||
|
|
||||||
/* Mobile menu button */
|
|
||||||
.mobile-menu-btn { display: none; position: fixed; top: 12px; left: 12px; z-index: 300; background: var(--bg-secondary); border: 1px solid var(--border-color); color: var(--text-primary); width: 40px; height: 40px; border-radius: 8px; cursor: pointer; align-items: center; justify-content: center; }
|
|
||||||
.mobile-menu-btn svg { width: 20px; height: 20px; }
|
|
||||||
|
|
||||||
/* Sidebar overlay */
|
|
||||||
.sidebar-overlay { display: none; position: fixed; top: 0; left: 0; right: 0; bottom: 0; background: var(--modal-overlay); z-index: 150; }
|
|
||||||
.sidebar-overlay.show { display: block; }
|
|
||||||
|
|
||||||
/* Responsive: tablet & mobile */
|
|
||||||
@media (max-width: 768px) {
|
|
||||||
.mobile-menu-btn { display: flex; }
|
|
||||||
.sidebar { left: -220px; z-index: 200; transition: left 0.25s ease; }
|
|
||||||
.sidebar.open { left: 0; }
|
|
||||||
.main { margin-left: 0; max-width: 100vw; padding: 60px 12px 12px; }
|
|
||||||
.cards { grid-template-columns: repeat(2, 1fr); }
|
|
||||||
.card .value { font-size: 1.2rem; }
|
|
||||||
table { min-width: 500px; font-size: 0.8rem; }
|
|
||||||
.modal { max-width: calc(100vw - 32px); padding: 20px; }
|
|
||||||
.filter-bar select { flex: 1; min-width: 120px; }
|
|
||||||
.page-header h1 { font-size: 1.1rem; }
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Responsive: small phone */
|
|
||||||
@media (max-width: 480px) {
|
|
||||||
.cards { grid-template-columns: 1fr; }
|
|
||||||
.card .value { font-size: 1rem; }
|
|
||||||
.main { padding: 56px 8px 8px; }
|
|
||||||
.section { padding: 12px; }
|
|
||||||
.modal { padding: 16px; }
|
|
||||||
.page-header h1 { font-size: 1rem; }
|
|
||||||
table { font-size: 0.75rem; }
|
|
||||||
th, td { padding: 6px; }
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<script>
|
|
||||||
function getThemePref() { return localStorage.getItem('theme') || 'auto'; }
|
|
||||||
function applyTheme(pref) {
|
|
||||||
var effective = pref;
|
|
||||||
if (pref === 'auto') effective = window.matchMedia('(prefers-color-scheme: light)').matches ? 'light' : 'dark';
|
|
||||||
if (effective === 'light') document.documentElement.setAttribute('data-theme', 'light');
|
|
||||||
else document.documentElement.removeAttribute('data-theme');
|
|
||||||
}
|
|
||||||
function themeLabel(pref) {
|
|
||||||
if (pref === 'light') return 'Light';
|
|
||||||
if (pref === 'dark') return 'Dark';
|
|
||||||
return 'Auto';
|
|
||||||
}
|
|
||||||
function toggleTheme() {
|
|
||||||
var order = ['dark', 'light', 'auto'];
|
|
||||||
var cur = getThemePref();
|
|
||||||
var next = order[(order.indexOf(cur) + 1) % order.length];
|
|
||||||
localStorage.setItem('theme', next);
|
|
||||||
applyTheme(next);
|
|
||||||
var btn = document.getElementById('theme-btn');
|
|
||||||
if (btn) btn.textContent = 'Theme: ' + themeLabel(next);
|
|
||||||
}
|
|
||||||
// Follow system changes when set to auto
|
|
||||||
window.matchMedia('(prefers-color-scheme: light)').addEventListener('change', function() {
|
|
||||||
if (getThemePref() === 'auto') applyTheme('auto');
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
<button class="mobile-menu-btn" onclick="toggleMobileSidebar()" aria-label="Menu">
|
|
||||||
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round"><line x1="3" y1="6" x2="21" y2="6"/><line x1="3" y1="12" x2="21" y2="12"/><line x1="3" y1="18" x2="21" y2="18"/></svg>
|
|
||||||
</button>
|
|
||||||
<div class="sidebar-overlay" onclick="toggleMobileSidebar()"></div>
|
|
||||||
<div class="sidebar">
|
|
||||||
<div class="sidebar-brand">LLM Gateway</div>
|
|
||||||
<nav class="sidebar-nav">
|
|
||||||
<a href="/dashboard" hx-get="/dashboard" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "dashboard"}}class="active"{{end}}>Dashboard</a>
|
|
||||||
<a href="/logs" hx-get="/logs" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "logs"}}class="active"{{end}}>Logs</a>
|
|
||||||
<a href="/models" hx-get="/models" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "models"}}class="active"{{end}}>Models</a>
|
|
||||||
<a href="/tokens" hx-get="/tokens" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "tokens"}}class="active"{{end}}>API Tokens</a>
|
|
||||||
{{if .User.IsAdmin}}
|
|
||||||
<a href="/users" hx-get="/users" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "users"}}class="active"{{end}}>Users</a>
|
|
||||||
<a href="/audit" hx-get="/audit" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "audit"}}class="active"{{end}}>Audit Log</a>
|
|
||||||
<a href="/debug" hx-get="/debug" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "debug"}}class="active"{{end}}>Debug</a>
|
|
||||||
{{end}}
|
|
||||||
<a href="/settings" hx-get="/settings" hx-target="#content" hx-push-url="true" {{if eq .ActivePage "settings"}}class="active"{{end}}>Settings</a>
|
|
||||||
</nav>
|
|
||||||
<div class="sidebar-footer">
|
|
||||||
<button class="theme-toggle" onclick="toggleTheme()" id="theme-btn">Switch Theme</button>
|
|
||||||
<div class="user-info">{{.User.Username}}</div>
|
|
||||||
<a href="#" hx-post="/api/auth/logout" hx-swap="none">Logout</a>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="main">
|
|
||||||
<div id="sse-source" style="display:none;"></div>
|
|
||||||
<div id="content">
|
|
||||||
{{template "content" .}}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<script>
|
|
||||||
// SSE auto-refresh: reload current page content when server sends refresh event
|
|
||||||
(function() {
|
|
||||||
var refreshable = ['/dashboard', '/logs', '/models', '/debug', '/audit'];
|
|
||||||
var source = null;
|
|
||||||
var retryDelay = 1000;
|
|
||||||
|
|
||||||
function connect() {
|
|
||||||
if (source) { source.close(); source = null; }
|
|
||||||
source = new EventSource('/api/events');
|
|
||||||
source.addEventListener('refresh', function() {
|
|
||||||
var path = window.location.pathname;
|
|
||||||
for (var i = 0; i < refreshable.length; i++) {
|
|
||||||
if (path === refreshable[i]) {
|
|
||||||
htmx.ajax('GET', path + window.location.search, {target: '#content', swap: 'innerHTML'});
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
source.onopen = function() { retryDelay = 1000; };
|
|
||||||
source.onerror = function() {
|
|
||||||
source.close();
|
|
||||||
source = null;
|
|
||||||
setTimeout(connect, retryDelay);
|
|
||||||
retryDelay = Math.min(retryDelay * 2, 30000);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close cleanly on page unload to avoid browser "interrupted" errors
|
|
||||||
window.addEventListener('beforeunload', function() {
|
|
||||||
if (source) { source.close(); source = null; }
|
|
||||||
});
|
|
||||||
|
|
||||||
connect();
|
|
||||||
})();
|
|
||||||
// Update active sidebar link on HTMX navigation
|
|
||||||
document.body.addEventListener('htmx:pushedIntoHistory', function(e) {
|
|
||||||
document.querySelectorAll('.sidebar-nav a').forEach(function(a) {
|
|
||||||
a.classList.toggle('active', a.getAttribute('href') === window.location.pathname);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
// Set initial theme button label
|
|
||||||
(function() {
|
|
||||||
var btn = document.getElementById('theme-btn');
|
|
||||||
if (btn) btn.textContent = 'Theme: ' + themeLabel(getThemePref());
|
|
||||||
})();
|
|
||||||
// Mobile sidebar toggle
|
|
||||||
function toggleMobileSidebar() {
|
|
||||||
document.querySelector('.sidebar').classList.toggle('open');
|
|
||||||
document.querySelector('.sidebar-overlay').classList.toggle('show');
|
|
||||||
}
|
|
||||||
// Close sidebar on nav link click (mobile)
|
|
||||||
document.querySelectorAll('.sidebar-nav a').forEach(function(a) {
|
|
||||||
a.addEventListener('click', function() {
|
|
||||||
document.querySelector('.sidebar').classList.remove('open');
|
|
||||||
document.querySelector('.sidebar-overlay').classList.remove('show');
|
|
||||||
});
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
{{define "login"}}
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Login - LLM Gateway</title>
|
|
||||||
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
|
||||||
<script src="https://unpkg.com/htmx-ext-json-enc@2.0.1/json-enc.js"></script>
|
|
||||||
<style>
|
|
||||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
|
||||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; background: #0f172a; color: #e2e8f0; min-height: 100vh; display: flex; align-items: center; justify-content: center; }
|
|
||||||
.auth-box { background: #1e293b; border-radius: 12px; padding: 32px; width: 100%; max-width: 400px; }
|
|
||||||
.auth-box h1 { text-align: center; margin-bottom: 24px; font-size: 1.5rem; color: #f8fafc; }
|
|
||||||
.form-group { margin-bottom: 16px; }
|
|
||||||
.form-group label { display: block; font-size: 0.85rem; color: #94a3b8; margin-bottom: 4px; }
|
|
||||||
.form-group input { width: 100%; padding: 10px 12px; background: #0f172a; border: 1px solid #334155; border-radius: 6px; color: #e2e8f0; font-size: 0.95rem; }
|
|
||||||
.form-group input:focus { outline: none; border-color: #3b82f6; }
|
|
||||||
.btn-primary { display: block; width: 100%; padding: 10px 20px; border-radius: 6px; border: none; cursor: pointer; font-size: 0.9rem; font-weight: 500; background: #3b82f6; color: #fff; }
|
|
||||||
.btn-primary:hover { background: #2563eb; }
|
|
||||||
.error-msg { background: #7f1d1d40; border: 1px solid #991b1b; color: #fca5a5; padding: 10px; border-radius: 6px; margin-bottom: 16px; font-size: 0.85rem; }
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="auth-box">
|
|
||||||
<h1>LLM Gateway</h1>
|
|
||||||
<div id="login-form-area">
|
|
||||||
<form hx-post="/api/auth/login" hx-target="#login-form-area" hx-swap="innerHTML" hx-ext="json-enc">
|
|
||||||
<div id="login-error"></div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Username</label>
|
|
||||||
<input type="text" name="username" required autocomplete="username">
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Password</label>
|
|
||||||
<input type="password" name="password" required autocomplete="current-password">
|
|
||||||
</div>
|
|
||||||
<button type="submit" class="btn-primary">Sign In</button>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
{{define "content"}}
|
|
||||||
<div class="page-header">
|
|
||||||
<h1>Audit Log</h1>
|
|
||||||
<span style="font-size:0.85rem;color:var(--text-muted)">{{.AuditResult.Total}} total</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="filter-bar">
|
|
||||||
<select id="filter-action" onchange="applyAuditFilter()">
|
|
||||||
<option value="">All Actions</option>
|
|
||||||
{{range .AuditFilterActions}}<option value="{{.}}" {{if eq . $.FilterAction}}selected{{end}}>{{.}}</option>{{end}}
|
|
||||||
</select>
|
|
||||||
<button class="btn btn-sm btn-outline" onclick="clearAuditFilter()">Clear</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<table>
|
|
||||||
<thead>
|
|
||||||
<tr>
|
|
||||||
<th>Time</th>
|
|
||||||
<th>User</th>
|
|
||||||
<th>Action</th>
|
|
||||||
<th>Target</th>
|
|
||||||
<th>Details</th>
|
|
||||||
<th>IP</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
{{range .AuditResult.Entries}}
|
|
||||||
<tr>
|
|
||||||
<td>{{formatTimeDetail .Timestamp}}</td>
|
|
||||||
<td>{{.Username}}</td>
|
|
||||||
<td><span class="badge badge-priority">{{.Action}}</span></td>
|
|
||||||
<td>{{if .TargetType}}{{.TargetType}}{{if .TargetID}}/{{.TargetID}}{{end}}{{else}}-{{end}}</td>
|
|
||||||
<td style="max-width:300px;overflow:hidden;text-overflow:ellipsis;white-space:nowrap" title="{{.Details}}">{{if .Details}}{{.Details}}{{else}}-{{end}}</td>
|
|
||||||
<td>{{if .IPAddress}}{{.IPAddress}}{{else}}-{{end}}</td>
|
|
||||||
</tr>
|
|
||||||
{{end}}
|
|
||||||
{{if not .AuditResult.Entries}}
|
|
||||||
<tr><td colspan="6" style="text-align:center;color:var(--text-muted);padding:24px;">No audit log entries</td></tr>
|
|
||||||
{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
{{if gt .AuditResult.TotalPages 1}}
|
|
||||||
<div class="pagination">
|
|
||||||
<button {{if le .AuditResult.Page 1}}disabled{{end}} onclick="goToAuditPage(1)">First</button>
|
|
||||||
<button {{if le .AuditResult.Page 1}}disabled{{end}} onclick="goToAuditPage({{subInt .AuditResult.Page 1}})">Prev</button>
|
|
||||||
{{$page := .AuditResult.Page}}
|
|
||||||
{{$total := .AuditResult.TotalPages}}
|
|
||||||
{{range seq (paginationStart $page $total) (paginationEnd $page $total)}}
|
|
||||||
<button class="{{if eq . $page}}active{{end}}" onclick="goToAuditPage({{.}})">{{.}}</button>
|
|
||||||
{{end}}
|
|
||||||
<button {{if ge .AuditResult.Page .AuditResult.TotalPages}}disabled{{end}} onclick="goToAuditPage({{addInt .AuditResult.Page 1}})">Next</button>
|
|
||||||
<button {{if ge .AuditResult.Page .AuditResult.TotalPages}}disabled{{end}} onclick="goToAuditPage({{.AuditResult.TotalPages}})">Last</button>
|
|
||||||
<span class="page-info">Page {{.AuditResult.Page}} of {{.AuditResult.TotalPages}}</span>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
function buildAuditURL(page) {
|
|
||||||
var params = [];
|
|
||||||
var action = document.getElementById('filter-action').value;
|
|
||||||
if (action) params.push('action=' + encodeURIComponent(action));
|
|
||||||
if (page > 1) params.push('page=' + page);
|
|
||||||
return '/audit' + (params.length ? '?' + params.join('&') : '');
|
|
||||||
}
|
|
||||||
function applyAuditFilter() {
|
|
||||||
var url = buildAuditURL(1);
|
|
||||||
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
|
|
||||||
history.pushState({}, '', url);
|
|
||||||
}
|
|
||||||
function goToAuditPage(page) {
|
|
||||||
var url = buildAuditURL(page);
|
|
||||||
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
|
|
||||||
history.pushState({}, '', url);
|
|
||||||
}
|
|
||||||
function clearAuditFilter() {
|
|
||||||
document.getElementById('filter-action').value = '';
|
|
||||||
applyAuditFilter();
|
|
||||||
}
|
|
||||||
</script>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,230 +0,0 @@
|
||||||
{{define "content"}}
|
|
||||||
<div class="page-header">
|
|
||||||
<h1>Dashboard</h1>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="cards">
|
|
||||||
{{with .Summary.Today}}
|
|
||||||
<div class="card"><div class="label">Requests Today</div><div class="value">{{.Requests}}</div></div>
|
|
||||||
<div class="card"><div class="label">Cost Today</div><div class="value green">{{formatCost .CostUSD}}</div></div>
|
|
||||||
<div class="card"><div class="label">Tokens Today</div><div class="value blue">{{addInt .InputTokens .OutputTokens}}</div><div class="sub">{{.InputTokens}} in / {{.OutputTokens}} out</div></div>
|
|
||||||
<div class="card"><div class="label">Errors Today</div><div class="value {{if gt .Errors 0}}red{{end}}">{{.Errors}}</div></div>
|
|
||||||
<div class="card"><div class="label">Cache Hits</div><div class="value">{{.CachedHits}}</div></div>
|
|
||||||
{{end}}
|
|
||||||
{{with .Summary.Week}}
|
|
||||||
<div class="card"><div class="label">Cost (7d)</div><div class="value green">{{formatCost .CostUSD}}</div></div>
|
|
||||||
{{end}}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{{if .ProviderHealth}}
|
|
||||||
<div class="section">
|
|
||||||
<h2>Provider Health</h2>
|
|
||||||
<div class="health-row">
|
|
||||||
{{range .ProviderHealth}}
|
|
||||||
<div class="health-item">
|
|
||||||
<span class="provider-name">{{.Provider}}</span>
|
|
||||||
<span class="badge badge-{{.Status}}">{{.Status}}</span>
|
|
||||||
{{if eq .CircuitState "open"}}<span class="badge badge-open">circuit open</span>{{end}}
|
|
||||||
{{if eq .CircuitState "half-open"}}<span class="badge badge-half-open">half-open</span>{{end}}
|
|
||||||
<span style="font-size:0.75rem;color:var(--text-muted)">{{printf "%.0f" .AvgLatency}}ms avg | {{formatPct .ErrorRate}} errors</span>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
|
|
||||||
{{if .Latency}}{{if gt .Latency.Max 0.0}}
|
|
||||||
<div class="cards">
|
|
||||||
<div class="card"><div class="label">P50 Latency</div><div class="value">{{printf "%.0f" .Latency.P50}}ms</div></div>
|
|
||||||
<div class="card"><div class="label">P95 Latency</div><div class="value yellow">{{printf "%.0f" .Latency.P95}}ms</div></div>
|
|
||||||
<div class="card"><div class="label">P99 Latency</div><div class="value red">{{printf "%.0f" .Latency.P99}}ms</div></div>
|
|
||||||
<div class="card"><div class="label">Avg Latency</div><div class="value">{{printf "%.0f" .Latency.Avg}}ms</div></div>
|
|
||||||
</div>
|
|
||||||
{{end}}{{end}}
|
|
||||||
|
|
||||||
{{if .CacheEnabled}}{{if .CacheInfo}}{{if .CacheInfo.Connected}}
|
|
||||||
<div class="cards">
|
|
||||||
<div class="card"><div class="label">Cache Hit Rate</div><div class="value green">{{formatPct .CacheInfo.HitRate}}</div><div class="sub">{{.CacheInfo.Hits}} hits / {{.CacheInfo.Misses}} misses</div></div>
|
|
||||||
<div class="card"><div class="label">Cache Memory</div><div class="value">{{.CacheInfo.MemoryUsed}}</div></div>
|
|
||||||
<div class="card"><div class="label">Cached Keys</div><div class="value">{{.CacheInfo.Keys}}</div></div>
|
|
||||||
</div>
|
|
||||||
{{end}}{{end}}{{end}}
|
|
||||||
|
|
||||||
<div class="tabs">
|
|
||||||
<button class="active" onclick="loadTimeseries('24h', this)">24h</button>
|
|
||||||
<button onclick="loadTimeseries('7d', this)">7d</button>
|
|
||||||
<button onclick="loadTimeseries('30d', this)">30d</button>
|
|
||||||
</div>
|
|
||||||
<div class="section">
|
|
||||||
<h2>Requests & Cost</h2>
|
|
||||||
<canvas id="chart" height="200"></canvas>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<h2>Cost Breakdown</h2>
|
|
||||||
<div class="tabs" id="cost-tabs">
|
|
||||||
<button class="active" onclick="loadCostBreakdown('model', this)">By Model</button>
|
|
||||||
<button onclick="loadCostBreakdown('token', this)">By Token</button>
|
|
||||||
<button onclick="loadCostBreakdown('provider', this)">By Provider</button>
|
|
||||||
</div>
|
|
||||||
<canvas id="cost-chart" height="200"></canvas>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{{if .Models}}
|
|
||||||
<div class="section">
|
|
||||||
<h2>Models<span class="export-links"><a href="/api/export/stats?format=csv&type=models" target="_blank">CSV</a><a href="/api/export/stats?format=json&type=models" target="_blank">JSON</a></span></h2>
|
|
||||||
<table>
|
|
||||||
<thead><tr><th>Model</th><th>Requests</th><th>Tokens (in/out)</th><th>Cost</th><th>Avg Latency</th></tr></thead>
|
|
||||||
<tbody>
|
|
||||||
{{range .Models}}
|
|
||||||
<tr>
|
|
||||||
<td>{{.Model}}</td>
|
|
||||||
<td>{{.Requests}}</td>
|
|
||||||
<td>{{.InputTokens}} / {{.OutputTokens}}</td>
|
|
||||||
<td class="green">{{formatCost .CostUSD}}</td>
|
|
||||||
<td>{{printf "%.0f" .AvgLatencyMS}}ms</td>
|
|
||||||
</tr>
|
|
||||||
{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
|
|
||||||
{{if .Providers}}
|
|
||||||
<div class="section">
|
|
||||||
<h2>Providers<span class="export-links"><a href="/api/export/stats?format=csv&type=providers" target="_blank">CSV</a><a href="/api/export/stats?format=json&type=providers" target="_blank">JSON</a></span></h2>
|
|
||||||
<table>
|
|
||||||
<thead><tr><th>Provider</th><th>Requests</th><th>Success</th><th>Errors</th><th>Avg Latency</th><th>Cost</th></tr></thead>
|
|
||||||
<tbody>
|
|
||||||
{{range .Providers}}
|
|
||||||
<tr>
|
|
||||||
<td>{{.Provider}}</td>
|
|
||||||
<td>{{.Requests}}</td>
|
|
||||||
<td class="green">{{.Successes}}</td>
|
|
||||||
<td class="red">{{.Errors}}</td>
|
|
||||||
<td>{{printf "%.0f" .AvgLatencyMS}}ms</td>
|
|
||||||
<td class="green">{{formatCost .CostUSD}}</td>
|
|
||||||
</tr>
|
|
||||||
{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
|
|
||||||
{{if .TokenStats}}
|
|
||||||
<div class="section">
|
|
||||||
<h2>API Token Usage<span class="export-links"><a href="/api/export/stats?format=csv&type=tokens" target="_blank">CSV</a><a href="/api/export/stats?format=json&type=tokens" target="_blank">JSON</a></span></h2>
|
|
||||||
<table>
|
|
||||||
<thead><tr><th>Token</th><th>Requests</th><th>Tokens (in/out)</th><th>Cost</th></tr></thead>
|
|
||||||
<tbody>
|
|
||||||
{{range .TokenStats}}
|
|
||||||
<tr>
|
|
||||||
<td>{{.TokenName}}</td>
|
|
||||||
<td>{{.Requests}}</td>
|
|
||||||
<td>{{.InputTokens}} / {{.OutputTokens}}</td>
|
|
||||||
<td class="green">{{formatCost .CostUSD}}</td>
|
|
||||||
</tr>
|
|
||||||
{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
|
|
||||||
<script>
|
|
||||||
var _chart, _costChart;
|
|
||||||
function chartColors() {
|
|
||||||
var isLight = document.documentElement.hasAttribute('data-theme');
|
|
||||||
return {
|
|
||||||
text: isLight ? '#475569' : '#94a3b8',
|
|
||||||
grid: isLight ? '#e2e8f020' : '#334155',
|
|
||||||
green: '#4ade80',
|
|
||||||
legend: isLight ? '#1e293b' : '#e2e8f0'
|
|
||||||
};
|
|
||||||
}
|
|
||||||
function formatCostTick(v) {
|
|
||||||
if (v === 0) return '$0';
|
|
||||||
if (v < 0.001) return '$' + v.toFixed(6);
|
|
||||||
if (v < 0.01) return '$' + v.toFixed(4);
|
|
||||||
return '$' + v.toFixed(2);
|
|
||||||
}
|
|
||||||
function loadTimeseries(period, btn) {
|
|
||||||
document.querySelectorAll('.tabs button').forEach(function(b) { b.classList.remove('active'); });
|
|
||||||
if (btn) btn.classList.add('active');
|
|
||||||
else document.querySelector('.tabs button').classList.add('active');
|
|
||||||
fetch('/api/stats/timeseries?period=' + period, {credentials: 'same-origin'})
|
|
||||||
.then(function(r) { return r.json(); })
|
|
||||||
.then(function(data) {
|
|
||||||
var c = chartColors();
|
|
||||||
var labels = (data||[]).map(function(d) { return d.bucket; });
|
|
||||||
var requests = (data||[]).map(function(d) { return d.requests; });
|
|
||||||
var costs = (data||[]).map(function(d) { return d.cost_usd; });
|
|
||||||
if (_chart) _chart.destroy();
|
|
||||||
_chart = new Chart(document.getElementById('chart'), {
|
|
||||||
type: 'bar',
|
|
||||||
data: {
|
|
||||||
labels: labels,
|
|
||||||
datasets: [
|
|
||||||
{ label: 'Requests', data: requests, backgroundColor: 'rgba(59,130,246,0.5)', yAxisID: 'y' },
|
|
||||||
{ label: 'Cost ($)', data: costs, type: 'line', borderColor: c.green, backgroundColor: 'rgba(74,222,128,0.1)', yAxisID: 'y1', tension: 0.3, pointRadius: 3 }
|
|
||||||
]
|
|
||||||
},
|
|
||||||
options: {
|
|
||||||
responsive: true,
|
|
||||||
interaction: { mode: 'index', intersect: false },
|
|
||||||
scales: {
|
|
||||||
y: { position: 'left', beginAtZero: true, ticks: { color: c.text, precision: 0 }, grid: { color: c.grid } },
|
|
||||||
y1: { position: 'right', beginAtZero: true, ticks: { color: c.green, callback: formatCostTick }, grid: { display: false } },
|
|
||||||
x: { ticks: { color: c.text, maxRotation: 45 }, grid: { color: c.grid } }
|
|
||||||
},
|
|
||||||
plugins: { legend: { labels: { color: c.legend } } }
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}).catch(function(){});
|
|
||||||
}
|
|
||||||
loadTimeseries('24h');
|
|
||||||
|
|
||||||
function loadCostBreakdown(groupBy, btn) {
|
|
||||||
document.querySelectorAll('#cost-tabs button').forEach(function(b) { b.classList.remove('active'); });
|
|
||||||
if (btn) btn.classList.add('active');
|
|
||||||
fetch('/api/stats/cost-breakdown?period=7d&group_by=' + groupBy, {credentials: 'same-origin'})
|
|
||||||
.then(function(r) { return r.json(); })
|
|
||||||
.then(function(data) {
|
|
||||||
if (!data || data.length === 0) {
|
|
||||||
if (_costChart) _costChart.destroy();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
var c = chartColors();
|
|
||||||
var days = [], groups = {};
|
|
||||||
(data||[]).forEach(function(d) {
|
|
||||||
if (days.indexOf(d.day) === -1) days.push(d.day);
|
|
||||||
if (!groups[d.group_by]) groups[d.group_by] = {};
|
|
||||||
groups[d.group_by][d.day] = d.cost_usd;
|
|
||||||
});
|
|
||||||
var palette = ['#3b82f6','#4ade80','#f87171','#fbbf24','#a78bfa','#f472b6','#22d3ee','#fb923c'];
|
|
||||||
var datasets = [], ci = 0;
|
|
||||||
for (var g in groups) {
|
|
||||||
datasets.push({
|
|
||||||
label: g,
|
|
||||||
data: days.map(function(day) { return groups[g][day] || 0; }),
|
|
||||||
backgroundColor: palette[ci % palette.length] + '80'
|
|
||||||
});
|
|
||||||
ci++;
|
|
||||||
}
|
|
||||||
if (_costChart) _costChart.destroy();
|
|
||||||
_costChart = new Chart(document.getElementById('cost-chart'), {
|
|
||||||
type: 'bar',
|
|
||||||
data: { labels: days, datasets: datasets },
|
|
||||||
options: {
|
|
||||||
responsive: true,
|
|
||||||
scales: {
|
|
||||||
x: { stacked: true, ticks: { color: c.text }, grid: { color: c.grid } },
|
|
||||||
y: { stacked: true, beginAtZero: true, ticks: { color: c.text, callback: formatCostTick }, grid: { color: c.grid } }
|
|
||||||
},
|
|
||||||
plugins: { legend: { labels: { color: c.legend } } }
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}).catch(function(){});
|
|
||||||
}
|
|
||||||
loadCostBreakdown('model');
|
|
||||||
</script>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,94 +0,0 @@
|
||||||
{{define "content"}}
|
|
||||||
<div class="page-header">
|
|
||||||
<h1>Debug Logging</h1>
|
|
||||||
<span style="font-size:0.85rem;color:var(--text-muted)">{{.DebugResult.Total}} entries</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section" style="display:flex;align-items:center;gap:16px;padding:12px 16px;">
|
|
||||||
<span style="font-size:0.9rem;font-weight:600;">Debug Mode</span>
|
|
||||||
<label class="toggle-switch">
|
|
||||||
<input type="checkbox" id="debug-toggle" {{if .DebugEnabled}}checked{{end}}
|
|
||||||
hx-post="/api/debug/toggle" hx-swap="none" hx-ext="json-enc"
|
|
||||||
hx-vals='js:{enabled: document.getElementById("debug-toggle").checked}'
|
|
||||||
hx-trigger="change"
|
|
||||||
hx-on::after-request="htmx.ajax('GET', '/debug', {target: '#content', swap: 'innerHTML'})">
|
|
||||||
<span class="toggle-slider"></span>
|
|
||||||
</label>
|
|
||||||
<span id="debug-status" style="font-size:0.8rem;color:var(--text-muted)">{{if .DebugEnabled}}Enabled — requests are being logged{{else}}Disabled{{end}}</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<table>
|
|
||||||
<thead>
|
|
||||||
<tr>
|
|
||||||
<th></th>
|
|
||||||
<th>Time</th>
|
|
||||||
<th>Request ID</th>
|
|
||||||
<th>Token</th>
|
|
||||||
<th>Model</th>
|
|
||||||
<th>Provider</th>
|
|
||||||
<th>Status</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
{{range $i, $entry := .DebugResult.Entries}}
|
|
||||||
<tr class="expandable" onclick="toggleDebugExpand('debug-expand-{{$i}}')">
|
|
||||||
<td style="width:20px;text-align:center;color:var(--text-muted)">▶</td>
|
|
||||||
<td>{{formatTimeDetail $entry.Timestamp}}</td>
|
|
||||||
<td><code style="font-size:0.75rem">{{$entry.RequestID}}</code></td>
|
|
||||||
<td>{{$entry.TokenName}}</td>
|
|
||||||
<td>{{$entry.Model}}</td>
|
|
||||||
<td>{{$entry.Provider}}</td>
|
|
||||||
<td>
|
|
||||||
{{if and (ge $entry.ResponseStatus 200) (lt $entry.ResponseStatus 300)}}<span class="badge badge-success">{{$entry.ResponseStatus}}</span>
|
|
||||||
{{else if ge $entry.ResponseStatus 400}}<span class="badge badge-error">{{$entry.ResponseStatus}}</span>
|
|
||||||
{{else}}<span class="badge">{{$entry.ResponseStatus}}</span>{{end}}
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td colspan="7" style="padding:0;">
|
|
||||||
<div id="debug-expand-{{$i}}" class="expand-content">
|
|
||||||
<div style="margin-bottom:8px"><strong>Request Headers:</strong></div>
|
|
||||||
<div class="code-block">{{if $entry.RequestHeaders}}{{$entry.RequestHeaders}}{{else}}(none){{end}}</div>
|
|
||||||
<div style="margin:8px 0"><strong>Request Body:</strong></div>
|
|
||||||
<div class="code-block">{{if $entry.RequestBody}}{{$entry.RequestBody}}{{else}}(none){{end}}</div>
|
|
||||||
<div style="margin:8px 0"><strong>Response Body:</strong></div>
|
|
||||||
<div class="code-block">{{if $entry.ResponseBody}}{{$entry.ResponseBody}}{{else}}(none){{end}}</div>
|
|
||||||
</div>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
{{end}}
|
|
||||||
{{if not .DebugResult.Entries}}
|
|
||||||
<tr><td colspan="7" style="text-align:center;color:var(--text-muted);padding:24px;">No debug log entries</td></tr>
|
|
||||||
{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
{{if gt .DebugResult.TotalPages 1}}
|
|
||||||
<div class="pagination">
|
|
||||||
<button {{if le .DebugResult.Page 1}}disabled{{end}} onclick="goToDebugPage(1)">First</button>
|
|
||||||
<button {{if le .DebugResult.Page 1}}disabled{{end}} onclick="goToDebugPage({{subInt .DebugResult.Page 1}})">Prev</button>
|
|
||||||
{{$page := .DebugResult.Page}}
|
|
||||||
{{$total := .DebugResult.TotalPages}}
|
|
||||||
{{range seq (paginationStart $page $total) (paginationEnd $page $total)}}
|
|
||||||
<button class="{{if eq . $page}}active{{end}}" onclick="goToDebugPage({{.}})">{{.}}</button>
|
|
||||||
{{end}}
|
|
||||||
<button {{if ge .DebugResult.Page .DebugResult.TotalPages}}disabled{{end}} onclick="goToDebugPage({{addInt .DebugResult.Page 1}})">Next</button>
|
|
||||||
<button {{if ge .DebugResult.Page .DebugResult.TotalPages}}disabled{{end}} onclick="goToDebugPage({{.DebugResult.TotalPages}})">Last</button>
|
|
||||||
<span class="page-info">Page {{.DebugResult.Page}} of {{.DebugResult.TotalPages}}</span>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
function toggleDebugExpand(id) {
|
|
||||||
var el = document.getElementById(id);
|
|
||||||
if (el) el.classList.toggle('show');
|
|
||||||
}
|
|
||||||
function goToDebugPage(page) {
|
|
||||||
var url = '/debug' + (page > 1 ? '?page=' + page : '');
|
|
||||||
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
|
|
||||||
history.pushState({}, '', url);
|
|
||||||
}
|
|
||||||
</script>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,133 +0,0 @@
|
||||||
{{define "content"}}
|
|
||||||
<div class="page-header">
|
|
||||||
<h1>Request Logs</h1>
|
|
||||||
<span style="font-size:0.85rem;color:var(--text-muted)">{{.LogsResult.Total}} total</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="filter-bar">
|
|
||||||
<select id="filter-model" onchange="applyLogsFilter()">
|
|
||||||
<option value="">All Models</option>
|
|
||||||
{{range .LogModels}}<option value="{{.}}" {{if eq . $.FilterModel}}selected{{end}}>{{.}}</option>{{end}}
|
|
||||||
</select>
|
|
||||||
<select id="filter-token" onchange="applyLogsFilter()">
|
|
||||||
<option value="">All Tokens</option>
|
|
||||||
{{range .LogTokens}}<option value="{{.}}" {{if eq . $.FilterToken}}selected{{end}}>{{.}}</option>{{end}}
|
|
||||||
</select>
|
|
||||||
<select id="filter-status" onchange="applyLogsFilter()">
|
|
||||||
<option value="">All Status</option>
|
|
||||||
<option value="success" {{if eq .FilterStatus "success"}}selected{{end}}>Success</option>
|
|
||||||
<option value="error" {{if eq .FilterStatus "error"}}selected{{end}}>Errors Only</option>
|
|
||||||
<option value="cached" {{if eq .FilterStatus "cached"}}selected{{end}}>Cached</option>
|
|
||||||
</select>
|
|
||||||
<button class="btn btn-sm btn-outline" onclick="clearLogsFilter()">Clear</button>
|
|
||||||
<span style="margin-left:auto"></span>
|
|
||||||
<button class="btn btn-sm btn-outline" onclick="exportLogs('csv')">Export CSV</button>
|
|
||||||
<button class="btn btn-sm btn-outline" onclick="exportLogs('json')">Export JSON</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<table>
|
|
||||||
<thead>
|
|
||||||
<tr>
|
|
||||||
<th>Time</th>
|
|
||||||
<th>Token</th>
|
|
||||||
<th>Model</th>
|
|
||||||
<th>Provider</th>
|
|
||||||
<th>Status</th>
|
|
||||||
<th>Latency</th>
|
|
||||||
<th>Tokens</th>
|
|
||||||
<th>Cost</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
{{range $i, $log := .LogsResult.Logs}}
|
|
||||||
<tr class="{{if $log.ErrorMessage}}expandable{{end}}" {{if $log.ErrorMessage}}onclick="toggleExpand('expand-{{$i}}')"{{end}}>
|
|
||||||
<td>{{formatTimeDetail $log.Timestamp}}</td>
|
|
||||||
<td>{{$log.TokenName}}</td>
|
|
||||||
<td>{{$log.Model}}</td>
|
|
||||||
<td>{{$log.Provider}}</td>
|
|
||||||
<td>
|
|
||||||
{{if eq $log.Status "success"}}<span class="badge badge-success">success</span>
|
|
||||||
{{else if eq $log.Status "error"}}<span class="badge badge-error">error</span>
|
|
||||||
{{else if eq $log.Status "cached"}}<span class="badge badge-cached">cached</span>
|
|
||||||
{{else}}<span class="badge">{{$log.Status}}</span>{{end}}
|
|
||||||
{{if $log.Streaming}} <span class="badge badge-totp">stream</span>{{end}}
|
|
||||||
</td>
|
|
||||||
<td>{{$log.LatencyMS}}ms</td>
|
|
||||||
<td>{{$log.InputTokens}} / {{$log.OutputTokens}}</td>
|
|
||||||
<td class="green">{{formatCost $log.CostUSD}}</td>
|
|
||||||
</tr>
|
|
||||||
{{if $log.ErrorMessage}}
|
|
||||||
<tr>
|
|
||||||
<td colspan="8" style="padding:0;">
|
|
||||||
<div id="expand-{{$i}}" class="expand-content {{if eq $.FilterStatus "error"}}show{{end}}">{{$log.ErrorMessage}}</div>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
{{end}}
|
|
||||||
{{end}}
|
|
||||||
{{if not .LogsResult.Logs}}
|
|
||||||
<tr><td colspan="8" style="text-align:center;color:var(--text-muted);padding:24px;">No logs found</td></tr>
|
|
||||||
{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
|
|
||||||
{{if gt .LogsResult.TotalPages 1}}
|
|
||||||
<div class="pagination">
|
|
||||||
<button {{if le .LogsResult.Page 1}}disabled{{end}} onclick="goToLogsPage(1)">First</button>
|
|
||||||
<button {{if le .LogsResult.Page 1}}disabled{{end}} onclick="goToLogsPage({{subInt .LogsResult.Page 1}})">Prev</button>
|
|
||||||
{{$page := .LogsResult.Page}}
|
|
||||||
{{$total := .LogsResult.TotalPages}}
|
|
||||||
{{range seq (paginationStart $page $total) (paginationEnd $page $total)}}
|
|
||||||
<button class="{{if eq . $page}}active{{end}}" onclick="goToLogsPage({{.}})">{{.}}</button>
|
|
||||||
{{end}}
|
|
||||||
<button {{if ge .LogsResult.Page .LogsResult.TotalPages}}disabled{{end}} onclick="goToLogsPage({{addInt .LogsResult.Page 1}})">Next</button>
|
|
||||||
<button {{if ge .LogsResult.Page .LogsResult.TotalPages}}disabled{{end}} onclick="goToLogsPage({{.LogsResult.TotalPages}})">Last</button>
|
|
||||||
<span class="page-info">Page {{.LogsResult.Page}} of {{.LogsResult.TotalPages}}</span>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
function buildLogsURL(page) {
|
|
||||||
var params = [];
|
|
||||||
var model = document.getElementById('filter-model').value;
|
|
||||||
var token = document.getElementById('filter-token').value;
|
|
||||||
var status = document.getElementById('filter-status').value;
|
|
||||||
if (model) params.push('model=' + encodeURIComponent(model));
|
|
||||||
if (token) params.push('token=' + encodeURIComponent(token));
|
|
||||||
if (status) params.push('status=' + encodeURIComponent(status));
|
|
||||||
if (page > 1) params.push('page=' + page);
|
|
||||||
return '/logs' + (params.length ? '?' + params.join('&') : '');
|
|
||||||
}
|
|
||||||
function applyLogsFilter() {
|
|
||||||
var url = buildLogsURL(1);
|
|
||||||
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
|
|
||||||
history.pushState({}, '', url);
|
|
||||||
}
|
|
||||||
function goToLogsPage(page) {
|
|
||||||
var url = buildLogsURL(page);
|
|
||||||
htmx.ajax('GET', url, {target: '#content', swap: 'innerHTML'});
|
|
||||||
history.pushState({}, '', url);
|
|
||||||
}
|
|
||||||
function clearLogsFilter() {
|
|
||||||
document.getElementById('filter-model').value = '';
|
|
||||||
document.getElementById('filter-token').value = '';
|
|
||||||
document.getElementById('filter-status').value = '';
|
|
||||||
applyLogsFilter();
|
|
||||||
}
|
|
||||||
function toggleExpand(id) {
|
|
||||||
var el = document.getElementById(id);
|
|
||||||
if (el) el.classList.toggle('show');
|
|
||||||
}
|
|
||||||
function exportLogs(format) {
|
|
||||||
var params = ['format=' + format];
|
|
||||||
var model = document.getElementById('filter-model').value;
|
|
||||||
var token = document.getElementById('filter-token').value;
|
|
||||||
var status = document.getElementById('filter-status').value;
|
|
||||||
if (model) params.push('model=' + encodeURIComponent(model));
|
|
||||||
if (token) params.push('token=' + encodeURIComponent(token));
|
|
||||||
if (status) params.push('status=' + encodeURIComponent(status));
|
|
||||||
window.open('/api/export/logs?' + params.join('&'), '_blank');
|
|
||||||
}
|
|
||||||
</script>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,53 +0,0 @@
|
||||||
{{define "content"}}
|
|
||||||
<div class="page-header">
|
|
||||||
<h1>Model Routing</h1>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{{if .ModelRoutes}}
|
|
||||||
{{range .ModelRoutes}}
|
|
||||||
<div class="section">
|
|
||||||
<h2>{{.Name}}{{if .Aliases}} <span style="font-size:0.75rem;color:var(--text-muted);font-weight:400;">aliases: {{range $i, $a := .Aliases}}{{if $i}}, {{end}}{{$a}}{{end}}</span>{{end}}</h2>
|
|
||||||
<table>
|
|
||||||
<thead>
|
|
||||||
<tr>
|
|
||||||
<th>Provider</th>
|
|
||||||
<th>Provider Model</th>
|
|
||||||
<th>Priority</th>
|
|
||||||
<th>Input Price (per 1M)</th>
|
|
||||||
<th>Output Price (per 1M)</th>
|
|
||||||
<th>Health</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
{{$health := $.ProviderHealth}}
|
|
||||||
{{range .Routes}}
|
|
||||||
<tr>
|
|
||||||
<td>{{.ProviderName}}</td>
|
|
||||||
<td><code>{{.ProviderModel}}</code></td>
|
|
||||||
<td><span class="badge badge-priority">{{.Priority}}</span></td>
|
|
||||||
<td>{{formatPrice .InputPrice}}</td>
|
|
||||||
<td>{{formatPrice .OutputPrice}}</td>
|
|
||||||
<td>
|
|
||||||
{{$pname := .ProviderName}}
|
|
||||||
{{range $health}}
|
|
||||||
{{if eq .Provider $pname}}
|
|
||||||
{{if eq .Status "healthy"}}<span class="badge" style="background:#166534;color:#4ade80;">healthy</span>
|
|
||||||
{{else if eq .Status "degraded"}}<span class="badge" style="background:#92400e;color:#fbbf24;">degraded</span>
|
|
||||||
{{else}}<span class="badge" style="background:#991b1b;color:#f87171;">down</span>
|
|
||||||
{{end}}
|
|
||||||
{{end}}
|
|
||||||
{{end}}
|
|
||||||
{{if not $health}}<span style="color:var(--text-muted);">-</span>{{end}}
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
{{else}}
|
|
||||||
<div class="section" style="text-align:center;color:var(--text-muted);padding:24px;">
|
|
||||||
No models configured
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,141 +0,0 @@
|
||||||
{{define "content"}}
|
|
||||||
<div class="page-header">
|
|
||||||
<h1>Settings</h1>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<h2>Profile</h2>
|
|
||||||
<div id="profile-msg"></div>
|
|
||||||
<form hx-put="/api/auth/me/username" hx-target="#profile-msg" hx-swap="innerHTML" hx-ext="json-enc"
|
|
||||||
style="max-width:400px;margin-bottom:16px;">
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Username</label>
|
|
||||||
<div style="display:flex;gap:8px;">
|
|
||||||
<input type="text" name="new_username" value="{{.User.Username}}" required>
|
|
||||||
<button type="submit" class="btn btn-sm btn-primary" style="white-space:nowrap;">Update</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
<form hx-put="/api/auth/me/email" hx-target="#profile-msg" hx-swap="innerHTML" hx-ext="json-enc"
|
|
||||||
style="max-width:400px;">
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Email</label>
|
|
||||||
<div style="display:flex;gap:8px;">
|
|
||||||
<input type="email" name="email" value="{{.User.Email}}" placeholder="optional">
|
|
||||||
<button type="submit" class="btn btn-sm btn-primary" style="white-space:nowrap;">Update</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<h2>Change Password</h2>
|
|
||||||
<div id="password-msg"></div>
|
|
||||||
<form id="password-form" hx-put="/api/auth/me/password" hx-target="#password-msg" hx-swap="innerHTML" hx-ext="json-enc"
|
|
||||||
style="max-width:400px;">
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Current Password</label>
|
|
||||||
<input type="password" name="current_password" required autocomplete="current-password">
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>New Password (min 8 characters)</label>
|
|
||||||
<input type="password" name="new_password" required minlength="8" autocomplete="new-password">
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Confirm New Password</label>
|
|
||||||
<input type="password" id="new-password2" required minlength="8" autocomplete="new-password">
|
|
||||||
</div>
|
|
||||||
<button type="submit" class="btn btn-primary" style="width:auto;">Change Password</button>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<h2>Two-Factor Authentication</h2>
|
|
||||||
<div id="totp-status">
|
|
||||||
{{if .User.TOTPEnabled}}
|
|
||||||
<p style="color:#4ade80;margin-bottom:12px;">Two-factor authentication is <strong>enabled</strong>.</p>
|
|
||||||
<button class="btn btn-sm btn-danger"
|
|
||||||
hx-delete="/api/auth/totp" hx-target="#totp-status" hx-swap="innerHTML"
|
|
||||||
hx-confirm="Disable two-factor authentication?">Disable 2FA</button>
|
|
||||||
{{else}}
|
|
||||||
<p style="color:#94a3b8;margin-bottom:12px;">Two-factor authentication is <strong>not enabled</strong>.</p>
|
|
||||||
<button class="btn btn-sm btn-primary" onclick="setupTOTP()">Enable 2FA</button>
|
|
||||||
{{end}}
|
|
||||||
</div>
|
|
||||||
<div id="totp-setup-area" style="display:none;">
|
|
||||||
<p style="color:#94a3b8;font-size:0.85rem;margin-bottom:12px;">Scan this QR code with your authenticator app, then enter the code below to verify.</p>
|
|
||||||
<div id="totp-qr" style="text-align:center;margin:16px 0;"></div>
|
|
||||||
<div id="totp-secret-display" style="text-align:center;margin:8px 0;font-family:monospace;color:#94a3b8;font-size:0.8rem;"></div>
|
|
||||||
<form hx-post="/api/auth/totp/verify" hx-target="#totp-status" hx-swap="innerHTML" hx-ext="json-enc"
|
|
||||||
style="max-width:300px;margin:0 auto;">
|
|
||||||
<div class="form-group">
|
|
||||||
<input type="text" name="code" required pattern="[0-9]{6}" maxlength="6" placeholder="Enter 6-digit code" autocomplete="one-time-code" inputmode="numeric" style="text-align:center;font-size:1.2rem;letter-spacing:0.2em;">
|
|
||||||
</div>
|
|
||||||
<button type="submit" class="btn btn-primary">Verify & Enable</button>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{{if .User.IsAdmin}}
|
|
||||||
<div class="section">
|
|
||||||
<h2>Config Validation</h2>
|
|
||||||
<p style="color:#94a3b8;font-size:0.85rem;margin-bottom:12px;">Validate the current gateway configuration file for errors.</p>
|
|
||||||
<button class="btn btn-sm btn-primary"
|
|
||||||
hx-get="/api/config/validate"
|
|
||||||
hx-target="#config-validation-result"
|
|
||||||
hx-swap="innerHTML">Validate Config</button>
|
|
||||||
<div id="config-validation-result" style="margin-top:12px;"></div>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
|
|
||||||
<script src="https://cdn.jsdelivr.net/npm/qrious@4.0.2/dist/qrious.min.js"></script>
|
|
||||||
<script>
|
|
||||||
// Password confirm validation
|
|
||||||
document.body.addEventListener('htmx:confirm', function(e) {
|
|
||||||
var form = e.target;
|
|
||||||
if (form.id !== 'password-form') return;
|
|
||||||
var np = form.querySelector('[name=new_password]').value;
|
|
||||||
if (np !== document.getElementById('new-password2').value) {
|
|
||||||
e.preventDefault();
|
|
||||||
document.getElementById('password-msg').innerHTML = '<div class="error-msg">Passwords do not match</div>';
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Clear password fields after successful change
|
|
||||||
document.body.addEventListener('htmx:afterRequest', function(e) {
|
|
||||||
if (e.target.id === 'password-form' && e.detail.successful) {
|
|
||||||
e.target.querySelectorAll('input[type=password]').forEach(function(i) { i.value = ''; });
|
|
||||||
document.getElementById('new-password2').value = '';
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Auto-clear messages after 5 seconds
|
|
||||||
document.body.addEventListener('htmx:afterSwap', function(e) {
|
|
||||||
var target = e.detail.target;
|
|
||||||
if (target.id === 'profile-msg' || target.id === 'password-msg') {
|
|
||||||
setTimeout(function() { target.innerHTML = ''; }, 5000);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Reload settings page when TOTP status changes
|
|
||||||
document.body.addEventListener('settingsRefresh', function() {
|
|
||||||
htmx.ajax('GET', '/settings', {target: '#content', swap: 'innerHTML'});
|
|
||||||
});
|
|
||||||
|
|
||||||
// TOTP setup - needs JS for QR code rendering
|
|
||||||
async function setupTOTP() {
|
|
||||||
try {
|
|
||||||
var resp = await fetch('/api/auth/totp/setup', { method: 'POST', credentials: 'same-origin' });
|
|
||||||
var data = await resp.json();
|
|
||||||
if (!resp.ok) { alert(data.error||'Failed'); return; }
|
|
||||||
document.getElementById('totp-setup-area').style.display = 'block';
|
|
||||||
document.getElementById('totp-secret-display').textContent = 'Secret: ' + data.secret;
|
|
||||||
var qrDiv = document.getElementById('totp-qr');
|
|
||||||
qrDiv.innerHTML = '';
|
|
||||||
var canvas = document.createElement('canvas');
|
|
||||||
new QRious({ element: canvas, value: data.uri, size: 200, level: 'M' });
|
|
||||||
qrDiv.appendChild(canvas);
|
|
||||||
} catch (e) { alert(e.message); }
|
|
||||||
}
|
|
||||||
</script>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,142 +0,0 @@
|
||||||
{{define "content"}}
|
|
||||||
<div class="page-header">
|
|
||||||
<h1>API Tokens</h1>
|
|
||||||
<button class="btn btn-sm btn-primary" onclick="showCreateTokenModal()">Create Token</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div id="new-token-display" style="display:none; margin-bottom:16px;"></div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<h2>Static Tokens <span style="font-size:0.75rem;color:var(--text-muted);font-weight:400;">(from config, managed via environment variables)</span></h2>
|
|
||||||
<table>
|
|
||||||
<thead><tr><th>Name</th><th>Prefix</th><th>Rate Limit</th><th>Budget</th><th>Today's Spend</th><th></th></tr></thead>
|
|
||||||
<tbody>
|
|
||||||
{{range .Tokens}}{{if lt .ID 0}}
|
|
||||||
<tr>
|
|
||||||
<td>{{.Name}}</td>
|
|
||||||
<td><code>{{.KeyPrefix}}...</code></td>
|
|
||||||
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
|
|
||||||
<td>
|
|
||||||
{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
|
|
||||||
{{if gt .MonthlyBudgetUSD 0.0}}<br>${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
{{$spend := index $.TokenSpend .Name}}
|
|
||||||
{{if gt .DailyBudgetUSD 0.0}}
|
|
||||||
{{$pct := budgetPct $spend .DailyBudgetUSD}}
|
|
||||||
<div style="min-width:120px;">
|
|
||||||
<div class="progress-bar"><div class="progress-bar-fill" style="width:{{if gt $pct 100.0}}100{{else}}{{printf "%.0f" $pct}}{{end}}%;background:{{budgetColor $pct}};"></div></div>
|
|
||||||
<div class="budget-info">${{printf "%.4f" $spend}} / ${{printf "%.2f" .DailyBudgetUSD}} ({{printf "%.1f" $pct}}%)</div>
|
|
||||||
</div>
|
|
||||||
{{else}}
|
|
||||||
{{if gt $spend 0.0}}{{formatCost $spend}}{{else}}-{{end}}
|
|
||||||
{{end}}
|
|
||||||
</td>
|
|
||||||
<td><span class="badge badge-totp">config</span></td>
|
|
||||||
</tr>
|
|
||||||
{{end}}{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<h2>Dynamic Tokens <span style="font-size:0.75rem;color:var(--text-muted);font-weight:400;">(created via dashboard)</span></h2>
|
|
||||||
<table>
|
|
||||||
<thead><tr><th>Name</th><th>Prefix</th><th>Rate Limit</th><th>Budget</th><th>Today's Spend</th><th>Created</th><th>Last Used</th><th></th></tr></thead>
|
|
||||||
<tbody id="tokens-tbody">
|
|
||||||
{{range .Tokens}}{{if gt .ID 0}}
|
|
||||||
<tr>
|
|
||||||
<td>{{.Name}}</td>
|
|
||||||
<td><code>{{.KeyPrefix}}...</code></td>
|
|
||||||
<td>{{if eq .RateLimitRPM 0}}unlimited{{else}}{{.RateLimitRPM}} rpm{{end}}</td>
|
|
||||||
<td>
|
|
||||||
{{if gt .DailyBudgetUSD 0.0}}${{printf "%.2f" .DailyBudgetUSD}}/day{{else}}-{{end}}
|
|
||||||
{{if gt .MonthlyBudgetUSD 0.0}}<br>${{printf "%.2f" .MonthlyBudgetUSD}}/mo{{end}}
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
{{$spend := index $.TokenSpend .Name}}
|
|
||||||
{{if gt .DailyBudgetUSD 0.0}}
|
|
||||||
{{$pct := budgetPct $spend .DailyBudgetUSD}}
|
|
||||||
<div style="min-width:120px;">
|
|
||||||
<div class="progress-bar"><div class="progress-bar-fill" style="width:{{if gt $pct 100.0}}100{{else}}{{printf "%.0f" $pct}}{{end}}%;background:{{budgetColor $pct}};"></div></div>
|
|
||||||
<div class="budget-info">${{printf "%.4f" $spend}} / ${{printf "%.2f" .DailyBudgetUSD}} ({{printf "%.1f" $pct}}%)</div>
|
|
||||||
</div>
|
|
||||||
{{else}}
|
|
||||||
{{if gt $spend 0.0}}{{formatCost $spend}}{{else}}-{{end}}
|
|
||||||
{{end}}
|
|
||||||
</td>
|
|
||||||
<td>{{formatTime .CreatedAt}}</td>
|
|
||||||
<td>{{if gt .LastUsedAt 0}}{{formatTime .LastUsedAt}}{{else}}never{{end}}</td>
|
|
||||||
<td><button class="btn btn-sm btn-danger"
|
|
||||||
hx-delete="/api/tokens/{{.ID}}" hx-swap="none"
|
|
||||||
hx-confirm="Revoke this API token? This cannot be undone.">Revoke</button></td>
|
|
||||||
</tr>
|
|
||||||
{{end}}{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Create Token Modal -->
|
|
||||||
<div id="modal-create-token" class="modal-overlay">
|
|
||||||
<div class="modal">
|
|
||||||
<h2>Create API Token</h2>
|
|
||||||
<div id="create-token-error"></div>
|
|
||||||
<form hx-post="/api/tokens" hx-target="#new-token-display" hx-swap="innerHTML" hx-ext="json-enc"
|
|
||||||
hx-vals='js:{name: document.getElementById("token-name").value, rate_limit_rpm: parseInt(document.getElementById("token-rpm").value) || 0, daily_budget_usd: parseFloat(document.getElementById("token-budget").value) || 0}'>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Token Name</label>
|
|
||||||
<input type="text" id="token-name" required placeholder="e.g. my-app">
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Rate Limit (requests/min, 0 = unlimited)</label>
|
|
||||||
<input type="number" id="token-rpm" value="0" min="0">
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Daily Budget (USD, 0 = unlimited)</label>
|
|
||||||
<input type="number" id="token-budget" value="0" min="0" step="0.01">
|
|
||||||
</div>
|
|
||||||
<div class="modal-actions">
|
|
||||||
<button type="button" class="btn btn-outline" onclick="closeModal()">Cancel</button>
|
|
||||||
<button type="submit" class="btn btn-primary" style="width:auto">Create</button>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
function showCreateTokenModal() {
|
|
||||||
document.getElementById('new-token-display').style.display = 'none';
|
|
||||||
document.getElementById('create-token-error').innerHTML = '';
|
|
||||||
document.getElementById('token-name').value = '';
|
|
||||||
document.getElementById('token-rpm').value = '0';
|
|
||||||
document.getElementById('token-budget').value = '0';
|
|
||||||
document.getElementById('modal-create-token').classList.add('show');
|
|
||||||
}
|
|
||||||
function closeModal() {
|
|
||||||
document.getElementById('modal-create-token').classList.remove('show');
|
|
||||||
}
|
|
||||||
document.getElementById('modal-create-token').addEventListener('click', function(e) {
|
|
||||||
if (e.target === this) closeModal();
|
|
||||||
});
|
|
||||||
|
|
||||||
// After token creation: show key, close modal, refresh token table only
|
|
||||||
document.body.addEventListener('tokenCreated', function() {
|
|
||||||
closeModal();
|
|
||||||
document.getElementById('new-token-display').style.display = 'block';
|
|
||||||
// Refresh only the token table body, preserving the new-token-display
|
|
||||||
setTimeout(function() {
|
|
||||||
htmx.ajax('GET', '/tokens', {target: '#tokens-tbody', select: '#tokens-tbody', swap: 'outerHTML'});
|
|
||||||
}, 100);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Handle token create errors (non-200 responses swap into error div)
|
|
||||||
document.body.addEventListener('htmx:beforeSwap', function(e) {
|
|
||||||
if (e.detail.target.id === 'new-token-display' && !e.detail.isError && !e.detail.xhr.getResponseHeader('HX-Trigger')) {
|
|
||||||
// Error response - redirect to error div
|
|
||||||
if (e.detail.xhr.status >= 400) {
|
|
||||||
e.detail.target = document.getElementById('create-token-error');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,70 +0,0 @@
|
||||||
{{define "content"}}
|
|
||||||
<div class="page-header">
|
|
||||||
<h1>Users</h1>
|
|
||||||
<button class="btn btn-sm btn-primary" onclick="showCreateUserModal()">Create User</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="section">
|
|
||||||
<table>
|
|
||||||
<thead><tr><th>ID</th><th>Username</th><th>Role</th><th>2FA</th><th>Created</th><th></th></tr></thead>
|
|
||||||
<tbody>
|
|
||||||
{{range .Users}}
|
|
||||||
<tr>
|
|
||||||
<td>{{.ID}}</td>
|
|
||||||
<td>{{.Username}}</td>
|
|
||||||
<td><span class="badge {{if .IsAdmin}}badge-admin{{else}}badge-user{{end}}">{{if .IsAdmin}}Admin{{else}}User{{end}}</span></td>
|
|
||||||
<td>{{if .TOTPEnabled}}<span class="badge badge-totp">Enabled</span>{{else}}Off{{end}}</td>
|
|
||||||
<td>{{formatTime .CreatedAt}}</td>
|
|
||||||
<td>{{if ne .ID $.User.ID}}<button class="btn btn-sm btn-danger"
|
|
||||||
hx-delete="/api/auth/users/{{.ID}}" hx-swap="none"
|
|
||||||
hx-confirm="Delete this user? All their sessions and tokens will be removed.">Delete</button>{{end}}</td>
|
|
||||||
</tr>
|
|
||||||
{{end}}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Create User Modal -->
|
|
||||||
<div id="modal-create-user" class="modal-overlay">
|
|
||||||
<div class="modal">
|
|
||||||
<h2>Create User</h2>
|
|
||||||
<div id="create-user-error"></div>
|
|
||||||
<form hx-post="/api/auth/users" hx-target="#create-user-error" hx-swap="innerHTML" hx-ext="json-enc"
|
|
||||||
hx-vals='js:{username: document.getElementById("new-user-username").value, password: document.getElementById("new-user-password").value, is_admin: document.getElementById("new-user-admin").checked}'>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Username</label>
|
|
||||||
<input type="text" id="new-user-username" required>
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Password (min 8 characters)</label>
|
|
||||||
<input type="password" id="new-user-password" required minlength="8">
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label style="display:flex;align-items:center;gap:8px;">
|
|
||||||
<input type="checkbox" id="new-user-admin"> Admin
|
|
||||||
</label>
|
|
||||||
</div>
|
|
||||||
<div class="modal-actions">
|
|
||||||
<button type="button" class="btn btn-outline" onclick="closeUserModal()">Cancel</button>
|
|
||||||
<button type="submit" class="btn btn-primary" style="width:auto">Create</button>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
function showCreateUserModal() {
|
|
||||||
document.getElementById('create-user-error').innerHTML = '';
|
|
||||||
document.getElementById('new-user-username').value = '';
|
|
||||||
document.getElementById('new-user-password').value = '';
|
|
||||||
document.getElementById('new-user-admin').checked = false;
|
|
||||||
document.getElementById('modal-create-user').classList.add('show');
|
|
||||||
}
|
|
||||||
function closeUserModal() {
|
|
||||||
document.getElementById('modal-create-user').classList.remove('show');
|
|
||||||
}
|
|
||||||
document.getElementById('modal-create-user').addEventListener('click', function(e) {
|
|
||||||
if (e.target === this) closeUserModal();
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,59 +0,0 @@
|
||||||
{{define "setup"}}
|
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Setup - LLM Gateway</title>
|
|
||||||
<script src="https://unpkg.com/htmx.org@2.0.4"></script>
|
|
||||||
<script src="https://unpkg.com/htmx-ext-json-enc@2.0.1/json-enc.js"></script>
|
|
||||||
<style>
|
|
||||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
|
||||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; background: #0f172a; color: #e2e8f0; min-height: 100vh; display: flex; align-items: center; justify-content: center; }
|
|
||||||
.auth-box { background: #1e293b; border-radius: 12px; padding: 32px; width: 100%; max-width: 400px; }
|
|
||||||
.auth-box h1 { text-align: center; margin-bottom: 8px; font-size: 1.5rem; color: #f8fafc; }
|
|
||||||
.subtitle { text-align: center; color: #94a3b8; font-size: 0.9rem; margin-bottom: 24px; }
|
|
||||||
.form-group { margin-bottom: 16px; }
|
|
||||||
.form-group label { display: block; font-size: 0.85rem; color: #94a3b8; margin-bottom: 4px; }
|
|
||||||
.form-group input { width: 100%; padding: 10px 12px; background: #0f172a; border: 1px solid #334155; border-radius: 6px; color: #e2e8f0; font-size: 0.95rem; }
|
|
||||||
.form-group input:focus { outline: none; border-color: #3b82f6; }
|
|
||||||
.btn-primary { display: block; width: 100%; padding: 10px 20px; border-radius: 6px; border: none; cursor: pointer; font-size: 0.9rem; font-weight: 500; background: #3b82f6; color: #fff; }
|
|
||||||
.btn-primary:hover { background: #2563eb; }
|
|
||||||
.error-msg { background: #7f1d1d40; border: 1px solid #991b1b; color: #fca5a5; padding: 10px; border-radius: 6px; margin-bottom: 16px; font-size: 0.85rem; }
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="auth-box">
|
|
||||||
<h1>LLM Gateway Setup</h1>
|
|
||||||
<p class="subtitle">Create the first admin account</p>
|
|
||||||
<div id="setup-error"></div>
|
|
||||||
<form hx-post="/api/auth/setup" hx-target="#setup-error" hx-swap="innerHTML" hx-ext="json-enc">
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Username</label>
|
|
||||||
<input type="text" name="username" required autocomplete="username">
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Password (min 8 characters)</label>
|
|
||||||
<input type="password" name="password" required minlength="8" autocomplete="new-password">
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<label>Confirm Password</label>
|
|
||||||
<input type="password" id="setup-password2" required minlength="8" autocomplete="new-password">
|
|
||||||
</div>
|
|
||||||
<button type="submit" class="btn-primary">Create Admin Account</button>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
<script>
|
|
||||||
document.body.addEventListener('htmx:confirm', function(e) {
|
|
||||||
var form = e.target;
|
|
||||||
if (!form.querySelector || !form.querySelector('#setup-password2')) return;
|
|
||||||
var pw = form.querySelector('[name=password]').value;
|
|
||||||
if (pw !== document.getElementById('setup-password2').value) {
|
|
||||||
e.preventDefault();
|
|
||||||
document.getElementById('setup-error').innerHTML = '<div class="error-msg">Passwords do not match</div>';
|
|
||||||
}
|
|
||||||
});
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
{{end}}
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
package metrics
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Metrics struct {
|
|
||||||
requestsTotal *prometheus.CounterVec
|
|
||||||
requestDuration *prometheus.HistogramVec
|
|
||||||
tokensTotal *prometheus.CounterVec
|
|
||||||
costTotal *prometheus.CounterVec
|
|
||||||
cacheHits prometheus.Counter
|
|
||||||
cacheMisses prometheus.Counter
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() *Metrics {
|
|
||||||
return &Metrics{
|
|
||||||
requestsTotal: promauto.NewCounterVec(prometheus.CounterOpts{
|
|
||||||
Name: "llm_gateway_requests_total",
|
|
||||||
Help: "Total number of LLM requests",
|
|
||||||
}, []string{"model", "provider", "token_name", "status"}),
|
|
||||||
|
|
||||||
requestDuration: promauto.NewHistogramVec(prometheus.HistogramOpts{
|
|
||||||
Name: "llm_gateway_request_duration_ms",
|
|
||||||
Help: "Request duration in milliseconds",
|
|
||||||
Buckets: []float64{100, 250, 500, 1000, 2500, 5000, 10000, 30000, 60000, 120000},
|
|
||||||
}, []string{"model", "provider"}),
|
|
||||||
|
|
||||||
tokensTotal: promauto.NewCounterVec(prometheus.CounterOpts{
|
|
||||||
Name: "llm_gateway_tokens_total",
|
|
||||||
Help: "Total tokens processed",
|
|
||||||
}, []string{"model", "provider", "type"}),
|
|
||||||
|
|
||||||
costTotal: promauto.NewCounterVec(prometheus.CounterOpts{
|
|
||||||
Name: "llm_gateway_cost_usd_total",
|
|
||||||
Help: "Total cost in USD",
|
|
||||||
}, []string{"model", "provider", "token_name"}),
|
|
||||||
|
|
||||||
cacheHits: promauto.NewCounter(prometheus.CounterOpts{
|
|
||||||
Name: "llm_gateway_cache_hits_total",
|
|
||||||
Help: "Total number of cache hits",
|
|
||||||
}),
|
|
||||||
|
|
||||||
cacheMisses: promauto.NewCounter(prometheus.CounterOpts{
|
|
||||||
Name: "llm_gateway_cache_misses_total",
|
|
||||||
Help: "Total number of cache misses",
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Metrics) RecordRequest(model, providerName, tokenName, status string, latencyMS int64, inputTokens, outputTokens int, cost float64) {
|
|
||||||
m.requestsTotal.WithLabelValues(model, providerName, tokenName, status).Inc()
|
|
||||||
m.requestDuration.WithLabelValues(model, providerName).Observe(float64(latencyMS))
|
|
||||||
|
|
||||||
if inputTokens > 0 {
|
|
||||||
m.tokensTotal.WithLabelValues(model, providerName, "input").Add(float64(inputTokens))
|
|
||||||
}
|
|
||||||
if outputTokens > 0 {
|
|
||||||
m.tokensTotal.WithLabelValues(model, providerName, "output").Add(float64(outputTokens))
|
|
||||||
}
|
|
||||||
if cost > 0 {
|
|
||||||
m.costTotal.WithLabelValues(model, providerName, tokenName).Add(cost)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Metrics) RecordCacheHit() {
|
|
||||||
m.cacheHits.Inc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Metrics) RecordCacheMiss() {
|
|
||||||
m.cacheMisses.Inc()
|
|
||||||
}
|
|
||||||
|
|
@ -1,191 +0,0 @@
|
||||||
package pricing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const defaultPricesURL = "https://raw.githubusercontent.com/pydantic/genai-prices/main/prices/data_slim.json"
|
|
||||||
|
|
||||||
// Provider represents a provider entry in genai_prices.json.
|
|
||||||
type Provider struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Models []Model `json:"models"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Model represents a model entry with pricing.
|
|
||||||
type Model struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Prices json.RawMessage `json:"prices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup provides pricing data fetched from genai-prices.
|
|
||||||
type Lookup struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
prices map[string][2]float64
|
|
||||||
url string
|
|
||||||
stopCh chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLookup creates a Lookup that fetches pricing data immediately and refreshes every interval.
|
|
||||||
// If url is empty, uses the default genai-prices URL.
|
|
||||||
// Returns a usable Lookup even if the initial fetch fails (prices will be empty until next refresh).
|
|
||||||
func NewLookup(url string, interval time.Duration) *Lookup {
|
|
||||||
if url == "" {
|
|
||||||
url = defaultPricesURL
|
|
||||||
}
|
|
||||||
l := &Lookup{
|
|
||||||
prices: make(map[string][2]float64),
|
|
||||||
url: url,
|
|
||||||
stopCh: make(chan struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initial fetch
|
|
||||||
l.refresh()
|
|
||||||
|
|
||||||
// Background refresh
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(interval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
l.refresh()
|
|
||||||
case <-l.stopCh:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close stops the background refresh goroutine.
|
|
||||||
func (l *Lookup) Close() {
|
|
||||||
close(l.stopCh)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns (inputPer1M, outputPer1M) for a provider:model pair.
|
|
||||||
// Returns (0, 0) if not found.
|
|
||||||
func (l *Lookup) Get(provider, model string) (float64, float64) {
|
|
||||||
if l == nil {
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
l.mu.RLock()
|
|
||||||
defer l.mu.RUnlock()
|
|
||||||
key := fmt.Sprintf("%s:%s", provider, model)
|
|
||||||
if p, ok := l.prices[key]; ok {
|
|
||||||
return p[0], p[1]
|
|
||||||
}
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// FillMissing fills in zero-value pricing from the lookup data.
|
|
||||||
// Returns the number of prices filled.
|
|
||||||
func (l *Lookup) FillMissing(provider, model string, input, output *float64) bool {
|
|
||||||
if l == nil || (*input > 0 && *output > 0) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
i, o := l.Get(provider, model)
|
|
||||||
if i == 0 && o == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if *input == 0 {
|
|
||||||
*input = i
|
|
||||||
}
|
|
||||||
if *output == 0 {
|
|
||||||
*output = o
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Lookup) refresh() {
|
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
|
||||||
resp, err := client.Get(l.url)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("WARNING: failed to fetch pricing data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
log.Printf("WARNING: pricing data fetch returned %d", resp.StatusCode)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("WARNING: failed to read pricing data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var providers []Provider
|
|
||||||
if err := json.Unmarshal(body, &providers); err != nil {
|
|
||||||
log.Printf("WARNING: failed to parse pricing data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
prices := make(map[string][2]float64)
|
|
||||||
for _, p := range providers {
|
|
||||||
for _, m := range p.Models {
|
|
||||||
input, output := parsePrices(m.Prices)
|
|
||||||
if input > 0 || output > 0 {
|
|
||||||
key := fmt.Sprintf("%s:%s", p.ID, m.ID)
|
|
||||||
prices[key] = [2]float64{input, output}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
l.mu.Lock()
|
|
||||||
l.prices = prices
|
|
||||||
l.mu.Unlock()
|
|
||||||
|
|
||||||
log.Printf("Loaded pricing data: %d model prices from genai-prices", len(prices))
|
|
||||||
}
|
|
||||||
|
|
||||||
// parsePrices handles the different shapes of the "prices" field:
|
|
||||||
// - object: {"input_mtok": 0.5, "output_mtok": 1.0}
|
|
||||||
// - array: [{"prices": {"input_mtok": 0.5, ...}}, ...] (time-of-day; use first entry)
|
|
||||||
func parsePrices(raw json.RawMessage) (input, output float64) {
|
|
||||||
if len(raw) == 0 {
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try as object first (most common)
|
|
||||||
var obj map[string]any
|
|
||||||
if json.Unmarshal(raw, &obj) == nil {
|
|
||||||
return extractPrice(obj, "input_mtok"), extractPrice(obj, "output_mtok")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try as array (time-of-day pricing) — use first entry
|
|
||||||
var arr []struct {
|
|
||||||
Prices map[string]any `json:"prices"`
|
|
||||||
}
|
|
||||||
if json.Unmarshal(raw, &arr) == nil && len(arr) > 0 {
|
|
||||||
return extractPrice(arr[0].Prices, "input_mtok"), extractPrice(arr[0].Prices, "output_mtok")
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractPrice handles both simple float and tiered pricing (uses base price).
|
|
||||||
func extractPrice(prices map[string]any, key string) float64 {
|
|
||||||
v, ok := prices[key]
|
|
||||||
if !ok {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
switch val := v.(type) {
|
|
||||||
case float64:
|
|
||||||
return val
|
|
||||||
case map[string]any:
|
|
||||||
if base, ok := val["base"].(float64); ok {
|
|
||||||
return base
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
@ -1,144 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"sort"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LoadBalancer reorders routes for load distribution.
|
|
||||||
type LoadBalancer interface {
|
|
||||||
Reorder(routes []Route) []Route
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLoadBalancer creates a load balancer by strategy name.
|
|
||||||
func NewLoadBalancer(strategy string) LoadBalancer {
|
|
||||||
switch strategy {
|
|
||||||
case "round-robin":
|
|
||||||
return &RoundRobinBalancer{}
|
|
||||||
case "random":
|
|
||||||
return &RandomBalancer{}
|
|
||||||
case "least-cost":
|
|
||||||
return &LeastCostBalancer{}
|
|
||||||
default:
|
|
||||||
return &FirstBalancer{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FirstBalancer is a no-op that preserves original order.
|
|
||||||
type FirstBalancer struct{}
|
|
||||||
|
|
||||||
func (b *FirstBalancer) Reorder(routes []Route) []Route {
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoundRobinBalancer rotates routes within same-priority groups.
|
|
||||||
type RoundRobinBalancer struct {
|
|
||||||
counter atomic.Uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *RoundRobinBalancer) Reorder(routes []Route) []Route {
|
|
||||||
if len(routes) <= 1 {
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]Route, len(routes))
|
|
||||||
copy(result, routes)
|
|
||||||
|
|
||||||
// Group by priority and rotate within each group
|
|
||||||
groups := groupByPriority(result)
|
|
||||||
idx := 0
|
|
||||||
count := b.counter.Add(1)
|
|
||||||
for _, group := range groups {
|
|
||||||
if len(group) > 1 {
|
|
||||||
offset := int(count) % len(group)
|
|
||||||
for j := 0; j < len(group); j++ {
|
|
||||||
result[idx] = group[(j+offset)%len(group)]
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result[idx] = group[0]
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// RandomBalancer shuffles routes within same-priority groups.
|
|
||||||
type RandomBalancer struct{}
|
|
||||||
|
|
||||||
func (b *RandomBalancer) Reorder(routes []Route) []Route {
|
|
||||||
if len(routes) <= 1 {
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]Route, len(routes))
|
|
||||||
copy(result, routes)
|
|
||||||
|
|
||||||
groups := groupByPriority(result)
|
|
||||||
idx := 0
|
|
||||||
for _, group := range groups {
|
|
||||||
rand.Shuffle(len(group), func(i, j int) {
|
|
||||||
group[i], group[j] = group[j], group[i]
|
|
||||||
})
|
|
||||||
for _, r := range group {
|
|
||||||
result[idx] = r
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// LeastCostBalancer sorts by price within same-priority groups.
|
|
||||||
type LeastCostBalancer struct{}
|
|
||||||
|
|
||||||
func (b *LeastCostBalancer) Reorder(routes []Route) []Route {
|
|
||||||
if len(routes) <= 1 {
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]Route, len(routes))
|
|
||||||
copy(result, routes)
|
|
||||||
|
|
||||||
groups := groupByPriority(result)
|
|
||||||
idx := 0
|
|
||||||
for _, group := range groups {
|
|
||||||
sort.Slice(group, func(i, j int) bool {
|
|
||||||
costI := group[i].InputPrice + group[i].OutputPrice
|
|
||||||
costJ := group[j].InputPrice + group[j].OutputPrice
|
|
||||||
return costI < costJ
|
|
||||||
})
|
|
||||||
for _, r := range group {
|
|
||||||
result[idx] = r
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// groupByPriority splits routes into groups of same priority, preserving order.
|
|
||||||
func groupByPriority(routes []Route) [][]Route {
|
|
||||||
if len(routes) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var groups [][]Route
|
|
||||||
currentPriority := routes[0].Priority
|
|
||||||
currentGroup := []Route{routes[0]}
|
|
||||||
|
|
||||||
for i := 1; i < len(routes); i++ {
|
|
||||||
if routes[i].Priority == currentPriority {
|
|
||||||
currentGroup = append(currentGroup, routes[i])
|
|
||||||
} else {
|
|
||||||
groups = append(groups, currentGroup)
|
|
||||||
currentPriority = routes[i].Priority
|
|
||||||
currentGroup = []Route{routes[i]}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
groups = append(groups, currentGroup)
|
|
||||||
|
|
||||||
return groups
|
|
||||||
}
|
|
||||||
|
|
@ -1,294 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type routeSpec struct {
|
|
||||||
name string
|
|
||||||
priority int
|
|
||||||
input float64
|
|
||||||
output float64
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeRoutes(specs ...routeSpec) []Route {
|
|
||||||
routes := make([]Route, len(specs))
|
|
||||||
for i, s := range specs {
|
|
||||||
routes[i] = Route{
|
|
||||||
Provider: &mockProvider{name: s.name},
|
|
||||||
ProviderModel: s.name + "-model",
|
|
||||||
Priority: s.priority,
|
|
||||||
InputPrice: s.input,
|
|
||||||
OutputPrice: s.output,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
func routeNames(routes []Route) []string {
|
|
||||||
names := make([]string, len(routes))
|
|
||||||
for i, r := range routes {
|
|
||||||
names[i] = r.Provider.Name()
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFirstBalancer_PreservesOrder(t *testing.T) {
|
|
||||||
routes := makeRoutes(
|
|
||||||
routeSpec{"a", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"b", 1, 2.0, 2.0},
|
|
||||||
routeSpec{"c", 1, 3.0, 3.0},
|
|
||||||
)
|
|
||||||
|
|
||||||
b := &FirstBalancer{}
|
|
||||||
result := b.Reorder(routes)
|
|
||||||
|
|
||||||
names := routeNames(result)
|
|
||||||
if names[0] != "a" || names[1] != "b" || names[2] != "c" {
|
|
||||||
t.Fatalf("expected [a b c], got %v", names)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRoundRobinBalancer_RotatesWithinPriorityGroup(t *testing.T) {
|
|
||||||
routes := makeRoutes(
|
|
||||||
routeSpec{"a", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"b", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"c", 1, 1.0, 1.0},
|
|
||||||
)
|
|
||||||
|
|
||||||
b := &RoundRobinBalancer{}
|
|
||||||
|
|
||||||
// Collect the first element from multiple calls
|
|
||||||
seen := make(map[string]bool)
|
|
||||||
for i := 0; i < 6; i++ {
|
|
||||||
result := b.Reorder(routes)
|
|
||||||
seen[result[0].Provider.Name()] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// All routes should have appeared as first at some point
|
|
||||||
for _, name := range []string{"a", "b", "c"} {
|
|
||||||
if !seen[name] {
|
|
||||||
t.Errorf("expected %q to appear as first element in rotation", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRoundRobinBalancer_PreservesPriorityOrder(t *testing.T) {
|
|
||||||
routes := makeRoutes(
|
|
||||||
routeSpec{"a", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"b", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"c", 2, 1.0, 1.0},
|
|
||||||
)
|
|
||||||
|
|
||||||
b := &RoundRobinBalancer{}
|
|
||||||
|
|
||||||
// Priority 2 route should always be last
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
result := b.Reorder(routes)
|
|
||||||
if result[2].Provider.Name() != "c" {
|
|
||||||
t.Fatalf("expected priority-2 route 'c' at the end, got %q", result[2].Provider.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRandomBalancer_AllRoutesPresent(t *testing.T) {
|
|
||||||
routes := makeRoutes(
|
|
||||||
routeSpec{"a", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"b", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"c", 1, 1.0, 1.0},
|
|
||||||
)
|
|
||||||
|
|
||||||
b := &RandomBalancer{}
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
result := b.Reorder(routes)
|
|
||||||
if len(result) != 3 {
|
|
||||||
t.Fatalf("expected 3 routes, got %d", len(result))
|
|
||||||
}
|
|
||||||
|
|
||||||
names := make(map[string]bool)
|
|
||||||
for _, r := range result {
|
|
||||||
names[r.Provider.Name()] = true
|
|
||||||
}
|
|
||||||
for _, want := range []string{"a", "b", "c"} {
|
|
||||||
if !names[want] {
|
|
||||||
t.Errorf("missing route %q in result", want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRandomBalancer_PreservesPriorityOrder(t *testing.T) {
|
|
||||||
routes := makeRoutes(
|
|
||||||
routeSpec{"a", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"b", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"c", 2, 1.0, 1.0},
|
|
||||||
)
|
|
||||||
|
|
||||||
b := &RandomBalancer{}
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
result := b.Reorder(routes)
|
|
||||||
if result[2].Provider.Name() != "c" {
|
|
||||||
t.Fatalf("expected priority-2 route 'c' last, got %q", result[2].Provider.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLeastCostBalancer_SortsByCost(t *testing.T) {
|
|
||||||
routes := makeRoutes(
|
|
||||||
routeSpec{"expensive", 1, 10.0, 10.0},
|
|
||||||
routeSpec{"cheap", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"medium", 1, 5.0, 5.0},
|
|
||||||
)
|
|
||||||
|
|
||||||
b := &LeastCostBalancer{}
|
|
||||||
result := b.Reorder(routes)
|
|
||||||
|
|
||||||
names := routeNames(result)
|
|
||||||
expected := []string{"cheap", "medium", "expensive"}
|
|
||||||
for i, want := range expected {
|
|
||||||
if names[i] != want {
|
|
||||||
t.Errorf("position %d: got %q, want %q", i, names[i], want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLeastCostBalancer_PreservesPriorityOrder(t *testing.T) {
|
|
||||||
routes := makeRoutes(
|
|
||||||
routeSpec{"expensive-p1", 1, 10.0, 10.0},
|
|
||||||
routeSpec{"cheap-p1", 1, 1.0, 1.0},
|
|
||||||
routeSpec{"cheap-p2", 2, 0.5, 0.5},
|
|
||||||
)
|
|
||||||
|
|
||||||
b := &LeastCostBalancer{}
|
|
||||||
result := b.Reorder(routes)
|
|
||||||
|
|
||||||
names := routeNames(result)
|
|
||||||
// Within priority 1, cheap should come first; priority 2 always last
|
|
||||||
if names[0] != "cheap-p1" {
|
|
||||||
t.Errorf("expected cheap-p1 first, got %q", names[0])
|
|
||||||
}
|
|
||||||
if names[1] != "expensive-p1" {
|
|
||||||
t.Errorf("expected expensive-p1 second, got %q", names[1])
|
|
||||||
}
|
|
||||||
if names[2] != "cheap-p2" {
|
|
||||||
t.Errorf("expected cheap-p2 last, got %q", names[2])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGroupByPriority(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
priorities []int
|
|
||||||
wantGroups [][]int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
priorities: nil,
|
|
||||||
wantGroups: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single",
|
|
||||||
priorities: []int{1},
|
|
||||||
wantGroups: [][]int{{1}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "all same",
|
|
||||||
priorities: []int{1, 1, 1},
|
|
||||||
wantGroups: [][]int{{1, 1, 1}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "two groups",
|
|
||||||
priorities: []int{1, 1, 2, 2},
|
|
||||||
wantGroups: [][]int{{1, 1}, {2, 2}},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "three groups",
|
|
||||||
priorities: []int{1, 2, 2, 3},
|
|
||||||
wantGroups: [][]int{{1}, {2, 2}, {3}},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
var routes []Route
|
|
||||||
for _, p := range tt.priorities {
|
|
||||||
routes = append(routes, Route{Priority: p})
|
|
||||||
}
|
|
||||||
|
|
||||||
groups := groupByPriority(routes)
|
|
||||||
|
|
||||||
if tt.wantGroups == nil {
|
|
||||||
if groups != nil {
|
|
||||||
t.Fatalf("expected nil groups, got %v", groups)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(groups) != len(tt.wantGroups) {
|
|
||||||
t.Fatalf("expected %d groups, got %d", len(tt.wantGroups), len(groups))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, wg := range tt.wantGroups {
|
|
||||||
if len(groups[i]) != len(wg) {
|
|
||||||
t.Errorf("group %d: expected %d routes, got %d", i, len(wg), len(groups[i]))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for j, wp := range wg {
|
|
||||||
if groups[i][j].Priority != wp {
|
|
||||||
t.Errorf("group %d, route %d: expected priority %d, got %d", i, j, wp, groups[i][j].Priority)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBalancer_SingleRoute(t *testing.T) {
|
|
||||||
routes := makeRoutes(routeSpec{"only", 1, 1.0, 1.0})
|
|
||||||
|
|
||||||
balancers := []struct {
|
|
||||||
name string
|
|
||||||
balancer LoadBalancer
|
|
||||||
}{
|
|
||||||
{"first", &FirstBalancer{}},
|
|
||||||
{"round-robin", &RoundRobinBalancer{}},
|
|
||||||
{"random", &RandomBalancer{}},
|
|
||||||
{"least-cost", &LeastCostBalancer{}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, bb := range balancers {
|
|
||||||
t.Run(bb.name, func(t *testing.T) {
|
|
||||||
result := bb.balancer.Reorder(routes)
|
|
||||||
if len(result) != 1 || result[0].Provider.Name() != "only" {
|
|
||||||
t.Fatalf("expected single route 'only', got %v", routeNames(result))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewLoadBalancer(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
strategy string
|
|
||||||
wantType string
|
|
||||||
}{
|
|
||||||
{"round-robin", "*provider.RoundRobinBalancer"},
|
|
||||||
{"random", "*provider.RandomBalancer"},
|
|
||||||
{"least-cost", "*provider.LeastCostBalancer"},
|
|
||||||
{"first", "*provider.FirstBalancer"},
|
|
||||||
{"unknown", "*provider.FirstBalancer"},
|
|
||||||
{"", "*provider.FirstBalancer"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.strategy, func(t *testing.T) {
|
|
||||||
b := NewLoadBalancer(tt.strategy)
|
|
||||||
got := fmt.Sprintf("%T", b)
|
|
||||||
if got != tt.wantType {
|
|
||||||
t.Errorf("NewLoadBalancer(%q) = %s, want %s", tt.strategy, got, tt.wantType)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,264 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CircuitState represents the state of a circuit breaker.
|
|
||||||
type CircuitState int
|
|
||||||
|
|
||||||
const (
|
|
||||||
CircuitClosed CircuitState = iota // normal operation
|
|
||||||
CircuitOpen // blocking requests
|
|
||||||
CircuitHalfOpen // testing with probe request
|
|
||||||
)
|
|
||||||
|
|
||||||
func (s CircuitState) String() string {
|
|
||||||
switch s {
|
|
||||||
case CircuitClosed:
|
|
||||||
return "closed"
|
|
||||||
case CircuitOpen:
|
|
||||||
return "open"
|
|
||||||
case CircuitHalfOpen:
|
|
||||||
return "half-open"
|
|
||||||
default:
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProviderCircuit tracks circuit breaker state for a single provider.
|
|
||||||
type ProviderCircuit struct {
|
|
||||||
State CircuitState
|
|
||||||
OpenedAt time.Time
|
|
||||||
LastProbe time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// HealthEvent represents a single request outcome for a provider.
|
|
||||||
type HealthEvent struct {
|
|
||||||
Timestamp time.Time
|
|
||||||
LatencyMS int64
|
|
||||||
IsError bool
|
|
||||||
ErrorMsg string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProviderHealth is the computed health status for a provider.
|
|
||||||
type ProviderHealth struct {
|
|
||||||
Provider string `json:"provider"`
|
|
||||||
Status string `json:"status"` // healthy, degraded, down
|
|
||||||
ErrorRate float64 `json:"error_rate"`
|
|
||||||
AvgLatency float64 `json:"avg_latency_ms"`
|
|
||||||
Total int `json:"total"`
|
|
||||||
Errors int `json:"errors"`
|
|
||||||
CircuitState string `json:"circuit_state"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// HealthTracker tracks per-provider health using a sliding window.
|
|
||||||
type HealthTracker struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
windows map[string][]HealthEvent
|
|
||||||
windowDu time.Duration
|
|
||||||
circuits map[string]*ProviderCircuit
|
|
||||||
cbConfig config.CircuitBreakerConfig
|
|
||||||
OnStateChange func(provider string, from, to CircuitState)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewHealthTracker creates a health tracker with the given window duration.
|
|
||||||
func NewHealthTracker(window time.Duration, cbCfg config.CircuitBreakerConfig) *HealthTracker {
|
|
||||||
if window == 0 {
|
|
||||||
window = 5 * time.Minute
|
|
||||||
}
|
|
||||||
return &HealthTracker{
|
|
||||||
windows: make(map[string][]HealthEvent),
|
|
||||||
circuits: make(map[string]*ProviderCircuit),
|
|
||||||
windowDu: window,
|
|
||||||
cbConfig: cbCfg,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsAvailable returns true if the provider's circuit breaker allows requests.
|
|
||||||
func (h *HealthTracker) IsAvailable(provider string) bool {
|
|
||||||
if !h.cbConfig.Enabled {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mu.RLock()
|
|
||||||
defer h.mu.RUnlock()
|
|
||||||
|
|
||||||
circuit, ok := h.circuits[provider]
|
|
||||||
if !ok {
|
|
||||||
return true // no circuit = closed = available
|
|
||||||
}
|
|
||||||
|
|
||||||
switch circuit.State {
|
|
||||||
case CircuitOpen:
|
|
||||||
// Check if cooldown has elapsed -> transition to half-open
|
|
||||||
if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
|
|
||||||
return true // will transition to half-open on next record
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
case CircuitHalfOpen:
|
|
||||||
return true // allow probe
|
|
||||||
default:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Record adds a health event for a provider and evaluates circuit transitions.
|
|
||||||
func (h *HealthTracker) Record(provider string, latencyMS int64, err error) {
|
|
||||||
event := HealthEvent{
|
|
||||||
Timestamp: time.Now(),
|
|
||||||
LatencyMS: latencyMS,
|
|
||||||
IsError: err != nil,
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
event.ErrorMsg = err.Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
h.mu.Lock()
|
|
||||||
defer h.mu.Unlock()
|
|
||||||
|
|
||||||
h.windows[provider] = append(h.windows[provider], event)
|
|
||||||
h.prune(provider)
|
|
||||||
|
|
||||||
if h.cbConfig.Enabled {
|
|
||||||
h.evaluateCircuit(provider, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// evaluateCircuit transitions circuit breaker state. Must be called with lock held.
|
|
||||||
func (h *HealthTracker) evaluateCircuit(providerName string, lastErr error) {
|
|
||||||
circuit, ok := h.circuits[providerName]
|
|
||||||
if !ok {
|
|
||||||
circuit = &ProviderCircuit{State: CircuitClosed}
|
|
||||||
h.circuits[providerName] = circuit
|
|
||||||
}
|
|
||||||
|
|
||||||
prevState := circuit.State
|
|
||||||
|
|
||||||
switch circuit.State {
|
|
||||||
case CircuitClosed:
|
|
||||||
// Check if error threshold exceeded
|
|
||||||
errorRate, total := h.errorRateUnlocked(providerName)
|
|
||||||
if total >= h.cbConfig.MinRequests && errorRate >= h.cbConfig.ErrorThreshold {
|
|
||||||
circuit.State = CircuitOpen
|
|
||||||
circuit.OpenedAt = time.Now()
|
|
||||||
}
|
|
||||||
case CircuitOpen:
|
|
||||||
// Check if cooldown elapsed -> half-open
|
|
||||||
if time.Since(circuit.OpenedAt) >= h.cbConfig.CooldownDuration {
|
|
||||||
circuit.State = CircuitHalfOpen
|
|
||||||
circuit.LastProbe = time.Now()
|
|
||||||
// Evaluate the probe result immediately
|
|
||||||
if lastErr == nil {
|
|
||||||
circuit.State = CircuitClosed
|
|
||||||
} else {
|
|
||||||
circuit.State = CircuitOpen
|
|
||||||
circuit.OpenedAt = time.Now()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case CircuitHalfOpen:
|
|
||||||
if lastErr == nil {
|
|
||||||
circuit.State = CircuitClosed
|
|
||||||
} else {
|
|
||||||
circuit.State = CircuitOpen
|
|
||||||
circuit.OpenedAt = time.Now()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if circuit.State != prevState && h.OnStateChange != nil {
|
|
||||||
cb := h.OnStateChange
|
|
||||||
from, to := prevState, circuit.State
|
|
||||||
// Call outside lock to avoid deadlocks
|
|
||||||
go cb(providerName, from, to)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// errorRateUnlocked computes error rate within window. Must be called with lock held.
|
|
||||||
func (h *HealthTracker) errorRateUnlocked(provider string) (float64, int) {
|
|
||||||
cutoff := time.Now().Add(-h.windowDu)
|
|
||||||
events := h.windows[provider]
|
|
||||||
var total, errors int
|
|
||||||
for _, e := range events {
|
|
||||||
if e.Timestamp.Before(cutoff) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
total++
|
|
||||||
if e.IsError {
|
|
||||||
errors++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if total == 0 {
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
return float64(errors) / float64(total), total
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status returns computed health for all tracked providers.
|
|
||||||
func (h *HealthTracker) Status() []ProviderHealth {
|
|
||||||
h.mu.RLock()
|
|
||||||
defer h.mu.RUnlock()
|
|
||||||
|
|
||||||
cutoff := time.Now().Add(-h.windowDu)
|
|
||||||
var results []ProviderHealth
|
|
||||||
|
|
||||||
for provider, events := range h.windows {
|
|
||||||
var total, errors int
|
|
||||||
var totalLatency int64
|
|
||||||
|
|
||||||
for _, e := range events {
|
|
||||||
if e.Timestamp.Before(cutoff) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
total++
|
|
||||||
totalLatency += e.LatencyMS
|
|
||||||
if e.IsError {
|
|
||||||
errors++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if total == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
errorRate := float64(errors) / float64(total)
|
|
||||||
status := "healthy"
|
|
||||||
if errorRate >= 0.5 {
|
|
||||||
status = "down"
|
|
||||||
} else if errorRate >= 0.1 {
|
|
||||||
status = "degraded"
|
|
||||||
}
|
|
||||||
|
|
||||||
circuitState := "closed"
|
|
||||||
if circuit, ok := h.circuits[provider]; ok {
|
|
||||||
circuitState = circuit.State.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
results = append(results, ProviderHealth{
|
|
||||||
Provider: provider,
|
|
||||||
Status: status,
|
|
||||||
ErrorRate: errorRate,
|
|
||||||
AvgLatency: float64(totalLatency) / float64(total),
|
|
||||||
Total: total,
|
|
||||||
Errors: errors,
|
|
||||||
CircuitState: circuitState,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
// prune removes events outside the window. Must be called with lock held.
|
|
||||||
func (h *HealthTracker) prune(provider string) {
|
|
||||||
cutoff := time.Now().Add(-h.windowDu)
|
|
||||||
events := h.windows[provider]
|
|
||||||
i := 0
|
|
||||||
for i < len(events) && events[i].Timestamp.Before(cutoff) {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i > 0 {
|
|
||||||
h.windows[provider] = events[i:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,345 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestTracker(window time.Duration, cb config.CircuitBreakerConfig) *HealthTracker {
|
|
||||||
return NewHealthTracker(window, cb)
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultCBConfig() config.CircuitBreakerConfig {
|
|
||||||
return config.CircuitBreakerConfig{
|
|
||||||
Enabled: true,
|
|
||||||
ErrorThreshold: 0.5,
|
|
||||||
MinRequests: 3,
|
|
||||||
CooldownDuration: 100 * time.Millisecond,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_Record(t *testing.T) {
|
|
||||||
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
|
|
||||||
|
|
||||||
ht.Record("provA", 100, nil)
|
|
||||||
ht.Record("provA", 200, errors.New("fail"))
|
|
||||||
ht.Record("provB", 50, nil)
|
|
||||||
|
|
||||||
ht.mu.RLock()
|
|
||||||
defer ht.mu.RUnlock()
|
|
||||||
|
|
||||||
if len(ht.windows["provA"]) != 2 {
|
|
||||||
t.Fatalf("expected 2 events for provA, got %d", len(ht.windows["provA"]))
|
|
||||||
}
|
|
||||||
if len(ht.windows["provB"]) != 1 {
|
|
||||||
t.Fatalf("expected 1 event for provB, got %d", len(ht.windows["provB"]))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify event fields
|
|
||||||
ev := ht.windows["provA"][1]
|
|
||||||
if !ev.IsError || ev.ErrorMsg != "fail" || ev.LatencyMS != 200 {
|
|
||||||
t.Fatalf("unexpected event fields: %+v", ev)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_Status(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
successCount int
|
|
||||||
errorCount int
|
|
||||||
wantStatus string
|
|
||||||
wantErrorRate float64
|
|
||||||
wantTotal int
|
|
||||||
wantErrors int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "healthy - no errors",
|
|
||||||
successCount: 10,
|
|
||||||
errorCount: 0,
|
|
||||||
wantStatus: "healthy",
|
|
||||||
wantErrorRate: 0.0,
|
|
||||||
wantTotal: 10,
|
|
||||||
wantErrors: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "healthy - below 10% errors",
|
|
||||||
successCount: 19,
|
|
||||||
errorCount: 1,
|
|
||||||
wantStatus: "healthy",
|
|
||||||
wantErrorRate: 0.05,
|
|
||||||
wantTotal: 20,
|
|
||||||
wantErrors: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "degraded - 20% errors",
|
|
||||||
successCount: 8,
|
|
||||||
errorCount: 2,
|
|
||||||
wantStatus: "degraded",
|
|
||||||
wantErrorRate: 0.2,
|
|
||||||
wantTotal: 10,
|
|
||||||
wantErrors: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "degraded - exactly 10% errors",
|
|
||||||
successCount: 9,
|
|
||||||
errorCount: 1,
|
|
||||||
wantStatus: "degraded",
|
|
||||||
wantErrorRate: 0.1,
|
|
||||||
wantTotal: 10,
|
|
||||||
wantErrors: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "down - 50% errors",
|
|
||||||
successCount: 5,
|
|
||||||
errorCount: 5,
|
|
||||||
wantStatus: "down",
|
|
||||||
wantErrorRate: 0.5,
|
|
||||||
wantTotal: 10,
|
|
||||||
wantErrors: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "down - all errors",
|
|
||||||
successCount: 0,
|
|
||||||
errorCount: 5,
|
|
||||||
wantStatus: "down",
|
|
||||||
wantErrorRate: 1.0,
|
|
||||||
wantTotal: 5,
|
|
||||||
wantErrors: 5,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
|
|
||||||
|
|
||||||
for i := 0; i < tt.successCount; i++ {
|
|
||||||
ht.Record("prov", 100, nil)
|
|
||||||
}
|
|
||||||
for i := 0; i < tt.errorCount; i++ {
|
|
||||||
ht.Record("prov", 100, errors.New("err"))
|
|
||||||
}
|
|
||||||
|
|
||||||
statuses := ht.Status()
|
|
||||||
if len(statuses) != 1 {
|
|
||||||
t.Fatalf("expected 1 status, got %d", len(statuses))
|
|
||||||
}
|
|
||||||
|
|
||||||
s := statuses[0]
|
|
||||||
if s.Status != tt.wantStatus {
|
|
||||||
t.Errorf("status = %q, want %q", s.Status, tt.wantStatus)
|
|
||||||
}
|
|
||||||
if s.Total != tt.wantTotal {
|
|
||||||
t.Errorf("total = %d, want %d", s.Total, tt.wantTotal)
|
|
||||||
}
|
|
||||||
if s.Errors != tt.wantErrors {
|
|
||||||
t.Errorf("errors = %d, want %d", s.Errors, tt.wantErrors)
|
|
||||||
}
|
|
||||||
// Allow small float tolerance
|
|
||||||
if diff := s.ErrorRate - tt.wantErrorRate; diff > 0.001 || diff < -0.001 {
|
|
||||||
t.Errorf("error_rate = %f, want %f", s.ErrorRate, tt.wantErrorRate)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_CircuitBreaker_ClosedToOpen(t *testing.T) {
|
|
||||||
cb := defaultCBConfig()
|
|
||||||
cb.MinRequests = 3
|
|
||||||
cb.ErrorThreshold = 0.5
|
|
||||||
|
|
||||||
ht := newTestTracker(5*time.Minute, cb)
|
|
||||||
|
|
||||||
// Record errors to exceed threshold (3 errors out of 3 = 100% > 50%)
|
|
||||||
ht.Record("prov", 100, errors.New("err"))
|
|
||||||
ht.Record("prov", 100, errors.New("err"))
|
|
||||||
ht.Record("prov", 100, errors.New("err"))
|
|
||||||
|
|
||||||
ht.mu.RLock()
|
|
||||||
state := ht.circuits["prov"].State
|
|
||||||
ht.mu.RUnlock()
|
|
||||||
|
|
||||||
if state != CircuitOpen {
|
|
||||||
t.Fatalf("expected CircuitOpen, got %s", state)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ht.IsAvailable("prov") {
|
|
||||||
t.Fatal("expected IsAvailable=false when circuit is open")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_CircuitBreaker_OpenToHalfOpenOnCooldown(t *testing.T) {
|
|
||||||
cb := defaultCBConfig()
|
|
||||||
cb.CooldownDuration = 50 * time.Millisecond
|
|
||||||
|
|
||||||
ht := newTestTracker(5*time.Minute, cb)
|
|
||||||
|
|
||||||
// Trip the circuit
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
ht.Record("prov", 100, errors.New("err"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if ht.IsAvailable("prov") {
|
|
||||||
t.Fatal("expected circuit open, IsAvailable should be false")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for cooldown
|
|
||||||
time.Sleep(60 * time.Millisecond)
|
|
||||||
|
|
||||||
// After cooldown, IsAvailable should return true (will transition to half-open)
|
|
||||||
if !ht.IsAvailable("prov") {
|
|
||||||
t.Fatal("expected IsAvailable=true after cooldown")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_CircuitBreaker_HalfOpenToClosedOnSuccess(t *testing.T) {
|
|
||||||
cb := defaultCBConfig()
|
|
||||||
cb.CooldownDuration = 10 * time.Millisecond
|
|
||||||
|
|
||||||
ht := newTestTracker(5*time.Minute, cb)
|
|
||||||
|
|
||||||
// Trip the circuit
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
ht.Record("prov", 100, errors.New("err"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for cooldown so next Record transitions through Open->HalfOpen
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
|
|
||||||
// A successful record should transition: Open -> HalfOpen -> Closed
|
|
||||||
ht.Record("prov", 100, nil)
|
|
||||||
|
|
||||||
ht.mu.RLock()
|
|
||||||
state := ht.circuits["prov"].State
|
|
||||||
ht.mu.RUnlock()
|
|
||||||
|
|
||||||
if state != CircuitClosed {
|
|
||||||
t.Fatalf("expected CircuitClosed after success in half-open, got %s", state)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ht.IsAvailable("prov") {
|
|
||||||
t.Fatal("expected IsAvailable=true after circuit closed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_CircuitBreaker_HalfOpenToOpenOnFailure(t *testing.T) {
|
|
||||||
cb := defaultCBConfig()
|
|
||||||
cb.CooldownDuration = 10 * time.Millisecond
|
|
||||||
|
|
||||||
ht := newTestTracker(5*time.Minute, cb)
|
|
||||||
|
|
||||||
// Trip the circuit
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
ht.Record("prov", 100, errors.New("err"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for cooldown
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
|
|
||||||
// A failed record should transition: Open -> HalfOpen -> Open
|
|
||||||
ht.Record("prov", 100, errors.New("still failing"))
|
|
||||||
|
|
||||||
ht.mu.RLock()
|
|
||||||
state := ht.circuits["prov"].State
|
|
||||||
ht.mu.RUnlock()
|
|
||||||
|
|
||||||
if state != CircuitOpen {
|
|
||||||
t.Fatalf("expected CircuitOpen after failure in half-open, got %s", state)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_IsAvailable_NoCircuitBreaker(t *testing.T) {
|
|
||||||
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{Enabled: false})
|
|
||||||
|
|
||||||
// Even with errors, IsAvailable should return true when CB is disabled
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
ht.Record("prov", 100, errors.New("err"))
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ht.IsAvailable("prov") {
|
|
||||||
t.Fatal("expected IsAvailable=true when circuit breaker disabled")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_IsAvailable_UnknownProvider(t *testing.T) {
|
|
||||||
ht := newTestTracker(5*time.Minute, defaultCBConfig())
|
|
||||||
|
|
||||||
if !ht.IsAvailable("unknown") {
|
|
||||||
t.Fatal("expected IsAvailable=true for unknown provider (no circuit)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_WindowPruning(t *testing.T) {
|
|
||||||
// Use a tiny window so events expire quickly
|
|
||||||
ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
|
|
||||||
|
|
||||||
ht.Record("prov", 100, nil)
|
|
||||||
ht.Record("prov", 200, nil)
|
|
||||||
|
|
||||||
// Wait for events to expire
|
|
||||||
time.Sleep(60 * time.Millisecond)
|
|
||||||
|
|
||||||
// Record a new event to trigger pruning
|
|
||||||
ht.Record("prov", 300, nil)
|
|
||||||
|
|
||||||
ht.mu.RLock()
|
|
||||||
count := len(ht.windows["prov"])
|
|
||||||
ht.mu.RUnlock()
|
|
||||||
|
|
||||||
if count != 1 {
|
|
||||||
t.Fatalf("expected 1 event after pruning, got %d", count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_Status_EmptyAfterPruning(t *testing.T) {
|
|
||||||
ht := newTestTracker(50*time.Millisecond, config.CircuitBreakerConfig{})
|
|
||||||
|
|
||||||
ht.Record("prov", 100, nil)
|
|
||||||
|
|
||||||
// Wait for events to expire
|
|
||||||
time.Sleep(60 * time.Millisecond)
|
|
||||||
|
|
||||||
statuses := ht.Status()
|
|
||||||
if len(statuses) != 0 {
|
|
||||||
t.Fatalf("expected 0 statuses after window expiry, got %d", len(statuses))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_Status_AvgLatency(t *testing.T) {
|
|
||||||
ht := newTestTracker(5*time.Minute, config.CircuitBreakerConfig{})
|
|
||||||
|
|
||||||
ht.Record("prov", 100, nil)
|
|
||||||
ht.Record("prov", 200, nil)
|
|
||||||
ht.Record("prov", 300, nil)
|
|
||||||
|
|
||||||
statuses := ht.Status()
|
|
||||||
if len(statuses) != 1 {
|
|
||||||
t.Fatalf("expected 1 status, got %d", len(statuses))
|
|
||||||
}
|
|
||||||
|
|
||||||
want := 200.0
|
|
||||||
if diff := statuses[0].AvgLatency - want; diff > 0.001 || diff < -0.001 {
|
|
||||||
t.Errorf("avg_latency = %f, want %f", statuses[0].AvgLatency, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthTracker_Status_CircuitStateReported(t *testing.T) {
|
|
||||||
cb := defaultCBConfig()
|
|
||||||
ht := newTestTracker(5*time.Minute, cb)
|
|
||||||
|
|
||||||
// Trip the circuit
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
ht.Record("prov", 100, errors.New("err"))
|
|
||||||
}
|
|
||||||
|
|
||||||
statuses := ht.Status()
|
|
||||||
if len(statuses) != 1 {
|
|
||||||
t.Fatalf("expected 1 status, got %d", len(statuses))
|
|
||||||
}
|
|
||||||
|
|
||||||
if statuses[0].CircuitState != "open" {
|
|
||||||
t.Errorf("circuit_state = %q, want %q", statuses[0].CircuitState, "open")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,178 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OpenAIProvider is a generic OpenAI-compatible HTTP client.
|
|
||||||
type OpenAIProvider struct {
|
|
||||||
name string
|
|
||||||
baseURL string
|
|
||||||
apiKey string
|
|
||||||
client *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewOpenAIProvider(name, baseURL, apiKey string, timeout time.Duration) *OpenAIProvider {
|
|
||||||
return &OpenAIProvider{
|
|
||||||
name: name,
|
|
||||||
baseURL: baseURL,
|
|
||||||
apiKey: apiKey,
|
|
||||||
client: &http.Client{
|
|
||||||
Timeout: timeout,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OpenAIProvider) Name() string { return p.name }
|
|
||||||
|
|
||||||
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error) {
|
|
||||||
reqCopy := *req
|
|
||||||
reqCopy.Model = model
|
|
||||||
reqCopy.Stream = false
|
|
||||||
|
|
||||||
body, err := json.Marshal(reqCopy)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshaling request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("creating request: %w", err)
|
|
||||||
}
|
|
||||||
p.setHeaders(httpReq)
|
|
||||||
|
|
||||||
resp, err := p.client.Do(httpReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("sending request: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("reading response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, &ProviderError{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
Body: string(respBody),
|
|
||||||
Provider: p.name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var chatResp ChatResponse
|
|
||||||
if err := json.Unmarshal(respBody, &chatResp); err != nil {
|
|
||||||
return nil, fmt.Errorf("unmarshaling response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &chatResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error) {
|
|
||||||
reqCopy := *req
|
|
||||||
reqCopy.Model = model
|
|
||||||
reqCopy.Stream = true
|
|
||||||
|
|
||||||
body, err := json.Marshal(reqCopy)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshaling request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/chat/completions", bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("creating request: %w", err)
|
|
||||||
}
|
|
||||||
p.setHeaders(httpReq)
|
|
||||||
|
|
||||||
resp, err := p.client.Do(httpReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("sending request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
defer resp.Body.Close()
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
|
||||||
return nil, &ProviderError{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
Body: string(respBody),
|
|
||||||
Provider: p.name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp.Body, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OpenAIProvider) Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
|
||||||
reqCopy := *req
|
|
||||||
reqCopy.Model = model
|
|
||||||
|
|
||||||
body, err := json.Marshal(reqCopy)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshaling request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.baseURL+"/embeddings", bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("creating request: %w", err)
|
|
||||||
}
|
|
||||||
p.setHeaders(httpReq)
|
|
||||||
|
|
||||||
resp, err := p.client.Do(httpReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("sending request: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("reading response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, &ProviderError{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
Body: string(respBody),
|
|
||||||
Provider: p.name,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var embResp EmbeddingResponse
|
|
||||||
if err := json.Unmarshal(respBody, &embResp); err != nil {
|
|
||||||
return nil, fmt.Errorf("unmarshaling response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &embResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OpenAIProvider) setHeaders(req *http.Request) {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
|
||||||
// Forward request ID if present in context
|
|
||||||
if reqID := req.Context().Value("requestID"); reqID != nil {
|
|
||||||
if id, ok := reqID.(string); ok && id != "" {
|
|
||||||
req.Header.Set("X-Request-ID", id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProviderError represents a non-200 response from a provider.
|
|
||||||
type ProviderError struct {
|
|
||||||
StatusCode int
|
|
||||||
Body string
|
|
||||||
Provider string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *ProviderError) Error() string {
|
|
||||||
return fmt.Sprintf("provider %s returned %d: %s", e.Provider, e.StatusCode, e.Body)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsRetryable returns true if the error is a server-side error worth retrying with another provider.
|
|
||||||
func (e *ProviderError) IsRetryable() bool {
|
|
||||||
return e.StatusCode >= 500 || e.StatusCode == 429
|
|
||||||
}
|
|
||||||
|
|
@ -1,89 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ChatRequest is the OpenAI-compatible chat completion request.
|
|
||||||
type ChatRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
Temperature *float64 `json:"temperature,omitempty"`
|
|
||||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
|
||||||
TopP *float64 `json:"top_p,omitempty"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
Stop any `json:"stop,omitempty"`
|
|
||||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
|
||||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
|
||||||
N *int `json:"n,omitempty"`
|
|
||||||
Tools []any `json:"tools,omitempty"`
|
|
||||||
ToolChoice any `json:"tool_choice,omitempty"`
|
|
||||||
ResponseFormat any `json:"response_format,omitempty"`
|
|
||||||
Extra map[string]any `json:"-"` // pass through unknown fields
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content any `json:"content"` // string or []ContentPart
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
ToolCalls []any `json:"tool_calls,omitempty"`
|
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatResponse struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Choices []Choice `json:"choices"`
|
|
||||||
Usage *Usage `json:"usage,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Choice struct {
|
|
||||||
Index int `json:"index"`
|
|
||||||
Message Message `json:"message"`
|
|
||||||
FinishReason string `json:"finish_reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Usage struct {
|
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingRequest is the OpenAI-compatible embedding request.
|
|
||||||
type EmbeddingRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input any `json:"input"` // string or []string
|
|
||||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingResponse is the OpenAI-compatible embedding response.
|
|
||||||
type EmbeddingResponse struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Data []EmbeddingData `json:"data"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Usage *EmbeddingUsage `json:"usage,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingData holds a single embedding vector.
|
|
||||||
type EmbeddingData struct {
|
|
||||||
Object string `json:"object"`
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingUsage reports token usage for embeddings.
|
|
||||||
type EmbeddingUsage struct {
|
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Provider sends requests to an LLM API.
|
|
||||||
type Provider interface {
|
|
||||||
Name() string
|
|
||||||
ChatCompletion(ctx context.Context, model string, req *ChatRequest) (*ChatResponse, error)
|
|
||||||
ChatCompletionStream(ctx context.Context, model string, req *ChatRequest) (io.ReadCloser, error)
|
|
||||||
Embedding(ctx context.Context, model string, req *EmbeddingRequest) (*EmbeddingResponse, error)
|
|
||||||
}
|
|
||||||
|
|
@ -1,214 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ModelTimeouts holds per-model timeout overrides.
|
|
||||||
type ModelTimeouts struct {
|
|
||||||
RequestTimeout time.Duration
|
|
||||||
StreamingTimeout time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// Route maps a model to a specific provider with pricing.
|
|
||||||
type Route struct {
|
|
||||||
Provider Provider
|
|
||||||
ProviderModel string
|
|
||||||
Priority int
|
|
||||||
InputPrice float64 // per 1M tokens
|
|
||||||
OutputPrice float64 // per 1M tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// Registry maps model names to provider routes.
|
|
||||||
type Registry struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
routes map[string][]Route
|
|
||||||
balancers map[string]LoadBalancer
|
|
||||||
aliases map[string]string // alias -> canonical name
|
|
||||||
order []string // preserves config order (canonical names only)
|
|
||||||
timeouts map[string]*ModelTimeouts
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRegistry(cfg *config.Config) (*Registry, error) {
|
|
||||||
r := &Registry{}
|
|
||||||
if err := r.buildFromConfig(cfg); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Registry) buildFromConfig(cfg *config.Config) error {
|
|
||||||
// Build providers
|
|
||||||
providers := make(map[string]Provider)
|
|
||||||
for _, pc := range cfg.Providers {
|
|
||||||
providers[pc.Name] = NewOpenAIProvider(pc.Name, pc.BaseURL, pc.APIKey, pc.Timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build routes
|
|
||||||
routes := make(map[string][]Route)
|
|
||||||
balancers := make(map[string]LoadBalancer)
|
|
||||||
aliases := make(map[string]string)
|
|
||||||
order := make([]string, 0, len(cfg.Models))
|
|
||||||
timeouts := make(map[string]*ModelTimeouts)
|
|
||||||
|
|
||||||
for _, mc := range cfg.Models {
|
|
||||||
var modelRoutes []Route
|
|
||||||
for _, rc := range mc.Routes {
|
|
||||||
p, ok := providers[rc.Provider]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("model %s: unknown provider %s", mc.Name, rc.Provider)
|
|
||||||
}
|
|
||||||
pc := cfg.ProviderByName(rc.Provider)
|
|
||||||
priority := pc.Priority
|
|
||||||
modelRoutes = append(modelRoutes, Route{
|
|
||||||
Provider: p,
|
|
||||||
ProviderModel: rc.Model,
|
|
||||||
Priority: priority,
|
|
||||||
InputPrice: rc.Pricing.Input,
|
|
||||||
OutputPrice: rc.Pricing.Output,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
// Sort by priority (lower = higher priority)
|
|
||||||
sort.Slice(modelRoutes, func(i, j int) bool {
|
|
||||||
return modelRoutes[i].Priority < modelRoutes[j].Priority
|
|
||||||
})
|
|
||||||
routes[mc.Name] = modelRoutes
|
|
||||||
order = append(order, mc.Name)
|
|
||||||
|
|
||||||
// Load balancer
|
|
||||||
balancers[mc.Name] = NewLoadBalancer(mc.LoadBalancing)
|
|
||||||
|
|
||||||
// Per-model timeouts
|
|
||||||
if mc.RequestTimeout > 0 || mc.StreamingTimeout > 0 {
|
|
||||||
timeouts[mc.Name] = &ModelTimeouts{
|
|
||||||
RequestTimeout: mc.RequestTimeout,
|
|
||||||
StreamingTimeout: mc.StreamingTimeout,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register aliases
|
|
||||||
for _, alias := range mc.Aliases {
|
|
||||||
aliases[alias] = mc.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
r.routes = routes
|
|
||||||
r.balancers = balancers
|
|
||||||
r.aliases = aliases
|
|
||||||
r.order = order
|
|
||||||
r.timeouts = timeouts
|
|
||||||
r.mu.Unlock()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reload rebuilds routes from new config. Used for hot-reload.
|
|
||||||
func (r *Registry) Reload(cfg *config.Config) error {
|
|
||||||
return r.buildFromConfig(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup returns the routes for a model name (resolving aliases).
|
|
||||||
func (r *Registry) Lookup(model string) ([]Route, bool) {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
// Resolve alias
|
|
||||||
canonical := model
|
|
||||||
if alias, ok := r.aliases[model]; ok {
|
|
||||||
canonical = alias
|
|
||||||
}
|
|
||||||
|
|
||||||
routes, ok := r.routes[canonical]
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply load balancer
|
|
||||||
if balancer, ok := r.balancers[canonical]; ok {
|
|
||||||
routes = balancer.Reorder(routes)
|
|
||||||
}
|
|
||||||
|
|
||||||
return routes, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModelNames returns all registered model names in config order (including aliases).
|
|
||||||
func (r *Registry) ModelNames() []string {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
var names []string
|
|
||||||
for _, name := range r.order {
|
|
||||||
names = append(names, name)
|
|
||||||
}
|
|
||||||
// Add aliases
|
|
||||||
for alias := range r.aliases {
|
|
||||||
names = append(names, alias)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModelTimeoutsFor returns per-model timeout overrides, resolving aliases. Returns nil if none set.
|
|
||||||
func (r *Registry) ModelTimeoutsFor(model string) *ModelTimeouts {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
canonical := model
|
|
||||||
if alias, ok := r.aliases[model]; ok {
|
|
||||||
canonical = alias
|
|
||||||
}
|
|
||||||
return r.timeouts[canonical]
|
|
||||||
}
|
|
||||||
|
|
||||||
// RouteInfo exposes route details for dashboard display.
|
|
||||||
type RouteInfo struct {
|
|
||||||
ProviderName string `json:"provider_name"`
|
|
||||||
ProviderModel string `json:"provider_model"`
|
|
||||||
Priority int `json:"priority"`
|
|
||||||
InputPrice float64 `json:"input_price"`
|
|
||||||
OutputPrice float64 `json:"output_price"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModelRouteInfo exposes a model and its routes for dashboard display.
|
|
||||||
type ModelRouteInfo struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Aliases []string `json:"aliases,omitempty"`
|
|
||||||
Routes []RouteInfo `json:"routes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// AllRoutes returns all models and their routes in config order.
|
|
||||||
func (r *Registry) AllRoutes() []ModelRouteInfo {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
|
|
||||||
// Build reverse alias map
|
|
||||||
modelAliases := make(map[string][]string)
|
|
||||||
for alias, canonical := range r.aliases {
|
|
||||||
modelAliases[canonical] = append(modelAliases[canonical], alias)
|
|
||||||
}
|
|
||||||
|
|
||||||
results := make([]ModelRouteInfo, 0, len(r.order))
|
|
||||||
for _, name := range r.order {
|
|
||||||
routes := r.routes[name]
|
|
||||||
info := ModelRouteInfo{
|
|
||||||
Name: name,
|
|
||||||
Aliases: modelAliases[name],
|
|
||||||
}
|
|
||||||
for _, rt := range routes {
|
|
||||||
info.Routes = append(info.Routes, RouteInfo{
|
|
||||||
ProviderName: rt.Provider.Name(),
|
|
||||||
ProviderModel: rt.ProviderModel,
|
|
||||||
Priority: rt.Priority,
|
|
||||||
InputPrice: rt.InputPrice,
|
|
||||||
OutputPrice: rt.OutputPrice,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
results = append(results, info)
|
|
||||||
}
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
@ -1,286 +0,0 @@
|
||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"llm-gateway/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// mockProvider implements the Provider interface for testing.
|
|
||||||
type mockProvider struct {
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockProvider) Name() string { return m.name }
|
|
||||||
|
|
||||||
func (m *mockProvider) ChatCompletion(_ context.Context, _ string, _ *ChatRequest) (*ChatResponse, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockProvider) ChatCompletionStream(_ context.Context, _ string, _ *ChatRequest) (io.ReadCloser, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockProvider) Embedding(_ context.Context, _ string, _ *EmbeddingRequest) (*EmbeddingResponse, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// newTestRegistry builds a Registry directly without going through config parsing.
|
|
||||||
func newTestRegistry(models []testModel) *Registry {
|
|
||||||
r := &Registry{
|
|
||||||
routes: make(map[string][]Route),
|
|
||||||
balancers: make(map[string]LoadBalancer),
|
|
||||||
aliases: make(map[string]string),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range models {
|
|
||||||
r.routes[m.name] = m.routes
|
|
||||||
r.balancers[m.name] = &FirstBalancer{}
|
|
||||||
r.order = append(r.order, m.name)
|
|
||||||
for _, alias := range m.aliases {
|
|
||||||
r.aliases[alias] = m.name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
type testModel struct {
|
|
||||||
name string
|
|
||||||
aliases []string
|
|
||||||
routes []Route
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_Lookup_Canonical(t *testing.T) {
|
|
||||||
reg := newTestRegistry([]testModel{
|
|
||||||
{
|
|
||||||
name: "gpt-4",
|
|
||||||
routes: []Route{
|
|
||||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
routes, ok := reg.Lookup("gpt-4")
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("expected Lookup to find gpt-4")
|
|
||||||
}
|
|
||||||
if len(routes) != 1 {
|
|
||||||
t.Fatalf("expected 1 route, got %d", len(routes))
|
|
||||||
}
|
|
||||||
if routes[0].Provider.Name() != "openai" {
|
|
||||||
t.Errorf("expected provider 'openai', got %q", routes[0].Provider.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_Lookup_Alias(t *testing.T) {
|
|
||||||
reg := newTestRegistry([]testModel{
|
|
||||||
{
|
|
||||||
name: "gpt-4",
|
|
||||||
aliases: []string{"gpt4", "gpt-4-latest"},
|
|
||||||
routes: []Route{
|
|
||||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
model string
|
|
||||||
found bool
|
|
||||||
}{
|
|
||||||
{"canonical", "gpt-4", true},
|
|
||||||
{"alias1", "gpt4", true},
|
|
||||||
{"alias2", "gpt-4-latest", true},
|
|
||||||
{"unknown", "gpt-5", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
routes, ok := reg.Lookup(tt.model)
|
|
||||||
if ok != tt.found {
|
|
||||||
t.Fatalf("Lookup(%q) found=%v, want %v", tt.model, ok, tt.found)
|
|
||||||
}
|
|
||||||
if tt.found && len(routes) != 1 {
|
|
||||||
t.Fatalf("expected 1 route, got %d", len(routes))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_ModelNames_IncludesAliases(t *testing.T) {
|
|
||||||
reg := newTestRegistry([]testModel{
|
|
||||||
{
|
|
||||||
name: "gpt-4",
|
|
||||||
aliases: []string{"gpt4"},
|
|
||||||
routes: []Route{
|
|
||||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "claude-3",
|
|
||||||
routes: []Route{
|
|
||||||
{Provider: &mockProvider{name: "anthropic"}, ProviderModel: "claude-3", Priority: 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
names := reg.ModelNames()
|
|
||||||
|
|
||||||
want := map[string]bool{"gpt-4": true, "gpt4": true, "claude-3": true}
|
|
||||||
got := make(map[string]bool)
|
|
||||||
for _, n := range names {
|
|
||||||
got[n] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for name := range want {
|
|
||||||
if !got[name] {
|
|
||||||
t.Errorf("expected %q in ModelNames, not found", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(names) != len(want) {
|
|
||||||
t.Errorf("expected %d names, got %d: %v", len(want), len(names), names)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_AllRoutes_ShowsAliases(t *testing.T) {
|
|
||||||
reg := newTestRegistry([]testModel{
|
|
||||||
{
|
|
||||||
name: "gpt-4",
|
|
||||||
aliases: []string{"gpt4", "gpt-4-latest"},
|
|
||||||
routes: []Route{
|
|
||||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
|
||||||
{Provider: &mockProvider{name: "azure"}, ProviderModel: "gpt-4", Priority: 2},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
allRoutes := reg.AllRoutes()
|
|
||||||
if len(allRoutes) != 1 {
|
|
||||||
t.Fatalf("expected 1 model, got %d", len(allRoutes))
|
|
||||||
}
|
|
||||||
|
|
||||||
m := allRoutes[0]
|
|
||||||
if m.Name != "gpt-4" {
|
|
||||||
t.Errorf("expected name 'gpt-4', got %q", m.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
aliasSet := make(map[string]bool)
|
|
||||||
for _, a := range m.Aliases {
|
|
||||||
aliasSet[a] = true
|
|
||||||
}
|
|
||||||
if !aliasSet["gpt4"] || !aliasSet["gpt-4-latest"] {
|
|
||||||
t.Errorf("expected aliases [gpt4, gpt-4-latest], got %v", m.Aliases)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.Routes) != 2 {
|
|
||||||
t.Fatalf("expected 2 routes, got %d", len(m.Routes))
|
|
||||||
}
|
|
||||||
if m.Routes[0].ProviderName != "openai" {
|
|
||||||
t.Errorf("expected first route provider 'openai', got %q", m.Routes[0].ProviderName)
|
|
||||||
}
|
|
||||||
if m.Routes[1].ProviderName != "azure" {
|
|
||||||
t.Errorf("expected second route provider 'azure', got %q", m.Routes[1].ProviderName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_AllRoutes_ConfigOrder(t *testing.T) {
|
|
||||||
reg := newTestRegistry([]testModel{
|
|
||||||
{
|
|
||||||
name: "model-b",
|
|
||||||
routes: []Route{
|
|
||||||
{Provider: &mockProvider{name: "prov"}, ProviderModel: "b", Priority: 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model-a",
|
|
||||||
routes: []Route{
|
|
||||||
{Provider: &mockProvider{name: "prov"}, ProviderModel: "a", Priority: 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
allRoutes := reg.AllRoutes()
|
|
||||||
if len(allRoutes) != 2 {
|
|
||||||
t.Fatalf("expected 2 models, got %d", len(allRoutes))
|
|
||||||
}
|
|
||||||
if allRoutes[0].Name != "model-b" {
|
|
||||||
t.Errorf("expected first model 'model-b', got %q", allRoutes[0].Name)
|
|
||||||
}
|
|
||||||
if allRoutes[1].Name != "model-a" {
|
|
||||||
t.Errorf("expected second model 'model-a', got %q", allRoutes[1].Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_PrioritySorting(t *testing.T) {
|
|
||||||
reg := newTestRegistry([]testModel{
|
|
||||||
{
|
|
||||||
name: "multi-provider",
|
|
||||||
routes: []Route{
|
|
||||||
{Provider: &mockProvider{name: "low-priority"}, ProviderModel: "m", Priority: 3},
|
|
||||||
{Provider: &mockProvider{name: "high-priority"}, ProviderModel: "m", Priority: 1},
|
|
||||||
{Provider: &mockProvider{name: "mid-priority"}, ProviderModel: "m", Priority: 2},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// Note: routes are stored as given (sorting happens during buildFromConfig).
|
|
||||||
// For this test we verify AllRoutes returns them in stored order.
|
|
||||||
allRoutes := reg.AllRoutes()
|
|
||||||
if len(allRoutes) != 1 {
|
|
||||||
t.Fatalf("expected 1 model, got %d", len(allRoutes))
|
|
||||||
}
|
|
||||||
|
|
||||||
routes := allRoutes[0].Routes
|
|
||||||
if len(routes) != 3 {
|
|
||||||
t.Fatalf("expected 3 routes, got %d", len(routes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the priorities are present
|
|
||||||
priorities := make(map[int]bool)
|
|
||||||
for _, r := range routes {
|
|
||||||
priorities[r.Priority] = true
|
|
||||||
}
|
|
||||||
for _, p := range []int{1, 2, 3} {
|
|
||||||
if !priorities[p] {
|
|
||||||
t.Errorf("expected priority %d in routes", p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_NewRegistry_UnknownProvider(t *testing.T) {
|
|
||||||
cfg := &config.Config{
|
|
||||||
Models: []config.ModelConfig{
|
|
||||||
{
|
|
||||||
Name: "test-model",
|
|
||||||
Routes: []config.RouteConfig{
|
|
||||||
{Provider: "nonexistent", Model: "m"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := NewRegistry(cfg)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("expected error for unknown provider, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_Lookup_NotFound(t *testing.T) {
|
|
||||||
reg := newTestRegistry([]testModel{
|
|
||||||
{
|
|
||||||
name: "gpt-4",
|
|
||||||
routes: []Route{
|
|
||||||
{Provider: &mockProvider{name: "openai"}, ProviderModel: "gpt-4", Priority: 1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
_, ok := reg.Lookup("nonexistent")
|
|
||||||
if ok {
|
|
||||||
t.Fatal("expected Lookup to return false for nonexistent model")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,43 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"llm-gateway/internal/auth"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthMiddleware struct {
|
|
||||||
authStore *auth.Store
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAuthMiddleware(authStore *auth.Store) *AuthMiddleware {
|
|
||||||
return &AuthMiddleware{authStore: authStore}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate validates the bearer token against the DB and sets token info in context.
|
|
||||||
func (a *AuthMiddleware) Authenticate(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
hdr := r.Header.Get("Authorization")
|
|
||||||
if !strings.HasPrefix(hdr, "Bearer ") {
|
|
||||||
writeError(w, http.StatusUnauthorized, "missing or invalid Authorization header")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
key := strings.TrimPrefix(hdr, "Bearer ")
|
|
||||||
|
|
||||||
token, err := a.authStore.LookupAPIToken(key)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusUnauthorized, "invalid API key")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update last used asynchronously (skip for static tokens)
|
|
||||||
if token.ID > 0 {
|
|
||||||
go a.authStore.UpdateAPITokenLastUsed(token.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := withTokenName(r.Context(), token.Name)
|
|
||||||
ctx = withAPIToken(ctx, token)
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
@ -1,51 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConcurrencyLimiter enforces per-token concurrent request limits.
|
|
||||||
type ConcurrencyLimiter struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
counters map[string]*atomic.Int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConcurrencyLimiter() *ConcurrencyLimiter {
|
|
||||||
return &ConcurrencyLimiter{
|
|
||||||
counters: make(map[string]*atomic.Int64),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cl *ConcurrencyLimiter) getCounter(tokenName string) *atomic.Int64 {
|
|
||||||
cl.mu.Lock()
|
|
||||||
defer cl.mu.Unlock()
|
|
||||||
c, ok := cl.counters[tokenName]
|
|
||||||
if !ok {
|
|
||||||
c = &atomic.Int64{}
|
|
||||||
cl.counters[tokenName] = c
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cl *ConcurrencyLimiter) Check(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
apiToken := getAPIToken(r.Context())
|
|
||||||
if apiToken == nil || apiToken.MaxConcurrent <= 0 {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
counter := cl.getCounter(apiToken.Name)
|
|
||||||
current := counter.Add(1)
|
|
||||||
defer counter.Add(-1)
|
|
||||||
|
|
||||||
if current > int64(apiToken.MaxConcurrent) {
|
|
||||||
writeError(w, http.StatusTooManyRequests, "concurrent request limit exceeded")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
@ -1,317 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/auth"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConcurrencyLimiter_AllowsWithinLimit(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
maxConcurrent int
|
|
||||||
numRequests int
|
|
||||||
wantAllowed int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "single request within limit",
|
|
||||||
maxConcurrent: 5,
|
|
||||||
numRequests: 1,
|
|
||||||
wantAllowed: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "all requests within limit",
|
|
||||||
maxConcurrent: 5,
|
|
||||||
numRequests: 5,
|
|
||||||
wantAllowed: 5,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cl := NewConcurrencyLimiter()
|
|
||||||
|
|
||||||
token := &auth.APIToken{
|
|
||||||
Name: "conc-token",
|
|
||||||
MaxConcurrent: tt.maxConcurrent,
|
|
||||||
}
|
|
||||||
|
|
||||||
var allowed atomic.Int64
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
// Use a channel to hold all goroutines inside the handler simultaneously.
|
|
||||||
gate := make(chan struct{})
|
|
||||||
|
|
||||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
allowed.Add(1)
|
|
||||||
<-gate // Block until released.
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
for i := 0; i < tt.numRequests; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx := withAPIToken(req.Context(), token)
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for goroutines to enter the handler.
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
close(gate)
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if int(allowed.Load()) != tt.wantAllowed {
|
|
||||||
t.Errorf("allowed = %d, want %d", allowed.Load(), tt.wantAllowed)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrencyLimiter_DeniesOverLimit(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
maxConcurrent int
|
|
||||||
numRequests int
|
|
||||||
wantDenied int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "one over limit",
|
|
||||||
maxConcurrent: 2,
|
|
||||||
numRequests: 3,
|
|
||||||
wantDenied: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "many over limit",
|
|
||||||
maxConcurrent: 1,
|
|
||||||
numRequests: 5,
|
|
||||||
wantDenied: 4,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cl := NewConcurrencyLimiter()
|
|
||||||
|
|
||||||
token := &auth.APIToken{
|
|
||||||
Name: "conc-token",
|
|
||||||
MaxConcurrent: tt.maxConcurrent,
|
|
||||||
}
|
|
||||||
|
|
||||||
var denied atomic.Int64
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
gate := make(chan struct{})
|
|
||||||
|
|
||||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
<-gate
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
results := make([]int, tt.numRequests)
|
|
||||||
for i := 0; i < tt.numRequests; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(idx int) {
|
|
||||||
defer wg.Done()
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx := withAPIToken(req.Context(), token)
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
results[idx] = rec.Code
|
|
||||||
if rec.Code == http.StatusTooManyRequests {
|
|
||||||
denied.Add(1)
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for goroutines to reach the handler or be rejected.
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
close(gate)
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if int(denied.Load()) != tt.wantDenied {
|
|
||||||
t.Errorf("denied = %d, want %d", denied.Load(), tt.wantDenied)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrencyLimiter_CounterDecrementsAfterCompletion(t *testing.T) {
|
|
||||||
cl := NewConcurrencyLimiter()
|
|
||||||
|
|
||||||
token := &auth.APIToken{
|
|
||||||
Name: "decrement-token",
|
|
||||||
MaxConcurrent: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
// First request should succeed and complete, decrementing the counter.
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx := withAPIToken(req.Context(), token)
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
|
|
||||||
if rec.Code != http.StatusOK {
|
|
||||||
t.Fatalf("first request: status = %d, want %d", rec.Code, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Counter should have decremented. A second request should also succeed.
|
|
||||||
rec2 := httptest.NewRecorder()
|
|
||||||
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx2 := withAPIToken(req2.Context(), token)
|
|
||||||
req2 = req2.WithContext(ctx2)
|
|
||||||
handler.ServeHTTP(rec2, req2)
|
|
||||||
|
|
||||||
if rec2.Code != http.StatusOK {
|
|
||||||
t.Errorf("second request after first completed: status = %d, want %d", rec2.Code, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the internal counter is back to 0.
|
|
||||||
counter := cl.getCounter(token.Name)
|
|
||||||
val := counter.Load()
|
|
||||||
if val != 0 {
|
|
||||||
t.Errorf("counter = %d, want 0 after all requests completed", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrencyLimiter_ZeroMaxConcurrentMeansUnlimited(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
maxConcurrent int
|
|
||||||
numRequests int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "zero allows unlimited concurrent requests",
|
|
||||||
maxConcurrent: 0,
|
|
||||||
numRequests: 50,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative allows unlimited concurrent requests",
|
|
||||||
maxConcurrent: -1,
|
|
||||||
numRequests: 50,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cl := NewConcurrencyLimiter()
|
|
||||||
|
|
||||||
token := &auth.APIToken{
|
|
||||||
Name: "unlimited-token",
|
|
||||||
MaxConcurrent: tt.maxConcurrent,
|
|
||||||
}
|
|
||||||
|
|
||||||
var allowed atomic.Int64
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
gate := make(chan struct{})
|
|
||||||
|
|
||||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
allowed.Add(1)
|
|
||||||
<-gate
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
for i := 0; i < tt.numRequests; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx := withAPIToken(req.Context(), token)
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Give goroutines time to enter the handler.
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
close(gate)
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
if int(allowed.Load()) != tt.numRequests {
|
|
||||||
t.Errorf("allowed = %d, want %d (zero/negative maxConcurrent should be unlimited)", allowed.Load(), tt.numRequests)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrencyLimiter_NoToken(t *testing.T) {
|
|
||||||
cl := NewConcurrencyLimiter()
|
|
||||||
|
|
||||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
// No API token in context.
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
|
|
||||||
if rec.Code != http.StatusOK {
|
|
||||||
t.Errorf("status = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrencyLimiter_PerTokenIsolation(t *testing.T) {
|
|
||||||
cl := NewConcurrencyLimiter()
|
|
||||||
|
|
||||||
tokenA := &auth.APIToken{
|
|
||||||
Name: "token-a",
|
|
||||||
MaxConcurrent: 1,
|
|
||||||
}
|
|
||||||
tokenB := &auth.APIToken{
|
|
||||||
Name: "token-b",
|
|
||||||
MaxConcurrent: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
gateA := make(chan struct{})
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
handler := cl.Check(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
tok := getAPIToken(r.Context())
|
|
||||||
if tok.Name == "token-a" {
|
|
||||||
<-gateA // Block token A's request.
|
|
||||||
}
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
// Start a request for token A that blocks.
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx := withAPIToken(req.Context(), tokenA)
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Give token A's goroutine time to enter handler.
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
// Token B should not be affected by token A's in-flight request.
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx := withAPIToken(req.Context(), tokenB)
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
|
|
||||||
if rec.Code != http.StatusOK {
|
|
||||||
t.Errorf("token-b status = %d, want %d (should not be affected by token-a)", rec.Code, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
close(gateA)
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
@ -1,107 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// inflight represents an in-progress deduplicated request.
|
|
||||||
type inflight struct {
|
|
||||||
done chan struct{}
|
|
||||||
result []byte
|
|
||||||
statusCode int
|
|
||||||
createdAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deduplicator coalesces identical concurrent non-streaming requests.
|
|
||||||
type Deduplicator struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
flights map[string]*inflight
|
|
||||||
window time.Duration
|
|
||||||
done chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDeduplicator creates a new request deduplicator.
|
|
||||||
func NewDeduplicator(window time.Duration) *Deduplicator {
|
|
||||||
if window == 0 {
|
|
||||||
window = 30 * time.Second
|
|
||||||
}
|
|
||||||
d := &Deduplicator{
|
|
||||||
flights: make(map[string]*inflight),
|
|
||||||
window: window,
|
|
||||||
done: make(chan struct{}),
|
|
||||||
}
|
|
||||||
go d.cleanup()
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
// DedupKey computes a dedup key from model name and request body.
|
|
||||||
func DedupKey(model string, body []byte) string {
|
|
||||||
h := sha256.New()
|
|
||||||
h.Write([]byte(model))
|
|
||||||
h.Write([]byte{0})
|
|
||||||
h.Write(body)
|
|
||||||
return hex.EncodeToString(h.Sum(nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TryJoin attempts to join an in-flight request. Returns the inflight entry and
|
|
||||||
// whether this caller is the leader (true) or a follower (false).
|
|
||||||
func (d *Deduplicator) TryJoin(key string) (*inflight, bool) {
|
|
||||||
d.mu.Lock()
|
|
||||||
defer d.mu.Unlock()
|
|
||||||
|
|
||||||
if f, ok := d.flights[key]; ok {
|
|
||||||
return f, false // follower
|
|
||||||
}
|
|
||||||
|
|
||||||
f := &inflight{
|
|
||||||
done: make(chan struct{}),
|
|
||||||
createdAt: time.Now(),
|
|
||||||
}
|
|
||||||
d.flights[key] = f
|
|
||||||
return f, true // leader
|
|
||||||
}
|
|
||||||
|
|
||||||
// Complete signals completion of a deduplicated request.
|
|
||||||
func (d *Deduplicator) Complete(key string, result []byte, statusCode int) {
|
|
||||||
d.mu.Lock()
|
|
||||||
f, ok := d.flights[key]
|
|
||||||
delete(d.flights, key)
|
|
||||||
d.mu.Unlock()
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
f.result = result
|
|
||||||
f.statusCode = statusCode
|
|
||||||
close(f.done)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close stops the background cleanup goroutine.
|
|
||||||
func (d *Deduplicator) Close() {
|
|
||||||
close(d.done)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup periodically removes stale in-flight entries.
|
|
||||||
func (d *Deduplicator) cleanup() {
|
|
||||||
ticker := time.NewTicker(d.window)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-d.done:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
d.mu.Lock()
|
|
||||||
now := time.Now()
|
|
||||||
for key, f := range d.flights {
|
|
||||||
if now.Sub(f.createdAt) > d.window*2 {
|
|
||||||
delete(d.flights, key)
|
|
||||||
close(f.done) // unblock any waiting followers
|
|
||||||
}
|
|
||||||
}
|
|
||||||
d.mu.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,74 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDedupKey(t *testing.T) {
|
|
||||||
k1 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
|
|
||||||
k2 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hi"}]}`))
|
|
||||||
k3 := DedupKey("gpt-4", []byte(`{"messages":[{"role":"user","content":"hello"}]}`))
|
|
||||||
|
|
||||||
if k1 != k2 {
|
|
||||||
t.Error("identical requests should produce the same key")
|
|
||||||
}
|
|
||||||
if k1 == k3 {
|
|
||||||
t.Error("different requests should produce different keys")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeduplicator_LeaderFollower(t *testing.T) {
|
|
||||||
d := NewDeduplicator(5 * time.Second)
|
|
||||||
defer d.Close()
|
|
||||||
|
|
||||||
key := DedupKey("gpt-4", []byte(`test`))
|
|
||||||
|
|
||||||
// First call is leader
|
|
||||||
f1, isLeader := d.TryJoin(key)
|
|
||||||
if !isLeader {
|
|
||||||
t.Fatal("first caller should be leader")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Second call with same key is follower
|
|
||||||
f2, isLeader := d.TryJoin(key)
|
|
||||||
if isLeader {
|
|
||||||
t.Fatal("second caller should be follower")
|
|
||||||
}
|
|
||||||
if f1 != f2 {
|
|
||||||
t.Fatal("follower should get same inflight entry")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Complete the request
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
<-f2.done
|
|
||||||
if string(f2.result) != "response" {
|
|
||||||
t.Error("follower should receive leader's result")
|
|
||||||
}
|
|
||||||
if f2.statusCode != 200 {
|
|
||||||
t.Error("follower should receive leader's status code")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
d.Complete(key, []byte("response"), 200)
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeduplicator_DifferentKeys(t *testing.T) {
|
|
||||||
d := NewDeduplicator(5 * time.Second)
|
|
||||||
defer d.Close()
|
|
||||||
|
|
||||||
_, isLeader1 := d.TryJoin("key1")
|
|
||||||
_, isLeader2 := d.TryJoin("key2")
|
|
||||||
|
|
||||||
if !isLeader1 || !isLeader2 {
|
|
||||||
t.Error("different keys should both be leaders")
|
|
||||||
}
|
|
||||||
|
|
||||||
d.Complete("key1", []byte("r1"), 200)
|
|
||||||
d.Complete("key2", []byte("r2"), 200)
|
|
||||||
}
|
|
||||||
|
|
@ -1,514 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
|
||||||
|
|
||||||
"llm-gateway/internal/auth"
|
|
||||||
"llm-gateway/internal/cache"
|
|
||||||
"llm-gateway/internal/config"
|
|
||||||
"llm-gateway/internal/metrics"
|
|
||||||
"llm-gateway/internal/provider"
|
|
||||||
"llm-gateway/internal/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
type contextKey string
|
|
||||||
|
|
||||||
const tokenNameKey contextKey = "token_name"
|
|
||||||
const apiTokenKey contextKey = "api_token"
|
|
||||||
|
|
||||||
func withTokenName(ctx context.Context, name string) context.Context {
|
|
||||||
return context.WithValue(ctx, tokenNameKey, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTokenName(ctx context.Context) string {
|
|
||||||
name, _ := ctx.Value(tokenNameKey).(string)
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
func withAPIToken(ctx context.Context, token *auth.APIToken) context.Context {
|
|
||||||
return context.WithValue(ctx, apiTokenKey, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getAPIToken(ctx context.Context) *auth.APIToken {
|
|
||||||
t, _ := ctx.Value(apiTokenKey).(*auth.APIToken)
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
type Handler struct {
|
|
||||||
registry *provider.Registry
|
|
||||||
logger *storage.AsyncLogger
|
|
||||||
cache *cache.Cache
|
|
||||||
metrics *metrics.Metrics
|
|
||||||
cfg *config.Config
|
|
||||||
healthTracker *provider.HealthTracker
|
|
||||||
debugLogger *storage.DebugLogger
|
|
||||||
dedup *Deduplicator
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHandler(registry *provider.Registry, logger *storage.AsyncLogger, c *cache.Cache, m *metrics.Metrics, cfg *config.Config, ht *provider.HealthTracker) *Handler {
|
|
||||||
return &Handler{
|
|
||||||
registry: registry,
|
|
||||||
logger: logger,
|
|
||||||
cache: c,
|
|
||||||
metrics: m,
|
|
||||||
cfg: cfg,
|
|
||||||
healthTracker: ht,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) SetDebugLogger(dl *storage.DebugLogger) {
|
|
||||||
h.debugLogger = dl
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) SetDeduplicator(d *Deduplicator) {
|
|
||||||
h.dedup = d
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|
||||||
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusBadRequest, "failed to read request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req provider.ChatRequest
|
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
|
||||||
writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Model == "" {
|
|
||||||
writeError(w, http.StatusBadRequest, "model is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
routes, ok := h.registry.Lookup(req.Model)
|
|
||||||
if !ok {
|
|
||||||
writeError(w, http.StatusNotFound, "model not found: "+req.Model)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter healthy routes (circuit breaker)
|
|
||||||
routes = h.filterHealthyRoutes(routes)
|
|
||||||
|
|
||||||
tokenName := getTokenName(r.Context())
|
|
||||||
requestID := middleware.GetReqID(r.Context())
|
|
||||||
|
|
||||||
// Check cache for non-streaming requests
|
|
||||||
if !req.Stream && h.cache != nil {
|
|
||||||
if cached, err := h.cache.Get(r.Context(), req.Model, body); err == nil && cached != nil {
|
|
||||||
h.logRequest(requestID, tokenName, req.Model, "cache", "", 0, 0, 0, 0, "cached", "", false, true)
|
|
||||||
if h.metrics != nil {
|
|
||||||
h.metrics.RecordCacheHit()
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("X-Cache", "HIT")
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
w.Write(cached)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if h.metrics != nil {
|
|
||||||
h.metrics.RecordCacheMiss()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply per-model timeout for non-streaming requests
|
|
||||||
modelTimeouts := h.registry.ModelTimeoutsFor(req.Model)
|
|
||||||
|
|
||||||
if req.Stream {
|
|
||||||
h.handleStream(w, r, &req, routes, tokenName, requestID, modelTimeouts)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request deduplication for non-streaming requests
|
|
||||||
if h.dedup != nil {
|
|
||||||
dedupKey := DedupKey(req.Model, body)
|
|
||||||
flight, isLeader := h.dedup.TryJoin(dedupKey)
|
|
||||||
if !isLeader {
|
|
||||||
// Wait for the leader to complete
|
|
||||||
select {
|
|
||||||
case <-flight.done:
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
w.Header().Set("X-Dedup", "HIT")
|
|
||||||
w.WriteHeader(flight.statusCode)
|
|
||||||
w.Write(flight.result)
|
|
||||||
return
|
|
||||||
case <-r.Context().Done():
|
|
||||||
writeError(w, http.StatusGatewayTimeout, "request cancelled while waiting for dedup")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Leader: proceed normally, but capture response for followers
|
|
||||||
defer func() {
|
|
||||||
// If we haven't completed yet (e.g., panic), clean up
|
|
||||||
}()
|
|
||||||
h.handleNonStreamDedup(w, r, &req, routes, tokenName, body, requestID, modelTimeouts, dedupKey)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
|
|
||||||
defer cancel()
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
h.handleNonStream(w, r, &req, routes, tokenName, body, requestID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) handleNonStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string) {
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for i, route := range routes {
|
|
||||||
// Retry backoff between attempts (not before first attempt)
|
|
||||||
if i > 0 {
|
|
||||||
backoff := backoffDuration(i, h.cfg.Retry)
|
|
||||||
select {
|
|
||||||
case <-time.After(backoff):
|
|
||||||
case <-r.Context().Done():
|
|
||||||
writeError(w, http.StatusGatewayTimeout, "request cancelled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
resp, err := route.Provider.ChatCompletion(r.Context(), route.ProviderModel, req)
|
|
||||||
latency := time.Since(start).Milliseconds()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
var pe *provider.ProviderError
|
|
||||||
if errors.As(err, &pe) && !pe.IsRetryable() {
|
|
||||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
|
||||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
||||||
}
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
writeErrorRaw(w, pe.StatusCode, pe.Body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lastErr = err
|
|
||||||
log.Printf("Provider %s failed for %s: %v", route.Provider.Name(), req.Model, err)
|
|
||||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
|
||||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), false, false)
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
h.healthTracker.Record(route.Provider.Name(), latency, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
inputTokens, outputTokens := 0, 0
|
|
||||||
if resp.Usage != nil {
|
|
||||||
inputTokens = resp.Usage.PromptTokens
|
|
||||||
outputTokens = resp.Usage.CompletionTokens
|
|
||||||
}
|
|
||||||
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
|
|
||||||
|
|
||||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
|
|
||||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", false, false)
|
|
||||||
|
|
||||||
resp.Model = req.Model
|
|
||||||
|
|
||||||
respBytes, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, "failed to marshal response")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.cache != nil {
|
|
||||||
h.cache.Set(r.Context(), req.Model, rawBody, respBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug logging
|
|
||||||
if h.debugLogger != nil && h.debugLogger.IsEnabled() {
|
|
||||||
reqBody := string(rawBody)
|
|
||||||
respBody := string(respBytes)
|
|
||||||
if h.cfg.Debug.MaxBodyBytes > 0 {
|
|
||||||
if len(reqBody) > h.cfg.Debug.MaxBodyBytes {
|
|
||||||
reqBody = reqBody[:h.cfg.Debug.MaxBodyBytes]
|
|
||||||
}
|
|
||||||
if len(respBody) > h.cfg.Debug.MaxBodyBytes {
|
|
||||||
respBody = respBody[:h.cfg.Debug.MaxBodyBytes]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.debugLogger.Log(storage.DebugLogEntry{
|
|
||||||
RequestID: requestID,
|
|
||||||
TokenName: tokenName,
|
|
||||||
Model: req.Model,
|
|
||||||
Provider: route.Provider.Name(),
|
|
||||||
RequestBody: reqBody,
|
|
||||||
ResponseBody: respBody,
|
|
||||||
RequestHeaders: formatHeaders(r.Header),
|
|
||||||
ResponseStatus: http.StatusOK,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("X-Cache", "MISS")
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
w.Write(respBytes)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastErr != nil {
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
|
|
||||||
} else {
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
writeError(w, http.StatusBadGateway, "all providers failed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleNonStreamDedup wraps handleNonStream to capture the response for dedup followers.
|
|
||||||
func (h *Handler) Embeddings(w http.ResponseWriter, r *http.Request) {
|
|
||||||
body, err := io.ReadAll(io.LimitReader(r.Body, int64(h.cfg.Server.MaxRequestBodyMB)<<20))
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusBadRequest, "failed to read request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var req provider.EmbeddingRequest
|
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
|
||||||
writeError(w, http.StatusBadRequest, "invalid JSON: "+err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Model == "" {
|
|
||||||
writeError(w, http.StatusBadRequest, "model is required")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
routes, ok := h.registry.Lookup(req.Model)
|
|
||||||
if !ok {
|
|
||||||
writeError(w, http.StatusNotFound, "model not found: "+req.Model)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
routes = h.filterHealthyRoutes(routes)
|
|
||||||
tokenName := getTokenName(r.Context())
|
|
||||||
requestID := middleware.GetReqID(r.Context())
|
|
||||||
|
|
||||||
var lastErr error
|
|
||||||
for i, route := range routes {
|
|
||||||
if i > 0 {
|
|
||||||
backoff := backoffDuration(i, h.cfg.Retry)
|
|
||||||
select {
|
|
||||||
case <-time.After(backoff):
|
|
||||||
case <-r.Context().Done():
|
|
||||||
writeError(w, http.StatusGatewayTimeout, "request cancelled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
resp, err := route.Provider.Embedding(r.Context(), route.ProviderModel, &req)
|
|
||||||
latency := time.Since(start).Milliseconds()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
var pe *provider.ProviderError
|
|
||||||
if errors.As(err, &pe) && !pe.IsRetryable() {
|
|
||||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
|
||||||
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
||||||
}
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
writeErrorRaw(w, pe.StatusCode, pe.Body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lastErr = err
|
|
||||||
log.Printf("Provider %s embedding failed for %s: %v", route.Provider.Name(), req.Model, err)
|
|
||||||
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, latency, "error", err.Error())
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
h.healthTracker.Record(route.Provider.Name(), latency, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
promptTokens := 0
|
|
||||||
if resp.Usage != nil {
|
|
||||||
promptTokens = resp.Usage.PromptTokens
|
|
||||||
}
|
|
||||||
cost := float64(promptTokens) / 1_000_000.0 * route.InputPrice
|
|
||||||
|
|
||||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, promptTokens, 0, cost)
|
|
||||||
h.logEmbeddingRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, promptTokens, cost, latency, "success", "")
|
|
||||||
|
|
||||||
resp.Model = req.Model
|
|
||||||
|
|
||||||
respBytes, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
writeError(w, http.StatusInternalServerError, "failed to marshal response")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
w.Write(respBytes)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
if lastErr != nil {
|
|
||||||
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
|
|
||||||
} else {
|
|
||||||
writeError(w, http.StatusBadGateway, "all providers failed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) logEmbeddingRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens int, cost float64, latencyMS int64, status, errMsg string) {
|
|
||||||
h.logger.Log(storage.RequestLog{
|
|
||||||
RequestID: requestID,
|
|
||||||
Timestamp: time.Now().Unix(),
|
|
||||||
TokenName: tokenName,
|
|
||||||
Model: model,
|
|
||||||
Provider: providerName,
|
|
||||||
ProviderModel: providerModel,
|
|
||||||
InputTokens: inputTokens,
|
|
||||||
CostUSD: cost,
|
|
||||||
LatencyMS: latencyMS,
|
|
||||||
Status: status,
|
|
||||||
ErrorMessage: errMsg,
|
|
||||||
RequestType: "embedding",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) handleNonStreamDedup(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName string, rawBody []byte, requestID string, modelTimeouts *provider.ModelTimeouts, dedupKey string) {
|
|
||||||
if modelTimeouts != nil && modelTimeouts.RequestTimeout > 0 {
|
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), modelTimeouts.RequestTimeout)
|
|
||||||
defer cancel()
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
|
|
||||||
h.handleNonStream(rec, r, req, routes, tokenName, rawBody, requestID)
|
|
||||||
h.dedup.Complete(dedupKey, rec.body, rec.statusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// responseRecorder captures the response for dedup.
|
|
||||||
type responseRecorder struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
statusCode int
|
|
||||||
body []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *responseRecorder) WriteHeader(code int) {
|
|
||||||
r.statusCode = code
|
|
||||||
r.ResponseWriter.WriteHeader(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *responseRecorder) Write(b []byte) (int, error) {
|
|
||||||
r.body = append(r.body, b...)
|
|
||||||
return r.ResponseWriter.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterHealthyRoutes removes providers with open circuit breakers.
|
|
||||||
// If all are filtered out, returns original routes as fallback.
|
|
||||||
func (h *Handler) filterHealthyRoutes(routes []provider.Route) []provider.Route {
|
|
||||||
if h.healthTracker == nil {
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
var healthy []provider.Route
|
|
||||||
for _, r := range routes {
|
|
||||||
if h.healthTracker.IsAvailable(r.Provider.Name()) {
|
|
||||||
healthy = append(healthy, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(healthy) == 0 {
|
|
||||||
return routes // all-down fallback
|
|
||||||
}
|
|
||||||
return healthy
|
|
||||||
}
|
|
||||||
|
|
||||||
// backoffDuration computes exponential backoff for the given attempt.
|
|
||||||
func backoffDuration(attempt int, cfg config.RetryConfig) time.Duration {
|
|
||||||
d := cfg.InitialBackoff
|
|
||||||
for i := 1; i < attempt; i++ {
|
|
||||||
d = time.Duration(float64(d) * cfg.Multiplier)
|
|
||||||
if d > cfg.MaxBackoff {
|
|
||||||
d = cfg.MaxBackoff
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) logRequest(requestID, tokenName, model, providerName, providerModel string, inputTokens, outputTokens int, cost float64, latencyMS int64, status, errMsg string, streaming, cached bool) {
|
|
||||||
h.logger.Log(storage.RequestLog{
|
|
||||||
RequestID: requestID,
|
|
||||||
Timestamp: time.Now().Unix(),
|
|
||||||
TokenName: tokenName,
|
|
||||||
Model: model,
|
|
||||||
Provider: providerName,
|
|
||||||
ProviderModel: providerModel,
|
|
||||||
InputTokens: inputTokens,
|
|
||||||
OutputTokens: outputTokens,
|
|
||||||
CostUSD: cost,
|
|
||||||
LatencyMS: latencyMS,
|
|
||||||
Status: status,
|
|
||||||
ErrorMessage: errMsg,
|
|
||||||
Streaming: streaming,
|
|
||||||
Cached: cached,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func computeCost(inputTokens, outputTokens int, inputPrice, outputPrice float64) float64 {
|
|
||||||
return (float64(inputTokens) / 1_000_000.0 * inputPrice) + (float64(outputTokens) / 1_000_000.0 * outputPrice)
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeError(w http.ResponseWriter, code int, msg string) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(code)
|
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
|
||||||
"error": map[string]any{
|
|
||||||
"message": msg,
|
|
||||||
"type": "error",
|
|
||||||
"code": code,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeErrorRaw(w http.ResponseWriter, code int, body string) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.WriteHeader(code)
|
|
||||||
w.Write([]byte(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatHeaders serializes HTTP headers to a readable string, sorted by key.
|
|
||||||
// Sensitive headers (Authorization) are redacted.
|
|
||||||
func formatHeaders(h http.Header) string {
|
|
||||||
keys := make([]string, 0, len(h))
|
|
||||||
for k := range h {
|
|
||||||
keys = append(keys, k)
|
|
||||||
}
|
|
||||||
sort.Strings(keys)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
for _, k := range keys {
|
|
||||||
val := strings.Join(h[k], ", ")
|
|
||||||
if strings.EqualFold(k, "Authorization") {
|
|
||||||
val = "[REDACTED]"
|
|
||||||
}
|
|
||||||
fmt.Fprintf(&b, "%s: %s\n", k, val)
|
|
||||||
}
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
@ -1,75 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/config"
|
|
||||||
"llm-gateway/internal/provider"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ModelsHandler struct {
|
|
||||||
registry *provider.Registry
|
|
||||||
healthTracker *provider.HealthTracker
|
|
||||||
cfg *config.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewModelsHandler(registry *provider.Registry, healthTracker *provider.HealthTracker, cfg *config.Config) *ModelsHandler {
|
|
||||||
return &ModelsHandler{
|
|
||||||
registry: registry,
|
|
||||||
healthTracker: healthTracker,
|
|
||||||
cfg: cfg,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *ModelsHandler) ListModels(w http.ResponseWriter, r *http.Request) {
|
|
||||||
allRoutes := h.registry.AllRoutes()
|
|
||||||
models := make([]map[string]any, 0, len(allRoutes))
|
|
||||||
|
|
||||||
for _, m := range allRoutes {
|
|
||||||
providers := make([]map[string]any, 0, len(m.Routes))
|
|
||||||
for _, rt := range m.Routes {
|
|
||||||
healthy := true
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
healthy = h.healthTracker.IsAvailable(rt.ProviderName)
|
|
||||||
}
|
|
||||||
providers = append(providers, map[string]any{
|
|
||||||
"name": rt.ProviderName,
|
|
||||||
"model": rt.ProviderModel,
|
|
||||||
"input_price": rt.InputPrice,
|
|
||||||
"output_price": rt.OutputPrice,
|
|
||||||
"priority": rt.Priority,
|
|
||||||
"healthy": healthy,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find load balancing strategy from config
|
|
||||||
loadBalancing := "first"
|
|
||||||
for _, mc := range h.cfg.Models {
|
|
||||||
if mc.Name == m.Name {
|
|
||||||
if mc.LoadBalancing != "" {
|
|
||||||
loadBalancing = mc.LoadBalancing
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
models = append(models, map[string]any{
|
|
||||||
"id": m.Name,
|
|
||||||
"object": "model",
|
|
||||||
"created": time.Now().Unix(),
|
|
||||||
"owned_by": "llm-gateway",
|
|
||||||
"providers": providers,
|
|
||||||
"provider_count": len(providers),
|
|
||||||
"load_balancing": loadBalancing,
|
|
||||||
"aliases": m.Aliases,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
json.NewEncoder(w).Encode(map[string]any{
|
|
||||||
"object": "list",
|
|
||||||
"data": models,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
@ -1,169 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/storage"
|
|
||||||
"llm-gateway/internal/webhook"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RateLimiter struct {
|
|
||||||
db *storage.DB
|
|
||||||
mu sync.Mutex
|
|
||||||
buckets map[string]*tokenBucket
|
|
||||||
notifier *webhook.Notifier
|
|
||||||
budgetNotified sync.Map // tracks which token+budget combos have been notified
|
|
||||||
}
|
|
||||||
|
|
||||||
type tokenBucket struct {
|
|
||||||
tokens float64
|
|
||||||
maxTokens float64
|
|
||||||
refillRate float64 // tokens per second
|
|
||||||
lastRefill time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRateLimiter(db *storage.DB) *RateLimiter {
|
|
||||||
return &RateLimiter{
|
|
||||||
db: db,
|
|
||||||
buckets: make(map[string]*tokenBucket),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNotifier sets the webhook notifier for budget threshold alerts.
|
|
||||||
func (rl *RateLimiter) SetNotifier(n *webhook.Notifier) {
|
|
||||||
rl.notifier = n
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rl *RateLimiter) Check(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
apiToken := getAPIToken(r.Context())
|
|
||||||
if apiToken == nil {
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenName := apiToken.Name
|
|
||||||
|
|
||||||
// Check rate limit
|
|
||||||
if apiToken.RateLimitRPM > 0 {
|
|
||||||
allowed, remaining, resetAt := rl.allow(tokenName, apiToken.RateLimitRPM)
|
|
||||||
|
|
||||||
// Set rate limit headers on all responses
|
|
||||||
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", apiToken.RateLimitRPM))
|
|
||||||
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", remaining))
|
|
||||||
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", resetAt))
|
|
||||||
|
|
||||||
if !allowed {
|
|
||||||
retryAfter := resetAt - time.Now().Unix()
|
|
||||||
if retryAfter < 1 {
|
|
||||||
retryAfter = 1
|
|
||||||
}
|
|
||||||
w.Header().Set("Retry-After", fmt.Sprintf("%d", retryAfter))
|
|
||||||
writeError(w, http.StatusTooManyRequests, "rate limit exceeded")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check daily budget
|
|
||||||
if apiToken.DailyBudgetUSD > 0 {
|
|
||||||
spent, err := rl.db.TodaySpend(tokenName)
|
|
||||||
if err == nil {
|
|
||||||
if spent >= apiToken.DailyBudgetUSD {
|
|
||||||
writeError(w, http.StatusTooManyRequests, "daily budget exceeded")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
rl.checkBudgetThreshold(tokenName, "daily", spent, apiToken.DailyBudgetUSD)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check monthly budget
|
|
||||||
if apiToken.MonthlyBudgetUSD > 0 {
|
|
||||||
spent, err := rl.db.MonthSpend(tokenName)
|
|
||||||
if err == nil {
|
|
||||||
if spent >= apiToken.MonthlyBudgetUSD {
|
|
||||||
writeError(w, http.StatusTooManyRequests, "monthly budget exceeded")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
rl.checkBudgetThreshold(tokenName, "monthly", spent, apiToken.MonthlyBudgetUSD)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkBudgetThreshold fires a webhook notification when spend reaches 80% of budget.
|
|
||||||
func (rl *RateLimiter) checkBudgetThreshold(tokenName, budgetType string, spent, budget float64) {
|
|
||||||
if rl.notifier == nil || budget <= 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if spent/budget < 0.8 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
key := tokenName + ":" + budgetType
|
|
||||||
if _, loaded := rl.budgetNotified.LoadOrStore(key, true); loaded {
|
|
||||||
return // already notified
|
|
||||||
}
|
|
||||||
rl.notifier.Notify(webhook.Event{
|
|
||||||
Type: webhook.EventBudgetThreshold,
|
|
||||||
Data: map[string]any{
|
|
||||||
"token": tokenName,
|
|
||||||
"budget_type": budgetType,
|
|
||||||
"spent": spent,
|
|
||||||
"budget": budget,
|
|
||||||
"percent": spent / budget * 100,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rl *RateLimiter) allow(tokenName string, rateLimitRPM int) (bool, int, int64) {
|
|
||||||
rl.mu.Lock()
|
|
||||||
defer rl.mu.Unlock()
|
|
||||||
|
|
||||||
bucket, ok := rl.buckets[tokenName]
|
|
||||||
if !ok {
|
|
||||||
bucket = &tokenBucket{
|
|
||||||
tokens: float64(rateLimitRPM),
|
|
||||||
maxTokens: float64(rateLimitRPM),
|
|
||||||
refillRate: float64(rateLimitRPM) / 60.0,
|
|
||||||
lastRefill: time.Now(),
|
|
||||||
}
|
|
||||||
rl.buckets[tokenName] = bucket
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
elapsed := now.Sub(bucket.lastRefill).Seconds()
|
|
||||||
bucket.tokens += elapsed * bucket.refillRate
|
|
||||||
if bucket.tokens > bucket.maxTokens {
|
|
||||||
bucket.tokens = bucket.maxTokens
|
|
||||||
}
|
|
||||||
bucket.lastRefill = now
|
|
||||||
|
|
||||||
remaining := int(math.Floor(bucket.tokens))
|
|
||||||
if remaining < 0 {
|
|
||||||
remaining = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute reset time: when bucket would be full again
|
|
||||||
deficit := bucket.maxTokens - bucket.tokens
|
|
||||||
var resetAt int64
|
|
||||||
if deficit > 0 && bucket.refillRate > 0 {
|
|
||||||
resetAt = now.Add(time.Duration(deficit/bucket.refillRate) * time.Second).Unix()
|
|
||||||
} else {
|
|
||||||
resetAt = now.Unix()
|
|
||||||
}
|
|
||||||
|
|
||||||
if bucket.tokens < 1 {
|
|
||||||
return false, 0, resetAt
|
|
||||||
}
|
|
||||||
bucket.tokens--
|
|
||||||
remaining = int(math.Floor(bucket.tokens))
|
|
||||||
if remaining < 0 {
|
|
||||||
remaining = 0
|
|
||||||
}
|
|
||||||
return true, remaining, resetAt
|
|
||||||
}
|
|
||||||
|
|
@ -1,374 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
|
|
||||||
"llm-gateway/internal/auth"
|
|
||||||
"llm-gateway/internal/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
// newTestDB creates an in-memory SQLite database wrapped in storage.DB.
|
|
||||||
// It creates the request_logs table needed by TodaySpend.
|
|
||||||
func newTestDB(t *testing.T) *storage.DB {
|
|
||||||
t.Helper()
|
|
||||||
sqlDB, err := sql.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("opening in-memory sqlite: %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { sqlDB.Close() })
|
|
||||||
|
|
||||||
// Create the minimal table needed for TodaySpend queries.
|
|
||||||
_, err = sqlDB.Exec(`CREATE TABLE IF NOT EXISTS request_logs (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
token_name TEXT,
|
|
||||||
cost_usd REAL,
|
|
||||||
timestamp INTEGER
|
|
||||||
)`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("creating request_logs table: %v", err)
|
|
||||||
}
|
|
||||||
return &storage.DB{DB: sqlDB}
|
|
||||||
}
|
|
||||||
|
|
||||||
// okHandler is a simple handler that writes 200 OK.
|
|
||||||
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
func TestRateLimiter_Allow(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
rateLimitRPM int
|
|
||||||
numRequests int
|
|
||||||
wantAllowed int
|
|
||||||
wantDenied int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "allows requests within limit",
|
|
||||||
rateLimitRPM: 10,
|
|
||||||
numRequests: 5,
|
|
||||||
wantAllowed: 5,
|
|
||||||
wantDenied: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "denies requests over limit",
|
|
||||||
rateLimitRPM: 3,
|
|
||||||
numRequests: 6,
|
|
||||||
wantAllowed: 3,
|
|
||||||
wantDenied: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "allows exactly up to limit",
|
|
||||||
rateLimitRPM: 5,
|
|
||||||
numRequests: 5,
|
|
||||||
wantAllowed: 5,
|
|
||||||
wantDenied: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
rl := NewRateLimiter(db)
|
|
||||||
|
|
||||||
allowed := 0
|
|
||||||
denied := 0
|
|
||||||
for i := 0; i < tt.numRequests; i++ {
|
|
||||||
ok, _, _ := rl.allow("test-token", tt.rateLimitRPM)
|
|
||||||
if ok {
|
|
||||||
allowed++
|
|
||||||
} else {
|
|
||||||
denied++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if allowed != tt.wantAllowed {
|
|
||||||
t.Errorf("allowed = %d, want %d", allowed, tt.wantAllowed)
|
|
||||||
}
|
|
||||||
if denied != tt.wantDenied {
|
|
||||||
t.Errorf("denied = %d, want %d", denied, tt.wantDenied)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiter_TokenRefillsOverTime(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
rl := NewRateLimiter(db)
|
|
||||||
|
|
||||||
rpm := 60 // 1 token per second refill rate
|
|
||||||
|
|
||||||
// Exhaust all tokens.
|
|
||||||
for i := 0; i < rpm; i++ {
|
|
||||||
ok, _, _ := rl.allow("refill-token", rpm)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("request %d should have been allowed", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next request should be denied.
|
|
||||||
ok, _, _ := rl.allow("refill-token", rpm)
|
|
||||||
if ok {
|
|
||||||
t.Fatal("request should have been denied after exhausting tokens")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manually advance the bucket's lastRefill to simulate time passing.
|
|
||||||
rl.mu.Lock()
|
|
||||||
bucket := rl.buckets["refill-token"]
|
|
||||||
bucket.lastRefill = bucket.lastRefill.Add(-2 * time.Second)
|
|
||||||
rl.mu.Unlock()
|
|
||||||
|
|
||||||
// After 2 seconds at 1 token/sec, we should have ~2 tokens refilled.
|
|
||||||
ok, remaining, _ := rl.allow("refill-token", rpm)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("request should have been allowed after token refill")
|
|
||||||
}
|
|
||||||
// We consumed 1 of the ~2 refilled tokens, so remaining should be >= 0.
|
|
||||||
if remaining < 0 {
|
|
||||||
t.Errorf("remaining = %d, want >= 0", remaining)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiter_AllowReturnValues(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
rateLimitRPM int
|
|
||||||
numRequests int
|
|
||||||
wantLastAllowed bool
|
|
||||||
wantLastRemaining int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "remaining decrements correctly",
|
|
||||||
rateLimitRPM: 5,
|
|
||||||
numRequests: 1,
|
|
||||||
wantLastAllowed: true,
|
|
||||||
wantLastRemaining: 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "remaining is zero at limit",
|
|
||||||
rateLimitRPM: 3,
|
|
||||||
numRequests: 3,
|
|
||||||
wantLastAllowed: true,
|
|
||||||
wantLastRemaining: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "denied returns zero remaining",
|
|
||||||
rateLimitRPM: 2,
|
|
||||||
numRequests: 3,
|
|
||||||
wantLastAllowed: false,
|
|
||||||
wantLastRemaining: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
rl := NewRateLimiter(db)
|
|
||||||
|
|
||||||
var allowed bool
|
|
||||||
var remaining int
|
|
||||||
for i := 0; i < tt.numRequests; i++ {
|
|
||||||
allowed, remaining, _ = rl.allow("test-token", tt.rateLimitRPM)
|
|
||||||
}
|
|
||||||
|
|
||||||
if allowed != tt.wantLastAllowed {
|
|
||||||
t.Errorf("allowed = %v, want %v", allowed, tt.wantLastAllowed)
|
|
||||||
}
|
|
||||||
if remaining != tt.wantLastRemaining {
|
|
||||||
t.Errorf("remaining = %d, want %d", remaining, tt.wantLastRemaining)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiter_CheckMiddleware_Headers(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
rateLimitRPM int
|
|
||||||
numRequests int
|
|
||||||
wantStatusCode int
|
|
||||||
wantLimitHeader string
|
|
||||||
wantRetryAfter bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "sets rate limit headers on allowed request",
|
|
||||||
rateLimitRPM: 10,
|
|
||||||
numRequests: 1,
|
|
||||||
wantStatusCode: http.StatusOK,
|
|
||||||
wantLimitHeader: "10",
|
|
||||||
wantRetryAfter: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sets Retry-After header on 429",
|
|
||||||
rateLimitRPM: 2,
|
|
||||||
numRequests: 3,
|
|
||||||
wantStatusCode: http.StatusTooManyRequests,
|
|
||||||
wantLimitHeader: "2",
|
|
||||||
wantRetryAfter: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
rl := NewRateLimiter(db)
|
|
||||||
|
|
||||||
token := &auth.APIToken{
|
|
||||||
Name: "header-test-token",
|
|
||||||
RateLimitRPM: tt.rateLimitRPM,
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := rl.Check(okHandler)
|
|
||||||
|
|
||||||
var rec *httptest.ResponseRecorder
|
|
||||||
for i := 0; i < tt.numRequests; i++ {
|
|
||||||
rec = httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx := withAPIToken(req.Context(), token)
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the last response.
|
|
||||||
if rec.Code != tt.wantStatusCode {
|
|
||||||
t.Errorf("status code = %d, want %d", rec.Code, tt.wantStatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// X-RateLimit-Limit header.
|
|
||||||
limitHeader := rec.Header().Get("X-RateLimit-Limit")
|
|
||||||
if limitHeader != tt.wantLimitHeader {
|
|
||||||
t.Errorf("X-RateLimit-Limit = %q, want %q", limitHeader, tt.wantLimitHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
// X-RateLimit-Remaining header must be present and numeric.
|
|
||||||
remainingHeader := rec.Header().Get("X-RateLimit-Remaining")
|
|
||||||
if remainingHeader == "" {
|
|
||||||
t.Error("X-RateLimit-Remaining header is missing")
|
|
||||||
} else if _, err := strconv.Atoi(remainingHeader); err != nil {
|
|
||||||
t.Errorf("X-RateLimit-Remaining = %q, not a valid integer", remainingHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
// X-RateLimit-Reset header must be present and numeric.
|
|
||||||
resetHeader := rec.Header().Get("X-RateLimit-Reset")
|
|
||||||
if resetHeader == "" {
|
|
||||||
t.Error("X-RateLimit-Reset header is missing")
|
|
||||||
} else if _, err := strconv.ParseInt(resetHeader, 10, 64); err != nil {
|
|
||||||
t.Errorf("X-RateLimit-Reset = %q, not a valid integer", resetHeader)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retry-After header.
|
|
||||||
retryAfter := rec.Header().Get("Retry-After")
|
|
||||||
if tt.wantRetryAfter && retryAfter == "" {
|
|
||||||
t.Error("Retry-After header is missing on 429 response")
|
|
||||||
}
|
|
||||||
if !tt.wantRetryAfter && retryAfter != "" {
|
|
||||||
t.Errorf("Retry-After header should not be present, got %q", retryAfter)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiter_CheckMiddleware_NoToken(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
rl := NewRateLimiter(db)
|
|
||||||
|
|
||||||
handler := rl.Check(okHandler)
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
// No API token in context.
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
|
|
||||||
if rec.Code != http.StatusOK {
|
|
||||||
t.Errorf("status code = %d, want %d (should pass through without token)", rec.Code, http.StatusOK)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiter_CheckMiddleware_ZeroRPM(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
rl := NewRateLimiter(db)
|
|
||||||
|
|
||||||
token := &auth.APIToken{
|
|
||||||
Name: "unlimited-token",
|
|
||||||
RateLimitRPM: 0, // zero means unlimited
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := rl.Check(okHandler)
|
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx := withAPIToken(req.Context(), token)
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
|
|
||||||
if rec.Code != http.StatusOK {
|
|
||||||
t.Fatalf("request %d: status code = %d, want %d (zero RPM should be unlimited)", i, rec.Code, http.StatusOK)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiter_PerTokenIsolation(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
rl := NewRateLimiter(db)
|
|
||||||
|
|
||||||
rpm := 2
|
|
||||||
|
|
||||||
// Exhaust token A.
|
|
||||||
for i := 0; i < rpm; i++ {
|
|
||||||
rl.allow("token-a", rpm)
|
|
||||||
}
|
|
||||||
ok, _, _ := rl.allow("token-a", rpm)
|
|
||||||
if ok {
|
|
||||||
t.Fatal("token-a should be rate limited")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Token B should still have its own bucket.
|
|
||||||
ok, _, _ = rl.allow("token-b", rpm)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("token-b should not be affected by token-a's rate limit")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiter_ResetAtIsFuture(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
rl := NewRateLimiter(db)
|
|
||||||
|
|
||||||
// Consume one token so there's a deficit.
|
|
||||||
_, _, resetAt := rl.allow("reset-token", 10)
|
|
||||||
now := time.Now().Unix()
|
|
||||||
|
|
||||||
if resetAt < now {
|
|
||||||
t.Errorf("resetAt = %d, want >= %d (should be now or in the future)", resetAt, now)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRateLimiter_CheckMiddleware_ContextCancelled(t *testing.T) {
|
|
||||||
db := newTestDB(t)
|
|
||||||
rl := NewRateLimiter(db)
|
|
||||||
|
|
||||||
token := &auth.APIToken{
|
|
||||||
Name: "ctx-token",
|
|
||||||
RateLimitRPM: 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := rl.Check(okHandler)
|
|
||||||
rec := httptest.NewRecorder()
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
ctx, cancel := context.WithCancel(req.Context())
|
|
||||||
ctx = withAPIToken(ctx, token)
|
|
||||||
cancel() // Cancel immediately.
|
|
||||||
req = req.WithContext(ctx)
|
|
||||||
|
|
||||||
// Should still process (rate limiter does not check context cancellation).
|
|
||||||
handler.ServeHTTP(rec, req)
|
|
||||||
// The handler itself may or may not respect cancelled context;
|
|
||||||
// the key point is no panic occurs.
|
|
||||||
}
|
|
||||||
|
|
@ -1,195 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/provider"
|
|
||||||
"llm-gateway/internal/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, routes []provider.Route, tokenName, requestID string, modelTimeouts *provider.ModelTimeouts) {
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
writeError(w, http.StatusInternalServerError, "streaming not supported")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for i, route := range routes {
|
|
||||||
// Retry backoff between attempts
|
|
||||||
if i > 0 {
|
|
||||||
backoff := backoffDuration(i, h.cfg.Retry)
|
|
||||||
select {
|
|
||||||
case <-time.After(backoff):
|
|
||||||
case <-r.Context().Done():
|
|
||||||
writeError(w, http.StatusGatewayTimeout, "request cancelled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
body, err := route.Provider.ChatCompletionStream(r.Context(), route.ProviderModel, req)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
var pe *provider.ProviderError
|
|
||||||
if errors.As(err, &pe) && !pe.IsRetryable() {
|
|
||||||
latency := time.Since(start).Milliseconds()
|
|
||||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "error", latency, 0, 0, 0)
|
|
||||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
||||||
}
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
writeErrorRaw(w, pe.StatusCode, pe.Body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lastErr = err
|
|
||||||
latency := time.Since(start).Milliseconds()
|
|
||||||
log.Printf("Provider %s stream failed for %s: %v", route.Provider.Name(), req.Model, err)
|
|
||||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, 0, 0, 0, latency, "error", err.Error(), true, false)
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
h.healthTracker.Record(route.Provider.Name(), latency, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply streaming timeout (per-model override takes precedence)
|
|
||||||
streamingTimeout := h.cfg.Server.StreamingTimeout
|
|
||||||
if modelTimeouts != nil && modelTimeouts.StreamingTimeout > 0 {
|
|
||||||
streamingTimeout = modelTimeouts.StreamingTimeout
|
|
||||||
}
|
|
||||||
var streamCtx context.Context
|
|
||||||
var streamCancel context.CancelFunc
|
|
||||||
if streamingTimeout > 0 {
|
|
||||||
streamCtx, streamCancel = context.WithTimeout(r.Context(), streamingTimeout)
|
|
||||||
} else {
|
|
||||||
streamCtx, streamCancel = context.WithCancel(r.Context())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream the response
|
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
|
||||||
w.Header().Set("Connection", "keep-alive")
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
|
|
||||||
inputTokens, outputTokens := 0, 0
|
|
||||||
scanner := bufio.NewScanner(body)
|
|
||||||
scanner.Buffer(make([]byte, 64*1024), 256*1024)
|
|
||||||
|
|
||||||
// Capture streamed lines for debug logging
|
|
||||||
debugEnabled := h.debugLogger != nil && h.debugLogger.IsEnabled()
|
|
||||||
var debugLines []string
|
|
||||||
|
|
||||||
scanDone := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer close(scanDone)
|
|
||||||
for scanner.Scan() {
|
|
||||||
select {
|
|
||||||
case <-streamCtx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
line := scanner.Text()
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "data: ") {
|
|
||||||
data := strings.TrimPrefix(line, "data: ")
|
|
||||||
if data != "[DONE]" {
|
|
||||||
var chunk streamChunk
|
|
||||||
if json.Unmarshal([]byte(data), &chunk) == nil {
|
|
||||||
if chunk.Usage != nil {
|
|
||||||
inputTokens = chunk.Usage.PromptTokens
|
|
||||||
outputTokens = chunk.Usage.CompletionTokens
|
|
||||||
}
|
|
||||||
if chunk.Model != "" {
|
|
||||||
chunk.Model = req.Model
|
|
||||||
if rewritten, err := json.Marshal(chunk); err == nil {
|
|
||||||
line = "data: " + string(rewritten)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if debugEnabled {
|
|
||||||
debugLines = append(debugLines, line)
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Write([]byte(line + "\n"))
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-scanDone:
|
|
||||||
// Normal completion
|
|
||||||
case <-streamCtx.Done():
|
|
||||||
log.Printf("Stream timeout for %s via %s", req.Model, route.Provider.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
body.Close()
|
|
||||||
streamCancel()
|
|
||||||
|
|
||||||
latency := time.Since(start).Milliseconds()
|
|
||||||
cost := computeCost(inputTokens, outputTokens, route.InputPrice, route.OutputPrice)
|
|
||||||
h.metrics.RecordRequest(req.Model, route.Provider.Name(), tokenName, "success", latency, inputTokens, outputTokens, cost)
|
|
||||||
h.logRequest(requestID, tokenName, req.Model, route.Provider.Name(), route.ProviderModel, inputTokens, outputTokens, cost, latency, "success", "", true, false)
|
|
||||||
if h.healthTracker != nil {
|
|
||||||
h.healthTracker.Record(route.Provider.Name(), latency, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug logging for streaming requests
|
|
||||||
if debugEnabled && len(debugLines) > 0 {
|
|
||||||
respBody := strings.Join(debugLines, "\n")
|
|
||||||
reqBody, _ := json.Marshal(req)
|
|
||||||
reqBodyStr := string(reqBody)
|
|
||||||
if h.cfg.Debug.MaxBodyBytes > 0 {
|
|
||||||
if len(reqBodyStr) > h.cfg.Debug.MaxBodyBytes {
|
|
||||||
reqBodyStr = reqBodyStr[:h.cfg.Debug.MaxBodyBytes]
|
|
||||||
}
|
|
||||||
if len(respBody) > h.cfg.Debug.MaxBodyBytes {
|
|
||||||
respBody = respBody[:h.cfg.Debug.MaxBodyBytes]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h.debugLogger.Log(storage.DebugLogEntry{
|
|
||||||
RequestID: requestID,
|
|
||||||
TokenName: tokenName,
|
|
||||||
Model: req.Model,
|
|
||||||
Provider: route.Provider.Name(),
|
|
||||||
RequestBody: reqBodyStr,
|
|
||||||
ResponseBody: respBody,
|
|
||||||
RequestHeaders: formatHeaders(r.Header),
|
|
||||||
ResponseStatus: http.StatusOK,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// All providers failed
|
|
||||||
w.Header().Set("X-Request-ID", requestID)
|
|
||||||
if lastErr != nil {
|
|
||||||
writeError(w, http.StatusBadGateway, "all providers failed: "+lastErr.Error())
|
|
||||||
} else {
|
|
||||||
writeError(w, http.StatusBadGateway, "all providers failed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type streamChunk struct {
|
|
||||||
ID string `json:"id,omitempty"`
|
|
||||||
Object string `json:"object,omitempty"`
|
|
||||||
Created int64 `json:"created,omitempty"`
|
|
||||||
Model string `json:"model,omitempty"`
|
|
||||||
Choices []any `json:"choices,omitempty"`
|
|
||||||
Usage *provider.Usage `json:"usage,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
@ -1,105 +0,0 @@
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuditEntry struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
Timestamp int64 `json:"timestamp"`
|
|
||||||
UserID int64 `json:"user_id"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
Action string `json:"action"`
|
|
||||||
TargetType string `json:"target_type"`
|
|
||||||
TargetID string `json:"target_id"`
|
|
||||||
Details string `json:"details"`
|
|
||||||
IPAddress string `json:"ip_address"`
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuditLogger struct {
|
|
||||||
db *DB
|
|
||||||
OnWrite func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAuditLogger(db *DB) *AuditLogger {
|
|
||||||
return &AuditLogger{db: db}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AuditLogger) Log(entry AuditEntry) {
|
|
||||||
if entry.Timestamp == 0 {
|
|
||||||
entry.Timestamp = time.Now().Unix()
|
|
||||||
}
|
|
||||||
_, err := a.db.Exec(`INSERT INTO audit_log
|
|
||||||
(timestamp, user_id, username, action, target_type, target_id, details, ip_address, request_id)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
||||||
entry.Timestamp, entry.UserID, entry.Username, entry.Action,
|
|
||||||
entry.TargetType, entry.TargetID, entry.Details, entry.IPAddress, entry.RequestID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ERROR: audit log: %v", err)
|
|
||||||
} else if a.OnWrite != nil {
|
|
||||||
a.OnWrite()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuditQueryResult struct {
|
|
||||||
Entries []AuditEntry `json:"entries"`
|
|
||||||
Page int `json:"page"`
|
|
||||||
TotalPages int `json:"total_pages"`
|
|
||||||
Total int `json:"total"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *AuditLogger) Query(since int64, action string, page, limit int) *AuditQueryResult {
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
if limit <= 0 {
|
|
||||||
limit = 50
|
|
||||||
}
|
|
||||||
offset := (page - 1) * limit
|
|
||||||
|
|
||||||
where := "WHERE timestamp >= ?"
|
|
||||||
args := []any{since}
|
|
||||||
|
|
||||||
if action != "" {
|
|
||||||
where += " AND action = ?"
|
|
||||||
args = append(args, action)
|
|
||||||
}
|
|
||||||
|
|
||||||
var total int
|
|
||||||
countArgs := make([]any, len(args))
|
|
||||||
copy(countArgs, args)
|
|
||||||
a.db.QueryRow("SELECT COUNT(*) FROM audit_log "+where, countArgs...).Scan(&total)
|
|
||||||
|
|
||||||
totalPages := (total + limit - 1) / limit
|
|
||||||
if totalPages < 1 {
|
|
||||||
totalPages = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
query := `SELECT id, timestamp, COALESCE(user_id, 0), username, action,
|
|
||||||
COALESCE(target_type, ''), COALESCE(target_id, ''), COALESCE(details, ''),
|
|
||||||
COALESCE(ip_address, ''), COALESCE(request_id, '')
|
|
||||||
FROM audit_log ` + where + ` ORDER BY timestamp DESC LIMIT ? OFFSET ?`
|
|
||||||
args = append(args, limit, offset)
|
|
||||||
|
|
||||||
rows, err := a.db.Query(query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return &AuditQueryResult{Entries: []AuditEntry{}, Page: page, TotalPages: totalPages, Total: total}
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var entries []AuditEntry
|
|
||||||
for rows.Next() {
|
|
||||||
var e AuditEntry
|
|
||||||
rows.Scan(&e.ID, &e.Timestamp, &e.UserID, &e.Username, &e.Action,
|
|
||||||
&e.TargetType, &e.TargetID, &e.Details, &e.IPAddress, &e.RequestID)
|
|
||||||
entries = append(entries, e)
|
|
||||||
}
|
|
||||||
if entries == nil {
|
|
||||||
entries = []AuditEntry{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &AuditQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
|
|
||||||
}
|
|
||||||
|
|
@ -1,142 +0,0 @@
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"path/filepath"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-migrate/migrate/v4"
|
|
||||||
"github.com/golang-migrate/migrate/v4/database/sqlite"
|
|
||||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
|
|
||||||
"llm-gateway/internal/storage/migrations"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DB struct {
|
|
||||||
*sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func Open(path string) (*DB, error) {
|
|
||||||
dir := filepath.Dir(path)
|
|
||||||
if dir != "." && dir != "" {
|
|
||||||
// Ensure directory exists — caller should create it if needed
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000&_cache_size=-20000")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("opening database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Performance pragmas
|
|
||||||
for _, pragma := range []string{
|
|
||||||
"PRAGMA foreign_keys = ON",
|
|
||||||
"PRAGMA temp_store = MEMORY",
|
|
||||||
"PRAGMA mmap_size = 268435456",
|
|
||||||
} {
|
|
||||||
if _, err := db.Exec(pragma); err != nil {
|
|
||||||
return nil, fmt.Errorf("setting pragma %s: %w", pragma, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetMaxOpenConns(1) // SQLite is single-writer
|
|
||||||
db.SetMaxIdleConns(1)
|
|
||||||
|
|
||||||
if err := runMigrations(db); err != nil {
|
|
||||||
return nil, fmt.Errorf("running migrations: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &DB{db}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func runMigrations(db *sql.DB) error {
|
|
||||||
sourceDriver, err := iofs.New(migrations.FS, ".")
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating migration source: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dbDriver, err := sqlite.WithInstance(db, &sqlite.Config{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating migration db driver: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := migrate.NewWithInstance("iofs", sourceDriver, "sqlite", dbDriver)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("creating migrator: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
|
||||||
return fmt.Errorf("applying migrations: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CleanupOldRecords deletes records older than retentionDays.
|
|
||||||
func (db *DB) CleanupOldRecords(retentionDays int) error {
|
|
||||||
cutoff := time.Now().AddDate(0, 0, -retentionDays).Unix()
|
|
||||||
result, err := db.Exec("DELETE FROM request_logs WHERE timestamp < ?", cutoff)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
affected, _ := result.RowsAffected()
|
|
||||||
if affected > 0 {
|
|
||||||
log.Printf("Cleaned up %d old request log records", affected)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TodaySpend returns the total cost in USD for a given token today.
|
|
||||||
func (db *DB) TodaySpend(tokenName string) (float64, error) {
|
|
||||||
startOfDay := time.Now().Truncate(24 * time.Hour).Unix()
|
|
||||||
var total sql.NullFloat64
|
|
||||||
err := db.QueryRow(
|
|
||||||
"SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?",
|
|
||||||
tokenName, startOfDay,
|
|
||||||
).Scan(&total)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return total.Float64, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MonthSpend returns the total cost in USD for a given token this month.
|
|
||||||
func (db *DB) MonthSpend(tokenName string) (float64, error) {
|
|
||||||
now := time.Now()
|
|
||||||
startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()).Unix()
|
|
||||||
var total sql.NullFloat64
|
|
||||||
err := db.QueryRow(
|
|
||||||
"SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?",
|
|
||||||
tokenName, startOfMonth,
|
|
||||||
).Scan(&total)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return total.Float64, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TodaySpendAll returns today's spend for all tokens as a map.
|
|
||||||
func (db *DB) TodaySpendAll() (map[string]float64, error) {
|
|
||||||
startOfDay := time.Now().Truncate(24 * time.Hour).Unix()
|
|
||||||
rows, err := db.Query(
|
|
||||||
"SELECT token_name, SUM(cost_usd) FROM request_logs WHERE timestamp >= ? GROUP BY token_name",
|
|
||||||
startOfDay,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
result := make(map[string]float64)
|
|
||||||
for rows.Next() {
|
|
||||||
var name string
|
|
||||||
var total float64
|
|
||||||
if err := rows.Scan(&name, &total); err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result[name] = total
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,253 +0,0 @@
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DebugLogEntry struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
Timestamp int64 `json:"timestamp"`
|
|
||||||
TokenName string `json:"token_name"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Provider string `json:"provider"`
|
|
||||||
RequestBody string `json:"request_body"`
|
|
||||||
ResponseBody string `json:"response_body"`
|
|
||||||
RequestHeaders string `json:"request_headers"`
|
|
||||||
ResponseStatus int `json:"response_status"`
|
|
||||||
FilePath string `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// debugFile is the JSON structure written to disk.
|
|
||||||
type debugFile struct {
|
|
||||||
RequestHeaders string `json:"request_headers"`
|
|
||||||
RequestBody string `json:"request_body"`
|
|
||||||
ResponseBody string `json:"response_body"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type DebugLogger struct {
|
|
||||||
db *DB
|
|
||||||
enabled atomic.Bool
|
|
||||||
dataDir string
|
|
||||||
OnWrite func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDebugLogger(db *DB, enabled bool, dataDir string) *DebugLogger {
|
|
||||||
dl := &DebugLogger{db: db, dataDir: dataDir}
|
|
||||||
dl.enabled.Store(enabled)
|
|
||||||
return dl
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DebugLogger) SetEnabled(v bool) {
|
|
||||||
d.enabled.Store(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DebugLogger) IsEnabled() bool {
|
|
||||||
return d.enabled.Load()
|
|
||||||
}
|
|
||||||
|
|
||||||
// debugLogDir returns the base directory for debug log files.
|
|
||||||
func (d *DebugLogger) debugLogDir() string {
|
|
||||||
return filepath.Join(d.dataDir, "debug-logs")
|
|
||||||
}
|
|
||||||
|
|
||||||
// debugFilePath builds the file path for a debug log entry.
|
|
||||||
func (d *DebugLogger) debugFilePath(requestID string, ts time.Time) string {
|
|
||||||
date := ts.Format("2006-01-02")
|
|
||||||
return filepath.Join(d.debugLogDir(), date, requestID+".json")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DebugLogger) Log(entry DebugLogEntry) {
|
|
||||||
if !d.IsEnabled() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if entry.Timestamp == 0 {
|
|
||||||
entry.Timestamp = time.Now().Unix()
|
|
||||||
}
|
|
||||||
|
|
||||||
ts := time.Unix(entry.Timestamp, 0)
|
|
||||||
fp := d.debugFilePath(entry.RequestID, ts)
|
|
||||||
|
|
||||||
// Write body file
|
|
||||||
if err := os.MkdirAll(filepath.Dir(fp), 0755); err != nil {
|
|
||||||
log.Printf("ERROR: debug log mkdir: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
df := debugFile{
|
|
||||||
RequestHeaders: entry.RequestHeaders,
|
|
||||||
RequestBody: entry.RequestBody,
|
|
||||||
ResponseBody: entry.ResponseBody,
|
|
||||||
}
|
|
||||||
data, err := json.Marshal(df)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ERROR: debug log marshal: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(fp, data, 0644); err != nil {
|
|
||||||
log.Printf("ERROR: debug log write: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert metadata into DB (no bodies)
|
|
||||||
_, err = d.db.Exec(`INSERT INTO debug_log
|
|
||||||
(request_id, timestamp, token_name, model, provider, response_status, file_path)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
||||||
entry.RequestID, entry.Timestamp, entry.TokenName, entry.Model,
|
|
||||||
entry.Provider, entry.ResponseStatus, fp,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ERROR: debug log db insert: %v", err)
|
|
||||||
} else if d.OnWrite != nil {
|
|
||||||
d.OnWrite()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type DebugLogQueryResult struct {
|
|
||||||
Entries []DebugLogEntry `json:"entries"`
|
|
||||||
Page int `json:"page"`
|
|
||||||
TotalPages int `json:"total_pages"`
|
|
||||||
Total int `json:"total"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query returns paginated debug log metadata (no bodies — fast).
|
|
||||||
func (d *DebugLogger) Query(page, limit int) *DebugLogQueryResult {
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
if limit <= 0 {
|
|
||||||
limit = 50
|
|
||||||
}
|
|
||||||
offset := (page - 1) * limit
|
|
||||||
|
|
||||||
var total int
|
|
||||||
d.db.QueryRow("SELECT COUNT(*) FROM debug_log").Scan(&total)
|
|
||||||
|
|
||||||
totalPages := (total + limit - 1) / limit
|
|
||||||
if totalPages < 1 {
|
|
||||||
totalPages = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := d.db.Query(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
|
|
||||||
COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
|
|
||||||
FROM debug_log ORDER BY timestamp DESC LIMIT ? OFFSET ?`, limit, offset)
|
|
||||||
if err != nil {
|
|
||||||
return &DebugLogQueryResult{Entries: []DebugLogEntry{}, Page: page, TotalPages: totalPages, Total: total}
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var entries []DebugLogEntry
|
|
||||||
for rows.Next() {
|
|
||||||
var e DebugLogEntry
|
|
||||||
rows.Scan(&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
|
|
||||||
&e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
|
|
||||||
entries = append(entries, e)
|
|
||||||
}
|
|
||||||
if entries == nil {
|
|
||||||
entries = []DebugLogEntry{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &DebugLogQueryResult{Entries: entries, Page: page, TotalPages: totalPages, Total: total}
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryFull returns paginated debug log entries including request/response bodies read from files.
|
|
||||||
func (d *DebugLogger) QueryFull(page, limit int) *DebugLogQueryResult {
|
|
||||||
result := d.Query(page, limit)
|
|
||||||
for i := range result.Entries {
|
|
||||||
d.populateFromFile(&result.Entries[i])
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetByRequestID returns a single debug log entry with bodies read from file.
|
|
||||||
func (d *DebugLogger) GetByRequestID(requestID string) *DebugLogEntry {
|
|
||||||
var e DebugLogEntry
|
|
||||||
err := d.db.QueryRow(`SELECT id, request_id, timestamp, COALESCE(token_name, ''),
|
|
||||||
COALESCE(model, ''), COALESCE(provider, ''), COALESCE(response_status, 0), COALESCE(file_path, '')
|
|
||||||
FROM debug_log WHERE request_id = ?`, requestID).Scan(
|
|
||||||
&e.ID, &e.RequestID, &e.Timestamp, &e.TokenName,
|
|
||||||
&e.Model, &e.Provider, &e.ResponseStatus, &e.FilePath)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
d.populateFromFile(&e)
|
|
||||||
return &e
|
|
||||||
}
|
|
||||||
|
|
||||||
// populateFromFile reads body data from the debug file on disk.
|
|
||||||
// Falls back to DB columns for pre-migration entries that have no file_path.
|
|
||||||
func (d *DebugLogger) populateFromFile(e *DebugLogEntry) {
|
|
||||||
if e.FilePath == "" {
|
|
||||||
// Legacy entry: try reading bodies from DB columns
|
|
||||||
d.db.QueryRow(`SELECT COALESCE(request_body, ''), COALESCE(response_body, ''), COALESCE(request_headers, '')
|
|
||||||
FROM debug_log WHERE id = ?`, e.ID).Scan(&e.RequestBody, &e.ResponseBody, &e.RequestHeaders)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
data, err := os.ReadFile(e.FilePath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("WARN: debug log read file %s: %v", e.FilePath, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var df debugFile
|
|
||||||
if err := json.Unmarshal(data, &df); err != nil {
|
|
||||||
log.Printf("WARN: debug log parse file %s: %v", e.FilePath, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
e.RequestHeaders = df.RequestHeaders
|
|
||||||
e.RequestBody = df.RequestBody
|
|
||||||
e.ResponseBody = df.ResponseBody
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup removes debug log entries and files older than retentionDays.
|
|
||||||
func (d *DebugLogger) Cleanup(retentionDays int) error {
|
|
||||||
cutoff := time.Now().AddDate(0, 0, -retentionDays)
|
|
||||||
cutoffUnix := cutoff.Unix()
|
|
||||||
|
|
||||||
// Delete old DB rows
|
|
||||||
result, err := d.db.Exec("DELETE FROM debug_log WHERE timestamp < ?", cutoffUnix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("delete old debug rows: %w", err)
|
|
||||||
}
|
|
||||||
affected, _ := result.RowsAffected()
|
|
||||||
if affected > 0 {
|
|
||||||
log.Printf("Cleaned up %d old debug log entries", affected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove old date directories
|
|
||||||
baseDir := d.debugLogDir()
|
|
||||||
dirs, err := os.ReadDir(baseDir)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("read debug log dir: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cutoffDate := cutoff.Format("2006-01-02")
|
|
||||||
sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() })
|
|
||||||
|
|
||||||
for _, dir := range dirs {
|
|
||||||
if !dir.IsDir() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Date directories are named YYYY-MM-DD; string comparison works
|
|
||||||
if strings.Compare(dir.Name(), cutoffDate) < 0 {
|
|
||||||
dirPath := filepath.Join(baseDir, dir.Name())
|
|
||||||
if err := os.RemoveAll(dirPath); err != nil {
|
|
||||||
log.Printf("WARN: failed to remove debug log dir %s: %v", dirPath, err)
|
|
||||||
} else {
|
|
||||||
log.Printf("Removed old debug log directory: %s", dir.Name())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,138 +0,0 @@
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RequestLog struct {
|
|
||||||
RequestID string
|
|
||||||
Timestamp int64
|
|
||||||
TokenName string
|
|
||||||
Model string
|
|
||||||
Provider string
|
|
||||||
ProviderModel string
|
|
||||||
InputTokens int
|
|
||||||
OutputTokens int
|
|
||||||
CostUSD float64
|
|
||||||
LatencyMS int64
|
|
||||||
Status string // success, error, cached
|
|
||||||
ErrorMessage string
|
|
||||||
Streaming bool
|
|
||||||
Cached bool
|
|
||||||
RequestType string // "chat" or "embedding"
|
|
||||||
}
|
|
||||||
|
|
||||||
type AsyncLogger struct {
|
|
||||||
db *DB
|
|
||||||
ch chan RequestLog
|
|
||||||
done chan struct{}
|
|
||||||
OnFlush func() // called after successful flush, if set
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewAsyncLogger(db *DB, bufferSize int) *AsyncLogger {
|
|
||||||
if bufferSize == 0 {
|
|
||||||
bufferSize = 1000
|
|
||||||
}
|
|
||||||
l := &AsyncLogger{
|
|
||||||
db: db,
|
|
||||||
ch: make(chan RequestLog, bufferSize),
|
|
||||||
done: make(chan struct{}),
|
|
||||||
}
|
|
||||||
go l.run()
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *AsyncLogger) Log(r RequestLog) {
|
|
||||||
select {
|
|
||||||
case l.ch <- r:
|
|
||||||
default:
|
|
||||||
log.Println("WARNING: request log buffer full, dropping entry")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *AsyncLogger) Close() {
|
|
||||||
close(l.ch)
|
|
||||||
<-l.done
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *AsyncLogger) run() {
|
|
||||||
defer close(l.done)
|
|
||||||
|
|
||||||
batch := make([]RequestLog, 0, 100)
|
|
||||||
ticker := time.NewTicker(1 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case r, ok := <-l.ch:
|
|
||||||
if !ok {
|
|
||||||
// Channel closed, flush remaining
|
|
||||||
if len(batch) > 0 {
|
|
||||||
l.flush(batch)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
batch = append(batch, r)
|
|
||||||
if len(batch) >= 100 {
|
|
||||||
l.flush(batch)
|
|
||||||
batch = batch[:0]
|
|
||||||
}
|
|
||||||
case <-ticker.C:
|
|
||||||
if len(batch) > 0 {
|
|
||||||
l.flush(batch)
|
|
||||||
batch = batch[:0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *AsyncLogger) flush(batch []RequestLog) {
|
|
||||||
tx, err := l.db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ERROR: starting log transaction: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt, err := tx.Prepare(`INSERT INTO request_logs
|
|
||||||
(request_id, timestamp, token_name, model, provider, provider_model, input_tokens, output_tokens, cost_usd, latency_ms, status, error_message, streaming, cached, request_type)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ERROR: preparing log statement: %v", err)
|
|
||||||
tx.Rollback()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
|
|
||||||
for _, r := range batch {
|
|
||||||
streaming := 0
|
|
||||||
if r.Streaming {
|
|
||||||
streaming = 1
|
|
||||||
}
|
|
||||||
cached := 0
|
|
||||||
if r.Cached {
|
|
||||||
cached = 1
|
|
||||||
}
|
|
||||||
reqType := r.RequestType
|
|
||||||
if reqType == "" {
|
|
||||||
reqType = "chat"
|
|
||||||
}
|
|
||||||
_, err := stmt.Exec(
|
|
||||||
r.RequestID, r.Timestamp, r.TokenName, r.Model, r.Provider, r.ProviderModel,
|
|
||||||
r.InputTokens, r.OutputTokens, r.CostUSD, r.LatencyMS,
|
|
||||||
r.Status, r.ErrorMessage, streaming, cached, reqType,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ERROR: inserting log: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
log.Printf("ERROR: committing log batch: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if l.OnFlush != nil {
|
|
||||||
l.OnFlush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
DROP TABLE IF EXISTS request_logs;
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
CREATE TABLE IF NOT EXISTS request_logs (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
timestamp INTEGER NOT NULL,
|
|
||||||
token_name TEXT NOT NULL,
|
|
||||||
model TEXT NOT NULL,
|
|
||||||
provider TEXT NOT NULL,
|
|
||||||
provider_model TEXT NOT NULL,
|
|
||||||
input_tokens INTEGER DEFAULT 0,
|
|
||||||
output_tokens INTEGER DEFAULT 0,
|
|
||||||
cost_usd REAL DEFAULT 0,
|
|
||||||
latency_ms INTEGER DEFAULT 0,
|
|
||||||
status TEXT NOT NULL,
|
|
||||||
error_message TEXT DEFAULT '',
|
|
||||||
streaming INTEGER DEFAULT 0,
|
|
||||||
cached INTEGER DEFAULT 0
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_timestamp ON request_logs(timestamp);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_token ON request_logs(token_name);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_model ON request_logs(model);
|
|
||||||
|
|
@ -1,3 +0,0 @@
|
||||||
DROP TABLE IF EXISTS api_tokens;
|
|
||||||
DROP TABLE IF EXISTS sessions;
|
|
||||||
DROP TABLE IF EXISTS users;
|
|
||||||
|
|
@ -1,33 +0,0 @@
|
||||||
CREATE TABLE users (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
username TEXT NOT NULL UNIQUE,
|
|
||||||
password_hash TEXT NOT NULL,
|
|
||||||
is_admin INTEGER NOT NULL DEFAULT 0,
|
|
||||||
totp_secret TEXT DEFAULT '',
|
|
||||||
totp_enabled INTEGER NOT NULL DEFAULT 0,
|
|
||||||
created_at INTEGER NOT NULL,
|
|
||||||
updated_at INTEGER NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE sessions (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
|
||||||
created_at INTEGER NOT NULL,
|
|
||||||
expires_at INTEGER NOT NULL
|
|
||||||
);
|
|
||||||
CREATE INDEX idx_sessions_user ON sessions(user_id);
|
|
||||||
CREATE INDEX idx_sessions_expires ON sessions(expires_at);
|
|
||||||
|
|
||||||
CREATE TABLE api_tokens (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
key_hash TEXT NOT NULL,
|
|
||||||
key_prefix TEXT NOT NULL,
|
|
||||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
|
||||||
rate_limit_rpm INTEGER DEFAULT 60,
|
|
||||||
daily_budget_usd REAL DEFAULT 0,
|
|
||||||
created_at INTEGER NOT NULL,
|
|
||||||
last_used_at INTEGER DEFAULT 0
|
|
||||||
);
|
|
||||||
CREATE UNIQUE INDEX idx_api_tokens_hash ON api_tokens(key_hash);
|
|
||||||
CREATE INDEX idx_api_tokens_user ON api_tokens(user_id);
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
-- SQLite doesn't support DROP COLUMN before 3.35.0, so we recreate
|
|
||||||
CREATE TABLE users_backup AS SELECT id, username, password_hash, is_admin, totp_secret, totp_enabled, created_at, updated_at FROM users;
|
|
||||||
DROP TABLE users;
|
|
||||||
ALTER TABLE users_backup RENAME TO users;
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE users ADD COLUMN email TEXT DEFAULT '';
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
-- SQLite doesn't support DROP COLUMN in older versions, so we recreate the table
|
|
||||||
CREATE TABLE api_tokens_backup AS SELECT id, name, key_hash, key_prefix, user_id, rate_limit_rpm, daily_budget_usd, created_at, last_used_at FROM api_tokens;
|
|
||||||
DROP TABLE api_tokens;
|
|
||||||
ALTER TABLE api_tokens_backup RENAME TO api_tokens;
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE api_tokens ADD COLUMN max_concurrent INTEGER DEFAULT 0;
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
DROP INDEX IF EXISTS idx_request_logs_request_id;
|
|
||||||
|
|
@ -1,2 +0,0 @@
|
||||||
ALTER TABLE request_logs ADD COLUMN request_id TEXT DEFAULT '';
|
|
||||||
CREATE INDEX idx_request_logs_request_id ON request_logs(request_id);
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
DROP TABLE IF EXISTS audit_log;
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
CREATE TABLE audit_log (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
timestamp INTEGER NOT NULL,
|
|
||||||
user_id INTEGER,
|
|
||||||
username TEXT NOT NULL DEFAULT '',
|
|
||||||
action TEXT NOT NULL,
|
|
||||||
target_type TEXT DEFAULT '',
|
|
||||||
target_id TEXT DEFAULT '',
|
|
||||||
details TEXT DEFAULT '',
|
|
||||||
ip_address TEXT DEFAULT '',
|
|
||||||
request_id TEXT DEFAULT ''
|
|
||||||
);
|
|
||||||
CREATE INDEX idx_audit_timestamp ON audit_log(timestamp);
|
|
||||||
CREATE INDEX idx_audit_action ON audit_log(action);
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
DROP TABLE IF EXISTS debug_log;
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
CREATE TABLE debug_log (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
request_id TEXT NOT NULL,
|
|
||||||
timestamp INTEGER NOT NULL,
|
|
||||||
token_name TEXT DEFAULT '',
|
|
||||||
model TEXT DEFAULT '',
|
|
||||||
provider TEXT DEFAULT '',
|
|
||||||
request_body TEXT DEFAULT '',
|
|
||||||
response_body TEXT DEFAULT '',
|
|
||||||
request_headers TEXT DEFAULT '',
|
|
||||||
response_status INTEGER DEFAULT 0
|
|
||||||
);
|
|
||||||
CREATE INDEX idx_debug_request_id ON debug_log(request_id);
|
|
||||||
CREATE INDEX idx_debug_timestamp ON debug_log(timestamp);
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
-- no-op: file_path column is harmless to keep
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE debug_log ADD COLUMN file_path TEXT DEFAULT '';
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE api_tokens DROP COLUMN monthly_budget_usd;
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE api_tokens ADD COLUMN monthly_budget_usd REAL NOT NULL DEFAULT 0;
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE request_logs DROP COLUMN request_type;
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
ALTER TABLE request_logs ADD COLUMN request_type TEXT NOT NULL DEFAULT 'chat';
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
package migrations
|
|
||||||
|
|
||||||
import "embed"
|
|
||||||
|
|
||||||
//go:embed *.sql
|
|
||||||
var FS embed.FS
|
|
||||||
|
|
@ -1,123 +0,0 @@
|
||||||
package webhook
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"llm-gateway/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Event types.
|
|
||||||
const (
|
|
||||||
EventCircuitBreakerOpen = "circuit_breaker.open"
|
|
||||||
EventCircuitBreakerClosed = "circuit_breaker.closed"
|
|
||||||
EventBudgetThreshold = "budget.threshold"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Event represents a webhook notification payload.
|
|
||||||
type Event struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Timestamp time.Time `json:"timestamp"`
|
|
||||||
Data map[string]any `json:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notifier sends webhook notifications.
|
|
||||||
type Notifier struct {
|
|
||||||
webhooks []config.WebhookConfig
|
|
||||||
ch chan Event
|
|
||||||
done chan struct{}
|
|
||||||
client *http.Client
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewNotifier creates a webhook notifier from config.
|
|
||||||
func NewNotifier(webhooks []config.WebhookConfig) *Notifier {
|
|
||||||
n := &Notifier{
|
|
||||||
webhooks: webhooks,
|
|
||||||
ch: make(chan Event, 100),
|
|
||||||
done: make(chan struct{}),
|
|
||||||
client: &http.Client{Timeout: 10 * time.Second},
|
|
||||||
}
|
|
||||||
go n.run()
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notify queues an event for delivery (non-blocking).
|
|
||||||
func (n *Notifier) Notify(evt Event) {
|
|
||||||
if evt.Timestamp.IsZero() {
|
|
||||||
evt.Timestamp = time.Now()
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case n.ch <- evt:
|
|
||||||
default:
|
|
||||||
log.Printf("WARNING: webhook channel full, dropping event %s", evt.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close drains pending events and shuts down.
|
|
||||||
func (n *Notifier) Close() {
|
|
||||||
close(n.ch)
|
|
||||||
<-n.done
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) run() {
|
|
||||||
defer close(n.done)
|
|
||||||
for evt := range n.ch {
|
|
||||||
for _, wh := range n.webhooks {
|
|
||||||
if !n.shouldSend(wh, evt.Type) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
n.send(wh, evt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) shouldSend(wh config.WebhookConfig, eventType string) bool {
|
|
||||||
if len(wh.Events) == 0 {
|
|
||||||
return true // no filter = send all
|
|
||||||
}
|
|
||||||
for _, e := range wh.Events {
|
|
||||||
if e == eventType {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) send(wh config.WebhookConfig, evt Event) {
|
|
||||||
body, err := json.Marshal(evt)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ERROR: webhook marshal: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodPost, wh.URL, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ERROR: webhook request: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
if wh.Secret != "" {
|
|
||||||
mac := hmac.New(sha256.New, []byte(wh.Secret))
|
|
||||||
mac.Write(body)
|
|
||||||
sig := hex.EncodeToString(mac.Sum(nil))
|
|
||||||
req.Header.Set("X-Webhook-Signature", "sha256="+sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := n.client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("WARNING: webhook delivery to %s failed: %v", wh.URL, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
|
||||||
log.Printf("WARNING: webhook %s returned %d", wh.URL, resp.StatusCode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Loading…
Reference in a new issue