83 lines
1.9 KiB
Go
83 lines
1.9 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const userContextKey contextKey = "auth_user"
|
|
|
|
const (
|
|
sessionCookieName = "llmgw_session"
|
|
sessionTTLDays = 7
|
|
)
|
|
|
|
type Middleware struct {
|
|
store *Store
|
|
}
|
|
|
|
func NewMiddleware(store *Store) *Middleware {
|
|
return &Middleware{store: store}
|
|
}
|
|
|
|
func UserFromContext(ctx context.Context) *User {
|
|
u, _ := ctx.Value(userContextKey).(*User)
|
|
return u
|
|
}
|
|
|
|
func (m *Middleware) RequireAuth(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
cookie, err := r.Cookie(sessionCookieName)
|
|
if err != nil || cookie.Value == "" {
|
|
m.unauthorized(w, r)
|
|
return
|
|
}
|
|
|
|
sess, err := m.store.GetSession(cookie.Value)
|
|
if err != nil {
|
|
m.unauthorized(w, r)
|
|
return
|
|
}
|
|
|
|
user, err := m.store.GetUserByID(sess.UserID)
|
|
if err != nil {
|
|
m.unauthorized(w, r)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userContextKey, user)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
|
|
func (m *Middleware) RequireAdmin(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
user := UserFromContext(r.Context())
|
|
if user == nil || !user.IsAdmin {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusForbidden)
|
|
json.NewEncoder(w).Encode(map[string]string{"error": "admin access required"})
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (m *Middleware) unauthorized(w http.ResponseWriter, r *http.Request) {
|
|
if r.Header.Get("HX-Request") == "true" {
|
|
w.Header().Set("HX-Redirect", "/login")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if strings.HasPrefix(r.URL.Path, "/api/") {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
json.NewEncoder(w).Encode(map[string]string{"error": "authentication required"})
|
|
return
|
|
}
|
|
http.Redirect(w, r, "/login", http.StatusFound)
|
|
}
|