easyai-ai-gateway/apps/api/internal/auth/auth.go

267 lines
7.5 KiB
Go

package auth
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
type Permission string
const (
PermissionPublic Permission = "public"
PermissionBasic Permission = "basic"
PermissionCreat Permission = "creat"
PermissionPower Permission = "power"
PermissionManager Permission = "manager"
)
type User struct {
ID string `json:"sub"`
Username string `json:"username"`
Roles []string `json:"role,omitempty"`
TenantID string `json:"tenantId,omitempty"`
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
TenantKey string `json:"tenantKey,omitempty"`
SSOID string `json:"sso_id,omitempty"`
Source string `json:"source,omitempty"`
GatewayUserID string `json:"gatewayUserId,omitempty"`
UserGroupID string `json:"userGroupId,omitempty"`
UserGroupKey string `json:"userGroupKey,omitempty"`
UserGroupKeys []string `json:"userGroupKeys,omitempty"`
APIKeyID string `json:"apiKeyId,omitempty"`
APIKeySecret string `json:"apiKeySecret,omitempty"`
APIKeyName string `json:"apiKeyName,omitempty"`
}
type contextKey string
const userContextKey contextKey = "easyai-auth-user"
var ErrUnauthorized = errors.New("unauthorized")
type Authenticator struct {
JWTSecret string
ServerMainBaseURL string
ServerMainInternalToken string
HTTPClient *http.Client
}
func New(jwtSecret string, serverMainBaseURL string, internalToken string) *Authenticator {
return &Authenticator{
JWTSecret: jwtSecret,
ServerMainBaseURL: strings.TrimRight(serverMainBaseURL, "/"),
ServerMainInternalToken: internalToken,
HTTPClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
func UserFromContext(ctx context.Context) (*User, bool) {
user, ok := ctx.Value(userContextKey).(*User)
return user, ok
}
func (a *Authenticator) Require(permission Permission, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, err := a.Authenticate(r)
if err != nil {
if permission == PermissionPublic {
next.ServeHTTP(w, r)
return
}
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if !hasPermission(user.Roles, permission) {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), userContextKey, user)))
})
}
func (a *Authenticator) Authenticate(r *http.Request) (*User, error) {
token := extractBearer(r.Header.Get("Authorization"))
if token == "" {
token = strings.TrimSpace(r.Header.Get("x-comfy-api-key"))
}
if token == "" {
return nil, ErrUnauthorized
}
if strings.HasPrefix(token, "sk-") {
return a.verifyAPIKey(r.Context(), token)
}
return a.verifyJWT(token)
}
func (a *Authenticator) verifyJWT(tokenString string) (*User, error) {
token, err := jwt.ParseWithClaims(tokenString, jwt.MapClaims{}, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(a.JWTSecret), nil
})
if err != nil || !token.Valid {
return nil, ErrUnauthorized
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, ErrUnauthorized
}
user := &User{
ID: stringClaim(claims, "sub"),
Username: stringClaim(claims, "username"),
Roles: stringSliceClaim(claims, "role"),
TenantID: stringClaim(claims, "tenantId"),
GatewayTenantID: stringClaim(claims, "gatewayTenantId"),
TenantKey: stringClaim(claims, "tenantKey"),
SSOID: stringClaim(claims, "sso_id"),
Source: stringClaim(claims, "source"),
GatewayUserID: stringClaim(claims, "gatewayUserId"),
UserGroupID: stringClaim(claims, "userGroupId"),
UserGroupKey: stringClaim(claims, "userGroupKey"),
UserGroupKeys: stringSliceClaim(claims, "userGroupKeys"),
APIKeyID: stringClaim(claims, "apiKeyId"),
APIKeySecret: stringClaim(claims, "apiKeySecret"),
APIKeyName: stringClaim(claims, "apiKeyName"),
}
if user.Source == "" {
user.Source = "gateway"
}
if user.ID == "" {
return nil, ErrUnauthorized
}
return user, nil
}
func (a *Authenticator) SignJWT(user *User, ttl time.Duration) (string, error) {
if ttl <= 0 {
ttl = time.Hour
}
now := time.Now()
claims := jwt.MapClaims{
"sub": user.ID,
"username": user.Username,
"role": user.Roles,
"tenantId": user.TenantID,
"gatewayTenantId": user.GatewayTenantID,
"tenantKey": user.TenantKey,
"source": user.Source,
"gatewayUserId": user.GatewayUserID,
"userGroupId": user.UserGroupID,
"userGroupKey": user.UserGroupKey,
"userGroupKeys": user.UserGroupKeys,
"iat": now.Unix(),
"exp": now.Add(ttl).Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(a.JWTSecret))
}
func (a *Authenticator) verifyAPIKey(ctx context.Context, apiKey string) (*User, error) {
if a.ServerMainBaseURL == "" || a.ServerMainInternalToken == "" {
return nil, ErrUnauthorized
}
body, _ := json.Marshal(map[string]string{"apiKey": apiKey})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.ServerMainBaseURL+"/internal/platform/auth/verify-api-key", bytes.NewReader(body))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+a.ServerMainInternalToken)
resp, err := a.HTTPClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, ErrUnauthorized
}
var user User
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return nil, err
}
if user.ID == "" {
return nil, ErrUnauthorized
}
if user.Source == "" {
user.Source = "server-main"
}
return &user, nil
}
func extractBearer(value string) string {
fields := strings.Fields(value)
if len(fields) == 2 && strings.EqualFold(fields[0], "bearer") {
return fields[1]
}
return ""
}
func hasPermission(roles []string, required Permission) bool {
if required == PermissionPublic {
return true
}
granted := map[Permission]bool{PermissionPublic: true}
for _, role := range roles {
for _, permission := range permissionsForRole(role) {
granted[permission] = true
}
}
return granted[required]
}
func permissionsForRole(role string) []Permission {
switch role {
case "admin", "manager":
return []Permission{PermissionPublic, PermissionBasic, PermissionCreat, PermissionPower, PermissionManager}
case "operator":
return []Permission{PermissionPublic, PermissionBasic, PermissionCreat, PermissionPower}
case "creator":
return []Permission{PermissionPublic, PermissionBasic, PermissionCreat}
case "user":
return []Permission{PermissionPublic, PermissionBasic}
default:
return []Permission{PermissionPublic}
}
}
func stringClaim(claims jwt.MapClaims, key string) string {
value, _ := claims[key].(string)
return value
}
func stringSliceClaim(claims jwt.MapClaims, key string) []string {
value := claims[key]
switch typed := value.(type) {
case []string:
return typed
case []any:
out := make([]string, 0, len(typed))
for _, item := range typed {
if s, ok := item.(string); ok && s != "" {
out = append(out, s)
}
}
return out
case string:
if typed == "" {
return nil
}
return []string{typed}
default:
return nil
}
}