Add complete authentication system with documentation

- Frontend authentication with login page (auth_login.html)
- API key injection script (auth_inject.js)
- Session management (localStorage/sessionStorage)
- Logout via URL: http://127.0.0.1:8188/auth_login.html?logout=true
- Modified server.py to inject auth scripts into index.html
- Added comprehensive documentation:
  * AUTHENTICATION_GUIDE.md - Complete authentication guide
  * FRONTEND_AUTH_GUIDE.md - Frontend-specific guide
- Health endpoint accessible without authentication
- Multiple auth methods: Bearer token, X-API-Key header, query parameter
This commit is contained in:
daverbj 2025-12-11 15:49:23 +03:00
parent 06bf79b19b
commit c5ad1381bf
9 changed files with 1723 additions and 0 deletions

272
AUTHENTICATION_GUIDE.md Normal file
View File

@ -0,0 +1,272 @@
# ComfyUI API Authentication Guide
## Overview
ComfyUI now supports API key authentication to protect your REST API endpoints. This guide covers setup, usage, and logout procedures.
## Features
- **API Key Authentication**: Protect all API endpoints with a simple API key
- **Multiple Auth Methods**: Bearer token, X-API-Key header, or query parameter
- **Health Endpoint**: Monitor server status without authentication
- **Frontend Login**: Built-in login page for browser access
- **Session Management**: Remember API key across sessions or per-session only
## Quick Start
### 1. Start Server with API Key
```bash
# Using command line argument
python3 main.py --api-key "your-secure-api-key-here"
# Or using a file (recommended for production)
echo "your-secure-api-key-here" > api_key.txt
python3 main.py --api-key-file api_key.txt
```
### 2. Access ComfyUI
When you navigate to `http://127.0.0.1:8188`, you'll be redirected to a login page.
### 3. Login
1. Enter your API key in the password field
2. Optionally check "Remember me" to persist the key across browser sessions
3. Click "Login"
The system validates your key against the `/health` endpoint before storing it.
### 4. Logout
To logout, visit:
```
http://127.0.0.1:8188/auth_login.html?logout=true
```
This will clear your stored API key and redirect you to the login page.
## API Usage
### Authentication Methods
You can authenticate API requests in three ways:
#### 1. Bearer Token (Recommended)
```bash
curl -H "Authorization: Bearer your-api-key" http://127.0.0.1:8188/prompt
```
#### 2. X-API-Key Header
```bash
curl -H "X-API-Key: your-api-key" http://127.0.0.1:8188/prompt
```
#### 3. Query Parameter
```bash
curl "http://127.0.0.1:8188/prompt?api_key=your-api-key"
```
### Python Example
```python
import requests
API_KEY = "your-api-key"
BASE_URL = "http://127.0.0.1:8188"
# Using Bearer token
headers = {"Authorization": f"Bearer {API_KEY}"}
response = requests.get(f"{BASE_URL}/queue", headers=headers)
# Using X-API-Key header
headers = {"X-API-Key": API_KEY}
response = requests.get(f"{BASE_URL}/queue", headers=headers)
# Using query parameter
response = requests.get(f"{BASE_URL}/queue?api_key={API_KEY}")
```
### JavaScript Example
```javascript
const API_KEY = "your-api-key";
const BASE_URL = "http://127.0.0.1:8188";
// Using Bearer token
fetch(`${BASE_URL}/queue`, {
headers: {
"Authorization": `Bearer ${API_KEY}`
}
})
.then(response => response.json())
.then(data => console.log(data));
```
## Health Endpoint
The `/health` endpoint is always accessible without authentication:
```bash
curl http://127.0.0.1:8188/health
```
Response includes:
- Server status
- Queue information (pending/running tasks)
- Device information (GPU/CPU)
- VRAM statistics
Example response:
```json
{
"status": "ok",
"queue": {
"pending": 0,
"running": 0
},
"device": {
"name": "mps",
"type": "mps"
},
"vram": {
"total": 68719476736,
"free": 68719476736
}
}
```
## Exempt Endpoints
The following paths are exempt from authentication:
### Static Files
- `.html`, `.js`, `.css`, `.json` files
- `.png`, `.jpg`, `.jpeg`, `.gif`, `.webp` images
- `.svg`, `.ico` icons
- `.woff`, `.woff2`, `.ttf` fonts
- `.mp3`, `.wav` audio files
- `.mp4`, `.webm` video files
### API Paths
- `/` - Root path (serves login if not authenticated)
- `/health` - Health check endpoint
- `/ws` - WebSocket endpoint
- `/auth_login.html` - Login page
- `/auth_inject.js` - Auth injection script
- `/extensions/*` - Extension files
- `/templates/*` - Template files
- `/docs/*` - Documentation
## Security Best Practices
1. **Use Strong API Keys**: Generate random, long API keys (32+ characters)
```bash
# Generate secure API key on macOS/Linux
openssl rand -base64 32
```
2. **Use API Key Files**: Store keys in files rather than command line
```bash
python3 main.py --api-key-file /secure/path/api_key.txt
```
3. **File Permissions**: Restrict key file access
```bash
chmod 600 api_key.txt
```
4. **HTTPS**: Use reverse proxy with SSL in production
```nginx
server {
listen 443 ssl;
server_name your-domain.com;
ssl_certificate /path/to/cert.pem;
ssl_certificate_key /path/to/key.pem;
location / {
proxy_pass http://127.0.0.1:8188;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
}
}
```
5. **Environment Variables**: Store keys in environment
```bash
export COMFYUI_API_KEY=$(cat api_key.txt)
python3 main.py --api-key "$COMFYUI_API_KEY"
```
## Frontend Integration
The authentication system automatically handles frontend requests:
1. When authentication is enabled, `auth_inject.js` is injected into `index.html`
2. This script intercepts all `fetch()` and `XMLHttpRequest` calls
3. Authorization headers are added automatically to all requests
4. On 401 responses, the user is redirected to the login page
### Session Storage
- **localStorage**: API key persists across browser sessions (when "Remember me" is checked)
- **sessionStorage**: API key cleared when browser/tab closes (when "Remember me" is not checked)
## Troubleshooting
### 401 Unauthorized Errors
1. Verify API key matches server configuration
2. Check authentication header format
3. Ensure endpoint isn't expecting different auth method
### Login Page Not Appearing
1. Clear browser cache
2. Verify `auth_login.html` and `auth_inject.js` exist in ComfyUI root
3. Check server logs for errors
### WebSocket Connection Issues
WebSocket connections (`/ws`) are exempt from authentication, but may require authentication for initial HTTP upgrade depending on your setup.
### Logout Not Working
Visit the logout URL directly:
```
http://127.0.0.1:8188/auth_login.html?logout=true
```
Or clear storage manually in browser DevTools:
```javascript
localStorage.removeItem('comfyui_api_key');
sessionStorage.removeItem('comfyui_api_key');
```
## Disabling Authentication
To disable authentication, simply start the server without the `--api-key` or `--api-key-file` arguments:
```bash
python3 main.py
```
## Migration from Unauthenticated Setup
If you're adding authentication to an existing ComfyUI installation:
1. Ensure `auth_login.html` and `auth_inject.js` exist in root directory
2. Update `server.py` with authentication routes
3. Add `middleware/auth_middleware.py`
4. Update `comfy/cli_args.py` with API key arguments
5. Restart server with `--api-key` argument
6. Update API clients to include authentication
## Support
For issues or questions:
- Check server logs for authentication errors
- Verify middleware is properly configured
- Test with `/health` endpoint first (no auth required)
- Review `middleware/auth_middleware.py` for exempt paths

233
FRONTEND_AUTH_GUIDE.md Normal file
View File

@ -0,0 +1,233 @@
# ComfyUI Frontend Authentication Guide
## Overview
When API key authentication is enabled, the ComfyUI frontend will automatically require users to log in with the API key before accessing the interface.
## How It Works
1. **Automatic Redirection**: When you access ComfyUI with authentication enabled, you'll be automatically redirected to a login page if no valid API key is stored.
2. **API Key Storage**: After successful login, your API key is stored in your browser (localStorage or sessionStorage) depending on your "Remember me" choice.
3. **Automatic Injection**: The API key is automatically added to all API requests via the `Authorization: Bearer <key>` header.
4. **Session Management**:
- **Remember me (checked)**: API key stored in localStorage (persists across browser sessions)
- **Remember me (unchecked)**: API key stored in sessionStorage (cleared when browser closes)
## Using the Frontend with Authentication
### First Time Access
1. Start ComfyUI with an API key:
```bash
python main.py --api-key "your-secret-key-123"
```
2. Open your browser and navigate to `http://localhost:8188`
3. You'll be presented with a login page
4. Enter your API key and click "Login"
5. Choose whether to remember your login (recommended for personal devices only)
### Login Page Features
- **Show API Key**: Toggle to view the key you're entering
- **Remember Me**: Keep your session active across browser restarts
- **Auto-validation**: The system validates your key before storing it
### Logging Out
To log out and clear your stored API key:
**Option 1: JavaScript Console**
```javascript
comfyuiLogout()
```
**Option 2: Browser DevTools**
1. Open DevTools (F12)
2. Go to Application > Storage
3. Clear localStorage and sessionStorage
4. Refresh the page
**Option 3: Manual URL**
Navigate to: `http://localhost:8188/auth_login.html`
### Security Considerations
#### For Personal Use
- ✅ Enable "Remember me" for convenience
- ✅ Use strong, unique API keys
- ✅ Keep your browser updated
#### For Shared/Public Computers
- ❌ **DO NOT** enable "Remember me"
- ✅ Always log out when finished
- ✅ Clear browser data after use
- ✅ Consider using a private/incognito window
#### For Production Deployments
- ✅ Always use HTTPS (combine with `--tls-keyfile` and `--tls-certfile`)
- ✅ Use strong, randomly generated API keys
- ✅ Rotate keys regularly
- ✅ Monitor access logs for unauthorized attempts
- ✅ Consider using additional security layers (VPN, firewall, etc.)
## Troubleshooting
### "Invalid API Key" Error
**Problem**: Login fails with invalid API key message
**Solutions**:
1. Verify you're using the correct API key
2. Check for extra spaces or newlines
3. Ensure the server was started with the same API key
4. Check server logs for authentication attempts
### Automatic Logout
**Problem**: Frequently logged out automatically
**Possible Causes**:
1. API key changed on server (restart required)
2. Browser cleared storage automatically
3. "Remember me" was not checked
4. Session expired (if using sessionStorage)
**Solution**: Check "Remember me" when logging in
### Login Page Not Showing
**Problem**: Can't access login page
**Possible Causes**:
1. Authentication not enabled (no `--api-key` argument)
2. Browser cached old version
**Solution**:
1. Hard refresh the page (Ctrl+Shift+R or Cmd+Shift+R)
2. Clear browser cache
3. Try accessing directly: `http://localhost:8188/auth_login.html`
### API Requests Still Failing
**Problem**: Some API requests return 401 after login
**Solutions**:
1. Check browser console for errors
2. Verify API key is stored: Open DevTools > Application > Storage
3. Try logging out and back in
4. Clear all browser data and try again
## Advanced Usage
### Programmatic Access with Frontend
If you need to access the API programmatically while using the frontend:
```javascript
// Get the stored API key
const apiKey = localStorage.getItem('comfyui_api_key') ||
sessionStorage.getItem('comfyui_api_key');
// Make authenticated requests
fetch('/api/system_stats', {
headers: {
'Authorization': `Bearer ${apiKey}`
}
}).then(r => r.json()).then(console.log);
```
### Custom Frontend Integration
If you're building a custom frontend, you can use the same mechanism:
1. **Login Flow**:
```javascript
// Validate API key
const response = await fetch('/health', {
headers: {
'Authorization': `Bearer ${apiKey}`
}
});
if (response.ok) {
// Store the key
localStorage.setItem('comfyui_api_key', apiKey);
}
```
2. **Request Interceptor**:
```javascript
// Add to all requests
const apiKey = localStorage.getItem('comfyui_api_key');
fetch(url, {
headers: {
'Authorization': `Bearer ${apiKey}`,
...otherHeaders
}
});
```
3. **Handle 401 Responses**:
```javascript
if (response.status === 401) {
// Clear stored key and redirect to login
localStorage.removeItem('comfyui_api_key');
window.location.href = '/auth_login.html';
}
```
## API Endpoints Reference
### Public Endpoints (No Authentication Required)
- `GET /` - Main page (with auth script injected)
- `GET /health` - Health check
- `GET /auth_login.html` - Login page
- `GET /auth_inject.js` - Auth injection script
- `GET /ws` - WebSocket connection
- Static files (`.js`, `.css`, `.html`, etc.)
### Protected Endpoints (Authentication Required)
- All `/api/*` endpoints
- All `/internal/*` endpoints
- Most other API endpoints
## Browser Compatibility
The authentication system works with all modern browsers:
- ✅ Chrome/Edge 90+
- ✅ Firefox 88+
- ✅ Safari 14+
- ✅ Opera 76+
## FAQs
**Q: Can I use ComfyUI without authentication?**
A: Yes! Simply start ComfyUI without the `--api-key` argument.
**Q: Can I change the API key without losing my workflows?**
A: Yes, workflows are stored separately. Just update the key on the server and re-login.
**Q: Is my API key secure?**
A: The key is stored in browser storage and sent over HTTPS (if configured). For maximum security, use HTTPS and strong keys.
**Q: Can multiple users use different API keys?**
A: Currently, the system supports a single API key. For multi-user scenarios, each user must use the same key.
**Q: What happens if I forget my API key?**
A: Check your server startup command or the file specified in `--api-key-file`.
## Support
For issues or questions:
1. Check the server logs
2. Review browser console for errors
3. Refer to the main authentication documentation: `API_AUTHENTICATION.md`
4. Check the quick start guide: `QUICK_START_AUTH.md`

110
auth_inject.js Normal file
View File

@ -0,0 +1,110 @@
/**
* ComfyUI API Key Injection
* This script automatically adds the stored API key to all HTTP requests
*/
(function() {
'use strict';
// Get the stored API key
function getApiKey() {
return localStorage.getItem('comfyui_api_key') || sessionStorage.getItem('comfyui_api_key');
}
// Check if user is authenticated
function checkAuth() {
const apiKey = getApiKey();
if (!apiKey && window.location.pathname !== '/auth_login.html') {
// Redirect to login page if no API key is found
window.location.href = '/auth_login.html';
return false;
}
return true;
}
// Intercept fetch requests
const originalFetch = window.fetch;
window.fetch = function(...args) {
const apiKey = getApiKey();
if (apiKey) {
// Clone or create the options object
let [url, options = {}] = args;
// Initialize headers if not present
if (!options.headers) {
options.headers = {};
}
// Convert Headers object to plain object if needed
if (options.headers instanceof Headers) {
const headersObj = {};
options.headers.forEach((value, key) => {
headersObj[key] = value;
});
options.headers = headersObj;
}
// Add Authorization header if not already present
if (!options.headers['Authorization'] && !options.headers['authorization']) {
options.headers['Authorization'] = `Bearer ${apiKey}`;
}
// Update args
args = [url, options];
}
// Call original fetch and handle 401 errors
return originalFetch.apply(this, args).then(response => {
if (response.status === 401 && window.location.pathname !== '/auth_login.html') {
// Clear stored API key and redirect to login
localStorage.removeItem('comfyui_api_key');
sessionStorage.removeItem('comfyui_api_key');
window.location.href = '/auth_login.html';
}
return response;
});
};
// Intercept XMLHttpRequest
const originalOpen = XMLHttpRequest.prototype.open;
const originalSend = XMLHttpRequest.prototype.send;
XMLHttpRequest.prototype.open = function(method, url, ...rest) {
this._url = url;
return originalOpen.apply(this, [method, url, ...rest]);
};
XMLHttpRequest.prototype.send = function(...args) {
const apiKey = getApiKey();
if (apiKey && !this.getRequestHeader('Authorization')) {
this.setRequestHeader('Authorization', `Bearer ${apiKey}`);
}
// Handle 401 responses
this.addEventListener('load', function() {
if (this.status === 401 && window.location.pathname !== '/auth_login.html') {
localStorage.removeItem('comfyui_api_key');
sessionStorage.removeItem('comfyui_api_key');
window.location.href = '/auth_login.html';
}
});
return originalSend.apply(this, args);
};
// Add logout function to window
window.comfyuiLogout = function() {
localStorage.removeItem('comfyui_api_key');
sessionStorage.removeItem('comfyui_api_key');
window.location.href = '/auth_login.html';
};
// Check authentication on page load (except for login page)
if (window.location.pathname !== '/auth_login.html') {
checkAuth();
}
console.log('[ComfyUI Auth] API key injection enabled');
})();

278
auth_login.html Normal file
View File

@ -0,0 +1,278 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>ComfyUI - API Key Required</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 20px;
}
.login-container {
background: white;
border-radius: 16px;
box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3);
padding: 40px;
max-width: 420px;
width: 100%;
}
.logo {
text-align: center;
margin-bottom: 30px;
}
.logo h1 {
font-size: 32px;
color: #667eea;
margin-bottom: 8px;
}
.logo p {
color: #666;
font-size: 14px;
}
.form-group {
margin-bottom: 24px;
}
label {
display: block;
margin-bottom: 8px;
color: #333;
font-weight: 500;
font-size: 14px;
}
input[type="password"],
input[type="text"] {
width: 100%;
padding: 12px 16px;
border: 2px solid #e0e0e0;
border-radius: 8px;
font-size: 15px;
transition: border-color 0.3s;
}
input[type="password"]:focus,
input[type="text"]:focus {
outline: none;
border-color: #667eea;
}
.show-password {
display: flex;
align-items: center;
margin-top: 8px;
font-size: 14px;
color: #666;
cursor: pointer;
user-select: none;
}
.show-password input[type="checkbox"] {
margin-right: 8px;
}
button {
width: 100%;
padding: 14px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s, box-shadow 0.2s;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.4);
}
button:active {
transform: translateY(0);
}
.error-message {
background: #fee;
color: #c33;
padding: 12px;
border-radius: 8px;
margin-bottom: 20px;
font-size: 14px;
display: none;
}
.error-message.show {
display: block;
}
.info-box {
background: #f0f7ff;
border-left: 4px solid #667eea;
padding: 12px 16px;
margin-top: 24px;
border-radius: 4px;
font-size: 13px;
color: #555;
}
.info-box strong {
color: #667eea;
}
.remember-me {
display: flex;
align-items: center;
margin-bottom: 20px;
font-size: 14px;
color: #666;
cursor: pointer;
user-select: none;
}
.remember-me input[type="checkbox"] {
margin-right: 8px;
}
</style>
</head>
<body>
<div class="login-container">
<div class="logo">
<h1>🎨 ComfyUI</h1>
<p>API Key Authentication</p>
</div>
<div class="error-message" id="errorMessage">
Invalid API key. Please try again.
</div>
<form id="loginForm">
<div class="form-group">
<label for="apiKey">API Key</label>
<input
type="password"
id="apiKey"
name="apiKey"
placeholder="Enter your API key"
autocomplete="off"
required
>
<label class="show-password">
<input type="checkbox" id="showPassword">
Show API key
</label>
</div>
<label class="remember-me">
<input type="checkbox" id="rememberMe" checked>
Remember me (store in browser)
</label>
<button type="submit">Login</button>
</form>
<div class="info-box">
<strong>Note:</strong> Your API key is required to access ComfyUI.
If you don't have one, please contact your administrator or check the server configuration.
</div>
</div>
<script>
const loginForm = document.getElementById('loginForm');
const apiKeyInput = document.getElementById('apiKey');
const errorMessage = document.getElementById('errorMessage');
const showPasswordCheckbox = document.getElementById('showPassword');
const rememberMeCheckbox = document.getElementById('rememberMe');
// Check if API key is already stored
const storedApiKey = localStorage.getItem('comfyui_api_key') || sessionStorage.getItem('comfyui_api_key');
if (storedApiKey) {
// Try to validate the stored key
validateAndRedirect(storedApiKey, false);
}
// Show/hide password
showPasswordCheckbox.addEventListener('change', function() {
apiKeyInput.type = this.checked ? 'text' : 'password';
});
// Form submission
loginForm.addEventListener('submit', async function(e) {
e.preventDefault();
const apiKey = apiKeyInput.value.trim();
if (!apiKey) {
showError('Please enter an API key');
return;
}
await validateAndRedirect(apiKey, true);
});
async function validateAndRedirect(apiKey, showErrors) {
try {
// Test the API key by making a request to the health endpoint
const response = await fetch('/health', {
method: 'GET',
headers: {
'Authorization': `Bearer ${apiKey}`
}
});
if (response.ok) {
// API key is valid, store it
if (rememberMeCheckbox.checked) {
localStorage.setItem('comfyui_api_key', apiKey);
sessionStorage.removeItem('comfyui_api_key');
} else {
sessionStorage.setItem('comfyui_api_key', apiKey);
localStorage.removeItem('comfyui_api_key');
}
// Redirect to main page
window.location.href = '/';
} else {
if (showErrors) {
showError('Invalid API key. Please check and try again.');
} else {
// Stored key is invalid, clear it
localStorage.removeItem('comfyui_api_key');
sessionStorage.removeItem('comfyui_api_key');
}
}
} catch (error) {
console.error('Validation error:', error);
if (showErrors) {
showError('Failed to validate API key. Please check your connection.');
}
}
}
function showError(message) {
errorMessage.textContent = message;
errorMessage.classList.add('show');
setTimeout(() => {
errorMessage.classList.remove('show');
}, 5000);
}
</script>
</body>
</html>

8
facebook_audio.py Normal file
View File

@ -0,0 +1,8 @@
from transformers import pipeline
import scipy
synthesiser = pipeline("text-to-audio", "facebook/musicgen-large")
music = synthesiser("lo-fi music with a soothing melody", forward_params={"do_sample": True})
scipy.io.wavfile.write("musicgen_out.wav", rate=music["sampling_rate"], data=music["audio"])

View File

@ -0,0 +1,275 @@
#!/usr/bin/env python3
"""
Standalone script to generate music from text using Stable Audio in ComfyUI.
Based on the workflow: user/default/workflows/audio_stable_audio_example.json
This script replicates the workflow:
1. Load checkpoint model (stable-audio-open-1.0.safetensors)
2. Load CLIP text encoder (t5-base.safetensors)
3. Encode positive prompt (music description)
4. Encode negative prompt (empty)
5. Create empty latent audio (47.6 seconds)
6. Sample using KSampler
7. Decode audio from latent using VAE
8. Save as MP3
Requirements:
- stable-audio-open-1.0.safetensors in models/checkpoints/
- t5-base.safetensors in models/text_encoders/
"""
import torch
import sys
import os
import random
import av
from io import BytesIO
# Add ComfyUI to path
script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, script_dir)
import comfy.sd
import comfy.sample
import comfy.samplers
import comfy.model_management
import folder_paths
import latent_preview
import comfy.utils
def load_checkpoint(ckpt_name):
"""Load checkpoint model - returns MODEL, CLIP, VAE"""
print(f"Loading checkpoint: {ckpt_name}")
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings")
)
return out[:3] # MODEL, CLIP, VAE
def load_clip(clip_name, clip_type="stable_audio"):
"""Load CLIP text encoder"""
print(f"Loading CLIP: {clip_name}")
clip_type_enum = getattr(comfy.sd.CLIPType, clip_type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(
ckpt_paths=[clip_path],
embedding_directory=folder_paths.get_folder_paths("embeddings"),
clip_type=clip_type_enum,
model_options={}
)
return clip
def encode_text(clip, text):
"""Encode text using CLIP - returns CONDITIONING"""
print(f"Encoding text: '{text}'")
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None")
tokens = clip.tokenize(text)
return clip.encode_from_tokens_scheduled(tokens)
def create_empty_latent_audio(seconds, batch_size=1):
"""Create empty latent audio tensor"""
print(f"Creating empty latent audio: {seconds} seconds")
length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros(
[batch_size, 64, length],
device=comfy.model_management.intermediate_device()
)
return {"samples": latent, "type": "audio"}
def sample_audio(model, seed, steps, cfg, sampler_name, scheduler,
positive, negative, latent_image, denoise=1.0):
"""Run KSampler to generate audio latents"""
print(f"Sampling with seed={seed}, steps={steps}, cfg={cfg}, sampler={sampler_name}, scheduler={scheduler}")
latent_samples = latent_image["samples"]
latent_samples = comfy.sample.fix_empty_latent_channels(model, latent_samples)
# Prepare noise
batch_inds = latent_image["batch_index"] if "batch_index" in latent_image else None
noise = comfy.sample.prepare_noise(latent_samples, seed, batch_inds)
# Check for noise mask
noise_mask = latent_image.get("noise_mask", None)
# Prepare callback for progress
callback = latent_preview.prepare_callback(model, steps)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
# Sample
samples = comfy.sample.sample(
model, noise, steps, cfg, sampler_name, scheduler,
positive, negative, latent_samples,
denoise=denoise,
disable_noise=False,
start_step=None,
last_step=None,
force_full_denoise=False,
noise_mask=noise_mask,
callback=callback,
disable_pbar=disable_pbar,
seed=seed
)
out = latent_image.copy()
out["samples"] = samples
return out
def decode_audio(vae, samples):
"""Decode audio from latent samples using VAE"""
print("Decoding audio from latents")
audio = vae.decode(samples["samples"]).movedim(-1, 1)
# Normalize audio
std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return {"waveform": audio, "sample_rate": 44100}
def save_audio_mp3(audio, filename, quality="V0"):
"""Save audio as MP3 file using PyAV (same as ComfyUI)"""
print(f"Saving audio to: {filename}")
# Create output directory if needed
os.makedirs(os.path.dirname(filename), exist_ok=True)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
# Ensure audio is in CPU
waveform = waveform.cpu()
# Process each audio in batch (usually just 1)
for batch_number, waveform_item in enumerate(waveform):
if batch_number > 0:
# Add batch number to filename if multiple
base, ext = os.path.splitext(filename)
output_path = f"{base}_{batch_number}{ext}"
else:
output_path = filename
# Create output buffer
output_buffer = BytesIO()
output_container = av.open(output_buffer, mode="w", format="mp3")
# Determine audio layout - waveform_item shape is [channels, samples]
num_channels = waveform_item.shape[0] if waveform_item.dim() > 1 else 1
layout = "mono" if num_channels == 1 else "stereo"
# Set up the MP3 output stream
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
# Set quality
if quality == "V0":
out_stream.codec_context.qscale = 1 # Highest VBR quality
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
# Prepare waveform for PyAV: needs to be [samples, channels]
# Use detach() to avoid gradient tracking issues
if waveform_item.dim() == 1:
# Mono audio, add channel dimension
waveform_numpy = waveform_item.unsqueeze(1).float().detach().numpy()
else:
# Transpose from [channels, samples] to [samples, channels]
waveform_numpy = waveform_item.transpose(0, 1).float().detach().numpy()
# Reshape to [1, samples * channels] for PyAV
waveform_numpy = waveform_numpy.reshape(1, -1)
# Create audio frame
frame = av.AudioFrame.from_ndarray(
waveform_numpy,
format="flt",
layout=layout,
)
frame.sample_rate = sample_rate
frame.pts = 0
# Encode
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close container
output_container.close()
# Write to file
output_buffer.seek(0)
with open(output_path, "wb") as f:
f.write(output_buffer.getbuffer())
print(f"Audio saved successfully: {output_path}")
def main():
# Configuration
checkpoint_name = "stable-audio-open-1.0.safetensors"
clip_name = "t5-base.safetensors"
positive_prompt = "A soft melodious acoustic guitar music"
negative_prompt = ""
audio_duration = 47.6 # seconds
seed = random.randint(0, 0xffffffffffffffff) # Random seed, or use specific value
steps = 50
cfg = 4.98
sampler_name = "dpmpp_3m_sde_gpu"
scheduler = "exponential"
denoise = 1.0
output_filename = "output/audio/generated_music.mp3"
quality = "V0"
print("=" * 60)
print("Stable Audio - Music Generation Script")
print("=" * 60)
print(f"Positive Prompt: {positive_prompt}")
print(f"Duration: {audio_duration} seconds")
print(f"Seed: {seed}")
print("=" * 60)
# 1. Load checkpoint (MODEL, CLIP, VAE)
model, checkpoint_clip, vae = load_checkpoint(checkpoint_name)
# 2. Load separate CLIP text encoder for stable audio
clip = load_clip(clip_name, "stable_audio")
# 3. Encode positive and negative prompts
positive_conditioning = encode_text(clip, positive_prompt)
negative_conditioning = encode_text(clip, negative_prompt)
# 4. Create empty latent audio
latent_audio = create_empty_latent_audio(audio_duration, batch_size=1)
# 5. Sample using KSampler
sampled_latent = sample_audio(
model, seed, steps, cfg, sampler_name, scheduler,
positive_conditioning, negative_conditioning, latent_audio, denoise
)
# 6. Decode audio from latent using VAE
audio = decode_audio(vae, sampled_latent)
# 7. Save as MP3
save_audio_mp3(audio, output_filename, quality)
print("=" * 60)
print("Generation complete!")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,476 @@
#!/usr/bin/env python3
"""
Standalone script to generate TTS audio using VibeVoice.
This script has NO ComfyUI dependencies and uses the models directly from HuggingFace.
Based on Microsoft's VibeVoice: https://github.com/microsoft/VibeVoice
Requirements:
pip install torch transformers numpy scipy soundfile librosa huggingface-hub
Usage:
python generate_vibevoice_standalone.py
"""
import torch
import numpy as np
import soundfile as sf
import os
import random
import re
import logging
from typing import Optional, List, Tuple
from huggingface_hub import snapshot_download
logging.basicConfig(level=logging.INFO, format='[VibeVoice] %(message)s')
logger = logging.getLogger(__name__)
try:
import librosa
LIBROSA_AVAILABLE = True
except ImportError:
logger.warning("librosa not available - resampling will not work")
LIBROSA_AVAILABLE = False
def set_seed(seed: int):
"""Set random seeds for reproducibility"""
if seed == 0:
seed = random.randint(1, 0xffffffffffffffff)
MAX_NUMPY_SEED = 2**32 - 1
numpy_seed = seed % MAX_NUMPY_SEED
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(numpy_seed)
random.seed(seed)
return seed
def parse_script(script: str) -> Tuple[List[Tuple[int, str]], List[int]]:
"""
Parse speaker script into (speaker_id, text) tuples.
Supports formats:
[1] Some text...
Speaker 1: Some text...
Returns:
parsed_lines: List of (0-based speaker_id, text) tuples
speaker_ids: List of unique 1-based speaker IDs in order of appearance
"""
parsed_lines = []
speaker_ids_in_script = []
line_format_regex = re.compile(r'^(?:Speaker\s+(\d+)\s*:|\[(\d+)\])\s*(.*)$', re.IGNORECASE)
for line in script.strip().split("\n"):
line = line.strip()
if not line:
continue
match = line_format_regex.match(line)
if match:
speaker_id_str = match.group(1) or match.group(2)
speaker_id = int(speaker_id_str)
text_content = match.group(3)
if match.group(1) is None and text_content.lstrip().startswith(':'):
colon_index = text_content.find(':')
text_content = text_content[colon_index + 1:]
if speaker_id < 1:
logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'")
continue
text = text_content.strip()
internal_speaker_id = speaker_id - 1
parsed_lines.append((internal_speaker_id, text))
if speaker_id not in speaker_ids_in_script:
speaker_ids_in_script.append(speaker_id)
else:
logger.warning(f"Could not parse speaker marker, ignoring: '{line}'")
if not parsed_lines and script.strip():
logger.info("No speaker markers found. Treating entire text as Speaker 1.")
parsed_lines.append((0, script.strip()))
speaker_ids_in_script.append(1)
return parsed_lines, sorted(list(set(speaker_ids_in_script)))
def load_audio_file(audio_path: str, target_sr: int = 24000) -> Optional[np.ndarray]:
"""Load audio file and convert to mono at target sample rate"""
if not os.path.exists(audio_path):
logger.error(f"Audio file not found: {audio_path}")
return None
logger.info(f"Loading audio: {audio_path}")
try:
# Load audio using soundfile
waveform, sr = sf.read(audio_path)
# Convert to mono if stereo
if waveform.ndim > 1:
waveform = np.mean(waveform, axis=1)
# Resample if needed
if sr != target_sr:
if not LIBROSA_AVAILABLE:
raise ImportError("librosa is required for resampling. Install with: pip install librosa")
logger.info(f"Resampling from {sr}Hz to {target_sr}Hz")
waveform = librosa.resample(y=waveform, orig_sr=sr, target_sr=target_sr)
# Validate audio
if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)):
logger.error("Audio contains NaN or Inf values, replacing with zeros")
waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
if np.all(waveform == 0):
logger.warning("Audio waveform is completely silent")
# Normalize extreme values
max_val = np.abs(waveform).max()
if max_val > 10.0:
logger.warning(f"Audio values are very large (max: {max_val}), normalizing")
waveform = waveform / max_val
return waveform.astype(np.float32)
except Exception as e:
logger.error(f"Error loading audio: {e}")
return None
def download_model(model_name: str = "VibeVoice-1.5B", cache_dir: str = "./models"):
"""Download VibeVoice model from HuggingFace"""
repo_mapping = {
"VibeVoice-1.5B": "microsoft/VibeVoice-1.5B",
"VibeVoice-Large": "aoi-ot/VibeVoice-Large"
}
if model_name not in repo_mapping:
raise ValueError(f"Unknown model: {model_name}. Choose from: {list(repo_mapping.keys())}")
repo_id = repo_mapping[model_name]
model_path = os.path.join(cache_dir, model_name)
if os.path.exists(os.path.join(model_path, "config.json")):
logger.info(f"Model already downloaded: {model_path}")
return model_path
logger.info(f"Downloading model from {repo_id}...")
os.makedirs(cache_dir, exist_ok=True)
model_path = snapshot_download(
repo_id=repo_id,
local_dir=model_path,
local_dir_use_symlinks=False
)
logger.info(f"Model downloaded to: {model_path}")
return model_path
def generate_tts(
text: str,
model_name: str = "VibeVoice-Large",
speaker_audio_paths: Optional[dict] = None,
output_path: str = "output.wav",
cfg_scale: float = 1.3,
inference_steps: int = 10,
seed: int = 42,
temperature: float = 0.95,
top_p: float = 0.95,
top_k: int = 0,
cache_dir: str = "./models",
device: str = "auto"
):
"""
Generate TTS audio using VibeVoice
Args:
text: Text script with speaker markers like "[1] text" or "Speaker 1: text"
model_name: Model to use ("VibeVoice-1.5B" or "VibeVoice-Large")
speaker_audio_paths: Dict mapping speaker IDs to audio file paths for voice cloning
e.g., {1: "voice1.wav", 2: "voice2.wav"}
output_path: Where to save the generated audio
cfg_scale: Classifier-Free Guidance scale (higher = more adherence to prompt)
inference_steps: Number of diffusion steps
seed: Random seed for reproducibility
temperature: Sampling temperature
top_p: Nucleus sampling parameter
top_k: Top-K sampling parameter
cache_dir: Directory to cache downloaded models
device: Device to use ("cuda", "mps", "cpu", or "auto" for automatic detection)
"""
# Set seed
actual_seed = set_seed(seed)
logger.info(f"Using seed: {actual_seed}")
# Determine device - with MPS support for Mac
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = "mps"
logger.info("MPS (Metal Performance Shaders) detected - using Mac GPU acceleration")
else:
device = "cpu"
logger.info(f"Using device: {device}")
# Download model if needed
model_path = download_model(model_name, cache_dir)
# Import VibeVoice components
logger.info("Loading VibeVoice model...")
try:
# Add the VibeVoice custom model code to path
import sys
vibevoice_custom_path = os.path.join(os.path.dirname(__file__), "custom_nodes", "ComfyUI-VibeVoice")
if vibevoice_custom_path not in sys.path:
sys.path.insert(0, vibevoice_custom_path)
# Import custom VibeVoice model
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast
from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
import json
# Load config
config_path = os.path.join(model_path, "config.json")
config = VibeVoiceConfig.from_pretrained(config_path)
# Load tokenizer - download if not present
tokenizer_file = os.path.join(model_path, "tokenizer.json")
if not os.path.exists(tokenizer_file):
logger.info(f"tokenizer.json not found, downloading from HuggingFace...")
from huggingface_hub import hf_hub_download
# Determine which Qwen model to use based on model size
qwen_repo = "Qwen/Qwen2.5-1.5B" if "1.5B" in model_name else "Qwen/Qwen2.5-7B"
try:
hf_hub_download(
repo_id=qwen_repo,
filename="tokenizer.json",
local_dir=model_path,
local_dir_use_symlinks=False
)
logger.info("tokenizer.json downloaded successfully")
except Exception as e:
logger.error(f"Failed to download tokenizer.json: {e}")
raise FileNotFoundError(f"Could not download tokenizer.json from {qwen_repo}")
tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file)
# Load processor config
preprocessor_config_path = os.path.join(model_path, "preprocessor_config.json")
processor_config_data = {}
if os.path.exists(preprocessor_config_path):
with open(preprocessor_config_path, 'r') as f:
processor_config_data = json.load(f)
audio_processor = VibeVoiceTokenizerProcessor()
processor = VibeVoiceProcessor(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=processor_config_data.get("speech_tok_compress_ratio", 3200),
db_normalize=processor_config_data.get("db_normalize", True)
)
# Load model
# MPS doesn't support bfloat16 well, use float16
if device == "mps":
dtype = torch.float16
logger.info("Using float16 for MPS device")
elif torch.cuda.is_available() and torch.cuda.is_bf16_supported():
dtype = torch.bfloat16
else:
dtype = torch.float16
model = VibeVoiceForConditionalGenerationInference.from_pretrained(
model_path,
config=config,
torch_dtype=dtype,
device_map=device,
attn_implementation="sdpa"
)
model.eval()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
import traceback
traceback.print_exc()
raise
# Parse script
parsed_lines, speaker_ids = parse_script(text)
if not parsed_lines:
raise ValueError("Script is empty or invalid")
logger.info(f"Parsed {len(parsed_lines)} lines with speakers: {speaker_ids}")
# Load speaker audio samples
voice_samples = []
if speaker_audio_paths is None:
speaker_audio_paths = {}
for speaker_id in speaker_ids:
audio_path = speaker_audio_paths.get(speaker_id)
if audio_path:
audio = load_audio_file(audio_path, target_sr=24000)
if audio is None:
logger.warning(f"Could not load audio for speaker {speaker_id}, using zero-shot TTS")
voice_samples.append(None)
else:
voice_samples.append(audio)
else:
logger.info(f"No reference audio for speaker {speaker_id}, using zero-shot TTS")
voice_samples.append(None)
# Prepare inputs
logger.info("Processing inputs...")
try:
inputs = processor(
parsed_scripts=[parsed_lines],
voice_samples=[voice_samples],
speaker_ids_for_prompt=[speaker_ids],
padding=True,
return_tensors="pt",
return_attention_mask=True
)
# Move to device
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
except Exception as e:
logger.error(f"Error processing inputs: {e}")
raise
# Configure generation
model.set_ddpm_inference_steps(num_steps=inference_steps)
generation_config = {
'do_sample': True,
'temperature': temperature,
'top_p': top_p,
}
if top_k > 0:
generation_config['top_k'] = top_k
# Generate
logger.info(f"Generating audio ({inference_steps} steps)...")
try:
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=cfg_scale,
tokenizer=processor.tokenizer,
generation_config=generation_config,
verbose=False
)
# Extract waveform
waveform = outputs.speech_outputs[0].cpu().numpy()
# Ensure correct shape
if waveform.ndim == 1:
waveform = waveform.reshape(1, -1)
elif waveform.ndim == 2 and waveform.shape[0] > 1:
# If multiple channels, take first
waveform = waveform[0:1, :]
# Convert to float32 for soundfile compatibility
waveform = waveform.astype(np.float32)
# Save audio
os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
sf.write(output_path, waveform.T, 24000)
logger.info(f"Audio saved to: {output_path}")
return waveform
except Exception as e:
logger.error(f"Error during generation: {e}")
raise
def main():
"""Example usage"""
# Configuration
model_name = "VibeVoice-Large" # or "VibeVoice-Large"
# Text to generate - supports multiple speakers
text = """
[1] Hello, this is speaker one. How are you today?
[2] Hi there! This is speaker two responding to you. It's great to meet you.
[1] Likewise! Let's generate some amazing speech together.
[2] Absolutely! VibeVoice makes it so easy to create diverse voices.
"""
# Reference audio for voice cloning (optional)
# If not provided, will use zero-shot TTS
speaker_audio_paths = {
1: "input/audio1.wav", # Path to reference audio for speaker 1
2: "input/laundry.mp3", # Uncomment to provide reference for speaker 2
}
# Generation parameters
output_path = "output/vibevoice_generated.wav"
cfg_scale = 1.3
inference_steps = 10
seed = 42 # or 0 for random
temperature = 0.95
top_p = 0.95
top_k = 0
print("=" * 60)
print("VibeVoice TTS - Standalone Script")
print("=" * 60)
print(f"Model: {model_name}")
print(f"Text: {text[:100]}...")
print("=" * 60)
try:
generate_tts(
text=text,
model_name=model_name,
speaker_audio_paths=speaker_audio_paths,
output_path=output_path,
cfg_scale=cfg_scale,
inference_steps=inference_steps,
seed=seed,
temperature=temperature,
top_p=top_p,
top_k=top_k,
cache_dir="./models",
device="auto"
)
print("=" * 60)
print("Generation complete!")
print(f"Audio saved to: {output_path}")
print("=" * 60)
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,7 @@
torch
transformers
numpy
scipy
soundfile
librosa
huggingface-hub

View File

@ -309,12 +309,76 @@ class PromptServer():
@routes.get("/") @routes.get("/")
async def get_root(request): async def get_root(request):
# If API key is enabled and not provided, redirect to login
if args.api_key:
auth_header = request.headers.get("Authorization", "")
api_key_header = request.headers.get("X-API-Key", "")
api_key_query = request.query.get("api_key", "")
has_valid_key = False
if auth_header.startswith("Bearer "):
provided_key = auth_header[7:]
has_valid_key = (provided_key == args.api_key)
elif api_key_header:
has_valid_key = (api_key_header == args.api_key)
elif api_key_query:
has_valid_key = (api_key_query == args.api_key)
# If no valid key, check if coming from login page, otherwise show login
if not has_valid_key:
# Serve modified index.html with auth script injected
index_path = os.path.join(self.web_root, "index.html")
if os.path.exists(index_path):
with open(index_path, 'r', encoding='utf-8') as f:
html_content = f.read()
# Inject auth script before closing </head> tag
auth_script = '<script src="/auth_inject.js"></script>'
if '</head>' in html_content:
html_content = html_content.replace('</head>', f'{auth_script}\n</head>')
else:
# Fallback: add at the beginning of body
html_content = html_content.replace('<body>', f'<body>\n{auth_script}')
response = web.Response(text=html_content, content_type='text/html')
else:
response = web.FileResponse(index_path)
response.headers['Cache-Control'] = 'no-cache'
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
# No auth required or valid key provided
response = web.FileResponse(os.path.join(self.web_root, "index.html")) response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-cache' response.headers['Cache-Control'] = 'no-cache'
response.headers["Pragma"] = "no-cache" response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0" response.headers["Expires"] = "0"
return response return response
@routes.get("/auth_login.html")
async def get_login_page(request):
"""Serve the login page"""
login_page_path = os.path.join(os.path.dirname(__file__), "auth_login.html")
if os.path.exists(login_page_path):
response = web.FileResponse(login_page_path)
else:
# Fallback if file doesn't exist
response = web.Response(text="Login page not found", status=404)
response.headers['Cache-Control'] = 'no-cache'
return response
@routes.get("/auth_inject.js")
async def get_auth_script(request):
"""Serve the auth injection script"""
script_path = os.path.join(os.path.dirname(__file__), "auth_inject.js")
if os.path.exists(script_path):
response = web.FileResponse(script_path, headers={'Content-Type': 'application/javascript'})
else:
response = web.Response(text="// Auth script not found", status=404, content_type='application/javascript')
response.headers['Cache-Control'] = 'no-cache'
return response
@routes.get("/health") @routes.get("/health")
async def get_health(request): async def get_health(request):
"""Health check endpoint that returns the status of the server""" """Health check endpoint that returns the status of the server"""