package httpapi import ( "encoding/json" "errors" "net" "net/http" "strings" "github.com/easyai/easyai-ai-gateway/apps/api/internal/auth" "github.com/easyai/easyai-ai-gateway/apps/api/internal/store" ) type walletBalanceRequest struct { Currency string `json:"currency"` Balance float64 `json:"balance"` Reason string `json:"reason"` IdempotencyKey string `json:"idempotencyKey"` Metadata map[string]any `json:"metadata"` } func (s *Server) setUserWalletBalance(w http.ResponseWriter, r *http.Request) { actor, _ := auth.UserFromContext(r.Context()) var input walletBalanceRequest if err := json.NewDecoder(r.Body).Decode(&input); err != nil { writeError(w, http.StatusBadRequest, "invalid json body") return } if input.Balance < 0 { writeError(w, http.StatusBadRequest, "wallet balance cannot be negative") return } gatewayUserID := strings.TrimSpace(r.PathValue("userID")) reason := strings.TrimSpace(input.Reason) if reason == "" { writeError(w, http.StatusBadRequest, "reason is required") return } var result store.WalletAdjustmentResult var auditLog store.AuditLog err := s.store.InTx(r.Context(), func(tx store.Tx) error { next, err := s.store.SetUserWalletBalanceTx(r.Context(), tx, store.WalletBalanceAdjustmentInput{ GatewayUserID: gatewayUserID, Currency: input.Currency, Balance: input.Balance, Reason: reason, IdempotencyKey: input.IdempotencyKey, Metadata: input.Metadata, }) if err != nil { return err } result = next record, err := s.store.RecordAuditLogTx(r.Context(), tx, walletAdjustmentAuditInput(r, actor, reason, result)) if err != nil { return err } auditLog = record return nil }) if err != nil { switch { case store.IsNotFound(err): writeError(w, http.StatusNotFound, "user not found") case errors.Is(err, store.ErrWalletBalanceUnchanged): writeError(w, http.StatusBadRequest, "wallet balance is unchanged") default: s.logger.Error("set user wallet balance failed", "error", err) writeError(w, http.StatusInternalServerError, "set user wallet balance failed") } return } writeJSON(w, http.StatusOK, map[string]any{ "account": result.Account, "before": result.Before, "transaction": result.Transaction, "auditLog": auditLog, }) } func (s *Server) listAuditLogs(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() limit, err := positiveQueryInt(query.Get("limit"), 100) if err != nil { writeError(w, http.StatusBadRequest, "invalid limit") return } items, err := s.store.ListAuditLogs(r.Context(), store.AuditLogFilter{ Category: query.Get("category"), Action: query.Get("action"), TargetType: query.Get("targetType"), TargetID: query.Get("targetId"), Limit: limit, }) if err != nil { s.logger.Error("list audit logs failed", "error", err) writeError(w, http.StatusInternalServerError, "list audit logs failed") return } writeJSON(w, http.StatusOK, map[string]any{"items": items}) } func walletAdjustmentAuditInput(r *http.Request, actor *auth.User, reason string, result store.WalletAdjustmentResult) store.AuditLogInput { actorGatewayUserID := "" actorUserID := "" actorUsername := "" actorSource := "" actorRoles := []string(nil) if actor != nil { actorGatewayUserID = uuidText(firstNonEmptyText(actor.GatewayUserID, actor.ID)) actorUserID = actor.ID actorUsername = actor.Username actorSource = actor.Source actorRoles = actor.Roles } return store.AuditLogInput{ Category: "billing", Action: "wallet.balance.set", ActorGatewayUserID: actorGatewayUserID, ActorUserID: actorUserID, ActorUsername: actorUsername, ActorSource: actorSource, ActorRoles: actorRoles, TargetType: "gateway_user", TargetID: result.Account.GatewayUserID, TargetGatewayUserID: result.Account.GatewayUserID, TargetGatewayTenantID: result.Account.GatewayTenantID, RequestIP: requestIP(r), UserAgent: r.UserAgent(), BeforeState: map[string]any{ "walletAccount": result.Before, }, AfterState: map[string]any{ "walletAccount": result.Account, "transaction": result.Transaction, }, Metadata: map[string]any{ "reason": reason, "currency": result.Account.Currency, "transactionId": result.Transaction.ID, "amount": result.Transaction.Amount, "direction": result.Transaction.Direction, }, } } func requestIP(r *http.Request) string { if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); forwarded != "" { parts := strings.Split(forwarded, ",") return strings.TrimSpace(parts[0]) } if realIP := strings.TrimSpace(r.Header.Get("X-Real-IP")); realIP != "" { return realIP } host, _, err := net.SplitHostPort(r.RemoteAddr) if err == nil { return host } return strings.TrimSpace(r.RemoteAddr) } func firstNonEmptyText(values ...string) string { for _, value := range values { if text := strings.TrimSpace(value); text != "" { return text } } return "" } func uuidText(value string) string { value = strings.TrimSpace(value) if len(value) != 36 { return "" } if value[8] != '-' || value[13] != '-' || value[18] != '-' || value[23] != '-' { return "" } return value }