267 lines
7.5 KiB
Go
267 lines
7.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
)
|
|
|
|
type Permission string
|
|
|
|
const (
|
|
PermissionPublic Permission = "public"
|
|
PermissionBasic Permission = "basic"
|
|
PermissionCreat Permission = "creat"
|
|
PermissionPower Permission = "power"
|
|
PermissionManager Permission = "manager"
|
|
)
|
|
|
|
type User struct {
|
|
ID string `json:"sub"`
|
|
Username string `json:"username"`
|
|
Roles []string `json:"role,omitempty"`
|
|
TenantID string `json:"tenantId,omitempty"`
|
|
GatewayTenantID string `json:"gatewayTenantId,omitempty"`
|
|
TenantKey string `json:"tenantKey,omitempty"`
|
|
SSOID string `json:"sso_id,omitempty"`
|
|
Source string `json:"source,omitempty"`
|
|
GatewayUserID string `json:"gatewayUserId,omitempty"`
|
|
UserGroupID string `json:"userGroupId,omitempty"`
|
|
UserGroupKey string `json:"userGroupKey,omitempty"`
|
|
UserGroupKeys []string `json:"userGroupKeys,omitempty"`
|
|
APIKeyID string `json:"apiKeyId,omitempty"`
|
|
APIKeySecret string `json:"apiKeySecret,omitempty"`
|
|
APIKeyName string `json:"apiKeyName,omitempty"`
|
|
}
|
|
|
|
type contextKey string
|
|
|
|
const userContextKey contextKey = "easyai-auth-user"
|
|
|
|
var ErrUnauthorized = errors.New("unauthorized")
|
|
|
|
type Authenticator struct {
|
|
JWTSecret string
|
|
ServerMainBaseURL string
|
|
ServerMainInternalToken string
|
|
HTTPClient *http.Client
|
|
}
|
|
|
|
func New(jwtSecret string, serverMainBaseURL string, internalToken string) *Authenticator {
|
|
return &Authenticator{
|
|
JWTSecret: jwtSecret,
|
|
ServerMainBaseURL: strings.TrimRight(serverMainBaseURL, "/"),
|
|
ServerMainInternalToken: internalToken,
|
|
HTTPClient: &http.Client{
|
|
Timeout: 10 * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
func UserFromContext(ctx context.Context) (*User, bool) {
|
|
user, ok := ctx.Value(userContextKey).(*User)
|
|
return user, ok
|
|
}
|
|
|
|
func (a *Authenticator) Require(permission Permission, next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
user, err := a.Authenticate(r)
|
|
if err != nil {
|
|
if permission == PermissionPublic {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if !hasPermission(user.Roles, permission) {
|
|
http.Error(w, "forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), userContextKey, user)))
|
|
})
|
|
}
|
|
|
|
func (a *Authenticator) Authenticate(r *http.Request) (*User, error) {
|
|
token := extractBearer(r.Header.Get("Authorization"))
|
|
if token == "" {
|
|
token = strings.TrimSpace(r.Header.Get("x-comfy-api-key"))
|
|
}
|
|
if token == "" {
|
|
return nil, ErrUnauthorized
|
|
}
|
|
if strings.HasPrefix(token, "sk-") {
|
|
return a.verifyAPIKey(r.Context(), token)
|
|
}
|
|
return a.verifyJWT(token)
|
|
}
|
|
|
|
func (a *Authenticator) verifyJWT(tokenString string) (*User, error) {
|
|
token, err := jwt.ParseWithClaims(tokenString, jwt.MapClaims{}, func(token *jwt.Token) (any, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(a.JWTSecret), nil
|
|
})
|
|
if err != nil || !token.Valid {
|
|
return nil, ErrUnauthorized
|
|
}
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, ErrUnauthorized
|
|
}
|
|
|
|
user := &User{
|
|
ID: stringClaim(claims, "sub"),
|
|
Username: stringClaim(claims, "username"),
|
|
Roles: stringSliceClaim(claims, "role"),
|
|
TenantID: stringClaim(claims, "tenantId"),
|
|
GatewayTenantID: stringClaim(claims, "gatewayTenantId"),
|
|
TenantKey: stringClaim(claims, "tenantKey"),
|
|
SSOID: stringClaim(claims, "sso_id"),
|
|
Source: stringClaim(claims, "source"),
|
|
GatewayUserID: stringClaim(claims, "gatewayUserId"),
|
|
UserGroupID: stringClaim(claims, "userGroupId"),
|
|
UserGroupKey: stringClaim(claims, "userGroupKey"),
|
|
UserGroupKeys: stringSliceClaim(claims, "userGroupKeys"),
|
|
APIKeyID: stringClaim(claims, "apiKeyId"),
|
|
APIKeySecret: stringClaim(claims, "apiKeySecret"),
|
|
APIKeyName: stringClaim(claims, "apiKeyName"),
|
|
}
|
|
if user.Source == "" {
|
|
user.Source = "gateway"
|
|
}
|
|
if user.ID == "" {
|
|
return nil, ErrUnauthorized
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (a *Authenticator) SignJWT(user *User, ttl time.Duration) (string, error) {
|
|
if ttl <= 0 {
|
|
ttl = time.Hour
|
|
}
|
|
now := time.Now()
|
|
claims := jwt.MapClaims{
|
|
"sub": user.ID,
|
|
"username": user.Username,
|
|
"role": user.Roles,
|
|
"tenantId": user.TenantID,
|
|
"gatewayTenantId": user.GatewayTenantID,
|
|
"tenantKey": user.TenantKey,
|
|
"source": user.Source,
|
|
"gatewayUserId": user.GatewayUserID,
|
|
"userGroupId": user.UserGroupID,
|
|
"userGroupKey": user.UserGroupKey,
|
|
"userGroupKeys": user.UserGroupKeys,
|
|
"iat": now.Unix(),
|
|
"exp": now.Add(ttl).Unix(),
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return token.SignedString([]byte(a.JWTSecret))
|
|
}
|
|
|
|
func (a *Authenticator) verifyAPIKey(ctx context.Context, apiKey string) (*User, error) {
|
|
if a.ServerMainBaseURL == "" || a.ServerMainInternalToken == "" {
|
|
return nil, ErrUnauthorized
|
|
}
|
|
body, _ := json.Marshal(map[string]string{"apiKey": apiKey})
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.ServerMainBaseURL+"/internal/platform/auth/verify-api-key", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+a.ServerMainInternalToken)
|
|
|
|
resp, err := a.HTTPClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, ErrUnauthorized
|
|
}
|
|
var user User
|
|
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
|
return nil, err
|
|
}
|
|
if user.ID == "" {
|
|
return nil, ErrUnauthorized
|
|
}
|
|
if user.Source == "" {
|
|
user.Source = "server-main"
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func extractBearer(value string) string {
|
|
fields := strings.Fields(value)
|
|
if len(fields) == 2 && strings.EqualFold(fields[0], "bearer") {
|
|
return fields[1]
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func hasPermission(roles []string, required Permission) bool {
|
|
if required == PermissionPublic {
|
|
return true
|
|
}
|
|
granted := map[Permission]bool{PermissionPublic: true}
|
|
for _, role := range roles {
|
|
for _, permission := range permissionsForRole(role) {
|
|
granted[permission] = true
|
|
}
|
|
}
|
|
return granted[required]
|
|
}
|
|
|
|
func permissionsForRole(role string) []Permission {
|
|
switch role {
|
|
case "admin", "manager":
|
|
return []Permission{PermissionPublic, PermissionBasic, PermissionCreat, PermissionPower, PermissionManager}
|
|
case "operator":
|
|
return []Permission{PermissionPublic, PermissionBasic, PermissionCreat, PermissionPower}
|
|
case "creator":
|
|
return []Permission{PermissionPublic, PermissionBasic, PermissionCreat}
|
|
case "user":
|
|
return []Permission{PermissionPublic, PermissionBasic}
|
|
default:
|
|
return []Permission{PermissionPublic}
|
|
}
|
|
}
|
|
|
|
func stringClaim(claims jwt.MapClaims, key string) string {
|
|
value, _ := claims[key].(string)
|
|
return value
|
|
}
|
|
|
|
func stringSliceClaim(claims jwt.MapClaims, key string) []string {
|
|
value := claims[key]
|
|
switch typed := value.(type) {
|
|
case []string:
|
|
return typed
|
|
case []any:
|
|
out := make([]string, 0, len(typed))
|
|
for _, item := range typed {
|
|
if s, ok := item.(string); ok && s != "" {
|
|
out = append(out, s)
|
|
}
|
|
}
|
|
return out
|
|
case string:
|
|
if typed == "" {
|
|
return nil
|
|
}
|
|
return []string{typed}
|
|
default:
|
|
return nil
|
|
}
|
|
}
|