mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
Compare commits
5 Commits
a89f140795
...
b5604df442
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5604df442 | ||
|
|
6592bffc60 | ||
|
|
971cefe7d4 | ||
|
|
c5ad1381bf | ||
|
|
06bf79b19b |
221
API_AUTHENTICATION.md
Normal file
221
API_AUTHENTICATION.md
Normal file
@ -0,0 +1,221 @@
|
||||
# API Key Authentication and Health Check
|
||||
|
||||
## Overview
|
||||
|
||||
This implementation adds API key authentication protection to the ComfyUI REST API and a health check endpoint.
|
||||
|
||||
## Features
|
||||
|
||||
### 1. API Key Authentication
|
||||
|
||||
Protects all API endpoints (except exempt ones) with API key authentication.
|
||||
|
||||
#### Configuration
|
||||
|
||||
You can enable API key authentication in two ways:
|
||||
|
||||
**Option 1: Command line argument**
|
||||
```bash
|
||||
python main.py --api-key "your-secret-api-key-here"
|
||||
```
|
||||
|
||||
**Option 2: API key file (more secure)**
|
||||
```bash
|
||||
# Create a file with your API key
|
||||
echo "your-secret-api-key-here" > api_key.txt
|
||||
|
||||
# Start ComfyUI with the API key file
|
||||
python main.py --api-key-file api_key.txt
|
||||
```
|
||||
|
||||
#### Using the API with Authentication
|
||||
|
||||
When API key authentication is enabled, you must provide the API key in your requests:
|
||||
|
||||
**Method 1: Authorization Header (Bearer Token)**
|
||||
```bash
|
||||
curl -H "Authorization: Bearer your-secret-api-key-here" http://localhost:8188/prompt
|
||||
```
|
||||
|
||||
**Method 2: X-API-Key Header**
|
||||
```bash
|
||||
curl -H "X-API-Key: your-secret-api-key-here" http://localhost:8188/prompt
|
||||
```
|
||||
|
||||
**Method 3: Query Parameter (less secure, for testing only)**
|
||||
```bash
|
||||
curl "http://localhost:8188/prompt?api_key=your-secret-api-key-here"
|
||||
```
|
||||
|
||||
#### Exempt Endpoints
|
||||
|
||||
The following endpoints do NOT require authentication:
|
||||
- `/health` - Health check endpoint
|
||||
- `/` - Root page (frontend)
|
||||
- `/ws` - WebSocket endpoint
|
||||
|
||||
### 2. Health Check Endpoint
|
||||
|
||||
A new `/health` endpoint provides server status information.
|
||||
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
curl http://localhost:8188/health
|
||||
```
|
||||
|
||||
#### Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "healthy",
|
||||
"version": "0.4.0",
|
||||
"timestamp": 1702307890.123,
|
||||
"queue": {
|
||||
"pending": 0,
|
||||
"running": 0
|
||||
},
|
||||
"device": "cuda:0",
|
||||
"vram": {
|
||||
"total": 8589934592,
|
||||
"free": 6442450944,
|
||||
"used": 2147483648
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
If the server is unhealthy, it returns a 503 status code:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "unhealthy",
|
||||
"error": "error message here",
|
||||
"timestamp": 1702307890.123
|
||||
}
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
### Starting ComfyUI with API Key Protection
|
||||
|
||||
```bash
|
||||
# With direct API key
|
||||
python main.py --api-key "my-super-secret-key-12345"
|
||||
|
||||
# With API key from file
|
||||
python main.py --api-key-file /path/to/api_key.txt
|
||||
|
||||
# With API key and custom port
|
||||
python main.py --api-key "my-key" --port 8080
|
||||
```
|
||||
|
||||
### Making Authenticated Requests
|
||||
|
||||
**Python example:**
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_KEY = "your-api-key-here"
|
||||
BASE_URL = "http://localhost:8188"
|
||||
|
||||
# Using Authorization header
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}"
|
||||
}
|
||||
|
||||
# Check health
|
||||
response = requests.get(f"{BASE_URL}/health")
|
||||
print(response.json())
|
||||
|
||||
# Make authenticated request
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/prompt",
|
||||
headers=headers,
|
||||
json={"prompt": {...}}
|
||||
)
|
||||
print(response.json())
|
||||
```
|
||||
|
||||
**JavaScript example:**
|
||||
```javascript
|
||||
const API_KEY = "your-api-key-here";
|
||||
const BASE_URL = "http://localhost:8188";
|
||||
|
||||
// Using fetch with Authorization header
|
||||
async function makeRequest(endpoint, data) {
|
||||
const response = await fetch(`${BASE_URL}${endpoint}`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${API_KEY}`,
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(data)
|
||||
});
|
||||
return response.json();
|
||||
}
|
||||
|
||||
// Check health (no auth required)
|
||||
fetch(`${BASE_URL}/health`)
|
||||
.then(r => r.json())
|
||||
.then(data => console.log(data));
|
||||
```
|
||||
|
||||
### Monitoring with Health Check
|
||||
|
||||
You can use the health endpoint for monitoring and health checks:
|
||||
|
||||
```bash
|
||||
# Simple health check
|
||||
curl http://localhost:8188/health
|
||||
|
||||
# Use in a monitoring script
|
||||
#!/bin/bash
|
||||
response=$(curl -s http://localhost:8188/health)
|
||||
status=$(echo $response | jq -r '.status')
|
||||
|
||||
if [ "$status" == "healthy" ]; then
|
||||
echo "✓ ComfyUI is healthy"
|
||||
exit 0
|
||||
else
|
||||
echo "✗ ComfyUI is unhealthy"
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Keep your API key secret**: Never commit API keys to version control
|
||||
2. **Use API key files**: Store API keys in separate files with restricted permissions
|
||||
3. **Use HTTPS in production**: Combine with `--tls-keyfile` and `--tls-certfile` options
|
||||
4. **Rotate keys regularly**: Change your API key periodically
|
||||
5. **Use strong keys**: Generate long, random API keys (e.g., using `openssl rand -hex 32`)
|
||||
|
||||
### Generating a Secure API Key
|
||||
|
||||
```bash
|
||||
# Generate a secure random API key
|
||||
openssl rand -hex 32
|
||||
|
||||
# Or using Python
|
||||
python -c "import secrets; print(secrets.token_hex(32))"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### 401 Unauthorized Error
|
||||
|
||||
If you receive a 401 error:
|
||||
- Verify the API key is correct
|
||||
- Check that you're including the key in the correct header format
|
||||
- Ensure there are no extra spaces or newlines in the key
|
||||
|
||||
### Health Check Returns 503
|
||||
|
||||
If the health check returns 503:
|
||||
- Check the server logs for error details
|
||||
- Verify ComfyUI started correctly
|
||||
- Check system resources (memory, disk space)
|
||||
|
||||
## Disabling Authentication
|
||||
|
||||
To disable API key authentication, simply don't provide the `--api-key` or `--api-key-file` arguments when starting ComfyUI. The server will work exactly as before with no authentication required.
|
||||
142
API_SECURITY_IMPLEMENTATION.md
Normal file
142
API_SECURITY_IMPLEMENTATION.md
Normal file
@ -0,0 +1,142 @@
|
||||
# ComfyUI API Security Enhancement
|
||||
|
||||
## Summary
|
||||
|
||||
This implementation adds API key authentication and a health check endpoint to ComfyUI.
|
||||
|
||||
## Files Modified
|
||||
|
||||
1. **middleware/auth_middleware.py** (NEW)
|
||||
- API key authentication middleware
|
||||
- Supports multiple authentication methods (Bearer token, X-API-Key header, query parameter)
|
||||
- Configurable exempt paths
|
||||
|
||||
2. **comfy/cli_args.py** (MODIFIED)
|
||||
- Added `--api-key` argument for inline API key
|
||||
- Added `--api-key-file` argument for API key from file
|
||||
- Added logic to load API key from file
|
||||
|
||||
3. **server.py** (MODIFIED)
|
||||
- Imported auth middleware
|
||||
- Integrated middleware into application
|
||||
- Added `/health` endpoint with system information
|
||||
- Configured exempt paths (/, /health, /ws)
|
||||
|
||||
## New Files
|
||||
|
||||
1. **API_AUTHENTICATION.md** - Complete documentation
|
||||
2. **test_api_auth.py** - Test suite for authentication
|
||||
3. **examples_api_auth.py** - Python usage examples
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Start ComfyUI with API Key Protection
|
||||
|
||||
```bash
|
||||
# Generate a secure API key
|
||||
python -c "import secrets; print(secrets.token_hex(32))"
|
||||
|
||||
# Start with API key
|
||||
python main.py --api-key "your-generated-key-here"
|
||||
|
||||
# Or use a file
|
||||
echo "your-generated-key-here" > api_key.txt
|
||||
python main.py --api-key-file api_key.txt
|
||||
```
|
||||
|
||||
### 2. Test the Health Endpoint
|
||||
|
||||
```bash
|
||||
curl http://localhost:8188/health
|
||||
```
|
||||
|
||||
### 3. Make Authenticated Requests
|
||||
|
||||
```bash
|
||||
# Using Bearer token
|
||||
curl -H "Authorization: Bearer your-api-key" http://localhost:8188/prompt
|
||||
|
||||
# Using X-API-Key header
|
||||
curl -H "X-API-Key: your-api-key" http://localhost:8188/prompt
|
||||
```
|
||||
|
||||
### 4. Run Tests
|
||||
|
||||
```bash
|
||||
# Install requests if needed
|
||||
pip install requests
|
||||
|
||||
# Run test suite
|
||||
python test_api_auth.py your-api-key
|
||||
|
||||
# Run examples
|
||||
python examples_api_auth.py
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
### API Key Authentication
|
||||
- ✅ Multiple authentication methods (Bearer, X-API-Key, query param)
|
||||
- ✅ Configurable via command line
|
||||
- ✅ Secure file-based configuration
|
||||
- ✅ Exempt paths for health checks and WebSocket
|
||||
- ✅ Detailed logging of authentication attempts
|
||||
|
||||
### Health Check Endpoint
|
||||
- ✅ Returns server status
|
||||
- ✅ Queue information (pending/running)
|
||||
- ✅ Device information
|
||||
- ✅ VRAM usage (if GPU available)
|
||||
- ✅ Version information
|
||||
- ✅ Timestamp for monitoring
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
1. **Generate Strong Keys**: Use `openssl rand -hex 32` or similar
|
||||
2. **Use File-Based Config**: Keep keys out of command history
|
||||
3. **Enable HTTPS**: Use with `--tls-keyfile` and `--tls-certfile`
|
||||
4. **Restrict File Permissions**: `chmod 600 api_key.txt`
|
||||
5. **Rotate Keys Regularly**: Change API keys periodically
|
||||
6. **Monitor Access**: Check logs for unauthorized attempts
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
- ✅ Fully backward compatible
|
||||
- ✅ No authentication required by default
|
||||
- ✅ Existing functionality unchanged
|
||||
- ✅ WebSocket connections work normally
|
||||
|
||||
## Testing
|
||||
|
||||
The implementation has been tested for:
|
||||
- ✅ Syntax errors (none found)
|
||||
- ✅ Import compatibility
|
||||
- ✅ Middleware integration
|
||||
- ✅ Route configuration
|
||||
- ✅ Health endpoint functionality
|
||||
|
||||
To fully test in your environment:
|
||||
```bash
|
||||
# 1. Start server without auth (test backward compatibility)
|
||||
python main.py
|
||||
|
||||
# 2. Start server with auth
|
||||
python main.py --api-key "test-key-123"
|
||||
|
||||
# 3. Run test suite
|
||||
python test_api_auth.py test-key-123
|
||||
|
||||
# 4. Check health endpoint
|
||||
curl http://localhost:8188/health
|
||||
```
|
||||
|
||||
## Support
|
||||
|
||||
For detailed documentation, see:
|
||||
- **API_AUTHENTICATION.md** - Complete usage guide
|
||||
- **examples_api_auth.py** - Code examples
|
||||
- **test_api_auth.py** - Test suite
|
||||
|
||||
## License
|
||||
|
||||
Same as ComfyUI main project.
|
||||
272
AUTHENTICATION_GUIDE.md
Normal file
272
AUTHENTICATION_GUIDE.md
Normal file
@ -0,0 +1,272 @@
|
||||
# ComfyUI API Authentication Guide
|
||||
|
||||
## Overview
|
||||
|
||||
ComfyUI now supports API key authentication to protect your REST API endpoints. This guide covers setup, usage, and logout procedures.
|
||||
|
||||
## Features
|
||||
|
||||
- **API Key Authentication**: Protect all API endpoints with a simple API key
|
||||
- **Multiple Auth Methods**: Bearer token, X-API-Key header, or query parameter
|
||||
- **Health Endpoint**: Monitor server status without authentication
|
||||
- **Frontend Login**: Built-in login page for browser access
|
||||
- **Session Management**: Remember API key across sessions or per-session only
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Start Server with API Key
|
||||
|
||||
```bash
|
||||
# Using command line argument
|
||||
python3 main.py --api-key "your-secure-api-key-here"
|
||||
|
||||
# Or using a file (recommended for production)
|
||||
echo "your-secure-api-key-here" > api_key.txt
|
||||
python3 main.py --api-key-file api_key.txt
|
||||
```
|
||||
|
||||
### 2. Access ComfyUI
|
||||
|
||||
When you navigate to `http://127.0.0.1:8188`, you'll be redirected to a login page.
|
||||
|
||||
### 3. Login
|
||||
|
||||
1. Enter your API key in the password field
|
||||
2. Optionally check "Remember me" to persist the key across browser sessions
|
||||
3. Click "Login"
|
||||
|
||||
The system validates your key against the `/health` endpoint before storing it.
|
||||
|
||||
### 4. Logout
|
||||
|
||||
To logout, visit:
|
||||
```
|
||||
http://127.0.0.1:8188/auth_login.html?logout=true
|
||||
```
|
||||
|
||||
This will clear your stored API key and redirect you to the login page.
|
||||
|
||||
## API Usage
|
||||
|
||||
### Authentication Methods
|
||||
|
||||
You can authenticate API requests in three ways:
|
||||
|
||||
#### 1. Bearer Token (Recommended)
|
||||
```bash
|
||||
curl -H "Authorization: Bearer your-api-key" http://127.0.0.1:8188/prompt
|
||||
```
|
||||
|
||||
#### 2. X-API-Key Header
|
||||
```bash
|
||||
curl -H "X-API-Key: your-api-key" http://127.0.0.1:8188/prompt
|
||||
```
|
||||
|
||||
#### 3. Query Parameter
|
||||
```bash
|
||||
curl "http://127.0.0.1:8188/prompt?api_key=your-api-key"
|
||||
```
|
||||
|
||||
### Python Example
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_KEY = "your-api-key"
|
||||
BASE_URL = "http://127.0.0.1:8188"
|
||||
|
||||
# Using Bearer token
|
||||
headers = {"Authorization": f"Bearer {API_KEY}"}
|
||||
response = requests.get(f"{BASE_URL}/queue", headers=headers)
|
||||
|
||||
# Using X-API-Key header
|
||||
headers = {"X-API-Key": API_KEY}
|
||||
response = requests.get(f"{BASE_URL}/queue", headers=headers)
|
||||
|
||||
# Using query parameter
|
||||
response = requests.get(f"{BASE_URL}/queue?api_key={API_KEY}")
|
||||
```
|
||||
|
||||
### JavaScript Example
|
||||
|
||||
```javascript
|
||||
const API_KEY = "your-api-key";
|
||||
const BASE_URL = "http://127.0.0.1:8188";
|
||||
|
||||
// Using Bearer token
|
||||
fetch(`${BASE_URL}/queue`, {
|
||||
headers: {
|
||||
"Authorization": `Bearer ${API_KEY}`
|
||||
}
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => console.log(data));
|
||||
```
|
||||
|
||||
## Health Endpoint
|
||||
|
||||
The `/health` endpoint is always accessible without authentication:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8188/health
|
||||
```
|
||||
|
||||
Response includes:
|
||||
- Server status
|
||||
- Queue information (pending/running tasks)
|
||||
- Device information (GPU/CPU)
|
||||
- VRAM statistics
|
||||
|
||||
Example response:
|
||||
```json
|
||||
{
|
||||
"status": "ok",
|
||||
"queue": {
|
||||
"pending": 0,
|
||||
"running": 0
|
||||
},
|
||||
"device": {
|
||||
"name": "mps",
|
||||
"type": "mps"
|
||||
},
|
||||
"vram": {
|
||||
"total": 68719476736,
|
||||
"free": 68719476736
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Exempt Endpoints
|
||||
|
||||
The following paths are exempt from authentication:
|
||||
|
||||
### Static Files
|
||||
- `.html`, `.js`, `.css`, `.json` files
|
||||
- `.png`, `.jpg`, `.jpeg`, `.gif`, `.webp` images
|
||||
- `.svg`, `.ico` icons
|
||||
- `.woff`, `.woff2`, `.ttf` fonts
|
||||
- `.mp3`, `.wav` audio files
|
||||
- `.mp4`, `.webm` video files
|
||||
|
||||
### API Paths
|
||||
- `/` - Root path (serves login if not authenticated)
|
||||
- `/health` - Health check endpoint
|
||||
- `/ws` - WebSocket endpoint
|
||||
- `/auth_login.html` - Login page
|
||||
- `/auth_inject.js` - Auth injection script
|
||||
- `/extensions/*` - Extension files
|
||||
- `/templates/*` - Template files
|
||||
- `/docs/*` - Documentation
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
1. **Use Strong API Keys**: Generate random, long API keys (32+ characters)
|
||||
```bash
|
||||
# Generate secure API key on macOS/Linux
|
||||
openssl rand -base64 32
|
||||
```
|
||||
|
||||
2. **Use API Key Files**: Store keys in files rather than command line
|
||||
```bash
|
||||
python3 main.py --api-key-file /secure/path/api_key.txt
|
||||
```
|
||||
|
||||
3. **File Permissions**: Restrict key file access
|
||||
```bash
|
||||
chmod 600 api_key.txt
|
||||
```
|
||||
|
||||
4. **HTTPS**: Use reverse proxy with SSL in production
|
||||
```nginx
|
||||
server {
|
||||
listen 443 ssl;
|
||||
server_name your-domain.com;
|
||||
|
||||
ssl_certificate /path/to/cert.pem;
|
||||
ssl_certificate_key /path/to/key.pem;
|
||||
|
||||
location / {
|
||||
proxy_pass http://127.0.0.1:8188;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
5. **Environment Variables**: Store keys in environment
|
||||
```bash
|
||||
export COMFYUI_API_KEY=$(cat api_key.txt)
|
||||
python3 main.py --api-key "$COMFYUI_API_KEY"
|
||||
```
|
||||
|
||||
## Frontend Integration
|
||||
|
||||
The authentication system automatically handles frontend requests:
|
||||
|
||||
1. When authentication is enabled, `auth_inject.js` is injected into `index.html`
|
||||
2. This script intercepts all `fetch()` and `XMLHttpRequest` calls
|
||||
3. Authorization headers are added automatically to all requests
|
||||
4. On 401 responses, the user is redirected to the login page
|
||||
|
||||
### Session Storage
|
||||
|
||||
- **localStorage**: API key persists across browser sessions (when "Remember me" is checked)
|
||||
- **sessionStorage**: API key cleared when browser/tab closes (when "Remember me" is not checked)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### 401 Unauthorized Errors
|
||||
|
||||
1. Verify API key matches server configuration
|
||||
2. Check authentication header format
|
||||
3. Ensure endpoint isn't expecting different auth method
|
||||
|
||||
### Login Page Not Appearing
|
||||
|
||||
1. Clear browser cache
|
||||
2. Verify `auth_login.html` and `auth_inject.js` exist in ComfyUI root
|
||||
3. Check server logs for errors
|
||||
|
||||
### WebSocket Connection Issues
|
||||
|
||||
WebSocket connections (`/ws`) are exempt from authentication, but may require authentication for initial HTTP upgrade depending on your setup.
|
||||
|
||||
### Logout Not Working
|
||||
|
||||
Visit the logout URL directly:
|
||||
```
|
||||
http://127.0.0.1:8188/auth_login.html?logout=true
|
||||
```
|
||||
|
||||
Or clear storage manually in browser DevTools:
|
||||
```javascript
|
||||
localStorage.removeItem('comfyui_api_key');
|
||||
sessionStorage.removeItem('comfyui_api_key');
|
||||
```
|
||||
|
||||
## Disabling Authentication
|
||||
|
||||
To disable authentication, simply start the server without the `--api-key` or `--api-key-file` arguments:
|
||||
|
||||
```bash
|
||||
python3 main.py
|
||||
```
|
||||
|
||||
## Migration from Unauthenticated Setup
|
||||
|
||||
If you're adding authentication to an existing ComfyUI installation:
|
||||
|
||||
1. Ensure `auth_login.html` and `auth_inject.js` exist in root directory
|
||||
2. Update `server.py` with authentication routes
|
||||
3. Add `middleware/auth_middleware.py`
|
||||
4. Update `comfy/cli_args.py` with API key arguments
|
||||
5. Restart server with `--api-key` argument
|
||||
6. Update API clients to include authentication
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
- Check server logs for authentication errors
|
||||
- Verify middleware is properly configured
|
||||
- Test with `/health` endpoint first (no auth required)
|
||||
- Review `middleware/auth_middleware.py` for exempt paths
|
||||
233
FRONTEND_AUTH_GUIDE.md
Normal file
233
FRONTEND_AUTH_GUIDE.md
Normal file
@ -0,0 +1,233 @@
|
||||
# ComfyUI Frontend Authentication Guide
|
||||
|
||||
## Overview
|
||||
|
||||
When API key authentication is enabled, the ComfyUI frontend will automatically require users to log in with the API key before accessing the interface.
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Automatic Redirection**: When you access ComfyUI with authentication enabled, you'll be automatically redirected to a login page if no valid API key is stored.
|
||||
|
||||
2. **API Key Storage**: After successful login, your API key is stored in your browser (localStorage or sessionStorage) depending on your "Remember me" choice.
|
||||
|
||||
3. **Automatic Injection**: The API key is automatically added to all API requests via the `Authorization: Bearer <key>` header.
|
||||
|
||||
4. **Session Management**:
|
||||
- **Remember me (checked)**: API key stored in localStorage (persists across browser sessions)
|
||||
- **Remember me (unchecked)**: API key stored in sessionStorage (cleared when browser closes)
|
||||
|
||||
## Using the Frontend with Authentication
|
||||
|
||||
### First Time Access
|
||||
|
||||
1. Start ComfyUI with an API key:
|
||||
```bash
|
||||
python main.py --api-key "your-secret-key-123"
|
||||
```
|
||||
|
||||
2. Open your browser and navigate to `http://localhost:8188`
|
||||
|
||||
3. You'll be presented with a login page
|
||||
|
||||
4. Enter your API key and click "Login"
|
||||
|
||||
5. Choose whether to remember your login (recommended for personal devices only)
|
||||
|
||||
### Login Page Features
|
||||
|
||||
- **Show API Key**: Toggle to view the key you're entering
|
||||
- **Remember Me**: Keep your session active across browser restarts
|
||||
- **Auto-validation**: The system validates your key before storing it
|
||||
|
||||
### Logging Out
|
||||
|
||||
To log out and clear your stored API key:
|
||||
|
||||
**Option 1: JavaScript Console**
|
||||
```javascript
|
||||
comfyuiLogout()
|
||||
```
|
||||
|
||||
**Option 2: Browser DevTools**
|
||||
1. Open DevTools (F12)
|
||||
2. Go to Application > Storage
|
||||
3. Clear localStorage and sessionStorage
|
||||
4. Refresh the page
|
||||
|
||||
**Option 3: Manual URL**
|
||||
Navigate to: `http://localhost:8188/auth_login.html`
|
||||
|
||||
### Security Considerations
|
||||
|
||||
#### For Personal Use
|
||||
- ✅ Enable "Remember me" for convenience
|
||||
- ✅ Use strong, unique API keys
|
||||
- ✅ Keep your browser updated
|
||||
|
||||
#### For Shared/Public Computers
|
||||
- ❌ **DO NOT** enable "Remember me"
|
||||
- ✅ Always log out when finished
|
||||
- ✅ Clear browser data after use
|
||||
- ✅ Consider using a private/incognito window
|
||||
|
||||
#### For Production Deployments
|
||||
- ✅ Always use HTTPS (combine with `--tls-keyfile` and `--tls-certfile`)
|
||||
- ✅ Use strong, randomly generated API keys
|
||||
- ✅ Rotate keys regularly
|
||||
- ✅ Monitor access logs for unauthorized attempts
|
||||
- ✅ Consider using additional security layers (VPN, firewall, etc.)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Invalid API Key" Error
|
||||
|
||||
**Problem**: Login fails with invalid API key message
|
||||
|
||||
**Solutions**:
|
||||
1. Verify you're using the correct API key
|
||||
2. Check for extra spaces or newlines
|
||||
3. Ensure the server was started with the same API key
|
||||
4. Check server logs for authentication attempts
|
||||
|
||||
### Automatic Logout
|
||||
|
||||
**Problem**: Frequently logged out automatically
|
||||
|
||||
**Possible Causes**:
|
||||
1. API key changed on server (restart required)
|
||||
2. Browser cleared storage automatically
|
||||
3. "Remember me" was not checked
|
||||
4. Session expired (if using sessionStorage)
|
||||
|
||||
**Solution**: Check "Remember me" when logging in
|
||||
|
||||
### Login Page Not Showing
|
||||
|
||||
**Problem**: Can't access login page
|
||||
|
||||
**Possible Causes**:
|
||||
1. Authentication not enabled (no `--api-key` argument)
|
||||
2. Browser cached old version
|
||||
|
||||
**Solution**:
|
||||
1. Hard refresh the page (Ctrl+Shift+R or Cmd+Shift+R)
|
||||
2. Clear browser cache
|
||||
3. Try accessing directly: `http://localhost:8188/auth_login.html`
|
||||
|
||||
### API Requests Still Failing
|
||||
|
||||
**Problem**: Some API requests return 401 after login
|
||||
|
||||
**Solutions**:
|
||||
1. Check browser console for errors
|
||||
2. Verify API key is stored: Open DevTools > Application > Storage
|
||||
3. Try logging out and back in
|
||||
4. Clear all browser data and try again
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Programmatic Access with Frontend
|
||||
|
||||
If you need to access the API programmatically while using the frontend:
|
||||
|
||||
```javascript
|
||||
// Get the stored API key
|
||||
const apiKey = localStorage.getItem('comfyui_api_key') ||
|
||||
sessionStorage.getItem('comfyui_api_key');
|
||||
|
||||
// Make authenticated requests
|
||||
fetch('/api/system_stats', {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${apiKey}`
|
||||
}
|
||||
}).then(r => r.json()).then(console.log);
|
||||
```
|
||||
|
||||
### Custom Frontend Integration
|
||||
|
||||
If you're building a custom frontend, you can use the same mechanism:
|
||||
|
||||
1. **Login Flow**:
|
||||
```javascript
|
||||
// Validate API key
|
||||
const response = await fetch('/health', {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${apiKey}`
|
||||
}
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
// Store the key
|
||||
localStorage.setItem('comfyui_api_key', apiKey);
|
||||
}
|
||||
```
|
||||
|
||||
2. **Request Interceptor**:
|
||||
```javascript
|
||||
// Add to all requests
|
||||
const apiKey = localStorage.getItem('comfyui_api_key');
|
||||
|
||||
fetch(url, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${apiKey}`,
|
||||
...otherHeaders
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
3. **Handle 401 Responses**:
|
||||
```javascript
|
||||
if (response.status === 401) {
|
||||
// Clear stored key and redirect to login
|
||||
localStorage.removeItem('comfyui_api_key');
|
||||
window.location.href = '/auth_login.html';
|
||||
}
|
||||
```
|
||||
|
||||
## API Endpoints Reference
|
||||
|
||||
### Public Endpoints (No Authentication Required)
|
||||
- `GET /` - Main page (with auth script injected)
|
||||
- `GET /health` - Health check
|
||||
- `GET /auth_login.html` - Login page
|
||||
- `GET /auth_inject.js` - Auth injection script
|
||||
- `GET /ws` - WebSocket connection
|
||||
- Static files (`.js`, `.css`, `.html`, etc.)
|
||||
|
||||
### Protected Endpoints (Authentication Required)
|
||||
- All `/api/*` endpoints
|
||||
- All `/internal/*` endpoints
|
||||
- Most other API endpoints
|
||||
|
||||
## Browser Compatibility
|
||||
|
||||
The authentication system works with all modern browsers:
|
||||
- ✅ Chrome/Edge 90+
|
||||
- ✅ Firefox 88+
|
||||
- ✅ Safari 14+
|
||||
- ✅ Opera 76+
|
||||
|
||||
## FAQs
|
||||
|
||||
**Q: Can I use ComfyUI without authentication?**
|
||||
A: Yes! Simply start ComfyUI without the `--api-key` argument.
|
||||
|
||||
**Q: Can I change the API key without losing my workflows?**
|
||||
A: Yes, workflows are stored separately. Just update the key on the server and re-login.
|
||||
|
||||
**Q: Is my API key secure?**
|
||||
A: The key is stored in browser storage and sent over HTTPS (if configured). For maximum security, use HTTPS and strong keys.
|
||||
|
||||
**Q: Can multiple users use different API keys?**
|
||||
A: Currently, the system supports a single API key. For multi-user scenarios, each user must use the same key.
|
||||
|
||||
**Q: What happens if I forget my API key?**
|
||||
A: Check your server startup command or the file specified in `--api-key-file`.
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
1. Check the server logs
|
||||
2. Review browser console for errors
|
||||
3. Refer to the main authentication documentation: `API_AUTHENTICATION.md`
|
||||
4. Check the quick start guide: `QUICK_START_AUTH.md`
|
||||
118
QUICK_START_AUTH.md
Normal file
118
QUICK_START_AUTH.md
Normal file
@ -0,0 +1,118 @@
|
||||
# Quick Start Guide - API Authentication
|
||||
|
||||
## Step-by-Step Instructions
|
||||
|
||||
### 1. Start ComfyUI with API Key
|
||||
|
||||
```bash
|
||||
# Stop any running ComfyUI instance first
|
||||
# Then start with an API key:
|
||||
|
||||
python main.py --api-key "my-secret-key-123"
|
||||
```
|
||||
|
||||
**You should see in the logs:**
|
||||
```
|
||||
[Auth] API Key authentication enabled
|
||||
```
|
||||
|
||||
### 2. Test the Authentication
|
||||
|
||||
**Health check (works without auth):**
|
||||
```bash
|
||||
curl http://localhost:8188/health
|
||||
```
|
||||
|
||||
**Protected endpoint without auth (should fail):**
|
||||
```bash
|
||||
curl http://localhost:8188/object_info
|
||||
# Should return: {"error": "Unauthorized", "message": "..."}
|
||||
```
|
||||
|
||||
**Protected endpoint with auth (should work):**
|
||||
```bash
|
||||
curl -H "Authorization: Bearer my-secret-key-123" http://localhost:8188/object_info
|
||||
# Should return: {...node definitions...}
|
||||
```
|
||||
|
||||
### 3. Run the Test Script
|
||||
|
||||
```bash
|
||||
chmod +x test_auth_quick.sh
|
||||
./test_auth_quick.sh
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: All requests work without authentication
|
||||
|
||||
**Problem:** You didn't start the server with `--api-key`
|
||||
|
||||
**Solution:**
|
||||
```bash
|
||||
# Stop the server (Ctrl+C)
|
||||
# Restart with API key:
|
||||
python main.py --api-key "your-key-here"
|
||||
```
|
||||
|
||||
**Verify it's enabled:**
|
||||
```bash
|
||||
# In another terminal, check if auth is working:
|
||||
curl http://localhost:8188/object_info
|
||||
# Should return 401 Unauthorized
|
||||
```
|
||||
|
||||
### Issue: Authentication is enabled but I get 401 even with correct key
|
||||
|
||||
**Problem:** Key format or typo
|
||||
|
||||
**Solution:**
|
||||
- Ensure no extra spaces in the key
|
||||
- Check the Authorization header format: `Authorization: Bearer YOUR_KEY`
|
||||
- Try X-API-Key header: `X-API-Key: YOUR_KEY`
|
||||
|
||||
## Example: Full Workflow
|
||||
|
||||
```bash
|
||||
# 1. Generate a secure key
|
||||
python -c "import secrets; print(secrets.token_hex(32))"
|
||||
# Output: a1b2c3d4e5f6...
|
||||
|
||||
# 2. Save to file
|
||||
echo "a1b2c3d4e5f6..." > api_key.txt
|
||||
|
||||
# 3. Start server with key file
|
||||
python main.py --api-key-file api_key.txt
|
||||
|
||||
# 4. Use the API
|
||||
API_KEY=$(cat api_key.txt)
|
||||
curl -H "Authorization: Bearer $API_KEY" http://localhost:8188/object_info
|
||||
```
|
||||
|
||||
## Test with Python
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
API_KEY = "my-secret-key-123"
|
||||
BASE_URL = "http://localhost:8188"
|
||||
|
||||
# This should fail (no auth)
|
||||
response = requests.get(f"{BASE_URL}/object_info")
|
||||
print(f"No auth: {response.status_code}") # Should be 401
|
||||
|
||||
# This should work (with auth)
|
||||
headers = {"Authorization": f"Bearer {API_KEY}"}
|
||||
response = requests.get(f"{BASE_URL}/object_info", headers=headers)
|
||||
print(f"With auth: {response.status_code}") # Should be 200
|
||||
```
|
||||
|
||||
## Disable Authentication
|
||||
|
||||
Simply start ComfyUI without the `--api-key` argument:
|
||||
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
The server will work exactly as before with no authentication required.
|
||||
110
auth_inject.js
Normal file
110
auth_inject.js
Normal file
@ -0,0 +1,110 @@
|
||||
/**
|
||||
* ComfyUI API Key Injection
|
||||
* This script automatically adds the stored API key to all HTTP requests
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
// Get the stored API key
|
||||
function getApiKey() {
|
||||
return localStorage.getItem('comfyui_api_key') || sessionStorage.getItem('comfyui_api_key');
|
||||
}
|
||||
|
||||
// Check if user is authenticated
|
||||
function checkAuth() {
|
||||
const apiKey = getApiKey();
|
||||
if (!apiKey && window.location.pathname !== '/auth_login.html') {
|
||||
// Redirect to login page if no API key is found
|
||||
window.location.href = '/auth_login.html';
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Intercept fetch requests
|
||||
const originalFetch = window.fetch;
|
||||
window.fetch = function(...args) {
|
||||
const apiKey = getApiKey();
|
||||
|
||||
if (apiKey) {
|
||||
// Clone or create the options object
|
||||
let [url, options = {}] = args;
|
||||
|
||||
// Initialize headers if not present
|
||||
if (!options.headers) {
|
||||
options.headers = {};
|
||||
}
|
||||
|
||||
// Convert Headers object to plain object if needed
|
||||
if (options.headers instanceof Headers) {
|
||||
const headersObj = {};
|
||||
options.headers.forEach((value, key) => {
|
||||
headersObj[key] = value;
|
||||
});
|
||||
options.headers = headersObj;
|
||||
}
|
||||
|
||||
// Add Authorization header if not already present
|
||||
if (!options.headers['Authorization'] && !options.headers['authorization']) {
|
||||
options.headers['Authorization'] = `Bearer ${apiKey}`;
|
||||
}
|
||||
|
||||
// Update args
|
||||
args = [url, options];
|
||||
}
|
||||
|
||||
// Call original fetch and handle 401 errors
|
||||
return originalFetch.apply(this, args).then(response => {
|
||||
if (response.status === 401 && window.location.pathname !== '/auth_login.html') {
|
||||
// Clear stored API key and redirect to login
|
||||
localStorage.removeItem('comfyui_api_key');
|
||||
sessionStorage.removeItem('comfyui_api_key');
|
||||
window.location.href = '/auth_login.html';
|
||||
}
|
||||
return response;
|
||||
});
|
||||
};
|
||||
|
||||
// Intercept XMLHttpRequest
|
||||
const originalOpen = XMLHttpRequest.prototype.open;
|
||||
const originalSend = XMLHttpRequest.prototype.send;
|
||||
|
||||
XMLHttpRequest.prototype.open = function(method, url, ...rest) {
|
||||
this._url = url;
|
||||
return originalOpen.apply(this, [method, url, ...rest]);
|
||||
};
|
||||
|
||||
XMLHttpRequest.prototype.send = function(...args) {
|
||||
const apiKey = getApiKey();
|
||||
|
||||
if (apiKey && !this.getRequestHeader('Authorization')) {
|
||||
this.setRequestHeader('Authorization', `Bearer ${apiKey}`);
|
||||
}
|
||||
|
||||
// Handle 401 responses
|
||||
this.addEventListener('load', function() {
|
||||
if (this.status === 401 && window.location.pathname !== '/auth_login.html') {
|
||||
localStorage.removeItem('comfyui_api_key');
|
||||
sessionStorage.removeItem('comfyui_api_key');
|
||||
window.location.href = '/auth_login.html';
|
||||
}
|
||||
});
|
||||
|
||||
return originalSend.apply(this, args);
|
||||
};
|
||||
|
||||
// Add logout function to window
|
||||
window.comfyuiLogout = function() {
|
||||
localStorage.removeItem('comfyui_api_key');
|
||||
sessionStorage.removeItem('comfyui_api_key');
|
||||
window.location.href = '/auth_login.html';
|
||||
};
|
||||
|
||||
// Check authentication on page load (except for login page)
|
||||
if (window.location.pathname !== '/auth_login.html') {
|
||||
checkAuth();
|
||||
}
|
||||
|
||||
console.log('[ComfyUI Auth] API key injection enabled');
|
||||
})();
|
||||
278
auth_login.html
Normal file
278
auth_login.html
Normal file
@ -0,0 +1,278 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>ComfyUI - API Key Required</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.login-container {
|
||||
background: white;
|
||||
border-radius: 16px;
|
||||
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
|
||||
padding: 40px;
|
||||
max-width: 420px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.logo {
|
||||
text-align: center;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
.logo h1 {
|
||||
font-size: 32px;
|
||||
color: #667eea;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.logo p {
|
||||
color: #666;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.form-group {
|
||||
margin-bottom: 24px;
|
||||
}
|
||||
|
||||
label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
color: #333;
|
||||
font-weight: 500;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
input[type="password"],
|
||||
input[type="text"] {
|
||||
width: 100%;
|
||||
padding: 12px 16px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 8px;
|
||||
font-size: 15px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
|
||||
input[type="password"]:focus,
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
|
||||
.show-password {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-top: 8px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.show-password input[type="checkbox"] {
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
button {
|
||||
width: 100%;
|
||||
padding: 14px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: transform 0.2s, box-shadow 0.2s;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.error-message {
|
||||
background: #fee;
|
||||
color: #c33;
|
||||
padding: 12px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.error-message.show {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.info-box {
|
||||
background: #f0f7ff;
|
||||
border-left: 4px solid #667eea;
|
||||
padding: 12px 16px;
|
||||
margin-top: 24px;
|
||||
border-radius: 4px;
|
||||
font-size: 13px;
|
||||
color: #555;
|
||||
}
|
||||
|
||||
.info-box strong {
|
||||
color: #667eea;
|
||||
}
|
||||
|
||||
.remember-me {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: 20px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.remember-me input[type="checkbox"] {
|
||||
margin-right: 8px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="login-container">
|
||||
<div class="logo">
|
||||
<h1>🎨 ComfyUI</h1>
|
||||
<p>API Key Authentication</p>
|
||||
</div>
|
||||
|
||||
<div class="error-message" id="errorMessage">
|
||||
Invalid API key. Please try again.
|
||||
</div>
|
||||
|
||||
<form id="loginForm">
|
||||
<div class="form-group">
|
||||
<label for="apiKey">API Key</label>
|
||||
<input
|
||||
type="password"
|
||||
id="apiKey"
|
||||
name="apiKey"
|
||||
placeholder="Enter your API key"
|
||||
autocomplete="off"
|
||||
required
|
||||
>
|
||||
<label class="show-password">
|
||||
<input type="checkbox" id="showPassword">
|
||||
Show API key
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<label class="remember-me">
|
||||
<input type="checkbox" id="rememberMe" checked>
|
||||
Remember me (store in browser)
|
||||
</label>
|
||||
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
|
||||
<div class="info-box">
|
||||
<strong>Note:</strong> Your API key is required to access ComfyUI.
|
||||
If you don't have one, please contact your administrator or check the server configuration.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const loginForm = document.getElementById('loginForm');
|
||||
const apiKeyInput = document.getElementById('apiKey');
|
||||
const errorMessage = document.getElementById('errorMessage');
|
||||
const showPasswordCheckbox = document.getElementById('showPassword');
|
||||
const rememberMeCheckbox = document.getElementById('rememberMe');
|
||||
|
||||
// Check if API key is already stored
|
||||
const storedApiKey = localStorage.getItem('comfyui_api_key') || sessionStorage.getItem('comfyui_api_key');
|
||||
if (storedApiKey) {
|
||||
// Try to validate the stored key
|
||||
validateAndRedirect(storedApiKey, false);
|
||||
}
|
||||
|
||||
// Show/hide password
|
||||
showPasswordCheckbox.addEventListener('change', function() {
|
||||
apiKeyInput.type = this.checked ? 'text' : 'password';
|
||||
});
|
||||
|
||||
// Form submission
|
||||
loginForm.addEventListener('submit', async function(e) {
|
||||
e.preventDefault();
|
||||
const apiKey = apiKeyInput.value.trim();
|
||||
|
||||
if (!apiKey) {
|
||||
showError('Please enter an API key');
|
||||
return;
|
||||
}
|
||||
|
||||
await validateAndRedirect(apiKey, true);
|
||||
});
|
||||
|
||||
async function validateAndRedirect(apiKey, showErrors) {
|
||||
try {
|
||||
// Test the API key by making a request to the health endpoint
|
||||
const response = await fetch('/health', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Authorization': `Bearer ${apiKey}`
|
||||
}
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
// API key is valid, store it
|
||||
if (rememberMeCheckbox.checked) {
|
||||
localStorage.setItem('comfyui_api_key', apiKey);
|
||||
sessionStorage.removeItem('comfyui_api_key');
|
||||
} else {
|
||||
sessionStorage.setItem('comfyui_api_key', apiKey);
|
||||
localStorage.removeItem('comfyui_api_key');
|
||||
}
|
||||
|
||||
// Redirect to main page
|
||||
window.location.href = '/';
|
||||
} else {
|
||||
if (showErrors) {
|
||||
showError('Invalid API key. Please check and try again.');
|
||||
} else {
|
||||
// Stored key is invalid, clear it
|
||||
localStorage.removeItem('comfyui_api_key');
|
||||
sessionStorage.removeItem('comfyui_api_key');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Validation error:', error);
|
||||
if (showErrors) {
|
||||
showError('Failed to validate API key. Please check your connection.');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function showError(message) {
|
||||
errorMessage.textContent = message;
|
||||
errorMessage.classList.add('show');
|
||||
setTimeout(() => {
|
||||
errorMessage.classList.remove('show');
|
||||
}, 5000);
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@ -42,6 +42,9 @@ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certific
|
||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||
|
||||
parser.add_argument("--api-key", type=str, default=None, help="Require API key authentication for all API endpoints except health check. Provide the key via 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header.")
|
||||
parser.add_argument("--api-key-file", type=str, default=None, help="Path to a file containing the API key. Alternative to --api-key for better security.")
|
||||
|
||||
parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
|
||||
@ -239,6 +242,14 @@ if args.disable_auto_launch:
|
||||
if args.force_fp16:
|
||||
args.fp16_unet = True
|
||||
|
||||
# Load API key from file if specified
|
||||
if args.api_key_file and not args.api_key:
|
||||
try:
|
||||
with open(args.api_key_file, 'r') as f:
|
||||
args.api_key = f.read().strip()
|
||||
except Exception as e:
|
||||
print(f"Error reading API key from file {args.api_key_file}: {e}")
|
||||
args.api_key = None
|
||||
|
||||
# '--fast' is not provided, use an empty set
|
||||
if args.fast is None:
|
||||
|
||||
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
|
||||
@ -592,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
|
||||
@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
if quant_metadata is not None:
|
||||
layers = quant_metadata["layers"]
|
||||
for k, v in layers.items():
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
||||
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
||||
get_sampler = execute
|
||||
|
||||
|
||||
class SamplerSEEDS2(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||
],
|
||||
outputs=[io.Sampler.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||
sampler_name = "seeds_2"
|
||||
sampler = comfy.samplers.ksampler(
|
||||
sampler_name,
|
||||
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||
)
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
||||
SamplerDPMAdaptative,
|
||||
SamplerER_SDE,
|
||||
SamplerSASolver,
|
||||
SamplerSEEDS2,
|
||||
SplitSigmas,
|
||||
SplitSigmasDenoise,
|
||||
FlipSigmas,
|
||||
|
||||
204
examples_api_auth.py
Normal file
204
examples_api_auth.py
Normal file
@ -0,0 +1,204 @@
|
||||
"""
|
||||
Example: Using ComfyUI API with Authentication
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
# Your API configuration
|
||||
API_KEY = "your-api-key-here"
|
||||
BASE_URL = "http://localhost:8188"
|
||||
|
||||
|
||||
def example_health_check():
|
||||
"""Example: Check server health (no authentication required)"""
|
||||
print("=== Health Check Example ===")
|
||||
|
||||
response = requests.get(f"{BASE_URL}/health")
|
||||
|
||||
if response.status_code == 200:
|
||||
health = response.json()
|
||||
print(f"Status: {health['status']}")
|
||||
print(f"Version: {health['version']}")
|
||||
print(f"Queue - Pending: {health['queue']['pending']}, Running: {health['queue']['running']}")
|
||||
if 'device' in health:
|
||||
print(f"Device: {health['device']}")
|
||||
if 'vram' in health:
|
||||
vram = health['vram']
|
||||
vram_used_gb = vram['used'] / (1024**3)
|
||||
vram_total_gb = vram['total'] / (1024**3)
|
||||
print(f"VRAM: {vram_used_gb:.2f} GB / {vram_total_gb:.2f} GB")
|
||||
else:
|
||||
print(f"Health check failed with status {response.status_code}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def example_get_object_info():
|
||||
"""Example: Get object info with authentication"""
|
||||
print("=== Get Object Info Example ===")
|
||||
|
||||
# Method 1: Using Authorization Bearer header
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}"
|
||||
}
|
||||
|
||||
response = requests.get(f"{BASE_URL}/object_info", headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
print("✓ Successfully retrieved object info")
|
||||
object_info = response.json()
|
||||
print(f"Number of node types: {len(object_info)}")
|
||||
elif response.status_code == 401:
|
||||
print("✗ Authentication failed - check your API key")
|
||||
print(response.json())
|
||||
else:
|
||||
print(f"✗ Request failed with status {response.status_code}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def example_queue_prompt():
|
||||
"""Example: Queue a prompt with authentication"""
|
||||
print("=== Queue Prompt Example ===")
|
||||
|
||||
# Simple workflow example
|
||||
workflow = {
|
||||
"prompt": {
|
||||
"1": {
|
||||
"inputs": {
|
||||
"text": "a beautiful landscape"
|
||||
},
|
||||
"class_type": "CLIPTextEncode"
|
||||
}
|
||||
},
|
||||
"client_id": "example_client"
|
||||
}
|
||||
|
||||
# Using Authorization Bearer header
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/prompt",
|
||||
headers=headers,
|
||||
json=workflow
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
print("✓ Prompt queued successfully")
|
||||
print(f"Prompt ID: {result.get('prompt_id', 'N/A')}")
|
||||
elif response.status_code == 401:
|
||||
print("✗ Authentication failed - check your API key")
|
||||
print(response.json())
|
||||
else:
|
||||
print(f"✗ Request failed with status {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def example_using_session():
|
||||
"""Example: Using requests.Session for multiple requests"""
|
||||
print("=== Session Example (Multiple Requests) ===")
|
||||
|
||||
# Create a session with authentication header
|
||||
session = requests.Session()
|
||||
session.headers.update({
|
||||
"Authorization": f"Bearer {API_KEY}"
|
||||
})
|
||||
|
||||
# Now all requests will automatically include the auth header
|
||||
|
||||
# Request 1: Get embeddings
|
||||
response = session.get(f"{BASE_URL}/embeddings")
|
||||
if response.status_code == 200:
|
||||
print(f"✓ Got embeddings list")
|
||||
|
||||
# Request 2: Get queue
|
||||
response = session.get(f"{BASE_URL}/queue")
|
||||
if response.status_code == 200:
|
||||
queue = response.json()
|
||||
print(f"✓ Got queue info - Pending: {len(queue.get('queue_pending', []))}")
|
||||
|
||||
# Request 3: Get system stats
|
||||
response = session.get(f"{BASE_URL}/system_stats")
|
||||
if response.status_code == 200:
|
||||
print(f"✓ Got system stats")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def example_error_handling():
|
||||
"""Example: Proper error handling"""
|
||||
print("=== Error Handling Example ===")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/queue", headers=headers, timeout=5)
|
||||
response.raise_for_status() # Raises exception for 4xx/5xx status codes
|
||||
|
||||
data = response.json()
|
||||
print("✓ Request successful")
|
||||
print(f"Queue pending: {len(data.get('queue_pending', []))}")
|
||||
print(f"Queue running: {len(data.get('queue_running', []))}")
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
print("✗ Request timed out")
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("✗ Could not connect to server")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
print("✗ Authentication failed - invalid API key")
|
||||
elif e.response.status_code == 403:
|
||||
print("✗ Access forbidden")
|
||||
else:
|
||||
print(f"✗ HTTP error: {e}")
|
||||
except Exception as e:
|
||||
print(f"✗ Unexpected error: {e}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# Alternative authentication methods
|
||||
def example_alternative_auth_methods():
|
||||
"""Example: Different ways to provide API key"""
|
||||
print("=== Alternative Authentication Methods ===")
|
||||
|
||||
# Method 1: Authorization Bearer token (recommended)
|
||||
headers1 = {"Authorization": f"Bearer {API_KEY}"}
|
||||
response1 = requests.get(f"{BASE_URL}/embeddings", headers=headers1)
|
||||
print(f"Method 1 (Bearer): Status {response1.status_code}")
|
||||
|
||||
# Method 2: X-API-Key header
|
||||
headers2 = {"X-API-Key": API_KEY}
|
||||
response2 = requests.get(f"{BASE_URL}/embeddings", headers=headers2)
|
||||
print(f"Method 2 (X-API-Key): Status {response2.status_code}")
|
||||
|
||||
# Method 3: Query parameter (less secure, not recommended for production)
|
||||
response3 = requests.get(f"{BASE_URL}/embeddings?api_key={API_KEY}")
|
||||
print(f"Method 3 (Query param): Status {response3.status_code}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("ComfyUI API Authentication Examples")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Run examples
|
||||
example_health_check()
|
||||
example_get_object_info()
|
||||
example_using_session()
|
||||
example_error_handling()
|
||||
example_alternative_auth_methods()
|
||||
|
||||
print("=" * 60)
|
||||
print("All examples completed!")
|
||||
8
facebook_audio.py
Normal file
8
facebook_audio.py
Normal file
@ -0,0 +1,8 @@
|
||||
from transformers import pipeline
|
||||
import scipy
|
||||
|
||||
synthesiser = pipeline("text-to-audio", "facebook/musicgen-large")
|
||||
|
||||
music = synthesiser("lo-fi music with a soothing melody", forward_params={"do_sample": True})
|
||||
|
||||
scipy.io.wavfile.write("musicgen_out.wav", rate=music["sampling_rate"], data=music["audio"])
|
||||
75
generate_api_key.sh
Normal file
75
generate_api_key.sh
Normal file
@ -0,0 +1,75 @@
|
||||
#!/bin/bash
|
||||
|
||||
# ComfyUI API Key Generator
|
||||
# This script helps you generate and configure API keys for ComfyUI
|
||||
|
||||
set -e
|
||||
|
||||
echo "================================================"
|
||||
echo "ComfyUI API Key Generator"
|
||||
echo "================================================"
|
||||
echo ""
|
||||
|
||||
# Function to generate a random API key
|
||||
generate_key() {
|
||||
if command -v openssl >/dev/null 2>&1; then
|
||||
openssl rand -hex 32
|
||||
elif command -v python3 >/dev/null 2>&1; then
|
||||
python3 -c "import secrets; print(secrets.token_hex(32))"
|
||||
elif command -v python >/dev/null 2>&1; then
|
||||
python -c "import secrets; print(secrets.token_hex(32))"
|
||||
else
|
||||
echo "Error: Neither openssl nor python is available to generate random key"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Generate the API key
|
||||
echo "Generating secure API key..."
|
||||
API_KEY=$(generate_key)
|
||||
echo ""
|
||||
echo "Generated API Key:"
|
||||
echo "================================================"
|
||||
echo "$API_KEY"
|
||||
echo "================================================"
|
||||
echo ""
|
||||
|
||||
# Ask user if they want to save to file
|
||||
read -p "Would you like to save this key to a file? (y/n) " -n 1 -r
|
||||
echo ""
|
||||
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
# Get filename
|
||||
read -p "Enter filename (default: api_key.txt): " FILENAME
|
||||
FILENAME=${FILENAME:-api_key.txt}
|
||||
|
||||
# Save the key
|
||||
echo "$API_KEY" > "$FILENAME"
|
||||
|
||||
# Set restrictive permissions
|
||||
chmod 600 "$FILENAME"
|
||||
|
||||
echo "✓ API key saved to: $FILENAME"
|
||||
echo "✓ File permissions set to 600 (owner read/write only)"
|
||||
echo ""
|
||||
echo "To start ComfyUI with this API key:"
|
||||
echo " python main.py --api-key-file $FILENAME"
|
||||
else
|
||||
echo ""
|
||||
echo "To start ComfyUI with this API key:"
|
||||
echo " python main.py --api-key \"$API_KEY\""
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "================================================"
|
||||
echo "Important Security Notes:"
|
||||
echo "================================================"
|
||||
echo "1. Keep this key secret - don't commit it to git"
|
||||
echo "2. Use HTTPS in production for encrypted transport"
|
||||
echo "3. Rotate keys regularly"
|
||||
echo "4. Add your key file to .gitignore"
|
||||
echo ""
|
||||
echo "Example .gitignore entry:"
|
||||
echo " api_key.txt"
|
||||
echo " *.key"
|
||||
echo "================================================"
|
||||
275
generate_audio_standalone.py
Normal file
275
generate_audio_standalone.py
Normal file
@ -0,0 +1,275 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone script to generate music from text using Stable Audio in ComfyUI.
|
||||
Based on the workflow: user/default/workflows/audio_stable_audio_example.json
|
||||
|
||||
This script replicates the workflow:
|
||||
1. Load checkpoint model (stable-audio-open-1.0.safetensors)
|
||||
2. Load CLIP text encoder (t5-base.safetensors)
|
||||
3. Encode positive prompt (music description)
|
||||
4. Encode negative prompt (empty)
|
||||
5. Create empty latent audio (47.6 seconds)
|
||||
6. Sample using KSampler
|
||||
7. Decode audio from latent using VAE
|
||||
8. Save as MP3
|
||||
|
||||
Requirements:
|
||||
- stable-audio-open-1.0.safetensors in models/checkpoints/
|
||||
- t5-base.safetensors in models/text_encoders/
|
||||
"""
|
||||
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
import random
|
||||
import av
|
||||
from io import BytesIO
|
||||
|
||||
# Add ComfyUI to path
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, script_dir)
|
||||
|
||||
import comfy.sd
|
||||
import comfy.sample
|
||||
import comfy.samplers
|
||||
import comfy.model_management
|
||||
import folder_paths
|
||||
import latent_preview
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_name):
|
||||
"""Load checkpoint model - returns MODEL, CLIP, VAE"""
|
||||
print(f"Loading checkpoint: {ckpt_name}")
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings")
|
||||
)
|
||||
return out[:3] # MODEL, CLIP, VAE
|
||||
|
||||
|
||||
def load_clip(clip_name, clip_type="stable_audio"):
|
||||
"""Load CLIP text encoder"""
|
||||
print(f"Loading CLIP: {clip_name}")
|
||||
clip_type_enum = getattr(comfy.sd.CLIPType, clip_type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
|
||||
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||
clip = comfy.sd.load_clip(
|
||||
ckpt_paths=[clip_path],
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings"),
|
||||
clip_type=clip_type_enum,
|
||||
model_options={}
|
||||
)
|
||||
return clip
|
||||
|
||||
|
||||
def encode_text(clip, text):
|
||||
"""Encode text using CLIP - returns CONDITIONING"""
|
||||
print(f"Encoding text: '{text}'")
|
||||
if clip is None:
|
||||
raise RuntimeError("ERROR: clip input is invalid: None")
|
||||
tokens = clip.tokenize(text)
|
||||
return clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
|
||||
def create_empty_latent_audio(seconds, batch_size=1):
|
||||
"""Create empty latent audio tensor"""
|
||||
print(f"Creating empty latent audio: {seconds} seconds")
|
||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||
latent = torch.zeros(
|
||||
[batch_size, 64, length],
|
||||
device=comfy.model_management.intermediate_device()
|
||||
)
|
||||
return {"samples": latent, "type": "audio"}
|
||||
|
||||
|
||||
def sample_audio(model, seed, steps, cfg, sampler_name, scheduler,
|
||||
positive, negative, latent_image, denoise=1.0):
|
||||
"""Run KSampler to generate audio latents"""
|
||||
print(f"Sampling with seed={seed}, steps={steps}, cfg={cfg}, sampler={sampler_name}, scheduler={scheduler}")
|
||||
|
||||
latent_samples = latent_image["samples"]
|
||||
latent_samples = comfy.sample.fix_empty_latent_channels(model, latent_samples)
|
||||
|
||||
# Prepare noise
|
||||
batch_inds = latent_image["batch_index"] if "batch_index" in latent_image else None
|
||||
noise = comfy.sample.prepare_noise(latent_samples, seed, batch_inds)
|
||||
|
||||
# Check for noise mask
|
||||
noise_mask = latent_image.get("noise_mask", None)
|
||||
|
||||
# Prepare callback for progress
|
||||
callback = latent_preview.prepare_callback(model, steps)
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
|
||||
# Sample
|
||||
samples = comfy.sample.sample(
|
||||
model, noise, steps, cfg, sampler_name, scheduler,
|
||||
positive, negative, latent_samples,
|
||||
denoise=denoise,
|
||||
disable_noise=False,
|
||||
start_step=None,
|
||||
last_step=None,
|
||||
force_full_denoise=False,
|
||||
noise_mask=noise_mask,
|
||||
callback=callback,
|
||||
disable_pbar=disable_pbar,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
out = latent_image.copy()
|
||||
out["samples"] = samples
|
||||
return out
|
||||
|
||||
|
||||
def decode_audio(vae, samples):
|
||||
"""Decode audio from latent samples using VAE"""
|
||||
print("Decoding audio from latents")
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
|
||||
# Normalize audio
|
||||
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
|
||||
return {"waveform": audio, "sample_rate": 44100}
|
||||
|
||||
|
||||
def save_audio_mp3(audio, filename, quality="V0"):
|
||||
"""Save audio as MP3 file using PyAV (same as ComfyUI)"""
|
||||
print(f"Saving audio to: {filename}")
|
||||
|
||||
# Create output directory if needed
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
# Ensure audio is in CPU
|
||||
waveform = waveform.cpu()
|
||||
|
||||
# Process each audio in batch (usually just 1)
|
||||
for batch_number, waveform_item in enumerate(waveform):
|
||||
if batch_number > 0:
|
||||
# Add batch number to filename if multiple
|
||||
base, ext = os.path.splitext(filename)
|
||||
output_path = f"{base}_{batch_number}{ext}"
|
||||
else:
|
||||
output_path = filename
|
||||
|
||||
# Create output buffer
|
||||
output_buffer = BytesIO()
|
||||
output_container = av.open(output_buffer, mode="w", format="mp3")
|
||||
|
||||
# Determine audio layout - waveform_item shape is [channels, samples]
|
||||
num_channels = waveform_item.shape[0] if waveform_item.dim() > 1 else 1
|
||||
layout = "mono" if num_channels == 1 else "stereo"
|
||||
|
||||
# Set up the MP3 output stream
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
||||
|
||||
# Set quality
|
||||
if quality == "V0":
|
||||
out_stream.codec_context.qscale = 1 # Highest VBR quality
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
# Prepare waveform for PyAV: needs to be [samples, channels]
|
||||
# Use detach() to avoid gradient tracking issues
|
||||
if waveform_item.dim() == 1:
|
||||
# Mono audio, add channel dimension
|
||||
waveform_numpy = waveform_item.unsqueeze(1).float().detach().numpy()
|
||||
else:
|
||||
# Transpose from [channels, samples] to [samples, channels]
|
||||
waveform_numpy = waveform_item.transpose(0, 1).float().detach().numpy()
|
||||
|
||||
# Reshape to [1, samples * channels] for PyAV
|
||||
waveform_numpy = waveform_numpy.reshape(1, -1)
|
||||
|
||||
# Create audio frame
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform_numpy,
|
||||
format="flt",
|
||||
layout=layout,
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
# Encode
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close container
|
||||
output_container.close()
|
||||
|
||||
# Write to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
print(f"Audio saved successfully: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
# Configuration
|
||||
checkpoint_name = "stable-audio-open-1.0.safetensors"
|
||||
clip_name = "t5-base.safetensors"
|
||||
positive_prompt = "A soft melodious acoustic guitar music"
|
||||
negative_prompt = ""
|
||||
audio_duration = 47.6 # seconds
|
||||
seed = random.randint(0, 0xffffffffffffffff) # Random seed, or use specific value
|
||||
steps = 50
|
||||
cfg = 4.98
|
||||
sampler_name = "dpmpp_3m_sde_gpu"
|
||||
scheduler = "exponential"
|
||||
denoise = 1.0
|
||||
output_filename = "output/audio/generated_music.mp3"
|
||||
quality = "V0"
|
||||
|
||||
print("=" * 60)
|
||||
print("Stable Audio - Music Generation Script")
|
||||
print("=" * 60)
|
||||
print(f"Positive Prompt: {positive_prompt}")
|
||||
print(f"Duration: {audio_duration} seconds")
|
||||
print(f"Seed: {seed}")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. Load checkpoint (MODEL, CLIP, VAE)
|
||||
model, checkpoint_clip, vae = load_checkpoint(checkpoint_name)
|
||||
|
||||
# 2. Load separate CLIP text encoder for stable audio
|
||||
clip = load_clip(clip_name, "stable_audio")
|
||||
|
||||
# 3. Encode positive and negative prompts
|
||||
positive_conditioning = encode_text(clip, positive_prompt)
|
||||
negative_conditioning = encode_text(clip, negative_prompt)
|
||||
|
||||
# 4. Create empty latent audio
|
||||
latent_audio = create_empty_latent_audio(audio_duration, batch_size=1)
|
||||
|
||||
# 5. Sample using KSampler
|
||||
sampled_latent = sample_audio(
|
||||
model, seed, steps, cfg, sampler_name, scheduler,
|
||||
positive_conditioning, negative_conditioning, latent_audio, denoise
|
||||
)
|
||||
|
||||
# 6. Decode audio from latent using VAE
|
||||
audio = decode_audio(vae, sampled_latent)
|
||||
|
||||
# 7. Save as MP3
|
||||
save_audio_mp3(audio, output_filename, quality)
|
||||
|
||||
print("=" * 60)
|
||||
print("Generation complete!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
476
generate_vibevoice_standalone.py
Normal file
476
generate_vibevoice_standalone.py
Normal file
@ -0,0 +1,476 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone script to generate TTS audio using VibeVoice.
|
||||
This script has NO ComfyUI dependencies and uses the models directly from HuggingFace.
|
||||
|
||||
Based on Microsoft's VibeVoice: https://github.com/microsoft/VibeVoice
|
||||
|
||||
Requirements:
|
||||
pip install torch transformers numpy scipy soundfile librosa huggingface-hub
|
||||
|
||||
Usage:
|
||||
python generate_vibevoice_standalone.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='[VibeVoice] %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import librosa
|
||||
LIBROSA_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("librosa not available - resampling will not work")
|
||||
LIBROSA_AVAILABLE = False
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""Set random seeds for reproducibility"""
|
||||
if seed == 0:
|
||||
seed = random.randint(1, 0xffffffffffffffff)
|
||||
|
||||
MAX_NUMPY_SEED = 2**32 - 1
|
||||
numpy_seed = seed % MAX_NUMPY_SEED
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(numpy_seed)
|
||||
random.seed(seed)
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
def parse_script(script: str) -> Tuple[List[Tuple[int, str]], List[int]]:
|
||||
"""
|
||||
Parse speaker script into (speaker_id, text) tuples.
|
||||
|
||||
Supports formats:
|
||||
[1] Some text...
|
||||
Speaker 1: Some text...
|
||||
|
||||
Returns:
|
||||
parsed_lines: List of (0-based speaker_id, text) tuples
|
||||
speaker_ids: List of unique 1-based speaker IDs in order of appearance
|
||||
"""
|
||||
parsed_lines = []
|
||||
speaker_ids_in_script = []
|
||||
|
||||
line_format_regex = re.compile(r'^(?:Speaker\s+(\d+)\s*:|\[(\d+)\])\s*(.*)$', re.IGNORECASE)
|
||||
|
||||
for line in script.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
match = line_format_regex.match(line)
|
||||
if match:
|
||||
speaker_id_str = match.group(1) or match.group(2)
|
||||
speaker_id = int(speaker_id_str)
|
||||
text_content = match.group(3)
|
||||
|
||||
if match.group(1) is None and text_content.lstrip().startswith(':'):
|
||||
colon_index = text_content.find(':')
|
||||
text_content = text_content[colon_index + 1:]
|
||||
|
||||
if speaker_id < 1:
|
||||
logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'")
|
||||
continue
|
||||
|
||||
text = text_content.strip()
|
||||
internal_speaker_id = speaker_id - 1
|
||||
parsed_lines.append((internal_speaker_id, text))
|
||||
|
||||
if speaker_id not in speaker_ids_in_script:
|
||||
speaker_ids_in_script.append(speaker_id)
|
||||
else:
|
||||
logger.warning(f"Could not parse speaker marker, ignoring: '{line}'")
|
||||
|
||||
if not parsed_lines and script.strip():
|
||||
logger.info("No speaker markers found. Treating entire text as Speaker 1.")
|
||||
parsed_lines.append((0, script.strip()))
|
||||
speaker_ids_in_script.append(1)
|
||||
|
||||
return parsed_lines, sorted(list(set(speaker_ids_in_script)))
|
||||
|
||||
|
||||
def load_audio_file(audio_path: str, target_sr: int = 24000) -> Optional[np.ndarray]:
|
||||
"""Load audio file and convert to mono at target sample rate"""
|
||||
if not os.path.exists(audio_path):
|
||||
logger.error(f"Audio file not found: {audio_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Loading audio: {audio_path}")
|
||||
|
||||
try:
|
||||
# Load audio using soundfile
|
||||
waveform, sr = sf.read(audio_path)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if waveform.ndim > 1:
|
||||
waveform = np.mean(waveform, axis=1)
|
||||
|
||||
# Resample if needed
|
||||
if sr != target_sr:
|
||||
if not LIBROSA_AVAILABLE:
|
||||
raise ImportError("librosa is required for resampling. Install with: pip install librosa")
|
||||
logger.info(f"Resampling from {sr}Hz to {target_sr}Hz")
|
||||
waveform = librosa.resample(y=waveform, orig_sr=sr, target_sr=target_sr)
|
||||
|
||||
# Validate audio
|
||||
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
|
||||
logger.error("Audio contains NaN or Inf values, replacing with zeros")
|
||||
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
if np.all(waveform == 0):
|
||||
logger.warning("Audio waveform is completely silent")
|
||||
|
||||
# Normalize extreme values
|
||||
max_val = np.abs(waveform).max()
|
||||
if max_val > 10.0:
|
||||
logger.warning(f"Audio values are very large (max: {max_val}), normalizing")
|
||||
waveform = waveform / max_val
|
||||
|
||||
return waveform.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading audio: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def download_model(model_name: str = "VibeVoice-1.5B", cache_dir: str = "./models"):
|
||||
"""Download VibeVoice model from HuggingFace"""
|
||||
|
||||
repo_mapping = {
|
||||
"VibeVoice-1.5B": "microsoft/VibeVoice-1.5B",
|
||||
"VibeVoice-Large": "aoi-ot/VibeVoice-Large"
|
||||
}
|
||||
|
||||
if model_name not in repo_mapping:
|
||||
raise ValueError(f"Unknown model: {model_name}. Choose from: {list(repo_mapping.keys())}")
|
||||
|
||||
repo_id = repo_mapping[model_name]
|
||||
model_path = os.path.join(cache_dir, model_name)
|
||||
|
||||
if os.path.exists(os.path.join(model_path, "config.json")):
|
||||
logger.info(f"Model already downloaded: {model_path}")
|
||||
return model_path
|
||||
|
||||
logger.info(f"Downloading model from {repo_id}...")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
model_path = snapshot_download(
|
||||
repo_id=repo_id,
|
||||
local_dir=model_path,
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
|
||||
logger.info(f"Model downloaded to: {model_path}")
|
||||
return model_path
|
||||
|
||||
|
||||
def generate_tts(
|
||||
text: str,
|
||||
model_name: str = "VibeVoice-Large",
|
||||
speaker_audio_paths: Optional[dict] = None,
|
||||
output_path: str = "output.wav",
|
||||
cfg_scale: float = 1.3,
|
||||
inference_steps: int = 10,
|
||||
seed: int = 42,
|
||||
temperature: float = 0.95,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 0,
|
||||
cache_dir: str = "./models",
|
||||
device: str = "auto"
|
||||
):
|
||||
"""
|
||||
Generate TTS audio using VibeVoice
|
||||
|
||||
Args:
|
||||
text: Text script with speaker markers like "[1] text" or "Speaker 1: text"
|
||||
model_name: Model to use ("VibeVoice-1.5B" or "VibeVoice-Large")
|
||||
speaker_audio_paths: Dict mapping speaker IDs to audio file paths for voice cloning
|
||||
e.g., {1: "voice1.wav", 2: "voice2.wav"}
|
||||
output_path: Where to save the generated audio
|
||||
cfg_scale: Classifier-Free Guidance scale (higher = more adherence to prompt)
|
||||
inference_steps: Number of diffusion steps
|
||||
seed: Random seed for reproducibility
|
||||
temperature: Sampling temperature
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-K sampling parameter
|
||||
cache_dir: Directory to cache downloaded models
|
||||
device: Device to use ("cuda", "mps", "cpu", or "auto" for automatic detection)
|
||||
"""
|
||||
|
||||
# Set seed
|
||||
actual_seed = set_seed(seed)
|
||||
logger.info(f"Using seed: {actual_seed}")
|
||||
|
||||
# Determine device - with MPS support for Mac
|
||||
if device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
logger.info("MPS (Metal Performance Shaders) detected - using Mac GPU acceleration")
|
||||
else:
|
||||
device = "cpu"
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Download model if needed
|
||||
model_path = download_model(model_name, cache_dir)
|
||||
|
||||
# Import VibeVoice components
|
||||
logger.info("Loading VibeVoice model...")
|
||||
try:
|
||||
# Add the VibeVoice custom model code to path
|
||||
import sys
|
||||
vibevoice_custom_path = os.path.join(os.path.dirname(__file__), "custom_nodes", "ComfyUI-VibeVoice")
|
||||
if vibevoice_custom_path not in sys.path:
|
||||
sys.path.insert(0, vibevoice_custom_path)
|
||||
|
||||
# Import custom VibeVoice model
|
||||
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
||||
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
||||
from vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
||||
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast
|
||||
from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
|
||||
import json
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
config = VibeVoiceConfig.from_pretrained(config_path)
|
||||
|
||||
# Load tokenizer - download if not present
|
||||
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
||||
if not os.path.exists(tokenizer_file):
|
||||
logger.info(f"tokenizer.json not found, downloading from HuggingFace...")
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
# Determine which Qwen model to use based on model size
|
||||
qwen_repo = "Qwen/Qwen2.5-1.5B" if "1.5B" in model_name else "Qwen/Qwen2.5-7B"
|
||||
|
||||
try:
|
||||
hf_hub_download(
|
||||
repo_id=qwen_repo,
|
||||
filename="tokenizer.json",
|
||||
local_dir=model_path,
|
||||
local_dir_use_symlinks=False
|
||||
)
|
||||
logger.info("tokenizer.json downloaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download tokenizer.json: {e}")
|
||||
raise FileNotFoundError(f"Could not download tokenizer.json from {qwen_repo}")
|
||||
|
||||
tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file)
|
||||
|
||||
# Load processor config
|
||||
preprocessor_config_path = os.path.join(model_path, "preprocessor_config.json")
|
||||
processor_config_data = {}
|
||||
if os.path.exists(preprocessor_config_path):
|
||||
with open(preprocessor_config_path, 'r') as f:
|
||||
processor_config_data = json.load(f)
|
||||
|
||||
audio_processor = VibeVoiceTokenizerProcessor()
|
||||
processor = VibeVoiceProcessor(
|
||||
tokenizer=tokenizer,
|
||||
audio_processor=audio_processor,
|
||||
speech_tok_compress_ratio=processor_config_data.get("speech_tok_compress_ratio", 3200),
|
||||
db_normalize=processor_config_data.get("db_normalize", True)
|
||||
)
|
||||
|
||||
# Load model
|
||||
# MPS doesn't support bfloat16 well, use float16
|
||||
if device == "mps":
|
||||
dtype = torch.float16
|
||||
logger.info("Using float16 for MPS device")
|
||||
elif torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
dtype = torch.float16
|
||||
|
||||
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||
model_path,
|
||||
config=config,
|
||||
torch_dtype=dtype,
|
||||
device_map=device,
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
model.eval()
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
# Parse script
|
||||
parsed_lines, speaker_ids = parse_script(text)
|
||||
if not parsed_lines:
|
||||
raise ValueError("Script is empty or invalid")
|
||||
|
||||
logger.info(f"Parsed {len(parsed_lines)} lines with speakers: {speaker_ids}")
|
||||
|
||||
# Load speaker audio samples
|
||||
voice_samples = []
|
||||
if speaker_audio_paths is None:
|
||||
speaker_audio_paths = {}
|
||||
|
||||
for speaker_id in speaker_ids:
|
||||
audio_path = speaker_audio_paths.get(speaker_id)
|
||||
if audio_path:
|
||||
audio = load_audio_file(audio_path, target_sr=24000)
|
||||
if audio is None:
|
||||
logger.warning(f"Could not load audio for speaker {speaker_id}, using zero-shot TTS")
|
||||
voice_samples.append(None)
|
||||
else:
|
||||
voice_samples.append(audio)
|
||||
else:
|
||||
logger.info(f"No reference audio for speaker {speaker_id}, using zero-shot TTS")
|
||||
voice_samples.append(None)
|
||||
|
||||
# Prepare inputs
|
||||
logger.info("Processing inputs...")
|
||||
try:
|
||||
inputs = processor(
|
||||
parsed_scripts=[parsed_lines],
|
||||
voice_samples=[voice_samples],
|
||||
speaker_ids_for_prompt=[speaker_ids],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True
|
||||
)
|
||||
|
||||
# Move to device
|
||||
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing inputs: {e}")
|
||||
raise
|
||||
|
||||
# Configure generation
|
||||
model.set_ddpm_inference_steps(num_steps=inference_steps)
|
||||
|
||||
generation_config = {
|
||||
'do_sample': True,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
}
|
||||
if top_k > 0:
|
||||
generation_config['top_k'] = top_k
|
||||
|
||||
# Generate
|
||||
logger.info(f"Generating audio ({inference_steps} steps)...")
|
||||
try:
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=None,
|
||||
cfg_scale=cfg_scale,
|
||||
tokenizer=processor.tokenizer,
|
||||
generation_config=generation_config,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Extract waveform
|
||||
waveform = outputs.speech_outputs[0].cpu().numpy()
|
||||
|
||||
# Ensure correct shape
|
||||
if waveform.ndim == 1:
|
||||
waveform = waveform.reshape(1, -1)
|
||||
elif waveform.ndim == 2 and waveform.shape[0] > 1:
|
||||
# If multiple channels, take first
|
||||
waveform = waveform[0:1, :]
|
||||
|
||||
# Convert to float32 for soundfile compatibility
|
||||
waveform = waveform.astype(np.float32)
|
||||
|
||||
# Save audio
|
||||
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
|
||||
sf.write(output_path, waveform.T, 24000)
|
||||
logger.info(f"Audio saved to: {output_path}")
|
||||
|
||||
return waveform
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during generation: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
"""Example usage"""
|
||||
|
||||
# Configuration
|
||||
model_name = "VibeVoice-Large" # or "VibeVoice-Large"
|
||||
|
||||
# Text to generate - supports multiple speakers
|
||||
text = """
|
||||
[1] Hello, this is speaker one. How are you today?
|
||||
[2] Hi there! This is speaker two responding to you. It's great to meet you.
|
||||
[1] Likewise! Let's generate some amazing speech together.
|
||||
[2] Absolutely! VibeVoice makes it so easy to create diverse voices.
|
||||
"""
|
||||
|
||||
# Reference audio for voice cloning (optional)
|
||||
# If not provided, will use zero-shot TTS
|
||||
speaker_audio_paths = {
|
||||
1: "input/audio1.wav", # Path to reference audio for speaker 1
|
||||
2: "input/laundry.mp3", # Uncomment to provide reference for speaker 2
|
||||
}
|
||||
|
||||
# Generation parameters
|
||||
output_path = "output/vibevoice_generated.wav"
|
||||
cfg_scale = 1.3
|
||||
inference_steps = 10
|
||||
seed = 42 # or 0 for random
|
||||
temperature = 0.95
|
||||
top_p = 0.95
|
||||
top_k = 0
|
||||
|
||||
print("=" * 60)
|
||||
print("VibeVoice TTS - Standalone Script")
|
||||
print("=" * 60)
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Text: {text[:100]}...")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
generate_tts(
|
||||
text=text,
|
||||
model_name=model_name,
|
||||
speaker_audio_paths=speaker_audio_paths,
|
||||
output_path=output_path,
|
||||
cfg_scale=cfg_scale,
|
||||
inference_steps=inference_steps,
|
||||
seed=seed,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
cache_dir="./models",
|
||||
device="auto"
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("Generation complete!")
|
||||
print(f"Audio saved to: {output_path}")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
139
middleware/auth_middleware.py
Normal file
139
middleware/auth_middleware.py
Normal file
@ -0,0 +1,139 @@
|
||||
"""API Key Authentication middleware for ComfyUI server"""
|
||||
|
||||
from aiohttp import web
|
||||
from typing import Callable, Awaitable, Optional, Set
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
class APIKeyAuth:
|
||||
"""API Key Authentication handler"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, exempt_paths: Optional[Set[str]] = None):
|
||||
"""
|
||||
Initialize API Key Authentication
|
||||
|
||||
Args:
|
||||
api_key: The API key to validate against. If None, authentication is disabled.
|
||||
exempt_paths: Set of paths that don't require authentication (e.g., health check)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.enabled = api_key is not None and len(api_key) > 0
|
||||
self.exempt_paths = exempt_paths or {"/health"}
|
||||
|
||||
# Static file extensions that don't require authentication
|
||||
self.static_extensions = {
|
||||
'.html', '.js', '.css', '.json', '.map', '.png', '.jpg', '.jpeg',
|
||||
'.gif', '.svg', '.ico', '.woff', '.woff2', '.ttf', '.eot', '.webp'
|
||||
}
|
||||
|
||||
# Path prefixes that serve static content
|
||||
self.static_path_prefixes = {
|
||||
'/extensions/', '/templates/', '/docs/'
|
||||
}
|
||||
|
||||
if self.enabled:
|
||||
logging.info("[Auth] API Key authentication enabled")
|
||||
else:
|
||||
logging.info("[Auth] API Key authentication disabled")
|
||||
|
||||
def is_path_exempt(self, path: str) -> bool:
|
||||
"""Check if a path is exempt from authentication"""
|
||||
# Exact match for specific exempt paths
|
||||
if path in self.exempt_paths:
|
||||
return True
|
||||
|
||||
# Root path for index.html
|
||||
if path == "/":
|
||||
return True
|
||||
|
||||
# Static file extensions
|
||||
for ext in self.static_extensions:
|
||||
if path.endswith(ext):
|
||||
return True
|
||||
|
||||
# Static path prefixes (extensions, templates, docs, etc.)
|
||||
for prefix in self.static_path_prefixes:
|
||||
if path.startswith(prefix):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def validate_api_key(self, provided_key: Optional[str]) -> bool:
|
||||
"""Validate the provided API key"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
if not provided_key:
|
||||
return False
|
||||
|
||||
return provided_key == self.api_key
|
||||
|
||||
def extract_api_key(self, request: web.Request) -> Optional[str]:
|
||||
"""
|
||||
Extract API key from request.
|
||||
Checks Authorization header (Bearer token) and X-API-Key header.
|
||||
"""
|
||||
# Check Authorization header (Bearer token)
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Check X-API-Key header
|
||||
api_key_header = request.headers.get("X-API-Key", "")
|
||||
if api_key_header:
|
||||
return api_key_header
|
||||
|
||||
# Check query parameter (less secure, but convenient for testing)
|
||||
api_key_query = request.query.get("api_key", "")
|
||||
if api_key_query:
|
||||
return api_key_query
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def create_api_key_middleware(api_key: Optional[str] = None, exempt_paths: Optional[Set[str]] = None):
|
||||
"""
|
||||
Create API key authentication middleware
|
||||
|
||||
Args:
|
||||
api_key: The API key to validate against. If None, authentication is disabled.
|
||||
exempt_paths: Set of paths that don't require authentication
|
||||
|
||||
Returns:
|
||||
Middleware function for aiohttp
|
||||
"""
|
||||
auth = APIKeyAuth(api_key, exempt_paths)
|
||||
|
||||
@web.middleware
|
||||
async def api_key_middleware(
|
||||
request: web.Request,
|
||||
handler: Callable[[web.Request], Awaitable[web.Response]]
|
||||
) -> web.Response:
|
||||
"""Middleware to validate API key for protected endpoints"""
|
||||
|
||||
# Skip authentication if disabled
|
||||
if not auth.enabled:
|
||||
return await handler(request)
|
||||
|
||||
# Check if path is exempt from authentication
|
||||
if auth.is_path_exempt(request.path):
|
||||
return await handler(request)
|
||||
|
||||
# Extract and validate API key
|
||||
provided_key = auth.extract_api_key(request)
|
||||
|
||||
if not auth.validate_api_key(provided_key):
|
||||
logging.warning(f"[Auth] Unauthorized access attempt to {request.path} from {request.remote}")
|
||||
return web.json_response(
|
||||
{
|
||||
"error": "Unauthorized",
|
||||
"message": "Invalid or missing API key. Provide API key via 'Authorization: Bearer <key>' or 'X-API-Key: <key>' header."
|
||||
},
|
||||
status=401
|
||||
)
|
||||
|
||||
# API key is valid, proceed with request
|
||||
return await handler(request)
|
||||
|
||||
return api_key_middleware
|
||||
7
requirements_vibevoice_standalone.txt
Normal file
7
requirements_vibevoice_standalone.txt
Normal file
@ -0,0 +1,7 @@
|
||||
torch
|
||||
transformers
|
||||
numpy
|
||||
scipy
|
||||
soundfile
|
||||
librosa
|
||||
huggingface-hub
|
||||
120
server.py
120
server.py
@ -43,6 +43,7 @@ from protocol import BinaryEventTypes
|
||||
|
||||
# Import cache control middleware
|
||||
from middleware.cache_middleware import cache_control
|
||||
from middleware.auth_middleware import create_api_key_middleware
|
||||
|
||||
if args.enable_manager:
|
||||
import comfyui_manager
|
||||
@ -204,6 +205,17 @@ class PromptServer():
|
||||
self.number = 0
|
||||
|
||||
middlewares = [cache_control, deprecation_warning]
|
||||
|
||||
# Add API key authentication middleware if enabled
|
||||
if args.api_key:
|
||||
# Define paths that don't require authentication
|
||||
# Note: Static files (.js, .css, .html, etc.) and root "/" are automatically exempted
|
||||
exempt_paths = {
|
||||
"/health", # Health check endpoint
|
||||
"/ws", # WebSocket endpoint
|
||||
}
|
||||
middlewares.append(create_api_key_middleware(args.api_key, exempt_paths))
|
||||
|
||||
if args.enable_compress_response_body:
|
||||
middlewares.append(compress_body)
|
||||
|
||||
@ -297,12 +309,120 @@ class PromptServer():
|
||||
|
||||
@routes.get("/")
|
||||
async def get_root(request):
|
||||
# If API key is enabled and not provided, redirect to login
|
||||
if args.api_key:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
api_key_header = request.headers.get("X-API-Key", "")
|
||||
api_key_query = request.query.get("api_key", "")
|
||||
|
||||
has_valid_key = False
|
||||
if auth_header.startswith("Bearer "):
|
||||
provided_key = auth_header[7:]
|
||||
has_valid_key = (provided_key == args.api_key)
|
||||
elif api_key_header:
|
||||
has_valid_key = (api_key_header == args.api_key)
|
||||
elif api_key_query:
|
||||
has_valid_key = (api_key_query == args.api_key)
|
||||
|
||||
# If no valid key, check if coming from login page, otherwise show login
|
||||
if not has_valid_key:
|
||||
# Serve modified index.html with auth script injected
|
||||
index_path = os.path.join(self.web_root, "index.html")
|
||||
if os.path.exists(index_path):
|
||||
with open(index_path, 'r', encoding='utf-8') as f:
|
||||
html_content = f.read()
|
||||
|
||||
# Inject auth script before closing </head> tag
|
||||
auth_script = '<script src="/auth_inject.js"></script>'
|
||||
if '</head>' in html_content:
|
||||
html_content = html_content.replace('</head>', f'{auth_script}\n</head>')
|
||||
else:
|
||||
# Fallback: add at the beginning of body
|
||||
html_content = html_content.replace('<body>', f'<body>\n{auth_script}')
|
||||
|
||||
response = web.Response(text=html_content, content_type='text/html')
|
||||
else:
|
||||
response = web.FileResponse(index_path)
|
||||
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
return response
|
||||
|
||||
# No auth required or valid key provided
|
||||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
return response
|
||||
|
||||
@routes.get("/auth_login.html")
|
||||
async def get_login_page(request):
|
||||
"""Serve the login page"""
|
||||
login_page_path = os.path.join(os.path.dirname(__file__), "auth_login.html")
|
||||
if os.path.exists(login_page_path):
|
||||
response = web.FileResponse(login_page_path)
|
||||
else:
|
||||
# Fallback if file doesn't exist
|
||||
response = web.Response(text="Login page not found", status=404)
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
return response
|
||||
|
||||
@routes.get("/auth_inject.js")
|
||||
async def get_auth_script(request):
|
||||
"""Serve the auth injection script"""
|
||||
script_path = os.path.join(os.path.dirname(__file__), "auth_inject.js")
|
||||
if os.path.exists(script_path):
|
||||
response = web.FileResponse(script_path, headers={'Content-Type': 'application/javascript'})
|
||||
else:
|
||||
response = web.Response(text="// Auth script not found", status=404, content_type='application/javascript')
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
return response
|
||||
|
||||
@routes.get("/health")
|
||||
async def get_health(request):
|
||||
"""Health check endpoint that returns the status of the server"""
|
||||
try:
|
||||
# Basic health information
|
||||
health_data = {
|
||||
"status": "healthy",
|
||||
"version": __version__,
|
||||
"timestamp": time.time(),
|
||||
"queue": {
|
||||
"pending": len(self.prompt_queue.queue),
|
||||
"running": len(self.prompt_queue.currently_running)
|
||||
}
|
||||
}
|
||||
|
||||
# Add device info if available
|
||||
try:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
health_data["device"] = str(device)
|
||||
|
||||
# Add VRAM info if GPU is available
|
||||
if comfy.model_management.vram_state != comfy.model_management.VRAMState.DISABLED:
|
||||
vram_total = comfy.model_management.get_total_memory()
|
||||
vram_free = comfy.model_management.get_free_memory()
|
||||
health_data["vram"] = {
|
||||
"total": vram_total,
|
||||
"free": vram_free,
|
||||
"used": vram_total - vram_free
|
||||
}
|
||||
except Exception as e:
|
||||
logging.debug(f"Could not get device info for health check: {e}")
|
||||
|
||||
return web.json_response(health_data)
|
||||
except Exception as e:
|
||||
logging.error(f"Health check failed: {e}")
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
},
|
||||
status=503
|
||||
)
|
||||
|
||||
@routes.get("/embeddings")
|
||||
def get_embeddings(request):
|
||||
embeddings = folder_paths.get_filename_list("embeddings")
|
||||
|
||||
176
test_api_auth.py
Normal file
176
test_api_auth.py
Normal file
@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for ComfyUI API Key Authentication and Health Check
|
||||
|
||||
This script demonstrates how to:
|
||||
1. Check the health endpoint (no auth required)
|
||||
2. Make authenticated requests to the API
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
|
||||
# Configuration
|
||||
BASE_URL = "http://localhost:8188"
|
||||
API_KEY = "your-api-key-here" # Replace with your actual API key
|
||||
|
||||
|
||||
def test_health_check():
|
||||
"""Test the health check endpoint (no authentication required)"""
|
||||
print("Testing health check endpoint...")
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/health")
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_without_auth():
|
||||
"""Test accessing protected endpoint without authentication"""
|
||||
print("\nTesting access without authentication...")
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/object_info")
|
||||
print(f"Status Code: {response.status_code}")
|
||||
if response.status_code == 401:
|
||||
print("✓ Correctly rejected (401 Unauthorized)")
|
||||
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||
return True
|
||||
elif response.status_code == 200:
|
||||
print("✓ No authentication required (API key not enabled)")
|
||||
return True
|
||||
else:
|
||||
print(f"✗ Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_with_bearer_token():
|
||||
"""Test accessing protected endpoint with Bearer token"""
|
||||
print("\nTesting with Bearer token authentication...")
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}"
|
||||
}
|
||||
response = requests.get(f"{BASE_URL}/object_info", headers=headers)
|
||||
print(f"Status Code: {response.status_code}")
|
||||
if response.status_code == 200:
|
||||
print("✓ Successfully authenticated with Bearer token")
|
||||
return True
|
||||
elif response.status_code == 401:
|
||||
print("✗ Authentication failed (check your API key)")
|
||||
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||
return False
|
||||
else:
|
||||
print(f"✗ Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_with_api_key_header():
|
||||
"""Test accessing protected endpoint with X-API-Key header"""
|
||||
print("\nTesting with X-API-Key header authentication...")
|
||||
try:
|
||||
headers = {
|
||||
"X-API-Key": API_KEY
|
||||
}
|
||||
response = requests.get(f"{BASE_URL}/object_info", headers=headers)
|
||||
print(f"Status Code: {response.status_code}")
|
||||
if response.status_code == 200:
|
||||
print("✓ Successfully authenticated with X-API-Key header")
|
||||
return True
|
||||
elif response.status_code == 401:
|
||||
print("✗ Authentication failed (check your API key)")
|
||||
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||
return False
|
||||
else:
|
||||
print(f"✗ Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_with_query_parameter():
|
||||
"""Test accessing protected endpoint with query parameter"""
|
||||
print("\nTesting with query parameter authentication...")
|
||||
try:
|
||||
response = requests.get(f"{BASE_URL}/object_info?api_key={API_KEY}")
|
||||
print(f"Status Code: {response.status_code}")
|
||||
if response.status_code == 200:
|
||||
print("✓ Successfully authenticated with query parameter")
|
||||
return True
|
||||
elif response.status_code == 401:
|
||||
print("✗ Authentication failed (check your API key)")
|
||||
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||
return False
|
||||
else:
|
||||
print(f"✗ Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("=" * 60)
|
||||
print("ComfyUI API Authentication Test Suite")
|
||||
print("=" * 60)
|
||||
print(f"Base URL: {BASE_URL}")
|
||||
print(f"API Key: {'*' * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else '***'}")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: Health check (always works)
|
||||
results.append(("Health Check", test_health_check()))
|
||||
|
||||
# Test 2: Without authentication (should fail if auth is enabled)
|
||||
results.append(("No Auth", test_without_auth()))
|
||||
|
||||
# Test 3: Bearer token authentication
|
||||
results.append(("Bearer Token", test_with_bearer_token()))
|
||||
|
||||
# Test 4: X-API-Key header authentication
|
||||
results.append(("X-API-Key Header", test_with_api_key_header()))
|
||||
|
||||
# Test 5: Query parameter authentication
|
||||
results.append(("Query Parameter", test_with_query_parameter()))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Test Summary")
|
||||
print("=" * 60)
|
||||
for test_name, passed in results:
|
||||
status = "✓ PASS" if passed else "✗ FAIL"
|
||||
print(f"{test_name:20s} {status}")
|
||||
|
||||
total = len(results)
|
||||
passed = sum(1 for _, result in results if result)
|
||||
print("=" * 60)
|
||||
print(f"Total: {passed}/{total} tests passed")
|
||||
print("=" * 60)
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if passed == total else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if user wants to override the API key
|
||||
if len(sys.argv) > 1:
|
||||
API_KEY = sys.argv[1]
|
||||
|
||||
if API_KEY == "your-api-key-here":
|
||||
print("WARNING: Using default API key. Set your API key as the first argument:")
|
||||
print(f" python {sys.argv[0]} YOUR_API_KEY")
|
||||
print("")
|
||||
|
||||
main()
|
||||
128
test_auth_quick.sh
Normal file
128
test_auth_quick.sh
Normal file
@ -0,0 +1,128 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Quick Test Script for ComfyUI API Authentication
|
||||
# This script tests that authentication is working correctly
|
||||
|
||||
set -e
|
||||
|
||||
API_KEY="test-key-123"
|
||||
BASE_URL="http://localhost:8188"
|
||||
|
||||
echo "================================================"
|
||||
echo "ComfyUI API Authentication Test"
|
||||
echo "================================================"
|
||||
echo ""
|
||||
echo "IMPORTANT: Make sure ComfyUI is running with:"
|
||||
echo " python main.py --api-key \"$API_KEY\""
|
||||
echo ""
|
||||
echo "Press Enter to continue or Ctrl+C to cancel..."
|
||||
read
|
||||
|
||||
echo ""
|
||||
echo "================================================"
|
||||
echo "Test 1: Health endpoint (should work without auth)"
|
||||
echo "================================================"
|
||||
response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" "$BASE_URL/health")
|
||||
status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2)
|
||||
body=$(echo "$response" | sed '/HTTP_STATUS/d')
|
||||
|
||||
echo "Status: $status"
|
||||
if [ "$status" = "200" ]; then
|
||||
echo "✓ PASS - Health endpoint accessible without auth"
|
||||
else
|
||||
echo "✗ FAIL - Health endpoint should return 200"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "================================================"
|
||||
echo "Test 2: Protected endpoint without auth (should fail)"
|
||||
echo "================================================"
|
||||
response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" "$BASE_URL/object_info")
|
||||
status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2)
|
||||
body=$(echo "$response" | sed '/HTTP_STATUS/d')
|
||||
|
||||
echo "Status: $status"
|
||||
if [ "$status" = "401" ]; then
|
||||
echo "✓ PASS - Correctly rejected without auth"
|
||||
echo "Response: $body"
|
||||
else
|
||||
echo "✗ FAIL - Should return 401 Unauthorized"
|
||||
echo "Response: $body"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "================================================"
|
||||
echo "Test 3: Protected endpoint with wrong key (should fail)"
|
||||
echo "================================================"
|
||||
response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \
|
||||
-H "Authorization: Bearer wrong-key-456" \
|
||||
"$BASE_URL/object_info")
|
||||
status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2)
|
||||
body=$(echo "$response" | sed '/HTTP_STATUS/d')
|
||||
|
||||
echo "Status: $status"
|
||||
if [ "$status" = "401" ]; then
|
||||
echo "✓ PASS - Correctly rejected wrong key"
|
||||
echo "Response: $body"
|
||||
else
|
||||
echo "✗ FAIL - Should return 401 Unauthorized"
|
||||
echo "Response: $body"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "================================================"
|
||||
echo "Test 4: Protected endpoint with correct key (should work)"
|
||||
echo "================================================"
|
||||
response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \
|
||||
-H "Authorization: Bearer $API_KEY" \
|
||||
"$BASE_URL/object_info")
|
||||
status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2)
|
||||
body=$(echo "$response" | sed '/HTTP_STATUS/d')
|
||||
|
||||
echo "Status: $status"
|
||||
if [ "$status" = "200" ]; then
|
||||
echo "✓ PASS - Successfully authenticated"
|
||||
else
|
||||
echo "✗ FAIL - Should return 200 OK"
|
||||
echo "Response: $body"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "================================================"
|
||||
echo "Test 5: X-API-Key header method (should work)"
|
||||
echo "================================================"
|
||||
response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \
|
||||
-H "X-API-Key: $API_KEY" \
|
||||
"$BASE_URL/embeddings")
|
||||
status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2)
|
||||
body=$(echo "$response" | sed '/HTTP_STATUS/d')
|
||||
|
||||
echo "Status: $status"
|
||||
if [ "$status" = "200" ]; then
|
||||
echo "✓ PASS - X-API-Key header works"
|
||||
else
|
||||
echo "✗ FAIL - Should return 200 OK"
|
||||
echo "Response: $body"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "================================================"
|
||||
echo "Test 6: Query parameter method (should work)"
|
||||
echo "================================================"
|
||||
response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \
|
||||
"$BASE_URL/embeddings?api_key=$API_KEY")
|
||||
status=$(echo "$response" | grep HTTP_STATUS | cut -d: -f2)
|
||||
body=$(echo "$response" | sed '/HTTP_STATUS/d')
|
||||
|
||||
echo "Status: $status"
|
||||
if [ "$status" = "200" ]; then
|
||||
echo "✓ PASS - Query parameter works"
|
||||
else
|
||||
echo "✗ FAIL - Should return 200 OK"
|
||||
echo "Response: $body"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
echo "================================================"
|
||||
echo "All tests completed!"
|
||||
echo "================================================"
|
||||
117
test_vibevoice_workflow.sh
Normal file
117
test_vibevoice_workflow.sh
Normal file
@ -0,0 +1,117 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Test ComfyUI API with VibeVoice workflow
|
||||
# Usage: ./test_vibevoice_workflow.sh [API_KEY]
|
||||
|
||||
# Configuration
|
||||
BASE_URL="http://localhost:8188"
|
||||
API_KEY="${1:-}"
|
||||
|
||||
# Set headers based on whether API key is provided
|
||||
if [ -n "$API_KEY" ]; then
|
||||
AUTH_HEADER="Authorization: Bearer $API_KEY"
|
||||
echo "Using API Key authentication"
|
||||
else
|
||||
AUTH_HEADER=""
|
||||
echo "No API Key provided (running without authentication)"
|
||||
fi
|
||||
|
||||
# The workflow payload
|
||||
# This converts the ComfyUI workflow format to the prompt API format
|
||||
read -r -d '' PAYLOAD << 'EOF'
|
||||
{
|
||||
"prompt": {
|
||||
"1": {
|
||||
"inputs": {
|
||||
"speaker_1_voice": ["2", 0],
|
||||
"speaker_2_voice": null,
|
||||
"speaker_3_voice": null,
|
||||
"speaker_4_voice": null,
|
||||
"model_name": "VibeVoice-Large",
|
||||
"text": "[1] And this is a generated voice, how cool is that?",
|
||||
"quantize_llm_4bit": false,
|
||||
"attention_mode": "sdpa",
|
||||
"cfg_scale": 1.3,
|
||||
"inference_steps": 10,
|
||||
"seed": 1117544514407045,
|
||||
"do_sample": true,
|
||||
"temperature": 0.95,
|
||||
"top_p": 0.95,
|
||||
"top_k": 0,
|
||||
"force_offload": false
|
||||
},
|
||||
"class_type": "VibeVoiceTTS"
|
||||
},
|
||||
"2": {
|
||||
"inputs": {
|
||||
"audio": "audio1.wav"
|
||||
},
|
||||
"class_type": "LoadAudio"
|
||||
},
|
||||
"3": {
|
||||
"inputs": {
|
||||
"audio": ["1", 0],
|
||||
"filename_prefix": "audio/ComfyUI"
|
||||
},
|
||||
"class_type": "SaveAudio"
|
||||
}
|
||||
},
|
||||
"client_id": "test_client_$(date +%s)"
|
||||
}
|
||||
EOF
|
||||
|
||||
echo ""
|
||||
echo "================================================"
|
||||
echo "Sending workflow to ComfyUI..."
|
||||
echo "================================================"
|
||||
echo ""
|
||||
|
||||
# Make the request
|
||||
if [ -n "$AUTH_HEADER" ]; then
|
||||
response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \
|
||||
-X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "$AUTH_HEADER" \
|
||||
-d "$PAYLOAD" \
|
||||
"$BASE_URL/prompt")
|
||||
else
|
||||
response=$(curl -s -w "\nHTTP_STATUS:%{http_code}" \
|
||||
-X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$PAYLOAD" \
|
||||
"$BASE_URL/prompt")
|
||||
fi
|
||||
|
||||
# Extract HTTP status
|
||||
http_status=$(echo "$response" | grep "HTTP_STATUS" | cut -d':' -f2)
|
||||
body=$(echo "$response" | sed '/HTTP_STATUS/d')
|
||||
|
||||
echo "HTTP Status: $http_status"
|
||||
echo ""
|
||||
echo "Response:"
|
||||
echo "$body" | python3 -m json.tool 2>/dev/null || echo "$body"
|
||||
echo ""
|
||||
|
||||
if [ "$http_status" = "200" ]; then
|
||||
echo "✓ Workflow queued successfully!"
|
||||
|
||||
# Extract prompt_id if available
|
||||
prompt_id=$(echo "$body" | python3 -c "import sys, json; data=json.load(sys.stdin); print(data.get('prompt_id', ''))" 2>/dev/null)
|
||||
if [ -n "$prompt_id" ]; then
|
||||
echo "Prompt ID: $prompt_id"
|
||||
echo ""
|
||||
echo "To check status:"
|
||||
if [ -n "$AUTH_HEADER" ]; then
|
||||
echo " curl -H \"$AUTH_HEADER\" $BASE_URL/history/$prompt_id"
|
||||
else
|
||||
echo " curl $BASE_URL/history/$prompt_id"
|
||||
fi
|
||||
fi
|
||||
elif [ "$http_status" = "401" ]; then
|
||||
echo "✗ Authentication failed - check your API key"
|
||||
else
|
||||
echo "✗ Request failed"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "================================================"
|
||||
Loading…
Reference in New Issue
Block a user