easyai-ai-gateway/apps/api/internal/httpapi/server.go

172 lines
13 KiB
Go

package httpapi
import (
"log/slog"
"net/http"
"strings"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/auth"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/config"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/runner"
"github.com/easyai/easyai-ai-gateway/apps/api/internal/store"
)
type Server struct {
cfg config.Config
store *store.Store
auth *auth.Authenticator
runner *runner.Service
logger *slog.Logger
}
func NewServer(cfg config.Config, db *store.Store, logger *slog.Logger) http.Handler {
server := &Server{
cfg: cfg,
store: db,
auth: auth.New(cfg.JWTSecret, cfg.ServerMainBaseURL, cfg.ServerMainInternalToken),
runner: runner.New(cfg, db, logger),
logger: logger,
}
server.auth.LocalAPIKeyVerifier = db.VerifyLocalAPIKey
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", server.health)
mux.HandleFunc("GET /readyz", server.ready)
mux.HandleFunc("GET /static/simulation/{asset}", serveSimulationAsset)
mux.Handle("POST /api/v1/auth/register", server.auth.Require(auth.PermissionPublic, http.HandlerFunc(server.register)))
mux.Handle("POST /api/v1/auth/login", server.auth.Require(auth.PermissionPublic, http.HandlerFunc(server.login)))
mux.Handle("GET /api/v1/me", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.me)))
mux.Handle("GET /api/v1/public/catalog/providers", server.auth.Require(auth.PermissionPublic, http.HandlerFunc(server.listCatalogProviders)))
mux.Handle("GET /api/v1/public/catalog/base-models", server.auth.Require(auth.PermissionPublic, http.HandlerFunc(server.listBaseModels)))
mux.Handle("GET /api/admin/catalog/providers", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listCatalogProviders)))
mux.Handle("POST /api/admin/catalog/providers", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createCatalogProvider)))
mux.Handle("PATCH /api/admin/catalog/providers/{providerID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateCatalogProvider)))
mux.Handle("DELETE /api/admin/catalog/providers/{providerID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteCatalogProvider)))
mux.Handle("GET /api/admin/catalog/base-models", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listBaseModels)))
mux.Handle("POST /api/admin/catalog/base-models", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createBaseModel)))
mux.Handle("POST /api/admin/catalog/base-models/reset-all", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.resetAllBaseModels)))
mux.Handle("PATCH /api/admin/catalog/base-models/{baseModelID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateBaseModel)))
mux.Handle("POST /api/admin/catalog/base-models/{baseModelID}/reset", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.resetBaseModel)))
mux.Handle("DELETE /api/admin/catalog/base-models/{baseModelID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteBaseModel)))
mux.Handle("GET /api/admin/tenants", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listTenants)))
mux.Handle("POST /api/admin/tenants", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createTenant)))
mux.Handle("PATCH /api/admin/tenants/{tenantID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateTenant)))
mux.Handle("DELETE /api/admin/tenants/{tenantID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteTenant)))
mux.Handle("GET /api/admin/users", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listUsers)))
mux.Handle("POST /api/admin/users", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createGatewayUser)))
mux.Handle("PATCH /api/admin/users/{userID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateGatewayUser)))
mux.Handle("PATCH /api/admin/users/{userID}/wallet", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.setUserWalletBalance)))
mux.Handle("DELETE /api/admin/users/{userID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteGatewayUser)))
mux.Handle("GET /api/admin/audit-logs", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listAuditLogs)))
mux.Handle("GET /api/admin/user-groups", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listUserGroups)))
mux.Handle("POST /api/admin/user-groups", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createUserGroup)))
mux.Handle("PATCH /api/admin/user-groups/{groupID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateUserGroup)))
mux.Handle("DELETE /api/admin/user-groups/{groupID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteUserGroup)))
mux.Handle("GET /api/admin/access-rules", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listAccessRules)))
mux.Handle("POST /api/admin/access-rules", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createAccessRule)))
mux.Handle("POST /api/admin/access-rules/batch", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.batchAccessRules)))
mux.Handle("PATCH /api/admin/access-rules/{ruleID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateAccessRule)))
mux.Handle("DELETE /api/admin/access-rules/{ruleID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteAccessRule)))
mux.Handle("GET /api/v1/api-keys", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listAPIKeys)))
mux.Handle("POST /api/v1/api-keys", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.createAPIKey)))
mux.Handle("GET /api/v1/api-keys/access-rules", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listAPIKeyAccessRules)))
mux.Handle("POST /api/v1/api-keys/access-rules/batch", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.batchAPIKeyAccessRules)))
mux.Handle("PATCH /api/v1/api-keys/{apiKeyID}/disable", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.disableAPIKey)))
mux.Handle("DELETE /api/v1/api-keys/{apiKeyID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.deleteAPIKey)))
mux.Handle("GET /api/playground/api-keys", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listPlayableAPIKeys)))
mux.Handle("GET /api/admin/pricing/rules", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPricingRules)))
mux.Handle("GET /api/admin/pricing/rule-sets", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPricingRuleSets)))
mux.Handle("POST /api/admin/pricing/rule-sets", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createPricingRuleSet)))
mux.Handle("PATCH /api/admin/pricing/rule-sets/{ruleSetID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updatePricingRuleSet)))
mux.Handle("DELETE /api/admin/pricing/rule-sets/{ruleSetID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deletePricingRuleSet)))
mux.Handle("POST /api/v1/pricing/estimate", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.estimatePricing)))
mux.Handle("GET /api/admin/runtime/policy-sets", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listRuntimePolicySets)))
mux.Handle("POST /api/admin/runtime/policy-sets", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createRuntimePolicySet)))
mux.Handle("PATCH /api/admin/runtime/policy-sets/{policySetID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updateRuntimePolicySet)))
mux.Handle("DELETE /api/admin/runtime/policy-sets/{policySetID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deleteRuntimePolicySet)))
mux.Handle("GET /api/admin/platforms", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listPlatforms)))
mux.Handle("POST /api/admin/platforms", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createPlatform)))
mux.Handle("PATCH /api/admin/platforms/{platformID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.updatePlatform)))
mux.Handle("DELETE /api/admin/platforms/{platformID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deletePlatform)))
mux.Handle("PUT /api/admin/platforms/{platformID}/models", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.replacePlatformModels)))
mux.Handle("POST /api/admin/platforms/{platformID}/models", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createPlatformModel)))
mux.Handle("POST /api/admin/platform-models", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.createPlatformModel)))
mux.Handle("DELETE /api/admin/platform-models/{modelID}", server.requireAdmin(auth.PermissionManager, http.HandlerFunc(server.deletePlatformModel)))
mux.Handle("GET /api/admin/models", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listModels)))
mux.Handle("GET /api/v1/model-catalog", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listModelCatalog)))
mux.Handle("GET /api/v1/platforms", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listPlayablePlatforms)))
mux.Handle("GET /api/v1/models", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listPlayableModels)))
mux.Handle("GET /api/v1/playground/models", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listPlayableModels)))
mux.Handle("GET /api/admin/runtime/rate-limit-windows", server.requireAdmin(auth.PermissionPower, http.HandlerFunc(server.listRateLimitWindows)))
mux.Handle("POST /api/v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", false)))
mux.Handle("POST /api/v1/responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", false)))
mux.Handle("POST /api/v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", false)))
mux.Handle("POST /api/v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", false)))
mux.Handle("POST /api/v1/videos/generations", server.auth.Require(auth.PermissionBasic, server.createTask("videos.generations", false)))
mux.Handle("GET /api/v1/tasks", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.listTasks)))
mux.Handle("GET /api/v1/tasks/{taskID}", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.getTask)))
mux.Handle("GET /api/v1/tasks/{taskID}/events", server.auth.Require(auth.PermissionBasic, http.HandlerFunc(server.taskEvents)))
mux.Handle("POST /chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
mux.Handle("POST /v1/chat/completions", server.auth.Require(auth.PermissionBasic, server.createTask("chat.completions", true)))
mux.Handle("POST /responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", true)))
mux.Handle("POST /v1/responses", server.auth.Require(auth.PermissionBasic, server.createTask("responses", true)))
mux.Handle("POST /images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", true)))
mux.Handle("POST /v1/images/generations", server.auth.Require(auth.PermissionBasic, server.createTask("images.generations", true)))
mux.Handle("POST /images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", true)))
mux.Handle("POST /v1/images/edits", server.auth.Require(auth.PermissionBasic, server.createTask("images.edits", true)))
return server.recover(server.cors(mux))
}
func (s *Server) requireAdmin(permission auth.Permission, next http.Handler) http.Handler {
return s.auth.Require(permission, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user, _ := auth.UserFromContext(r.Context())
if user != nil && strings.TrimSpace(user.APIKeyID) != "" {
writeError(w, http.StatusForbidden, "admin api does not accept api key credentials")
return
}
next.ServeHTTP(w, r)
}))
}
func (s *Server) cors(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" && originAllowed(origin, s.cfg.CORSAllowedOrigin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Comfy-Api-Key")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
}
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
func originAllowed(origin string, allowed string) bool {
for _, item := range strings.Split(allowed, ",") {
item = strings.TrimSpace(item)
if item == "*" || strings.EqualFold(origin, item) {
return true
}
}
return false
}
func (s *Server) recover(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
s.logger.Error("panic recovered", "error", err, "path", r.URL.Path)
writeError(w, http.StatusInternalServerError, "internal server error")
}
}()
next.ServeHTTP(w, r)
})
}