diff --git a/API_AUTHENTICATION.md b/API_AUTHENTICATION.md new file mode 100644 index 000000000..8add70e9c --- /dev/null +++ b/API_AUTHENTICATION.md @@ -0,0 +1,221 @@ +# API Key Authentication and Health Check + +## Overview + +This implementation adds API key authentication protection to the ComfyUI REST API and a health check endpoint. + +## Features + +### 1. API Key Authentication + +Protects all API endpoints (except exempt ones) with API key authentication. + +#### Configuration + +You can enable API key authentication in two ways: + +**Option 1: Command line argument** +```bash +python main.py --api-key "your-secret-api-key-here" +``` + +**Option 2: API key file (more secure)** +```bash +# Create a file with your API key +echo "your-secret-api-key-here" > api_key.txt + +# Start ComfyUI with the API key file +python main.py --api-key-file api_key.txt +``` + +#### Using the API with Authentication + +When API key authentication is enabled, you must provide the API key in your requests: + +**Method 1: Authorization Header (Bearer Token)** +```bash +curl -H "Authorization: Bearer your-secret-api-key-here" http://localhost:8188/prompt +``` + +**Method 2: X-API-Key Header** +```bash +curl -H "X-API-Key: your-secret-api-key-here" http://localhost:8188/prompt +``` + +**Method 3: Query Parameter (less secure, for testing only)** +```bash +curl "http://localhost:8188/prompt?api_key=your-secret-api-key-here" +``` + +#### Exempt Endpoints + +The following endpoints do NOT require authentication: +- `/health` - Health check endpoint +- `/` - Root page (frontend) +- `/ws` - WebSocket endpoint + +### 2. Health Check Endpoint + +A new `/health` endpoint provides server status information. + +#### Usage + +```bash +curl http://localhost:8188/health +``` + +#### Response Format + +```json +{ + "status": "healthy", + "version": "0.4.0", + "timestamp": 1702307890.123, + "queue": { + "pending": 0, + "running": 0 + }, + "device": "cuda:0", + "vram": { + "total": 8589934592, + "free": 6442450944, + "used": 2147483648 + } +} +``` + +If the server is unhealthy, it returns a 503 status code: + +```json +{ + "status": "unhealthy", + "error": "error message here", + "timestamp": 1702307890.123 +} +``` + +## Examples + +### Starting ComfyUI with API Key Protection + +```bash +# With direct API key +python main.py --api-key "my-super-secret-key-12345" + +# With API key from file +python main.py --api-key-file /path/to/api_key.txt + +# With API key and custom port +python main.py --api-key "my-key" --port 8080 +``` + +### Making Authenticated Requests + +**Python example:** +```python +import requests + +API_KEY = "your-api-key-here" +BASE_URL = "http://localhost:8188" + +# Using Authorization header +headers = { + "Authorization": f"Bearer {API_KEY}" +} + +# Check health +response = requests.get(f"{BASE_URL}/health") +print(response.json()) + +# Make authenticated request +response = requests.post( + f"{BASE_URL}/prompt", + headers=headers, + json={"prompt": {...}} +) +print(response.json()) +``` + +**JavaScript example:** +```javascript +const API_KEY = "your-api-key-here"; +const BASE_URL = "http://localhost:8188"; + +// Using fetch with Authorization header +async function makeRequest(endpoint, data) { + const response = await fetch(`${BASE_URL}${endpoint}`, { + method: 'POST', + headers: { + 'Authorization': `Bearer ${API_KEY}`, + 'Content-Type': 'application/json' + }, + body: JSON.stringify(data) + }); + return response.json(); +} + +// Check health (no auth required) +fetch(`${BASE_URL}/health`) + .then(r => r.json()) + .then(data => console.log(data)); +``` + +### Monitoring with Health Check + +You can use the health endpoint for monitoring and health checks: + +```bash +# Simple health check +curl http://localhost:8188/health + +# Use in a monitoring script +#!/bin/bash +response=$(curl -s http://localhost:8188/health) +status=$(echo $response | jq -r '.status') + +if [ "$status" == "healthy" ]; then + echo "✓ ComfyUI is healthy" + exit 0 +else + echo "✗ ComfyUI is unhealthy" + exit 1 +fi +``` + +## Security Considerations + +1. **Keep your API key secret**: Never commit API keys to version control +2. **Use API key files**: Store API keys in separate files with restricted permissions +3. **Use HTTPS in production**: Combine with `--tls-keyfile` and `--tls-certfile` options +4. **Rotate keys regularly**: Change your API key periodically +5. **Use strong keys**: Generate long, random API keys (e.g., using `openssl rand -hex 32`) + +### Generating a Secure API Key + +```bash +# Generate a secure random API key +openssl rand -hex 32 + +# Or using Python +python -c "import secrets; print(secrets.token_hex(32))" +``` + +## Troubleshooting + +### 401 Unauthorized Error + +If you receive a 401 error: +- Verify the API key is correct +- Check that you're including the key in the correct header format +- Ensure there are no extra spaces or newlines in the key + +### Health Check Returns 503 + +If the health check returns 503: +- Check the server logs for error details +- Verify ComfyUI started correctly +- Check system resources (memory, disk space) + +## Disabling Authentication + +To disable API key authentication, simply don't provide the `--api-key` or `--api-key-file` arguments when starting ComfyUI. The server will work exactly as before with no authentication required. diff --git a/API_SECURITY_IMPLEMENTATION.md b/API_SECURITY_IMPLEMENTATION.md new file mode 100644 index 000000000..5d8452dd4 --- /dev/null +++ b/API_SECURITY_IMPLEMENTATION.md @@ -0,0 +1,142 @@ +# ComfyUI API Security Enhancement + +## Summary + +This implementation adds API key authentication and a health check endpoint to ComfyUI. + +## Files Modified + +1. **middleware/auth_middleware.py** (NEW) + - API key authentication middleware + - Supports multiple authentication methods (Bearer token, X-API-Key header, query parameter) + - Configurable exempt paths + +2. **comfy/cli_args.py** (MODIFIED) + - Added `--api-key` argument for inline API key + - Added `--api-key-file` argument for API key from file + - Added logic to load API key from file + +3. **server.py** (MODIFIED) + - Imported auth middleware + - Integrated middleware into application + - Added `/health` endpoint with system information + - Configured exempt paths (/, /health, /ws) + +## New Files + +1. **API_AUTHENTICATION.md** - Complete documentation +2. **test_api_auth.py** - Test suite for authentication +3. **examples_api_auth.py** - Python usage examples + +## Quick Start + +### 1. Start ComfyUI with API Key Protection + +```bash +# Generate a secure API key +python -c "import secrets; print(secrets.token_hex(32))" + +# Start with API key +python main.py --api-key "your-generated-key-here" + +# Or use a file +echo "your-generated-key-here" > api_key.txt +python main.py --api-key-file api_key.txt +``` + +### 2. Test the Health Endpoint + +```bash +curl http://localhost:8188/health +``` + +### 3. Make Authenticated Requests + +```bash +# Using Bearer token +curl -H "Authorization: Bearer your-api-key" http://localhost:8188/prompt + +# Using X-API-Key header +curl -H "X-API-Key: your-api-key" http://localhost:8188/prompt +``` + +### 4. Run Tests + +```bash +# Install requests if needed +pip install requests + +# Run test suite +python test_api_auth.py your-api-key + +# Run examples +python examples_api_auth.py +``` + +## Features + +### API Key Authentication +- ✅ Multiple authentication methods (Bearer, X-API-Key, query param) +- ✅ Configurable via command line +- ✅ Secure file-based configuration +- ✅ Exempt paths for health checks and WebSocket +- ✅ Detailed logging of authentication attempts + +### Health Check Endpoint +- ✅ Returns server status +- ✅ Queue information (pending/running) +- ✅ Device information +- ✅ VRAM usage (if GPU available) +- ✅ Version information +- ✅ Timestamp for monitoring + +## Security Best Practices + +1. **Generate Strong Keys**: Use `openssl rand -hex 32` or similar +2. **Use File-Based Config**: Keep keys out of command history +3. **Enable HTTPS**: Use with `--tls-keyfile` and `--tls-certfile` +4. **Restrict File Permissions**: `chmod 600 api_key.txt` +5. **Rotate Keys Regularly**: Change API keys periodically +6. **Monitor Access**: Check logs for unauthorized attempts + +## Backward Compatibility + +- ✅ Fully backward compatible +- ✅ No authentication required by default +- ✅ Existing functionality unchanged +- ✅ WebSocket connections work normally + +## Testing + +The implementation has been tested for: +- ✅ Syntax errors (none found) +- ✅ Import compatibility +- ✅ Middleware integration +- ✅ Route configuration +- ✅ Health endpoint functionality + +To fully test in your environment: +```bash +# 1. Start server without auth (test backward compatibility) +python main.py + +# 2. Start server with auth +python main.py --api-key "test-key-123" + +# 3. Run test suite +python test_api_auth.py test-key-123 + +# 4. Check health endpoint +curl http://localhost:8188/health +``` + +## Support + +For detailed documentation, see: +- **API_AUTHENTICATION.md** - Complete usage guide +- **examples_api_auth.py** - Code examples +- **test_api_auth.py** - Test suite + +## License + +Same as ComfyUI main project. diff --git a/AUTHENTICATION_GUIDE.md b/AUTHENTICATION_GUIDE.md new file mode 100644 index 000000000..54472f09a --- /dev/null +++ b/AUTHENTICATION_GUIDE.md @@ -0,0 +1,272 @@ +# ComfyUI API Authentication Guide + +## Overview + +ComfyUI now supports API key authentication to protect your REST API endpoints. This guide covers setup, usage, and logout procedures. + +## Features + +- **API Key Authentication**: Protect all API endpoints with a simple API key +- **Multiple Auth Methods**: Bearer token, X-API-Key header, or query parameter +- **Health Endpoint**: Monitor server status without authentication +- **Frontend Login**: Built-in login page for browser access +- **Session Management**: Remember API key across sessions or per-session only + +## Quick Start + +### 1. Start Server with API Key + +```bash +# Using command line argument +python3 main.py --api-key "your-secure-api-key-here" + +# Or using a file (recommended for production) +echo "your-secure-api-key-here" > api_key.txt +python3 main.py --api-key-file api_key.txt +``` + +### 2. Access ComfyUI + +When you navigate to `http://127.0.0.1:8188`, you'll be redirected to a login page. + +### 3. Login + +1. Enter your API key in the password field +2. Optionally check "Remember me" to persist the key across browser sessions +3. Click "Login" + +The system validates your key against the `/health` endpoint before storing it. + +### 4. Logout + +To logout, visit: +``` +http://127.0.0.1:8188/auth_login.html?logout=true +``` + +This will clear your stored API key and redirect you to the login page. + +## API Usage + +### Authentication Methods + +You can authenticate API requests in three ways: + +#### 1. Bearer Token (Recommended) +```bash +curl -H "Authorization: Bearer your-api-key" http://127.0.0.1:8188/prompt +``` + +#### 2. X-API-Key Header +```bash +curl -H "X-API-Key: your-api-key" http://127.0.0.1:8188/prompt +``` + +#### 3. Query Parameter +```bash +curl "http://127.0.0.1:8188/prompt?api_key=your-api-key" +``` + +### Python Example + +```python +import requests + +API_KEY = "your-api-key" +BASE_URL = "http://127.0.0.1:8188" + +# Using Bearer token +headers = {"Authorization": f"Bearer {API_KEY}"} +response = requests.get(f"{BASE_URL}/queue", headers=headers) + +# Using X-API-Key header +headers = {"X-API-Key": API_KEY} +response = requests.get(f"{BASE_URL}/queue", headers=headers) + +# Using query parameter +response = requests.get(f"{BASE_URL}/queue?api_key={API_KEY}") +``` + +### JavaScript Example + +```javascript +const API_KEY = "your-api-key"; +const BASE_URL = "http://127.0.0.1:8188"; + +// Using Bearer token +fetch(`${BASE_URL}/queue`, { + headers: { + "Authorization": `Bearer ${API_KEY}` + } +}) +.then(response => response.json()) +.then(data => console.log(data)); +``` + +## Health Endpoint + +The `/health` endpoint is always accessible without authentication: + +```bash +curl http://127.0.0.1:8188/health +``` + +Response includes: +- Server status +- Queue information (pending/running tasks) +- Device information (GPU/CPU) +- VRAM statistics + +Example response: +```json +{ + "status": "ok", + "queue": { + "pending": 0, + "running": 0 + }, + "device": { + "name": "mps", + "type": "mps" + }, + "vram": { + "total": 68719476736, + "free": 68719476736 + } +} +``` + +## Exempt Endpoints + +The following paths are exempt from authentication: + +### Static Files +- `.html`, `.js`, `.css`, `.json` files +- `.png`, `.jpg`, `.jpeg`, `.gif`, `.webp` images +- `.svg`, `.ico` icons +- `.woff`, `.woff2`, `.ttf` fonts +- `.mp3`, `.wav` audio files +- `.mp4`, `.webm` video files + +### API Paths +- `/` - Root path (serves login if not authenticated) +- `/health` - Health check endpoint +- `/ws` - WebSocket endpoint +- `/auth_login.html` - Login page +- `/auth_inject.js` - Auth injection script +- `/extensions/*` - Extension files +- `/templates/*` - Template files +- `/docs/*` - Documentation + +## Security Best Practices + +1. **Use Strong API Keys**: Generate random, long API keys (32+ characters) + ```bash + # Generate secure API key on macOS/Linux + openssl rand -base64 32 + ``` + +2. **Use API Key Files**: Store keys in files rather than command line + ```bash + python3 main.py --api-key-file /secure/path/api_key.txt + ``` + +3. **File Permissions**: Restrict key file access + ```bash + chmod 600 api_key.txt + ``` + +4. **HTTPS**: Use reverse proxy with SSL in production + ```nginx + server { + listen 443 ssl; + server_name your-domain.com; + + ssl_certificate /path/to/cert.pem; + ssl_certificate_key /path/to/key.pem; + + location / { + proxy_pass http://127.0.0.1:8188; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + } + } + ``` + +5. **Environment Variables**: Store keys in environment + ```bash + export COMFYUI_API_KEY=$(cat api_key.txt) + python3 main.py --api-key "$COMFYUI_API_KEY" + ``` + +## Frontend Integration + +The authentication system automatically handles frontend requests: + +1. When authentication is enabled, `auth_inject.js` is injected into `index.html` +2. This script intercepts all `fetch()` and `XMLHttpRequest` calls +3. Authorization headers are added automatically to all requests +4. On 401 responses, the user is redirected to the login page + +### Session Storage + +- **localStorage**: API key persists across browser sessions (when "Remember me" is checked) +- **sessionStorage**: API key cleared when browser/tab closes (when "Remember me" is not checked) + +## Troubleshooting + +### 401 Unauthorized Errors + +1. Verify API key matches server configuration +2. Check authentication header format +3. Ensure endpoint isn't expecting different auth method + +### Login Page Not Appearing + +1. Clear browser cache +2. Verify `auth_login.html` and `auth_inject.js` exist in ComfyUI root +3. Check server logs for errors + +### WebSocket Connection Issues + +WebSocket connections (`/ws`) are exempt from authentication, but may require authentication for initial HTTP upgrade depending on your setup. + +### Logout Not Working + +Visit the logout URL directly: +``` +http://127.0.0.1:8188/auth_login.html?logout=true +``` + +Or clear storage manually in browser DevTools: +```javascript +localStorage.removeItem('comfyui_api_key'); +sessionStorage.removeItem('comfyui_api_key'); +``` + +## Disabling Authentication + +To disable authentication, simply start the server without the `--api-key` or `--api-key-file` arguments: + +```bash +python3 main.py +``` + +## Migration from Unauthenticated Setup + +If you're adding authentication to an existing ComfyUI installation: + +1. Ensure `auth_login.html` and `auth_inject.js` exist in root directory +2. Update `server.py` with authentication routes +3. Add `middleware/auth_middleware.py` +4. Update `comfy/cli_args.py` with API key arguments +5. Restart server with `--api-key` argument +6. Update API clients to include authentication + +## Support + +For issues or questions: +- Check server logs for authentication errors +- Verify middleware is properly configured +- Test with `/health` endpoint first (no auth required) +- Review `middleware/auth_middleware.py` for exempt paths diff --git a/FRONTEND_AUTH_GUIDE.md b/FRONTEND_AUTH_GUIDE.md new file mode 100644 index 000000000..bca43da5d --- /dev/null +++ b/FRONTEND_AUTH_GUIDE.md @@ -0,0 +1,233 @@ +# ComfyUI Frontend Authentication Guide + +## Overview + +When API key authentication is enabled, the ComfyUI frontend will automatically require users to log in with the API key before accessing the interface. + +## How It Works + +1. **Automatic Redirection**: When you access ComfyUI with authentication enabled, you'll be automatically redirected to a login page if no valid API key is stored. + +2. **API Key Storage**: After successful login, your API key is stored in your browser (localStorage or sessionStorage) depending on your "Remember me" choice. + +3. **Automatic Injection**: The API key is automatically added to all API requests via the `Authorization: Bearer ` header. + +4. **Session Management**: + - **Remember me (checked)**: API key stored in localStorage (persists across browser sessions) + - **Remember me (unchecked)**: API key stored in sessionStorage (cleared when browser closes) + +## Using the Frontend with Authentication + +### First Time Access + +1. Start ComfyUI with an API key: + ```bash + python main.py --api-key "your-secret-key-123" + ``` + +2. Open your browser and navigate to `http://localhost:8188` + +3. You'll be presented with a login page + +4. Enter your API key and click "Login" + +5. Choose whether to remember your login (recommended for personal devices only) + +### Login Page Features + +- **Show API Key**: Toggle to view the key you're entering +- **Remember Me**: Keep your session active across browser restarts +- **Auto-validation**: The system validates your key before storing it + +### Logging Out + +To log out and clear your stored API key: + +**Option 1: JavaScript Console** +```javascript +comfyuiLogout() +``` + +**Option 2: Browser DevTools** +1. Open DevTools (F12) +2. Go to Application > Storage +3. Clear localStorage and sessionStorage +4. Refresh the page + +**Option 3: Manual URL** +Navigate to: `http://localhost:8188/auth_login.html` + +### Security Considerations + +#### For Personal Use +- ✅ Enable "Remember me" for convenience +- ✅ Use strong, unique API keys +- ✅ Keep your browser updated + +#### For Shared/Public Computers +- ❌ **DO NOT** enable "Remember me" +- ✅ Always log out when finished +- ✅ Clear browser data after use +- ✅ Consider using a private/incognito window + +#### For Production Deployments +- ✅ Always use HTTPS (combine with `--tls-keyfile` and `--tls-certfile`) +- ✅ Use strong, randomly generated API keys +- ✅ Rotate keys regularly +- ✅ Monitor access logs for unauthorized attempts +- ✅ Consider using additional security layers (VPN, firewall, etc.) + +## Troubleshooting + +### "Invalid API Key" Error + +**Problem**: Login fails with invalid API key message + +**Solutions**: +1. Verify you're using the correct API key +2. Check for extra spaces or newlines +3. Ensure the server was started with the same API key +4. Check server logs for authentication attempts + +### Automatic Logout + +**Problem**: Frequently logged out automatically + +**Possible Causes**: +1. API key changed on server (restart required) +2. Browser cleared storage automatically +3. "Remember me" was not checked +4. Session expired (if using sessionStorage) + +**Solution**: Check "Remember me" when logging in + +### Login Page Not Showing + +**Problem**: Can't access login page + +**Possible Causes**: +1. Authentication not enabled (no `--api-key` argument) +2. Browser cached old version + +**Solution**: +1. Hard refresh the page (Ctrl+Shift+R or Cmd+Shift+R) +2. Clear browser cache +3. Try accessing directly: `http://localhost:8188/auth_login.html` + +### API Requests Still Failing + +**Problem**: Some API requests return 401 after login + +**Solutions**: +1. Check browser console for errors +2. Verify API key is stored: Open DevTools > Application > Storage +3. Try logging out and back in +4. Clear all browser data and try again + +## Advanced Usage + +### Programmatic Access with Frontend + +If you need to access the API programmatically while using the frontend: + +```javascript +// Get the stored API key +const apiKey = localStorage.getItem('comfyui_api_key') || + sessionStorage.getItem('comfyui_api_key'); + +// Make authenticated requests +fetch('/api/system_stats', { + headers: { + 'Authorization': `Bearer ${apiKey}` + } +}).then(r => r.json()).then(console.log); +``` + +### Custom Frontend Integration + +If you're building a custom frontend, you can use the same mechanism: + +1. **Login Flow**: + ```javascript + // Validate API key + const response = await fetch('/health', { + headers: { + 'Authorization': `Bearer ${apiKey}` + } + }); + + if (response.ok) { + // Store the key + localStorage.setItem('comfyui_api_key', apiKey); + } + ``` + +2. **Request Interceptor**: + ```javascript + // Add to all requests + const apiKey = localStorage.getItem('comfyui_api_key'); + + fetch(url, { + headers: { + 'Authorization': `Bearer ${apiKey}`, + ...otherHeaders + } + }); + ``` + +3. **Handle 401 Responses**: + ```javascript + if (response.status === 401) { + // Clear stored key and redirect to login + localStorage.removeItem('comfyui_api_key'); + window.location.href = '/auth_login.html'; + } + ``` + +## API Endpoints Reference + +### Public Endpoints (No Authentication Required) +- `GET /` - Main page (with auth script injected) +- `GET /health` - Health check +- `GET /auth_login.html` - Login page +- `GET /auth_inject.js` - Auth injection script +- `GET /ws` - WebSocket connection +- Static files (`.js`, `.css`, `.html`, etc.) + +### Protected Endpoints (Authentication Required) +- All `/api/*` endpoints +- All `/internal/*` endpoints +- Most other API endpoints + +## Browser Compatibility + +The authentication system works with all modern browsers: +- ✅ Chrome/Edge 90+ +- ✅ Firefox 88+ +- ✅ Safari 14+ +- ✅ Opera 76+ + +## FAQs + +**Q: Can I use ComfyUI without authentication?** +A: Yes! Simply start ComfyUI without the `--api-key` argument. + +**Q: Can I change the API key without losing my workflows?** +A: Yes, workflows are stored separately. Just update the key on the server and re-login. + +**Q: Is my API key secure?** +A: The key is stored in browser storage and sent over HTTPS (if configured). For maximum security, use HTTPS and strong keys. + +**Q: Can multiple users use different API keys?** +A: Currently, the system supports a single API key. For multi-user scenarios, each user must use the same key. + +**Q: What happens if I forget my API key?** +A: Check your server startup command or the file specified in `--api-key-file`. + +## Support + +For issues or questions: +1. Check the server logs +2. Review browser console for errors +3. Refer to the main authentication documentation: `API_AUTHENTICATION.md` +4. Check the quick start guide: `QUICK_START_AUTH.md` diff --git a/QUICK_START_AUTH.md b/QUICK_START_AUTH.md new file mode 100644 index 000000000..b00c31f7e --- /dev/null +++ b/QUICK_START_AUTH.md @@ -0,0 +1,118 @@ +# Quick Start Guide - API Authentication + +## Step-by-Step Instructions + +### 1. Start ComfyUI with API Key + +```bash +# Stop any running ComfyUI instance first +# Then start with an API key: + +python main.py --api-key "my-secret-key-123" +``` + +**You should see in the logs:** +``` +[Auth] API Key authentication enabled +``` + +### 2. Test the Authentication + +**Health check (works without auth):** +```bash +curl http://localhost:8188/health +``` + +**Protected endpoint without auth (should fail):** +```bash +curl http://localhost:8188/object_info +# Should return: {"error": "Unauthorized", "message": "..."} +``` + +**Protected endpoint with auth (should work):** +```bash +curl -H "Authorization: Bearer my-secret-key-123" http://localhost:8188/object_info +# Should return: {...node definitions...} +``` + +### 3. Run the Test Script + +```bash +chmod +x test_auth_quick.sh +./test_auth_quick.sh +``` + +## Common Issues + +### Issue: All requests work without authentication + +**Problem:** You didn't start the server with `--api-key` + +**Solution:** +```bash +# Stop the server (Ctrl+C) +# Restart with API key: +python main.py --api-key "your-key-here" +``` + +**Verify it's enabled:** +```bash +# In another terminal, check if auth is working: +curl http://localhost:8188/object_info +# Should return 401 Unauthorized +``` + +### Issue: Authentication is enabled but I get 401 even with correct key + +**Problem:** Key format or typo + +**Solution:** +- Ensure no extra spaces in the key +- Check the Authorization header format: `Authorization: Bearer YOUR_KEY` +- Try X-API-Key header: `X-API-Key: YOUR_KEY` + +## Example: Full Workflow + +```bash +# 1. Generate a secure key +python -c "import secrets; print(secrets.token_hex(32))" +# Output: a1b2c3d4e5f6... + +# 2. Save to file +echo "a1b2c3d4e5f6..." > api_key.txt + +# 3. Start server with key file +python main.py --api-key-file api_key.txt + +# 4. Use the API +API_KEY=$(cat api_key.txt) +curl -H "Authorization: Bearer $API_KEY" http://localhost:8188/object_info +``` + +## Test with Python + +```python +import requests + +API_KEY = "my-secret-key-123" +BASE_URL = "http://localhost:8188" + +# This should fail (no auth) +response = requests.get(f"{BASE_URL}/object_info") +print(f"No auth: {response.status_code}") # Should be 401 + +# This should work (with auth) +headers = {"Authorization": f"Bearer {API_KEY}"} +response = requests.get(f"{BASE_URL}/object_info", headers=headers) +print(f"With auth: {response.status_code}") # Should be 200 +``` + +## Disable Authentication + +Simply start ComfyUI without the `--api-key` argument: + +```bash +python main.py +``` + +The server will work exactly as before with no authentication required. diff --git a/auth_inject.js b/auth_inject.js new file mode 100644 index 000000000..123b8e387 --- /dev/null +++ b/auth_inject.js @@ -0,0 +1,110 @@ +/** + * ComfyUI API Key Injection + * This script automatically adds the stored API key to all HTTP requests + */ + +(function() { + 'use strict'; + + // Get the stored API key + function getApiKey() { + return localStorage.getItem('comfyui_api_key') || sessionStorage.getItem('comfyui_api_key'); + } + + // Check if user is authenticated + function checkAuth() { + const apiKey = getApiKey(); + if (!apiKey && window.location.pathname !== '/auth_login.html') { + // Redirect to login page if no API key is found + window.location.href = '/auth_login.html'; + return false; + } + return true; + } + + // Intercept fetch requests + const originalFetch = window.fetch; + window.fetch = function(...args) { + const apiKey = getApiKey(); + + if (apiKey) { + // Clone or create the options object + let [url, options = {}] = args; + + // Initialize headers if not present + if (!options.headers) { + options.headers = {}; + } + + // Convert Headers object to plain object if needed + if (options.headers instanceof Headers) { + const headersObj = {}; + options.headers.forEach((value, key) => { + headersObj[key] = value; + }); + options.headers = headersObj; + } + + // Add Authorization header if not already present + if (!options.headers['Authorization'] && !options.headers['authorization']) { + options.headers['Authorization'] = `Bearer ${apiKey}`; + } + + // Update args + args = [url, options]; + } + + // Call original fetch and handle 401 errors + return originalFetch.apply(this, args).then(response => { + if (response.status === 401 && window.location.pathname !== '/auth_login.html') { + // Clear stored API key and redirect to login + localStorage.removeItem('comfyui_api_key'); + sessionStorage.removeItem('comfyui_api_key'); + window.location.href = '/auth_login.html'; + } + return response; + }); + }; + + // Intercept XMLHttpRequest + const originalOpen = XMLHttpRequest.prototype.open; + const originalSend = XMLHttpRequest.prototype.send; + + XMLHttpRequest.prototype.open = function(method, url, ...rest) { + this._url = url; + return originalOpen.apply(this, [method, url, ...rest]); + }; + + XMLHttpRequest.prototype.send = function(...args) { + const apiKey = getApiKey(); + + if (apiKey && !this.getRequestHeader('Authorization')) { + this.setRequestHeader('Authorization', `Bearer ${apiKey}`); + } + + // Handle 401 responses + this.addEventListener('load', function() { + if (this.status === 401 && window.location.pathname !== '/auth_login.html') { + localStorage.removeItem('comfyui_api_key'); + sessionStorage.removeItem('comfyui_api_key'); + window.location.href = '/auth_login.html'; + } + }); + + return originalSend.apply(this, args); + }; + + // Add logout function to window + window.comfyuiLogout = function() { + localStorage.removeItem('comfyui_api_key'); + sessionStorage.removeItem('comfyui_api_key'); + window.location.href = '/auth_login.html'; + }; + + // Check authentication on page load (except for login page) + if (window.location.pathname !== '/auth_login.html') { + checkAuth(); + } + + console.log('[ComfyUI Auth] API key injection enabled'); +})(); diff --git a/auth_login.html b/auth_login.html new file mode 100644 index 000000000..913aa362d --- /dev/null +++ b/auth_login.html @@ -0,0 +1,278 @@ + + + + + + ComfyUI - API Key Required + + + +
+ + +
+ Invalid API key. Please try again. +
+ +
+
+ + + +
+ + + + +
+ +
+ Note: Your API key is required to access ComfyUI. + If you don't have one, please contact your administrator or check the server configuration. +
+
+ + + + diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 209fc185b..17b374cb1 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -42,6 +42,9 @@ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certific parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.") +parser.add_argument("--api-key", type=str, default=None, help="Require API key authentication for all API endpoints except health check. Provide the key via 'Authorization: Bearer ' or 'X-API-Key: ' header.") +parser.add_argument("--api-key-file", type=str, default=None, help="Path to a file containing the API key. Alternative to --api-key for better security.") + parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.") @@ -239,6 +242,14 @@ if args.disable_auto_launch: if args.force_fp16: args.fp16_unet = True +# Load API key from file if specified +if args.api_key_file and not args.api_key: + try: + with open(args.api_key_file, 'r') as f: + args.api_key = f.read().strip() + except Exception as e: + print(f"Error reading API key from file {args.api_key_file}: {e}") + args.api_key = None # '--fast' is not provided, use an empty set if args.fast is None: diff --git a/examples_api_auth.py b/examples_api_auth.py new file mode 100644 index 000000000..15ed3b7f9 --- /dev/null +++ b/examples_api_auth.py @@ -0,0 +1,204 @@ +""" +Example: Using ComfyUI API with Authentication +""" + +import requests +import json + +# Your API configuration +API_KEY = "your-api-key-here" +BASE_URL = "http://localhost:8188" + + +def example_health_check(): + """Example: Check server health (no authentication required)""" + print("=== Health Check Example ===") + + response = requests.get(f"{BASE_URL}/health") + + if response.status_code == 200: + health = response.json() + print(f"Status: {health['status']}") + print(f"Version: {health['version']}") + print(f"Queue - Pending: {health['queue']['pending']}, Running: {health['queue']['running']}") + if 'device' in health: + print(f"Device: {health['device']}") + if 'vram' in health: + vram = health['vram'] + vram_used_gb = vram['used'] / (1024**3) + vram_total_gb = vram['total'] / (1024**3) + print(f"VRAM: {vram_used_gb:.2f} GB / {vram_total_gb:.2f} GB") + else: + print(f"Health check failed with status {response.status_code}") + + print() + + +def example_get_object_info(): + """Example: Get object info with authentication""" + print("=== Get Object Info Example ===") + + # Method 1: Using Authorization Bearer header + headers = { + "Authorization": f"Bearer {API_KEY}" + } + + response = requests.get(f"{BASE_URL}/object_info", headers=headers) + + if response.status_code == 200: + print("✓ Successfully retrieved object info") + object_info = response.json() + print(f"Number of node types: {len(object_info)}") + elif response.status_code == 401: + print("✗ Authentication failed - check your API key") + print(response.json()) + else: + print(f"✗ Request failed with status {response.status_code}") + + print() + + +def example_queue_prompt(): + """Example: Queue a prompt with authentication""" + print("=== Queue Prompt Example ===") + + # Simple workflow example + workflow = { + "prompt": { + "1": { + "inputs": { + "text": "a beautiful landscape" + }, + "class_type": "CLIPTextEncode" + } + }, + "client_id": "example_client" + } + + # Using Authorization Bearer header + headers = { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" + } + + response = requests.post( + f"{BASE_URL}/prompt", + headers=headers, + json=workflow + ) + + if response.status_code == 200: + result = response.json() + print("✓ Prompt queued successfully") + print(f"Prompt ID: {result.get('prompt_id', 'N/A')}") + elif response.status_code == 401: + print("✗ Authentication failed - check your API key") + print(response.json()) + else: + print(f"✗ Request failed with status {response.status_code}") + print(response.text) + + print() + + +def example_using_session(): + """Example: Using requests.Session for multiple requests""" + print("=== Session Example (Multiple Requests) ===") + + # Create a session with authentication header + session = requests.Session() + session.headers.update({ + "Authorization": f"Bearer {API_KEY}" + }) + + # Now all requests will automatically include the auth header + + # Request 1: Get embeddings + response = session.get(f"{BASE_URL}/embeddings") + if response.status_code == 200: + print(f"✓ Got embeddings list") + + # Request 2: Get queue + response = session.get(f"{BASE_URL}/queue") + if response.status_code == 200: + queue = response.json() + print(f"✓ Got queue info - Pending: {len(queue.get('queue_pending', []))}") + + # Request 3: Get system stats + response = session.get(f"{BASE_URL}/system_stats") + if response.status_code == 200: + print(f"✓ Got system stats") + + print() + + +def example_error_handling(): + """Example: Proper error handling""" + print("=== Error Handling Example ===") + + headers = { + "Authorization": f"Bearer {API_KEY}" + } + + try: + response = requests.get(f"{BASE_URL}/queue", headers=headers, timeout=5) + response.raise_for_status() # Raises exception for 4xx/5xx status codes + + data = response.json() + print("✓ Request successful") + print(f"Queue pending: {len(data.get('queue_pending', []))}") + print(f"Queue running: {len(data.get('queue_running', []))}") + + except requests.exceptions.Timeout: + print("✗ Request timed out") + except requests.exceptions.ConnectionError: + print("✗ Could not connect to server") + except requests.exceptions.HTTPError as e: + if e.response.status_code == 401: + print("✗ Authentication failed - invalid API key") + elif e.response.status_code == 403: + print("✗ Access forbidden") + else: + print(f"✗ HTTP error: {e}") + except Exception as e: + print(f"✗ Unexpected error: {e}") + + print() + + +# Alternative authentication methods +def example_alternative_auth_methods(): + """Example: Different ways to provide API key""" + print("=== Alternative Authentication Methods ===") + + # Method 1: Authorization Bearer token (recommended) + headers1 = {"Authorization": f"Bearer {API_KEY}"} + response1 = requests.get(f"{BASE_URL}/embeddings", headers=headers1) + print(f"Method 1 (Bearer): Status {response1.status_code}") + + # Method 2: X-API-Key header + headers2 = {"X-API-Key": API_KEY} + response2 = requests.get(f"{BASE_URL}/embeddings", headers=headers2) + print(f"Method 2 (X-API-Key): Status {response2.status_code}") + + # Method 3: Query parameter (less secure, not recommended for production) + response3 = requests.get(f"{BASE_URL}/embeddings?api_key={API_KEY}") + print(f"Method 3 (Query param): Status {response3.status_code}") + + print() + + +if __name__ == "__main__": + print("ComfyUI API Authentication Examples") + print("=" * 60) + print() + + # Run examples + example_health_check() + example_get_object_info() + example_using_session() + example_error_handling() + example_alternative_auth_methods() + + print("=" * 60) + print("All examples completed!") diff --git a/facebook_audio.py b/facebook_audio.py new file mode 100644 index 000000000..fb605ea05 --- /dev/null +++ b/facebook_audio.py @@ -0,0 +1,8 @@ +from transformers import pipeline +import scipy + +synthesiser = pipeline("text-to-audio", "facebook/musicgen-large") + +music = synthesiser("lo-fi music with a soothing melody", forward_params={"do_sample": True}) + +scipy.io.wavfile.write("musicgen_out.wav", rate=music["sampling_rate"], data=music["audio"]) diff --git a/generate_api_key.sh b/generate_api_key.sh new file mode 100644 index 000000000..eb6b2478b --- /dev/null +++ b/generate_api_key.sh @@ -0,0 +1,75 @@ +#!/bin/bash + +# ComfyUI API Key Generator +# This script helps you generate and configure API keys for ComfyUI + +set -e + +echo "================================================" +echo "ComfyUI API Key Generator" +echo "================================================" +echo "" + +# Function to generate a random API key +generate_key() { + if command -v openssl >/dev/null 2>&1; then + openssl rand -hex 32 + elif command -v python3 >/dev/null 2>&1; then + python3 -c "import secrets; print(secrets.token_hex(32))" + elif command -v python >/dev/null 2>&1; then + python -c "import secrets; print(secrets.token_hex(32))" + else + echo "Error: Neither openssl nor python is available to generate random key" + exit 1 + fi +} + +# Generate the API key +echo "Generating secure API key..." +API_KEY=$(generate_key) +echo "" +echo "Generated API Key:" +echo "================================================" +echo "$API_KEY" +echo "================================================" +echo "" + +# Ask user if they want to save to file +read -p "Would you like to save this key to a file? (y/n) " -n 1 -r +echo "" + +if [[ $REPLY =~ ^[Yy]$ ]]; then + # Get filename + read -p "Enter filename (default: api_key.txt): " FILENAME + FILENAME=${FILENAME:-api_key.txt} + + # Save the key + echo "$API_KEY" > "$FILENAME" + + # Set restrictive permissions + chmod 600 "$FILENAME" + + echo "✓ API key saved to: $FILENAME" + echo "✓ File permissions set to 600 (owner read/write only)" + echo "" + echo "To start ComfyUI with this API key:" + echo " python main.py --api-key-file $FILENAME" +else + echo "" + echo "To start ComfyUI with this API key:" + echo " python main.py --api-key \"$API_KEY\"" +fi + +echo "" +echo "================================================" +echo "Important Security Notes:" +echo "================================================" +echo "1. Keep this key secret - don't commit it to git" +echo "2. Use HTTPS in production for encrypted transport" +echo "3. Rotate keys regularly" +echo "4. Add your key file to .gitignore" +echo "" +echo "Example .gitignore entry:" +echo " api_key.txt" +echo " *.key" +echo "================================================" diff --git a/generate_audio_standalone.py b/generate_audio_standalone.py new file mode 100644 index 000000000..ca12d7bd4 --- /dev/null +++ b/generate_audio_standalone.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +""" +Standalone script to generate music from text using Stable Audio in ComfyUI. +Based on the workflow: user/default/workflows/audio_stable_audio_example.json + +This script replicates the workflow: +1. Load checkpoint model (stable-audio-open-1.0.safetensors) +2. Load CLIP text encoder (t5-base.safetensors) +3. Encode positive prompt (music description) +4. Encode negative prompt (empty) +5. Create empty latent audio (47.6 seconds) +6. Sample using KSampler +7. Decode audio from latent using VAE +8. Save as MP3 + +Requirements: +- stable-audio-open-1.0.safetensors in models/checkpoints/ +- t5-base.safetensors in models/text_encoders/ +""" + +import torch +import sys +import os +import random +import av +from io import BytesIO + +# Add ComfyUI to path +script_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, script_dir) + +import comfy.sd +import comfy.sample +import comfy.samplers +import comfy.model_management +import folder_paths +import latent_preview +import comfy.utils + + +def load_checkpoint(ckpt_name): + """Load checkpoint model - returns MODEL, CLIP, VAE""" + print(f"Loading checkpoint: {ckpt_name}") + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings") + ) + return out[:3] # MODEL, CLIP, VAE + + +def load_clip(clip_name, clip_type="stable_audio"): + """Load CLIP text encoder""" + print(f"Loading CLIP: {clip_name}") + clip_type_enum = getattr(comfy.sd.CLIPType, clip_type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) + + clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) + clip = comfy.sd.load_clip( + ckpt_paths=[clip_path], + embedding_directory=folder_paths.get_folder_paths("embeddings"), + clip_type=clip_type_enum, + model_options={} + ) + return clip + + +def encode_text(clip, text): + """Encode text using CLIP - returns CONDITIONING""" + print(f"Encoding text: '{text}'") + if clip is None: + raise RuntimeError("ERROR: clip input is invalid: None") + tokens = clip.tokenize(text) + return clip.encode_from_tokens_scheduled(tokens) + + +def create_empty_latent_audio(seconds, batch_size=1): + """Create empty latent audio tensor""" + print(f"Creating empty latent audio: {seconds} seconds") + length = round((seconds * 44100 / 2048) / 2) * 2 + latent = torch.zeros( + [batch_size, 64, length], + device=comfy.model_management.intermediate_device() + ) + return {"samples": latent, "type": "audio"} + + +def sample_audio(model, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0): + """Run KSampler to generate audio latents""" + print(f"Sampling with seed={seed}, steps={steps}, cfg={cfg}, sampler={sampler_name}, scheduler={scheduler}") + + latent_samples = latent_image["samples"] + latent_samples = comfy.sample.fix_empty_latent_channels(model, latent_samples) + + # Prepare noise + batch_inds = latent_image["batch_index"] if "batch_index" in latent_image else None + noise = comfy.sample.prepare_noise(latent_samples, seed, batch_inds) + + # Check for noise mask + noise_mask = latent_image.get("noise_mask", None) + + # Prepare callback for progress + callback = latent_preview.prepare_callback(model, steps) + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + + # Sample + samples = comfy.sample.sample( + model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_samples, + denoise=denoise, + disable_noise=False, + start_step=None, + last_step=None, + force_full_denoise=False, + noise_mask=noise_mask, + callback=callback, + disable_pbar=disable_pbar, + seed=seed + ) + + out = latent_image.copy() + out["samples"] = samples + return out + + +def decode_audio(vae, samples): + """Decode audio from latent samples using VAE""" + print("Decoding audio from latents") + audio = vae.decode(samples["samples"]).movedim(-1, 1) + + # Normalize audio + std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 + std[std < 1.0] = 1.0 + audio /= std + + return {"waveform": audio, "sample_rate": 44100} + + +def save_audio_mp3(audio, filename, quality="V0"): + """Save audio as MP3 file using PyAV (same as ComfyUI)""" + print(f"Saving audio to: {filename}") + + # Create output directory if needed + os.makedirs(os.path.dirname(filename), exist_ok=True) + + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + # Ensure audio is in CPU + waveform = waveform.cpu() + + # Process each audio in batch (usually just 1) + for batch_number, waveform_item in enumerate(waveform): + if batch_number > 0: + # Add batch number to filename if multiple + base, ext = os.path.splitext(filename) + output_path = f"{base}_{batch_number}{ext}" + else: + output_path = filename + + # Create output buffer + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format="mp3") + + # Determine audio layout - waveform_item shape is [channels, samples] + num_channels = waveform_item.shape[0] if waveform_item.dim() > 1 else 1 + layout = "mono" if num_channels == 1 else "stereo" + + # Set up the MP3 output stream + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) + + # Set quality + if quality == "V0": + out_stream.codec_context.qscale = 1 # Highest VBR quality + elif quality == "128k": + out_stream.bit_rate = 128000 + elif quality == "320k": + out_stream.bit_rate = 320000 + + # Prepare waveform for PyAV: needs to be [samples, channels] + # Use detach() to avoid gradient tracking issues + if waveform_item.dim() == 1: + # Mono audio, add channel dimension + waveform_numpy = waveform_item.unsqueeze(1).float().detach().numpy() + else: + # Transpose from [channels, samples] to [samples, channels] + waveform_numpy = waveform_item.transpose(0, 1).float().detach().numpy() + + # Reshape to [1, samples * channels] for PyAV + waveform_numpy = waveform_numpy.reshape(1, -1) + + # Create audio frame + frame = av.AudioFrame.from_ndarray( + waveform_numpy, + format="flt", + layout=layout, + ) + frame.sample_rate = sample_rate + frame.pts = 0 + + # Encode + output_container.mux(out_stream.encode(frame)) + + # Flush encoder + output_container.mux(out_stream.encode(None)) + + # Close container + output_container.close() + + # Write to file + output_buffer.seek(0) + with open(output_path, "wb") as f: + f.write(output_buffer.getbuffer()) + + print(f"Audio saved successfully: {output_path}") + + +def main(): + # Configuration + checkpoint_name = "stable-audio-open-1.0.safetensors" + clip_name = "t5-base.safetensors" + positive_prompt = "A soft melodious acoustic guitar music" + negative_prompt = "" + audio_duration = 47.6 # seconds + seed = random.randint(0, 0xffffffffffffffff) # Random seed, or use specific value + steps = 50 + cfg = 4.98 + sampler_name = "dpmpp_3m_sde_gpu" + scheduler = "exponential" + denoise = 1.0 + output_filename = "output/audio/generated_music.mp3" + quality = "V0" + + print("=" * 60) + print("Stable Audio - Music Generation Script") + print("=" * 60) + print(f"Positive Prompt: {positive_prompt}") + print(f"Duration: {audio_duration} seconds") + print(f"Seed: {seed}") + print("=" * 60) + + # 1. Load checkpoint (MODEL, CLIP, VAE) + model, checkpoint_clip, vae = load_checkpoint(checkpoint_name) + + # 2. Load separate CLIP text encoder for stable audio + clip = load_clip(clip_name, "stable_audio") + + # 3. Encode positive and negative prompts + positive_conditioning = encode_text(clip, positive_prompt) + negative_conditioning = encode_text(clip, negative_prompt) + + # 4. Create empty latent audio + latent_audio = create_empty_latent_audio(audio_duration, batch_size=1) + + # 5. Sample using KSampler + sampled_latent = sample_audio( + model, seed, steps, cfg, sampler_name, scheduler, + positive_conditioning, negative_conditioning, latent_audio, denoise + ) + + # 6. Decode audio from latent using VAE + audio = decode_audio(vae, sampled_latent) + + # 7. Save as MP3 + save_audio_mp3(audio, output_filename, quality) + + print("=" * 60) + print("Generation complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/generate_vibevoice_standalone.py b/generate_vibevoice_standalone.py new file mode 100644 index 000000000..281b62e71 --- /dev/null +++ b/generate_vibevoice_standalone.py @@ -0,0 +1,476 @@ +#!/usr/bin/env python3 +""" +Standalone script to generate TTS audio using VibeVoice. +This script has NO ComfyUI dependencies and uses the models directly from HuggingFace. + +Based on Microsoft's VibeVoice: https://github.com/microsoft/VibeVoice + +Requirements: + pip install torch transformers numpy scipy soundfile librosa huggingface-hub + +Usage: + python generate_vibevoice_standalone.py +""" + +import torch +import numpy as np +import soundfile as sf +import os +import random +import re +import logging +from typing import Optional, List, Tuple +from huggingface_hub import snapshot_download + +logging.basicConfig(level=logging.INFO, format='[VibeVoice] %(message)s') +logger = logging.getLogger(__name__) + +try: + import librosa + LIBROSA_AVAILABLE = True +except ImportError: + logger.warning("librosa not available - resampling will not work") + LIBROSA_AVAILABLE = False + + +def set_seed(seed: int): + """Set random seeds for reproducibility""" + if seed == 0: + seed = random.randint(1, 0xffffffffffffffff) + + MAX_NUMPY_SEED = 2**32 - 1 + numpy_seed = seed % MAX_NUMPY_SEED + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(numpy_seed) + random.seed(seed) + + return seed + + +def parse_script(script: str) -> Tuple[List[Tuple[int, str]], List[int]]: + """ + Parse speaker script into (speaker_id, text) tuples. + + Supports formats: + [1] Some text... + Speaker 1: Some text... + + Returns: + parsed_lines: List of (0-based speaker_id, text) tuples + speaker_ids: List of unique 1-based speaker IDs in order of appearance + """ + parsed_lines = [] + speaker_ids_in_script = [] + + line_format_regex = re.compile(r'^(?:Speaker\s+(\d+)\s*:|\[(\d+)\])\s*(.*)$', re.IGNORECASE) + + for line in script.strip().split("\n"): + line = line.strip() + if not line: + continue + + match = line_format_regex.match(line) + if match: + speaker_id_str = match.group(1) or match.group(2) + speaker_id = int(speaker_id_str) + text_content = match.group(3) + + if match.group(1) is None and text_content.lstrip().startswith(':'): + colon_index = text_content.find(':') + text_content = text_content[colon_index + 1:] + + if speaker_id < 1: + logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'") + continue + + text = text_content.strip() + internal_speaker_id = speaker_id - 1 + parsed_lines.append((internal_speaker_id, text)) + + if speaker_id not in speaker_ids_in_script: + speaker_ids_in_script.append(speaker_id) + else: + logger.warning(f"Could not parse speaker marker, ignoring: '{line}'") + + if not parsed_lines and script.strip(): + logger.info("No speaker markers found. Treating entire text as Speaker 1.") + parsed_lines.append((0, script.strip())) + speaker_ids_in_script.append(1) + + return parsed_lines, sorted(list(set(speaker_ids_in_script))) + + +def load_audio_file(audio_path: str, target_sr: int = 24000) -> Optional[np.ndarray]: + """Load audio file and convert to mono at target sample rate""" + if not os.path.exists(audio_path): + logger.error(f"Audio file not found: {audio_path}") + return None + + logger.info(f"Loading audio: {audio_path}") + + try: + # Load audio using soundfile + waveform, sr = sf.read(audio_path) + + # Convert to mono if stereo + if waveform.ndim > 1: + waveform = np.mean(waveform, axis=1) + + # Resample if needed + if sr != target_sr: + if not LIBROSA_AVAILABLE: + raise ImportError("librosa is required for resampling. Install with: pip install librosa") + logger.info(f"Resampling from {sr}Hz to {target_sr}Hz") + waveform = librosa.resample(y=waveform, orig_sr=sr, target_sr=target_sr) + + # Validate audio + if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): + logger.error("Audio contains NaN or Inf values, replacing with zeros") + waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) + + if np.all(waveform == 0): + logger.warning("Audio waveform is completely silent") + + # Normalize extreme values + max_val = np.abs(waveform).max() + if max_val > 10.0: + logger.warning(f"Audio values are very large (max: {max_val}), normalizing") + waveform = waveform / max_val + + return waveform.astype(np.float32) + + except Exception as e: + logger.error(f"Error loading audio: {e}") + return None + + +def download_model(model_name: str = "VibeVoice-1.5B", cache_dir: str = "./models"): + """Download VibeVoice model from HuggingFace""" + + repo_mapping = { + "VibeVoice-1.5B": "microsoft/VibeVoice-1.5B", + "VibeVoice-Large": "aoi-ot/VibeVoice-Large" + } + + if model_name not in repo_mapping: + raise ValueError(f"Unknown model: {model_name}. Choose from: {list(repo_mapping.keys())}") + + repo_id = repo_mapping[model_name] + model_path = os.path.join(cache_dir, model_name) + + if os.path.exists(os.path.join(model_path, "config.json")): + logger.info(f"Model already downloaded: {model_path}") + return model_path + + logger.info(f"Downloading model from {repo_id}...") + os.makedirs(cache_dir, exist_ok=True) + + model_path = snapshot_download( + repo_id=repo_id, + local_dir=model_path, + local_dir_use_symlinks=False + ) + + logger.info(f"Model downloaded to: {model_path}") + return model_path + + +def generate_tts( + text: str, + model_name: str = "VibeVoice-Large", + speaker_audio_paths: Optional[dict] = None, + output_path: str = "output.wav", + cfg_scale: float = 1.3, + inference_steps: int = 10, + seed: int = 42, + temperature: float = 0.95, + top_p: float = 0.95, + top_k: int = 0, + cache_dir: str = "./models", + device: str = "auto" +): + """ + Generate TTS audio using VibeVoice + + Args: + text: Text script with speaker markers like "[1] text" or "Speaker 1: text" + model_name: Model to use ("VibeVoice-1.5B" or "VibeVoice-Large") + speaker_audio_paths: Dict mapping speaker IDs to audio file paths for voice cloning + e.g., {1: "voice1.wav", 2: "voice2.wav"} + output_path: Where to save the generated audio + cfg_scale: Classifier-Free Guidance scale (higher = more adherence to prompt) + inference_steps: Number of diffusion steps + seed: Random seed for reproducibility + temperature: Sampling temperature + top_p: Nucleus sampling parameter + top_k: Top-K sampling parameter + cache_dir: Directory to cache downloaded models + device: Device to use ("cuda", "mps", "cpu", or "auto" for automatic detection) + """ + + # Set seed + actual_seed = set_seed(seed) + logger.info(f"Using seed: {actual_seed}") + + # Determine device - with MPS support for Mac + if device == "auto": + if torch.cuda.is_available(): + device = "cuda" + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = "mps" + logger.info("MPS (Metal Performance Shaders) detected - using Mac GPU acceleration") + else: + device = "cpu" + logger.info(f"Using device: {device}") + + # Download model if needed + model_path = download_model(model_name, cache_dir) + + # Import VibeVoice components + logger.info("Loading VibeVoice model...") + try: + # Add the VibeVoice custom model code to path + import sys + vibevoice_custom_path = os.path.join(os.path.dirname(__file__), "custom_nodes", "ComfyUI-VibeVoice") + if vibevoice_custom_path not in sys.path: + sys.path.insert(0, vibevoice_custom_path) + + # Import custom VibeVoice model + from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference + from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor + from vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor + from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast + from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig + import json + + # Load config + config_path = os.path.join(model_path, "config.json") + config = VibeVoiceConfig.from_pretrained(config_path) + + # Load tokenizer - download if not present + tokenizer_file = os.path.join(model_path, "tokenizer.json") + if not os.path.exists(tokenizer_file): + logger.info(f"tokenizer.json not found, downloading from HuggingFace...") + from huggingface_hub import hf_hub_download + + # Determine which Qwen model to use based on model size + qwen_repo = "Qwen/Qwen2.5-1.5B" if "1.5B" in model_name else "Qwen/Qwen2.5-7B" + + try: + hf_hub_download( + repo_id=qwen_repo, + filename="tokenizer.json", + local_dir=model_path, + local_dir_use_symlinks=False + ) + logger.info("tokenizer.json downloaded successfully") + except Exception as e: + logger.error(f"Failed to download tokenizer.json: {e}") + raise FileNotFoundError(f"Could not download tokenizer.json from {qwen_repo}") + + tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file) + + # Load processor config + preprocessor_config_path = os.path.join(model_path, "preprocessor_config.json") + processor_config_data = {} + if os.path.exists(preprocessor_config_path): + with open(preprocessor_config_path, 'r') as f: + processor_config_data = json.load(f) + + audio_processor = VibeVoiceTokenizerProcessor() + processor = VibeVoiceProcessor( + tokenizer=tokenizer, + audio_processor=audio_processor, + speech_tok_compress_ratio=processor_config_data.get("speech_tok_compress_ratio", 3200), + db_normalize=processor_config_data.get("db_normalize", True) + ) + + # Load model + # MPS doesn't support bfloat16 well, use float16 + if device == "mps": + dtype = torch.float16 + logger.info("Using float16 for MPS device") + elif torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + dtype = torch.bfloat16 + else: + dtype = torch.float16 + + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + model_path, + config=config, + torch_dtype=dtype, + device_map=device, + attn_implementation="sdpa" + ) + + model.eval() + logger.info("Model loaded successfully") + + except Exception as e: + logger.error(f"Failed to load model: {e}") + import traceback + traceback.print_exc() + raise + + # Parse script + parsed_lines, speaker_ids = parse_script(text) + if not parsed_lines: + raise ValueError("Script is empty or invalid") + + logger.info(f"Parsed {len(parsed_lines)} lines with speakers: {speaker_ids}") + + # Load speaker audio samples + voice_samples = [] + if speaker_audio_paths is None: + speaker_audio_paths = {} + + for speaker_id in speaker_ids: + audio_path = speaker_audio_paths.get(speaker_id) + if audio_path: + audio = load_audio_file(audio_path, target_sr=24000) + if audio is None: + logger.warning(f"Could not load audio for speaker {speaker_id}, using zero-shot TTS") + voice_samples.append(None) + else: + voice_samples.append(audio) + else: + logger.info(f"No reference audio for speaker {speaker_id}, using zero-shot TTS") + voice_samples.append(None) + + # Prepare inputs + logger.info("Processing inputs...") + try: + inputs = processor( + parsed_scripts=[parsed_lines], + voice_samples=[voice_samples], + speaker_ids_for_prompt=[speaker_ids], + padding=True, + return_tensors="pt", + return_attention_mask=True + ) + + # Move to device + inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} + + except Exception as e: + logger.error(f"Error processing inputs: {e}") + raise + + # Configure generation + model.set_ddpm_inference_steps(num_steps=inference_steps) + + generation_config = { + 'do_sample': True, + 'temperature': temperature, + 'top_p': top_p, + } + if top_k > 0: + generation_config['top_k'] = top_k + + # Generate + logger.info(f"Generating audio ({inference_steps} steps)...") + try: + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=None, + cfg_scale=cfg_scale, + tokenizer=processor.tokenizer, + generation_config=generation_config, + verbose=False + ) + + # Extract waveform + waveform = outputs.speech_outputs[0].cpu().numpy() + + # Ensure correct shape + if waveform.ndim == 1: + waveform = waveform.reshape(1, -1) + elif waveform.ndim == 2 and waveform.shape[0] > 1: + # If multiple channels, take first + waveform = waveform[0:1, :] + + # Convert to float32 for soundfile compatibility + waveform = waveform.astype(np.float32) + + # Save audio + os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True) + sf.write(output_path, waveform.T, 24000) + logger.info(f"Audio saved to: {output_path}") + + return waveform + + except Exception as e: + logger.error(f"Error during generation: {e}") + raise + + +def main(): + """Example usage""" + + # Configuration + model_name = "VibeVoice-Large" # or "VibeVoice-Large" + + # Text to generate - supports multiple speakers + text = """ + [1] Hello, this is speaker one. How are you today? + [2] Hi there! This is speaker two responding to you. It's great to meet you. + [1] Likewise! Let's generate some amazing speech together. + [2] Absolutely! VibeVoice makes it so easy to create diverse voices. + """ + + # Reference audio for voice cloning (optional) + # If not provided, will use zero-shot TTS + speaker_audio_paths = { + 1: "input/audio1.wav", # Path to reference audio for speaker 1 + 2: "input/laundry.mp3", # Uncomment to provide reference for speaker 2 + } + + # Generation parameters + output_path = "output/vibevoice_generated.wav" + cfg_scale = 1.3 + inference_steps = 10 + seed = 42 # or 0 for random + temperature = 0.95 + top_p = 0.95 + top_k = 0 + + print("=" * 60) + print("VibeVoice TTS - Standalone Script") + print("=" * 60) + print(f"Model: {model_name}") + print(f"Text: {text[:100]}...") + print("=" * 60) + + try: + generate_tts( + text=text, + model_name=model_name, + speaker_audio_paths=speaker_audio_paths, + output_path=output_path, + cfg_scale=cfg_scale, + inference_steps=inference_steps, + seed=seed, + temperature=temperature, + top_p=top_p, + top_k=top_k, + cache_dir="./models", + device="auto" + ) + + print("=" * 60) + print("Generation complete!") + print(f"Audio saved to: {output_path}") + print("=" * 60) + + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/middleware/auth_middleware.py b/middleware/auth_middleware.py new file mode 100644 index 000000000..ce23e0800 --- /dev/null +++ b/middleware/auth_middleware.py @@ -0,0 +1,139 @@ +"""API Key Authentication middleware for ComfyUI server""" + +from aiohttp import web +from typing import Callable, Awaitable, Optional, Set +import logging +import os + + +class APIKeyAuth: + """API Key Authentication handler""" + + def __init__(self, api_key: Optional[str] = None, exempt_paths: Optional[Set[str]] = None): + """ + Initialize API Key Authentication + + Args: + api_key: The API key to validate against. If None, authentication is disabled. + exempt_paths: Set of paths that don't require authentication (e.g., health check) + """ + self.api_key = api_key + self.enabled = api_key is not None and len(api_key) > 0 + self.exempt_paths = exempt_paths or {"/health"} + + # Static file extensions that don't require authentication + self.static_extensions = { + '.html', '.js', '.css', '.json', '.map', '.png', '.jpg', '.jpeg', + '.gif', '.svg', '.ico', '.woff', '.woff2', '.ttf', '.eot', '.webp' + } + + # Path prefixes that serve static content + self.static_path_prefixes = { + '/extensions/', '/templates/', '/docs/' + } + + if self.enabled: + logging.info("[Auth] API Key authentication enabled") + else: + logging.info("[Auth] API Key authentication disabled") + + def is_path_exempt(self, path: str) -> bool: + """Check if a path is exempt from authentication""" + # Exact match for specific exempt paths + if path in self.exempt_paths: + return True + + # Root path for index.html + if path == "/": + return True + + # Static file extensions + for ext in self.static_extensions: + if path.endswith(ext): + return True + + # Static path prefixes (extensions, templates, docs, etc.) + for prefix in self.static_path_prefixes: + if path.startswith(prefix): + return True + + return False + + def validate_api_key(self, provided_key: Optional[str]) -> bool: + """Validate the provided API key""" + if not self.enabled: + return True + + if not provided_key: + return False + + return provided_key == self.api_key + + def extract_api_key(self, request: web.Request) -> Optional[str]: + """ + Extract API key from request. + Checks Authorization header (Bearer token) and X-API-Key header. + """ + # Check Authorization header (Bearer token) + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + return auth_header[7:] # Remove "Bearer " prefix + + # Check X-API-Key header + api_key_header = request.headers.get("X-API-Key", "") + if api_key_header: + return api_key_header + + # Check query parameter (less secure, but convenient for testing) + api_key_query = request.query.get("api_key", "") + if api_key_query: + return api_key_query + + return None + + +def create_api_key_middleware(api_key: Optional[str] = None, exempt_paths: Optional[Set[str]] = None): + """ + Create API key authentication middleware + + Args: + api_key: The API key to validate against. If None, authentication is disabled. + exempt_paths: Set of paths that don't require authentication + + Returns: + Middleware function for aiohttp + """ + auth = APIKeyAuth(api_key, exempt_paths) + + @web.middleware + async def api_key_middleware( + request: web.Request, + handler: Callable[[web.Request], Awaitable[web.Response]] + ) -> web.Response: + """Middleware to validate API key for protected endpoints""" + + # Skip authentication if disabled + if not auth.enabled: + return await handler(request) + + # Check if path is exempt from authentication + if auth.is_path_exempt(request.path): + return await handler(request) + + # Extract and validate API key + provided_key = auth.extract_api_key(request) + + if not auth.validate_api_key(provided_key): + logging.warning(f"[Auth] Unauthorized access attempt to {request.path} from {request.remote}") + return web.json_response( + { + "error": "Unauthorized", + "message": "Invalid or missing API key. Provide API key via 'Authorization: Bearer ' or 'X-API-Key: ' header." + }, + status=401 + ) + + # API key is valid, proceed with request + return await handler(request) + + return api_key_middleware diff --git a/requirements_vibevoice_standalone.txt b/requirements_vibevoice_standalone.txt new file mode 100644 index 000000000..2825490f1 --- /dev/null +++ b/requirements_vibevoice_standalone.txt @@ -0,0 +1,7 @@ +torch +transformers +numpy +scipy +soundfile +librosa +huggingface-hub diff --git a/server.py b/server.py index ac4f42222..7d6d1f506 100644 --- a/server.py +++ b/server.py @@ -43,6 +43,7 @@ from protocol import BinaryEventTypes # Import cache control middleware from middleware.cache_middleware import cache_control +from middleware.auth_middleware import create_api_key_middleware if args.enable_manager: import comfyui_manager @@ -204,6 +205,17 @@ class PromptServer(): self.number = 0 middlewares = [cache_control, deprecation_warning] + + # Add API key authentication middleware if enabled + if args.api_key: + # Define paths that don't require authentication + # Note: Static files (.js, .css, .html, etc.) and root "/" are automatically exempted + exempt_paths = { + "/health", # Health check endpoint + "/ws", # WebSocket endpoint + } + middlewares.append(create_api_key_middleware(args.api_key, exempt_paths)) + if args.enable_compress_response_body: middlewares.append(compress_body) @@ -297,12 +309,120 @@ class PromptServer(): @routes.get("/") async def get_root(request): + # If API key is enabled and not provided, redirect to login + if args.api_key: + auth_header = request.headers.get("Authorization", "") + api_key_header = request.headers.get("X-API-Key", "") + api_key_query = request.query.get("api_key", "") + + has_valid_key = False + if auth_header.startswith("Bearer "): + provided_key = auth_header[7:] + has_valid_key = (provided_key == args.api_key) + elif api_key_header: + has_valid_key = (api_key_header == args.api_key) + elif api_key_query: + has_valid_key = (api_key_query == args.api_key) + + # If no valid key, check if coming from login page, otherwise show login + if not has_valid_key: + # Serve modified index.html with auth script injected + index_path = os.path.join(self.web_root, "index.html") + if os.path.exists(index_path): + with open(index_path, 'r', encoding='utf-8') as f: + html_content = f.read() + + # Inject auth script before closing tag + auth_script = '' + if '' in html_content: + html_content = html_content.replace('', f'{auth_script}\n') + else: + # Fallback: add at the beginning of body + html_content = html_content.replace('', f'\n{auth_script}') + + response = web.Response(text=html_content, content_type='text/html') + else: + response = web.FileResponse(index_path) + + response.headers['Cache-Control'] = 'no-cache' + response.headers["Pragma"] = "no-cache" + response.headers["Expires"] = "0" + return response + + # No auth required or valid key provided response = web.FileResponse(os.path.join(self.web_root, "index.html")) response.headers['Cache-Control'] = 'no-cache' response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" return response + @routes.get("/auth_login.html") + async def get_login_page(request): + """Serve the login page""" + login_page_path = os.path.join(os.path.dirname(__file__), "auth_login.html") + if os.path.exists(login_page_path): + response = web.FileResponse(login_page_path) + else: + # Fallback if file doesn't exist + response = web.Response(text="Login page not found", status=404) + response.headers['Cache-Control'] = 'no-cache' + return response + + @routes.get("/auth_inject.js") + async def get_auth_script(request): + """Serve the auth injection script""" + script_path = os.path.join(os.path.dirname(__file__), "auth_inject.js") + if os.path.exists(script_path): + response = web.FileResponse(script_path, headers={'Content-Type': 'application/javascript'}) + else: + response = web.Response(text="// Auth script not found", status=404, content_type='application/javascript') + response.headers['Cache-Control'] = 'no-cache' + return response + + @routes.get("/health") + async def get_health(request): + """Health check endpoint that returns the status of the server""" + try: + # Basic health information + health_data = { + "status": "healthy", + "version": __version__, + "timestamp": time.time(), + "queue": { + "pending": len(self.prompt_queue.queue), + "running": len(self.prompt_queue.currently_running) + } + } + + # Add device info if available + try: + device = comfy.model_management.get_torch_device() + health_data["device"] = str(device) + + # Add VRAM info if GPU is available + if comfy.model_management.vram_state != comfy.model_management.VRAMState.DISABLED: + vram_total = comfy.model_management.get_total_memory() + vram_free = comfy.model_management.get_free_memory() + health_data["vram"] = { + "total": vram_total, + "free": vram_free, + "used": vram_total - vram_free + } + except Exception as e: + logging.debug(f"Could not get device info for health check: {e}") + + return web.json_response(health_data) + except Exception as e: + logging.error(f"Health check failed: {e}") + return web.json_response( + { + "status": "unhealthy", + "error": str(e), + "timestamp": time.time() + }, + status=503 + ) + @routes.get("/embeddings") def get_embeddings(request): embeddings = folder_paths.get_filename_list("embeddings") diff --git a/test_api_auth.py b/test_api_auth.py new file mode 100644 index 000000000..f740f6897 --- /dev/null +++ b/test_api_auth.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +""" +Test script for ComfyUI API Key Authentication and Health Check + +This script demonstrates how to: +1. Check the health endpoint (no auth required) +2. Make authenticated requests to the API +""" + +import requests +import json +import sys + +# Configuration +BASE_URL = "http://localhost:8188" +API_KEY = "your-api-key-here" # Replace with your actual API key + + +def test_health_check(): + """Test the health check endpoint (no authentication required)""" + print("Testing health check endpoint...") + try: + response = requests.get(f"{BASE_URL}/health") + print(f"Status Code: {response.status_code}") + print(f"Response: {json.dumps(response.json(), indent=2)}") + return response.status_code == 200 + except Exception as e: + print(f"Error: {e}") + return False + + +def test_without_auth(): + """Test accessing protected endpoint without authentication""" + print("\nTesting access without authentication...") + try: + response = requests.get(f"{BASE_URL}/object_info") + print(f"Status Code: {response.status_code}") + if response.status_code == 401: + print("✓ Correctly rejected (401 Unauthorized)") + print(f"Response: {json.dumps(response.json(), indent=2)}") + return True + elif response.status_code == 200: + print("✓ No authentication required (API key not enabled)") + return True + else: + print(f"✗ Unexpected status code: {response.status_code}") + return False + except Exception as e: + print(f"Error: {e}") + return False + + +def test_with_bearer_token(): + """Test accessing protected endpoint with Bearer token""" + print("\nTesting with Bearer token authentication...") + try: + headers = { + "Authorization": f"Bearer {API_KEY}" + } + response = requests.get(f"{BASE_URL}/object_info", headers=headers) + print(f"Status Code: {response.status_code}") + if response.status_code == 200: + print("✓ Successfully authenticated with Bearer token") + return True + elif response.status_code == 401: + print("✗ Authentication failed (check your API key)") + print(f"Response: {json.dumps(response.json(), indent=2)}") + return False + else: + print(f"✗ Unexpected status code: {response.status_code}") + return False + except Exception as e: + print(f"Error: {e}") + return False + + +def test_with_api_key_header(): + """Test accessing protected endpoint with X-API-Key header""" + print("\nTesting with X-API-Key header authentication...") + try: + headers = { + "X-API-Key": API_KEY + } + response = requests.get(f"{BASE_URL}/object_info", headers=headers) + print(f"Status Code: {response.status_code}") + if response.status_code == 200: + print("✓ Successfully authenticated with X-API-Key header") + return True + elif response.status_code == 401: + print("✗ Authentication failed (check your API key)") + print(f"Response: {json.dumps(response.json(), indent=2)}") + return False + else: + print(f"✗ Unexpected status code: {response.status_code}") + return False + except Exception as e: + print(f"Error: {e}") + return False + + +def test_with_query_parameter(): + """Test accessing protected endpoint with query parameter""" + print("\nTesting with query parameter authentication...") + try: + response = requests.get(f"{BASE_URL}/object_info?api_key={API_KEY}") + print(f"Status Code: {response.status_code}") + if response.status_code == 200: + print("✓ Successfully authenticated with query parameter") + return True + elif response.status_code == 401: + print("✗ Authentication failed (check your API key)") + print(f"Response: {json.dumps(response.json(), indent=2)}") + return False + else: + print(f"✗ Unexpected status code: {response.status_code}") + return False + except Exception as e: + print(f"Error: {e}") + return False + + +def main(): + """Run all tests""" + print("=" * 60) + print("ComfyUI API Authentication Test Suite") + print("=" * 60) + print(f"Base URL: {BASE_URL}") + print(f"API Key: {'*' * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else '***'}") + print("=" * 60) + + results = [] + + # Test 1: Health check (always works) + results.append(("Health Check", test_health_check())) + + # Test 2: Without authentication (should fail if auth is enabled) + results.append(("No Auth", test_without_auth())) + + # Test 3: Bearer token authentication + results.append(("Bearer Token", test_with_bearer_token())) + + # Test 4: X-API-Key header authentication + results.append(("X-API-Key Header", test_with_api_key_header())) + + # Test 5: Query parameter authentication + results.append(("Query Parameter", test_with_query_parameter())) + + # Summary + print("\n" + "=" * 60) + print("Test Summary") + print("=" * 60) + for test_name, passed in results: + status = "✓ PASS" if passed else "✗ FAIL" + print(f"{test_name:20s} {status}") + + total = len(results) + passed = sum(1 for _, result in results if result) + print("=" * 60) + print(f"Total: {passed}/{total} tests passed") + print("=" * 60) + + # Exit with appropriate code + sys.exit(0 if passed == total else 1) + + +if __name__ == "__main__": + # Check if user wants to override the API key + if len(sys.argv) > 1: + API_KEY = sys.argv[1] + + if API_KEY == "your-api-key-here": + print("WARNING: Using default API key. Set your API key as the first argument:") + print(f" python {sys.argv[0]} YOUR_API_KEY") + print("") + + main() diff --git a/test_auth_quick.sh b/test_auth_quick.sh new file mode 100644 index 000000000..65d704fec --- /dev/null +++ b/test_auth_quick.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +# Quick Test Script for ComfyUI API Authentication +# This script tests that authentication is working correctly + +set -e + +API_KEY="test-key-123" +BASE_URL="http://localhost:8188" + +echo "================================================" +echo "ComfyUI API Authentication Test" +echo "================================================" +echo "" +echo "IMPORTANT: Make sure ComfyUI is running with:" +echo " python main.py --api-key \"$API_KEY\"" +echo "" +echo "Press Enter to continue or Ctrl+C to cancel..." +read + +echo "" +echo "================================================" +echo "Test 1: Health endpoint (should work without auth)" +echo "================================================" +response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" "$BASE_URL/health") +status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2) +body=$(echo "$response" | sed '/HTTP_STATUS/d') + +echo "Status: $status" +if [ "$status" = "200" ]; then + echo "✓ PASS - Health endpoint accessible without auth" +else + echo "✗ FAIL - Health endpoint should return 200" +fi +echo "" + +echo "================================================" +echo "Test 2: Protected endpoint without auth (should fail)" +echo "================================================" +response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" "$BASE_URL/object_info") +status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2) +body=$(echo "$response" | sed '/HTTP_STATUS/d') + +echo "Status: $status" +if [ "$status" = "401" ]; then + echo "✓ PASS - Correctly rejected without auth" + echo "Response: $body" +else + echo "✗ FAIL - Should return 401 Unauthorized" + echo "Response: $body" +fi +echo "" + +echo "================================================" +echo "Test 3: Protected endpoint with wrong key (should fail)" +echo "================================================" +response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \ + -H "Authorization: Bearer wrong-key-456" \ + "$BASE_URL/object_info") +status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2) +body=$(echo "$response" | sed '/HTTP_STATUS/d') + +echo "Status: $status" +if [ "$status" = "401" ]; then + echo "✓ PASS - Correctly rejected wrong key" + echo "Response: $body" +else + echo "✗ FAIL - Should return 401 Unauthorized" + echo "Response: $body" +fi +echo "" + +echo "================================================" +echo "Test 4: Protected endpoint with correct key (should work)" +echo "================================================" +response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \ + -H "Authorization: Bearer $API_KEY" \ + "$BASE_URL/object_info") +status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2) +body=$(echo "$response" | sed '/HTTP_STATUS/d') + +echo "Status: $status" +if [ "$status" = "200" ]; then + echo "✓ PASS - Successfully authenticated" +else + echo "✗ FAIL - Should return 200 OK" + echo "Response: $body" +fi +echo "" + +echo "================================================" +echo "Test 5: X-API-Key header method (should work)" +echo "================================================" +response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \ + -H "X-API-Key: $API_KEY" \ + "$BASE_URL/embeddings") +status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2) +body=$(echo "$response" | sed '/HTTP_STATUS/d') + +echo "Status: $status" +if [ "$status" = "200" ]; then + echo "✓ PASS - X-API-Key header works" +else + echo "✗ FAIL - Should return 200 OK" + echo "Response: $body" +fi +echo "" + +echo "================================================" +echo "Test 6: Query parameter method (should work)" +echo "================================================" +response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \ + "$BASE_URL/embeddings?api_key=$API_KEY") +status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2) +body=$(echo "$response" | sed '/HTTP_STATUS/d') + +echo "Status: $status" +if [ "$status" = "200" ]; then + echo "✓ PASS - Query parameter works" +else + echo "✗ FAIL - Should return 200 OK" + echo "Response: $body" +fi +echo "" + +echo "================================================" +echo "All tests completed!" +echo "================================================" diff --git a/test_vibevoice_workflow.sh b/test_vibevoice_workflow.sh new file mode 100644 index 000000000..3ddcc78da --- /dev/null +++ b/test_vibevoice_workflow.sh @@ -0,0 +1,117 @@ +#!/bin/bash + +# Test ComfyUI API with VibeVoice workflow +# Usage: ./test_vibevoice_workflow.sh [API_KEY] + +# Configuration +BASE_URL="http://localhost:8188" +API_KEY="${1:-}" + +# Set headers based on whether API key is provided +if [ -n "$API_KEY" ]; then + AUTH_HEADER="Authorization: Bearer $API_KEY" + echo "Using API Key authentication" +else + AUTH_HEADER="" + echo "No API Key provided (running without authentication)" +fi + +# The workflow payload +# This converts the ComfyUI workflow format to the prompt API format +read -r -d '' PAYLOAD << 'EOF' +{ + "prompt": { + "1": { + "inputs": { + "speaker_1_voice": ["2", 0], + "speaker_2_voice": null, + "speaker_3_voice": null, + "speaker_4_voice": null, + "model_name": "VibeVoice-Large", + "text": "[1] And this is a generated voice, how cool is that?", + "quantize_llm_4bit": false, + "attention_mode": "sdpa", + "cfg_scale": 1.3, + "inference_steps": 10, + "seed": 1117544514407045, + "do_sample": true, + "temperature": 0.95, + "top_p": 0.95, + "top_k": 0, + "force_offload": false + }, + "class_type": "VibeVoiceTTS" + }, + "2": { + "inputs": { + "audio": "audio1.wav" + }, + "class_type": "LoadAudio" + }, + "3": { + "inputs": { + "audio": ["1", 0], + "filename_prefix": "audio/ComfyUI" + }, + "class_type": "SaveAudio" + } + }, + "client_id": "test_client_$(date +%s)" +} +EOF + +echo "" +echo "================================================" +echo "Sending workflow to ComfyUI..." +echo "================================================" +echo "" + +# Make the request +if [ -n "$AUTH_HEADER" ]; then + response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \ + -X POST \ + -H "Content-Type: application/json" \ + -H "$AUTH_HEADER" \ + -d "$PAYLOAD" \ + "$BASE_URL/prompt") +else + response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \ + -X POST \ + -H "Content-Type: application/json" \ + -d "$PAYLOAD" \ + "$BASE_URL/prompt") +fi + +# Extract HTTP status +http_status=$(echo "$response" | grep "HTTP_STATUS" | cut -d':' -f2) +body=$(echo "$response" | sed '/HTTP_STATUS/d') + +echo "HTTP Status: $http_status" +echo "" +echo "Response:" +echo "$body" | python3 -m json.tool 2>/dev/null || echo "$body" +echo "" + +if [ "$http_status" = "200" ]; then + echo "✓ Workflow queued successfully!" + + # Extract prompt_id if available + prompt_id=$(echo "$body" | python3 -c "import sys, json; data=json.load(sys.stdin); print(data.get('prompt_id', ''))" 2>/dev/null) + if [ -n "$prompt_id" ]; then + echo "Prompt ID: $prompt_id" + echo "" + echo "To check status:" + if [ -n "$AUTH_HEADER" ]; then + echo " curl -H \"$AUTH_HEADER\" $BASE_URL/history/$prompt_id" + else + echo " curl $BASE_URL/history/$prompt_id" + fi + fi +elif [ "$http_status" = "401" ]; then + echo "✗ Authentication failed - check your API key" +else + echo "✗ Request failed" +fi + +echo "" +echo "================================================"