easyai-ai-gateway/apps/api/internal/store/access_rules.go

659 lines
20 KiB
Go

package store
import (
"context"
"encoding/json"
"errors"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/jackc/pgx/v5"
)
type AccessRuleInput struct {
SubjectType string `json:"subjectType"`
SubjectID string `json:"subjectId"`
ResourceType string `json:"resourceType"`
ResourceID string `json:"resourceId"`
Effect string `json:"effect"`
Priority int `json:"priority"`
MinPermissionLevel int `json:"minPermissionLevel"`
Conditions map[string]any `json:"conditions"`
Metadata map[string]any `json:"metadata"`
Status string `json:"status"`
}
type AccessRuleResourceInput struct {
ResourceType string `json:"resourceType"`
ResourceID string `json:"resourceId"`
Priority int `json:"priority"`
MinPermissionLevel int `json:"minPermissionLevel"`
Conditions map[string]any `json:"conditions"`
Metadata map[string]any `json:"metadata"`
Status string `json:"status"`
}
type AccessRuleBatchInput struct {
SubjectType string `json:"subjectType"`
SubjectID string `json:"subjectId"`
Effect string `json:"effect"`
UpsertResources []AccessRuleResourceInput `json:"upsertResources"`
DeleteResources []AccessRuleResourceInput `json:"deleteResources"`
}
type accessRuleResource struct {
Type string
ID string
}
func (s *Store) ListAccessRules(ctx context.Context) ([]AccessRule, error) {
rows, err := s.pool.Query(ctx, `
SELECT `+accessRuleColumns+`
FROM gateway_access_rules
ORDER BY resource_type ASC, priority ASC, subject_type ASC, created_at DESC`)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]AccessRule, 0)
for rows.Next() {
item, err := scanAccessRule(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) ListAPIKeyAccessRules(ctx context.Context, user *auth.User) ([]AccessRule, error) {
gatewayUserID := localGatewayUserID(user)
if gatewayUserID == "" {
return nil, ErrLocalUserRequired
}
rows, err := s.pool.Query(ctx, `
SELECT `+apiKeyAccessRuleColumns+`
FROM gateway_access_rules ar
JOIN gateway_api_keys k ON k.id = ar.subject_id
WHERE ar.subject_type = 'api_key'
AND k.gateway_user_id = $1::uuid
AND k.deleted_at IS NULL
ORDER BY ar.resource_type ASC, ar.priority ASC, ar.created_at DESC`, gatewayUserID)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]AccessRule, 0)
for rows.Next() {
item, err := scanAccessRule(rows)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, rows.Err()
}
func (s *Store) CreateAccessRule(ctx context.Context, input AccessRuleInput) (AccessRule, error) {
input = normalizeAccessRuleInput(input)
conditions, _ := json.Marshal(emptyObjectIfNil(input.Conditions))
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
return scanAccessRule(s.pool.QueryRow(ctx, `
INSERT INTO gateway_access_rules (
subject_type, subject_id, resource_type, resource_id, effect, priority,
min_permission_level, conditions, metadata, status
)
VALUES ($1, $2::uuid, $3, $4::uuid, $5, $6, $7, $8::jsonb, $9::jsonb, $10)
RETURNING `+accessRuleColumns,
input.SubjectType, input.SubjectID, input.ResourceType, input.ResourceID, input.Effect,
input.Priority, input.MinPermissionLevel, string(conditions), string(metadata), input.Status,
))
}
func (s *Store) UpdateAccessRule(ctx context.Context, id string, input AccessRuleInput) (AccessRule, error) {
input = normalizeAccessRuleInput(input)
conditions, _ := json.Marshal(emptyObjectIfNil(input.Conditions))
metadata, _ := json.Marshal(emptyObjectIfNil(input.Metadata))
return scanAccessRule(s.pool.QueryRow(ctx, `
UPDATE gateway_access_rules
SET subject_type = $2,
subject_id = $3::uuid,
resource_type = $4,
resource_id = $5::uuid,
effect = $6,
priority = $7,
min_permission_level = $8,
conditions = $9::jsonb,
metadata = $10::jsonb,
status = $11,
updated_at = now()
WHERE id = $1::uuid
RETURNING `+accessRuleColumns,
id, input.SubjectType, input.SubjectID, input.ResourceType, input.ResourceID, input.Effect,
input.Priority, input.MinPermissionLevel, string(conditions), string(metadata), input.Status,
))
}
func (s *Store) DeleteAccessRule(ctx context.Context, id string) error {
result, err := s.pool.Exec(ctx, `DELETE FROM gateway_access_rules WHERE id = $1::uuid`, id)
if err != nil {
return err
}
if result.RowsAffected() == 0 {
return pgx.ErrNoRows
}
return nil
}
func (s *Store) BatchAccessRules(ctx context.Context, input AccessRuleBatchInput) ([]AccessRule, error) {
input = normalizeAccessRuleBatchInput(input)
tx, err := s.pool.Begin(ctx)
if err != nil {
return nil, err
}
defer tx.Rollback(ctx)
for _, resource := range dedupeAccessRuleResources(input.DeleteResources) {
resource = normalizeAccessRuleResource(resource, input.Effect)
if resource.ResourceType == "" || resource.ResourceID == "" {
continue
}
if _, err := tx.Exec(ctx, `
DELETE FROM gateway_access_rules
WHERE subject_type = $1
AND subject_id = $2::uuid
AND effect = $3
AND resource_type = $4
AND resource_id = $5::uuid`,
input.SubjectType, input.SubjectID, input.Effect, resource.ResourceType, resource.ResourceID,
); err != nil {
return nil, err
}
}
for _, resource := range dedupeAccessRuleResources(input.UpsertResources) {
resource = normalizeAccessRuleResource(resource, input.Effect)
if resource.ResourceType == "" || resource.ResourceID == "" {
continue
}
oppositeEffect := "deny"
if input.Effect == "deny" {
oppositeEffect = "allow"
}
if _, err := tx.Exec(ctx, `
DELETE FROM gateway_access_rules
WHERE subject_type = $1
AND subject_id = $2::uuid
AND effect = $3
AND resource_type = $4
AND resource_id = $5::uuid`,
input.SubjectType, input.SubjectID, oppositeEffect, resource.ResourceType, resource.ResourceID,
); err != nil {
return nil, err
}
conditions, _ := json.Marshal(emptyObjectIfNil(resource.Conditions))
metadata, _ := json.Marshal(emptyObjectIfNil(resource.Metadata))
if _, err := tx.Exec(ctx, `
INSERT INTO gateway_access_rules (
subject_type, subject_id, resource_type, resource_id, effect, priority,
min_permission_level, conditions, metadata, status
)
VALUES ($1, $2::uuid, $3, $4::uuid, $5, $6, $7, $8::jsonb, $9::jsonb, $10)
ON CONFLICT (subject_type, subject_id, resource_type, resource_id, effect)
DO UPDATE SET priority = EXCLUDED.priority,
min_permission_level = EXCLUDED.min_permission_level,
conditions = EXCLUDED.conditions,
metadata = EXCLUDED.metadata,
status = EXCLUDED.status,
updated_at = now()`,
input.SubjectType, input.SubjectID, resource.ResourceType, resource.ResourceID, input.Effect,
resource.Priority, resource.MinPermissionLevel, string(conditions), string(metadata), resource.Status,
); err != nil {
return nil, err
}
}
if err := tx.Commit(ctx); err != nil {
return nil, err
}
return s.ListAccessRules(ctx)
}
func (s *Store) BatchAPIKeyAccessRules(ctx context.Context, input AccessRuleBatchInput, user *auth.User) ([]AccessRule, error) {
gatewayUserID := localGatewayUserID(user)
if gatewayUserID == "" {
return nil, ErrLocalUserRequired
}
input = normalizeAccessRuleBatchInput(input)
if input.SubjectType != "api_key" {
return nil, pgx.ErrNoRows
}
var exists bool
if err := s.pool.QueryRow(ctx, `
SELECT EXISTS (
SELECT 1
FROM gateway_api_keys
WHERE id = $1::uuid
AND gateway_user_id = $2::uuid
AND deleted_at IS NULL
)`, input.SubjectID, gatewayUserID).Scan(&exists); err != nil {
return nil, err
}
if !exists {
return nil, pgx.ErrNoRows
}
if err := s.ensureAPIKeyAccessRuleResourcesAllowed(ctx, user, input.UpsertResources); err != nil {
return nil, err
}
if _, err := s.BatchAccessRules(ctx, input); err != nil {
return nil, err
}
return s.ListAPIKeyAccessRules(ctx, user)
}
func (s *Store) filterCandidatesByAccessRules(ctx context.Context, user *auth.User, candidates []RuntimeModelCandidate) ([]RuntimeModelCandidate, error) {
if len(candidates) == 0 {
return candidates, nil
}
resources := candidateAccessResources(candidates)
if len(resources) == 0 {
return candidates, nil
}
rules, err := s.listActiveAccessRulesForResources(ctx, resources)
if err != nil {
return nil, err
}
if len(rules) == 0 {
return candidates, nil
}
subjects := accessRuleSubjects(user)
level := 0
if user != nil {
level = auth.PermissionLevel(user.Roles)
}
filtered := candidates[:0]
for _, candidate := range candidates {
if candidateAllowedByAccessRules(candidate, rules, subjects, level) {
filtered = append(filtered, candidate)
}
}
return filtered, nil
}
func (s *Store) ListAccessiblePlatformModels(ctx context.Context, user *auth.User) ([]PlatformModel, error) {
accessUser, err := s.resolveCurrentAccessUser(ctx, user)
if err != nil {
return nil, err
}
models, err := s.ListModels(ctx)
if err != nil {
return nil, err
}
platforms, err := s.ListPlatforms(ctx)
if err != nil {
return nil, err
}
enabledPlatforms := map[string]bool{}
for _, platform := range platforms {
if platform.Status == "enabled" {
enabledPlatforms[platform.ID] = true
}
}
enabled := make([]PlatformModel, 0, len(models))
for _, model := range models {
if model.Enabled && enabledPlatforms[model.PlatformID] {
enabled = append(enabled, model)
}
}
return s.filterPlatformModelsByAccessRules(ctx, accessUser, enabled)
}
func (s *Store) ensureAPIKeyAccessRuleResourcesAllowed(ctx context.Context, user *auth.User, resources []AccessRuleResourceInput) error {
resources = dedupeAccessRuleResources(resources)
if len(resources) == 0 {
return nil
}
allowed, err := s.accessibleAccessRuleResources(ctx, user)
if err != nil {
return err
}
for _, resource := range resources {
if !allowed[resource.ResourceType+":"+resource.ResourceID] {
return ErrAccessRuleResourceDenied
}
}
return nil
}
func (s *Store) accessibleAccessRuleResources(ctx context.Context, user *auth.User) (map[string]bool, error) {
models, err := s.ListAccessiblePlatformModels(ctx, user)
if err != nil {
return nil, err
}
allowed := map[string]bool{}
for _, model := range models {
allowed["platform:"+model.PlatformID] = true
allowed["platform_model:"+model.ID] = true
if model.BaseModelID != "" {
allowed["base_model:"+model.BaseModelID] = true
}
}
return allowed, nil
}
func (s *Store) resolveCurrentAccessUser(ctx context.Context, user *auth.User) (*auth.User, error) {
if user == nil {
return nil, nil
}
gatewayUserID := localGatewayUserID(user)
if gatewayUserID == "" {
return user, nil
}
next := *user
var userGroupID string
var err error
if strings.TrimSpace(user.APIKeyID) != "" {
err = s.pool.QueryRow(ctx, `
SELECT COALESCE(k.user_group_id::text, u.default_user_group_id::text, '')
FROM gateway_users u
JOIN gateway_api_keys k ON k.gateway_user_id = u.id
WHERE u.id = $1::uuid
AND k.id = $2::uuid
AND u.status = 'active'
AND u.deleted_at IS NULL
AND k.status = 'active'
AND k.deleted_at IS NULL`, gatewayUserID, user.APIKeyID).Scan(&userGroupID)
} else {
err = s.pool.QueryRow(ctx, `
SELECT COALESCE(default_user_group_id::text, '')
FROM gateway_users
WHERE id = $1::uuid
AND status = 'active'
AND deleted_at IS NULL`, gatewayUserID).Scan(&userGroupID)
}
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return &next, nil
}
return nil, err
}
next.UserGroupID = userGroupID
return &next, nil
}
func (s *Store) filterPlatformModelsByAccessRules(ctx context.Context, user *auth.User, models []PlatformModel) ([]PlatformModel, error) {
if len(models) == 0 {
return models, nil
}
resources := platformModelAccessResources(models)
if len(resources) == 0 {
return models, nil
}
rules, err := s.listActiveAccessRulesForResources(ctx, resources)
if err != nil {
return nil, err
}
if len(rules) == 0 {
return models, nil
}
subjects := accessRuleSubjects(user)
level := 0
if user != nil {
level = auth.PermissionLevel(user.Roles)
}
filtered := models[:0]
for _, model := range models {
if platformModelAllowedByAccessRules(model, rules, subjects, level) {
filtered = append(filtered, model)
}
}
return filtered, nil
}
func (s *Store) listActiveAccessRulesForResources(ctx context.Context, resources []accessRuleResource) ([]AccessRule, error) {
values := make([]string, 0, len(resources))
for _, resource := range resources {
if resource.Type == "" || resource.ID == "" {
continue
}
values = append(values, resource.Type+":"+resource.ID)
}
if len(values) == 0 {
return nil, nil
}
rows, err := s.pool.Query(ctx, `
SELECT `+accessRuleColumns+`
FROM gateway_access_rules
WHERE status = 'active'
AND (resource_type || ':' || resource_id::text) = ANY($1)
ORDER BY priority ASC, created_at ASC`, values)
if err != nil {
return nil, err
}
defer rows.Close()
rules := make([]AccessRule, 0)
for rows.Next() {
item, err := scanAccessRule(rows)
if err != nil {
return nil, err
}
rules = append(rules, item)
}
return rules, rows.Err()
}
func candidateAccessResources(candidates []RuntimeModelCandidate) []accessRuleResource {
seen := map[string]bool{}
out := make([]accessRuleResource, 0, len(candidates)*3)
add := func(resourceType string, resourceID string) {
key := resourceType + ":" + resourceID
if resourceID == "" || seen[key] {
return
}
seen[key] = true
out = append(out, accessRuleResource{Type: resourceType, ID: resourceID})
}
for _, candidate := range candidates {
add("platform", candidate.PlatformID)
add("platform_model", candidate.PlatformModelID)
add("base_model", candidate.BaseModelID)
}
return out
}
func platformModelAccessResources(models []PlatformModel) []accessRuleResource {
seen := map[string]bool{}
out := make([]accessRuleResource, 0, len(models)*3)
add := func(resourceType string, resourceID string) {
key := resourceType + ":" + resourceID
if resourceID == "" || seen[key] {
return
}
seen[key] = true
out = append(out, accessRuleResource{Type: resourceType, ID: resourceID})
}
for _, model := range models {
add("platform", model.PlatformID)
add("platform_model", model.ID)
add("base_model", model.BaseModelID)
}
return out
}
func candidateAllowedByAccessRules(candidate RuntimeModelCandidate, rules []AccessRule, subjects map[string]bool, permissionLevel int) bool {
resourceKeys := map[string]bool{
"platform:" + candidate.PlatformID: true,
"platform_model:" + candidate.PlatformModelID: true,
"base_model:" + candidate.BaseModelID: candidate.BaseModelID != "",
}
allowByResource := map[string]bool{}
matchedAllowByResource := map[string]bool{}
for _, rule := range rules {
resourceKey := rule.ResourceType + ":" + rule.ResourceID
if !resourceKeys[resourceKey] {
continue
}
subjectKey := rule.SubjectType + ":" + rule.SubjectID
if rule.Effect == "deny" && subjects[subjectKey] {
return false
}
if rule.Effect == "allow" {
allowByResource[resourceKey] = true
if subjects[subjectKey] && permissionLevel >= rule.MinPermissionLevel {
matchedAllowByResource[resourceKey] = true
}
}
}
for resourceKey := range allowByResource {
if !matchedAllowByResource[resourceKey] {
return false
}
}
return true
}
func platformModelAllowedByAccessRules(model PlatformModel, rules []AccessRule, subjects map[string]bool, permissionLevel int) bool {
resourceKeys := map[string]bool{
"platform:" + model.PlatformID: true,
"platform_model:" + model.ID: true,
"base_model:" + model.BaseModelID: model.BaseModelID != "",
}
allowByResource := map[string]bool{}
matchedAllowByResource := map[string]bool{}
for _, rule := range rules {
resourceKey := rule.ResourceType + ":" + rule.ResourceID
if !resourceKeys[resourceKey] {
continue
}
subjectKey := rule.SubjectType + ":" + rule.SubjectID
if rule.Effect == "deny" && subjects[subjectKey] {
return false
}
if rule.Effect == "allow" {
allowByResource[resourceKey] = true
if subjects[subjectKey] && permissionLevel >= rule.MinPermissionLevel {
matchedAllowByResource[resourceKey] = true
}
}
}
for resourceKey := range allowByResource {
if !matchedAllowByResource[resourceKey] {
return false
}
}
return true
}
func accessRuleSubjects(user *auth.User) map[string]bool {
subjects := map[string]bool{}
if user == nil {
return subjects
}
add := func(subjectType string, id string) {
id = strings.TrimSpace(id)
if id != "" {
subjects[subjectType+":"+id] = true
}
}
add("user", firstNonEmpty(user.GatewayUserID, user.ID))
add("tenant", firstNonEmpty(user.GatewayTenantID, user.TenantID))
add("api_key", user.APIKeyID)
add("user_group", user.UserGroupID)
for _, groupKey := range user.UserGroupKeys {
add("user_group", groupKey)
}
return subjects
}
const accessRuleColumns = `
id::text, subject_type, subject_id::text, resource_type, resource_id::text, effect,
priority, min_permission_level, conditions, metadata, status, created_at, updated_at`
const apiKeyAccessRuleColumns = `
ar.id::text, ar.subject_type, ar.subject_id::text, ar.resource_type, ar.resource_id::text, ar.effect,
ar.priority, ar.min_permission_level, ar.conditions, ar.metadata, ar.status, ar.created_at, ar.updated_at`
func scanAccessRule(row scanner) (AccessRule, error) {
var item AccessRule
var conditions []byte
var metadata []byte
if err := row.Scan(
&item.ID,
&item.SubjectType,
&item.SubjectID,
&item.ResourceType,
&item.ResourceID,
&item.Effect,
&item.Priority,
&item.MinPermissionLevel,
&conditions,
&metadata,
&item.Status,
&item.CreatedAt,
&item.UpdatedAt,
); err != nil {
return AccessRule{}, err
}
item.Conditions = decodeObject(conditions)
item.Metadata = decodeObject(metadata)
return item, nil
}
func normalizeAccessRuleInput(input AccessRuleInput) AccessRuleInput {
input.SubjectType = strings.TrimSpace(input.SubjectType)
input.SubjectID = strings.TrimSpace(input.SubjectID)
input.ResourceType = strings.TrimSpace(input.ResourceType)
input.ResourceID = strings.TrimSpace(input.ResourceID)
input.Effect = firstNonEmpty(strings.TrimSpace(input.Effect), "allow")
input.Status = firstNonEmpty(strings.TrimSpace(input.Status), "active")
if input.Priority == 0 {
input.Priority = 100
}
if input.MinPermissionLevel < 0 {
input.MinPermissionLevel = 0
}
return input
}
func normalizeAccessRuleBatchInput(input AccessRuleBatchInput) AccessRuleBatchInput {
input.SubjectType = strings.TrimSpace(input.SubjectType)
input.SubjectID = strings.TrimSpace(input.SubjectID)
input.Effect = firstNonEmpty(strings.TrimSpace(input.Effect), "allow")
return input
}
func normalizeAccessRuleResource(input AccessRuleResourceInput, effect string) AccessRuleResourceInput {
input.ResourceType = strings.TrimSpace(input.ResourceType)
input.ResourceID = strings.TrimSpace(input.ResourceID)
input.Status = firstNonEmpty(strings.TrimSpace(input.Status), "active")
if input.Priority == 0 {
if effect == "deny" {
input.Priority = 10
} else {
input.Priority = 100
}
}
if input.MinPermissionLevel < 0 {
input.MinPermissionLevel = 0
}
return input
}
func dedupeAccessRuleResources(resources []AccessRuleResourceInput) []AccessRuleResourceInput {
seen := map[string]bool{}
out := make([]AccessRuleResourceInput, 0, len(resources))
for _, resource := range resources {
resource.ResourceType = strings.TrimSpace(resource.ResourceType)
resource.ResourceID = strings.TrimSpace(resource.ResourceID)
key := resource.ResourceType + ":" + resource.ResourceID
if resource.ResourceType == "" || resource.ResourceID == "" || seen[key] {
continue
}
seen[key] = true
out = append(out, resource)
}
return out
}