ai-servers/llm-gateway/internal/storage/db.go

103 lines
2.6 KiB
Go

package storage
import (
"database/sql"
"fmt"
"log"
"path/filepath"
"time"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "modernc.org/sqlite"
"llm-gateway/internal/storage/migrations"
)
type DB struct {
*sql.DB
}
func Open(path string) (*DB, error) {
dir := filepath.Dir(path)
if dir != "." && dir != "" {
// Ensure directory exists — caller should create it if needed
}
db, err := sql.Open("sqlite", path+"?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000&_cache_size=-20000")
if err != nil {
return nil, fmt.Errorf("opening database: %w", err)
}
// Performance pragmas
for _, pragma := range []string{
"PRAGMA foreign_keys = ON",
"PRAGMA temp_store = MEMORY",
"PRAGMA mmap_size = 268435456",
} {
if _, err := db.Exec(pragma); err != nil {
return nil, fmt.Errorf("setting pragma %s: %w", pragma, err)
}
}
db.SetMaxOpenConns(1) // SQLite is single-writer
db.SetMaxIdleConns(1)
if err := runMigrations(db); err != nil {
return nil, fmt.Errorf("running migrations: %w", err)
}
return &DB{db}, nil
}
func runMigrations(db *sql.DB) error {
sourceDriver, err := iofs.New(migrations.FS, ".")
if err != nil {
return fmt.Errorf("creating migration source: %w", err)
}
dbDriver, err := sqlite.WithInstance(db, &sqlite.Config{})
if err != nil {
return fmt.Errorf("creating migration db driver: %w", err)
}
m, err := migrate.NewWithInstance("iofs", sourceDriver, "sqlite", dbDriver)
if err != nil {
return fmt.Errorf("creating migrator: %w", err)
}
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("applying migrations: %w", err)
}
return nil
}
// CleanupOldRecords deletes records older than retentionDays.
func (db *DB) CleanupOldRecords(retentionDays int) error {
cutoff := time.Now().AddDate(0, 0, -retentionDays).Unix()
result, err := db.Exec("DELETE FROM request_logs WHERE timestamp < ?", cutoff)
if err != nil {
return err
}
affected, _ := result.RowsAffected()
if affected > 0 {
log.Printf("Cleaned up %d old request log records", affected)
}
return nil
}
// TodaySpend returns the total cost in USD for a given token today.
func (db *DB) TodaySpend(tokenName string) (float64, error) {
startOfDay := time.Now().Truncate(24 * time.Hour).Unix()
var total sql.NullFloat64
err := db.QueryRow(
"SELECT SUM(cost_usd) FROM request_logs WHERE token_name = ? AND timestamp >= ?",
tokenName, startOfDay,
).Scan(&total)
if err != nil {
return 0, err
}
return total.Float64, nil
}