From c5ad1381bf2140be086374daee25a1845ea1d077 Mon Sep 17 00:00:00 2001 From: daverbj Date: Thu, 11 Dec 2025 15:49:23 +0300 Subject: [PATCH] Add complete authentication system with documentation - Frontend authentication with login page (auth_login.html) - API key injection script (auth_inject.js) - Session management (localStorage/sessionStorage) - Logout via URL: http://127.0.0.1:8188/auth_login.html?logout=true - Modified server.py to inject auth scripts into index.html - Added comprehensive documentation: * AUTHENTICATION_GUIDE.md - Complete authentication guide * FRONTEND_AUTH_GUIDE.md - Frontend-specific guide - Health endpoint accessible without authentication - Multiple auth methods: Bearer token, X-API-Key header, query parameter --- AUTHENTICATION_GUIDE.md | 272 +++++++++++++++ FRONTEND_AUTH_GUIDE.md | 233 +++++++++++++ auth_inject.js | 110 ++++++ auth_login.html | 278 +++++++++++++++ facebook_audio.py | 8 + generate_audio_standalone.py | 275 +++++++++++++++ generate_vibevoice_standalone.py | 476 ++++++++++++++++++++++++++ requirements_vibevoice_standalone.txt | 7 + server.py | 64 ++++ 9 files changed, 1723 insertions(+) create mode 100644 AUTHENTICATION_GUIDE.md create mode 100644 FRONTEND_AUTH_GUIDE.md create mode 100644 auth_inject.js create mode 100644 auth_login.html create mode 100644 facebook_audio.py create mode 100644 generate_audio_standalone.py create mode 100644 generate_vibevoice_standalone.py create mode 100644 requirements_vibevoice_standalone.txt diff --git a/AUTHENTICATION_GUIDE.md b/AUTHENTICATION_GUIDE.md new file mode 100644 index 000000000..54472f09a --- /dev/null +++ b/AUTHENTICATION_GUIDE.md @@ -0,0 +1,272 @@ +# ComfyUI API Authentication Guide + +## Overview + +ComfyUI now supports API key authentication to protect your REST API endpoints. This guide covers setup, usage, and logout procedures. + +## Features + +- **API Key Authentication**: Protect all API endpoints with a simple API key +- **Multiple Auth Methods**: Bearer token, X-API-Key header, or query parameter +- **Health Endpoint**: Monitor server status without authentication +- **Frontend Login**: Built-in login page for browser access +- **Session Management**: Remember API key across sessions or per-session only + +## Quick Start + +### 1. Start Server with API Key + +```bash +# Using command line argument +python3 main.py --api-key "your-secure-api-key-here" + +# Or using a file (recommended for production) +echo "your-secure-api-key-here" > api_key.txt +python3 main.py --api-key-file api_key.txt +``` + +### 2. Access ComfyUI + +When you navigate to `http://127.0.0.1:8188`, you'll be redirected to a login page. + +### 3. Login + +1. Enter your API key in the password field +2. Optionally check "Remember me" to persist the key across browser sessions +3. Click "Login" + +The system validates your key against the `/health` endpoint before storing it. + +### 4. Logout + +To logout, visit: +``` +http://127.0.0.1:8188/auth_login.html?logout=true +``` + +This will clear your stored API key and redirect you to the login page. + +## API Usage + +### Authentication Methods + +You can authenticate API requests in three ways: + +#### 1. Bearer Token (Recommended) +```bash +curl -H "Authorization: Bearer your-api-key" http://127.0.0.1:8188/prompt +``` + +#### 2. X-API-Key Header +```bash +curl -H "X-API-Key: your-api-key" http://127.0.0.1:8188/prompt +``` + +#### 3. Query Parameter +```bash +curl "http://127.0.0.1:8188/prompt?api_key=your-api-key" +``` + +### Python Example + +```python +import requests + +API_KEY = "your-api-key" +BASE_URL = "http://127.0.0.1:8188" + +# Using Bearer token +headers = {"Authorization": f"Bearer {API_KEY}"} +response = requests.get(f"{BASE_URL}/queue", headers=headers) + +# Using X-API-Key header +headers = {"X-API-Key": API_KEY} +response = requests.get(f"{BASE_URL}/queue", headers=headers) + +# Using query parameter +response = requests.get(f"{BASE_URL}/queue?api_key={API_KEY}") +``` + +### JavaScript Example + +```javascript +const API_KEY = "your-api-key"; +const BASE_URL = "http://127.0.0.1:8188"; + +// Using Bearer token +fetch(`${BASE_URL}/queue`, { + headers: { + "Authorization": `Bearer ${API_KEY}` + } +}) +.then(response => response.json()) +.then(data => console.log(data)); +``` + +## Health Endpoint + +The `/health` endpoint is always accessible without authentication: + +```bash +curl http://127.0.0.1:8188/health +``` + +Response includes: +- Server status +- Queue information (pending/running tasks) +- Device information (GPU/CPU) +- VRAM statistics + +Example response: +```json +{ + "status": "ok", + "queue": { + "pending": 0, + "running": 0 + }, + "device": { + "name": "mps", + "type": "mps" + }, + "vram": { + "total": 68719476736, + "free": 68719476736 + } +} +``` + +## Exempt Endpoints + +The following paths are exempt from authentication: + +### Static Files +- `.html`, `.js`, `.css`, `.json` files +- `.png`, `.jpg`, `.jpeg`, `.gif`, `.webp` images +- `.svg`, `.ico` icons +- `.woff`, `.woff2`, `.ttf` fonts +- `.mp3`, `.wav` audio files +- `.mp4`, `.webm` video files + +### API Paths +- `/` - Root path (serves login if not authenticated) +- `/health` - Health check endpoint +- `/ws` - WebSocket endpoint +- `/auth_login.html` - Login page +- `/auth_inject.js` - Auth injection script +- `/extensions/*` - Extension files +- `/templates/*` - Template files +- `/docs/*` - Documentation + +## Security Best Practices + +1. **Use Strong API Keys**: Generate random, long API keys (32+ characters) + ```bash + # Generate secure API key on macOS/Linux + openssl rand -base64 32 + ``` + +2. **Use API Key Files**: Store keys in files rather than command line + ```bash + python3 main.py --api-key-file /secure/path/api_key.txt + ``` + +3. **File Permissions**: Restrict key file access + ```bash + chmod 600 api_key.txt + ``` + +4. **HTTPS**: Use reverse proxy with SSL in production + ```nginx + server { + listen 443 ssl; + server_name your-domain.com; + + ssl_certificate /path/to/cert.pem; + ssl_certificate_key /path/to/key.pem; + + location / { + proxy_pass http://127.0.0.1:8188; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + } + } + ``` + +5. **Environment Variables**: Store keys in environment + ```bash + export COMFYUI_API_KEY=$(cat api_key.txt) + python3 main.py --api-key "$COMFYUI_API_KEY" + ``` + +## Frontend Integration + +The authentication system automatically handles frontend requests: + +1. When authentication is enabled, `auth_inject.js` is injected into `index.html` +2. This script intercepts all `fetch()` and `XMLHttpRequest` calls +3. Authorization headers are added automatically to all requests +4. On 401 responses, the user is redirected to the login page + +### Session Storage + +- **localStorage**: API key persists across browser sessions (when "Remember me" is checked) +- **sessionStorage**: API key cleared when browser/tab closes (when "Remember me" is not checked) + +## Troubleshooting + +### 401 Unauthorized Errors + +1. Verify API key matches server configuration +2. Check authentication header format +3. Ensure endpoint isn't expecting different auth method + +### Login Page Not Appearing + +1. Clear browser cache +2. Verify `auth_login.html` and `auth_inject.js` exist in ComfyUI root +3. Check server logs for errors + +### WebSocket Connection Issues + +WebSocket connections (`/ws`) are exempt from authentication, but may require authentication for initial HTTP upgrade depending on your setup. + +### Logout Not Working + +Visit the logout URL directly: +``` +http://127.0.0.1:8188/auth_login.html?logout=true +``` + +Or clear storage manually in browser DevTools: +```javascript +localStorage.removeItem('comfyui_api_key'); +sessionStorage.removeItem('comfyui_api_key'); +``` + +## Disabling Authentication + +To disable authentication, simply start the server without the `--api-key` or `--api-key-file` arguments: + +```bash +python3 main.py +``` + +## Migration from Unauthenticated Setup + +If you're adding authentication to an existing ComfyUI installation: + +1. Ensure `auth_login.html` and `auth_inject.js` exist in root directory +2. Update `server.py` with authentication routes +3. Add `middleware/auth_middleware.py` +4. Update `comfy/cli_args.py` with API key arguments +5. Restart server with `--api-key` argument +6. Update API clients to include authentication + +## Support + +For issues or questions: +- Check server logs for authentication errors +- Verify middleware is properly configured +- Test with `/health` endpoint first (no auth required) +- Review `middleware/auth_middleware.py` for exempt paths diff --git a/FRONTEND_AUTH_GUIDE.md b/FRONTEND_AUTH_GUIDE.md new file mode 100644 index 000000000..bca43da5d --- /dev/null +++ b/FRONTEND_AUTH_GUIDE.md @@ -0,0 +1,233 @@ +# ComfyUI Frontend Authentication Guide + +## Overview + +When API key authentication is enabled, the ComfyUI frontend will automatically require users to log in with the API key before accessing the interface. + +## How It Works + +1. **Automatic Redirection**: When you access ComfyUI with authentication enabled, you'll be automatically redirected to a login page if no valid API key is stored. + +2. **API Key Storage**: After successful login, your API key is stored in your browser (localStorage or sessionStorage) depending on your "Remember me" choice. + +3. **Automatic Injection**: The API key is automatically added to all API requests via the `Authorization: Bearer ` header. + +4. **Session Management**: + - **Remember me (checked)**: API key stored in localStorage (persists across browser sessions) + - **Remember me (unchecked)**: API key stored in sessionStorage (cleared when browser closes) + +## Using the Frontend with Authentication + +### First Time Access + +1. Start ComfyUI with an API key: + ```bash + python main.py --api-key "your-secret-key-123" + ``` + +2. Open your browser and navigate to `http://localhost:8188` + +3. You'll be presented with a login page + +4. Enter your API key and click "Login" + +5. Choose whether to remember your login (recommended for personal devices only) + +### Login Page Features + +- **Show API Key**: Toggle to view the key you're entering +- **Remember Me**: Keep your session active across browser restarts +- **Auto-validation**: The system validates your key before storing it + +### Logging Out + +To log out and clear your stored API key: + +**Option 1: JavaScript Console** +```javascript +comfyuiLogout() +``` + +**Option 2: Browser DevTools** +1. Open DevTools (F12) +2. Go to Application > Storage +3. Clear localStorage and sessionStorage +4. Refresh the page + +**Option 3: Manual URL** +Navigate to: `http://localhost:8188/auth_login.html` + +### Security Considerations + +#### For Personal Use +- ✅ Enable "Remember me" for convenience +- ✅ Use strong, unique API keys +- ✅ Keep your browser updated + +#### For Shared/Public Computers +- ❌ **DO NOT** enable "Remember me" +- ✅ Always log out when finished +- ✅ Clear browser data after use +- ✅ Consider using a private/incognito window + +#### For Production Deployments +- ✅ Always use HTTPS (combine with `--tls-keyfile` and `--tls-certfile`) +- ✅ Use strong, randomly generated API keys +- ✅ Rotate keys regularly +- ✅ Monitor access logs for unauthorized attempts +- ✅ Consider using additional security layers (VPN, firewall, etc.) + +## Troubleshooting + +### "Invalid API Key" Error + +**Problem**: Login fails with invalid API key message + +**Solutions**: +1. Verify you're using the correct API key +2. Check for extra spaces or newlines +3. Ensure the server was started with the same API key +4. Check server logs for authentication attempts + +### Automatic Logout + +**Problem**: Frequently logged out automatically + +**Possible Causes**: +1. API key changed on server (restart required) +2. Browser cleared storage automatically +3. "Remember me" was not checked +4. Session expired (if using sessionStorage) + +**Solution**: Check "Remember me" when logging in + +### Login Page Not Showing + +**Problem**: Can't access login page + +**Possible Causes**: +1. Authentication not enabled (no `--api-key` argument) +2. Browser cached old version + +**Solution**: +1. Hard refresh the page (Ctrl+Shift+R or Cmd+Shift+R) +2. Clear browser cache +3. Try accessing directly: `http://localhost:8188/auth_login.html` + +### API Requests Still Failing + +**Problem**: Some API requests return 401 after login + +**Solutions**: +1. Check browser console for errors +2. Verify API key is stored: Open DevTools > Application > Storage +3. Try logging out and back in +4. Clear all browser data and try again + +## Advanced Usage + +### Programmatic Access with Frontend + +If you need to access the API programmatically while using the frontend: + +```javascript +// Get the stored API key +const apiKey = localStorage.getItem('comfyui_api_key') || + sessionStorage.getItem('comfyui_api_key'); + +// Make authenticated requests +fetch('/api/system_stats', { + headers: { + 'Authorization': `Bearer ${apiKey}` + } +}).then(r => r.json()).then(console.log); +``` + +### Custom Frontend Integration + +If you're building a custom frontend, you can use the same mechanism: + +1. **Login Flow**: + ```javascript + // Validate API key + const response = await fetch('/health', { + headers: { + 'Authorization': `Bearer ${apiKey}` + } + }); + + if (response.ok) { + // Store the key + localStorage.setItem('comfyui_api_key', apiKey); + } + ``` + +2. **Request Interceptor**: + ```javascript + // Add to all requests + const apiKey = localStorage.getItem('comfyui_api_key'); + + fetch(url, { + headers: { + 'Authorization': `Bearer ${apiKey}`, + ...otherHeaders + } + }); + ``` + +3. **Handle 401 Responses**: + ```javascript + if (response.status === 401) { + // Clear stored key and redirect to login + localStorage.removeItem('comfyui_api_key'); + window.location.href = '/auth_login.html'; + } + ``` + +## API Endpoints Reference + +### Public Endpoints (No Authentication Required) +- `GET /` - Main page (with auth script injected) +- `GET /health` - Health check +- `GET /auth_login.html` - Login page +- `GET /auth_inject.js` - Auth injection script +- `GET /ws` - WebSocket connection +- Static files (`.js`, `.css`, `.html`, etc.) + +### Protected Endpoints (Authentication Required) +- All `/api/*` endpoints +- All `/internal/*` endpoints +- Most other API endpoints + +## Browser Compatibility + +The authentication system works with all modern browsers: +- ✅ Chrome/Edge 90+ +- ✅ Firefox 88+ +- ✅ Safari 14+ +- ✅ Opera 76+ + +## FAQs + +**Q: Can I use ComfyUI without authentication?** +A: Yes! Simply start ComfyUI without the `--api-key` argument. + +**Q: Can I change the API key without losing my workflows?** +A: Yes, workflows are stored separately. Just update the key on the server and re-login. + +**Q: Is my API key secure?** +A: The key is stored in browser storage and sent over HTTPS (if configured). For maximum security, use HTTPS and strong keys. + +**Q: Can multiple users use different API keys?** +A: Currently, the system supports a single API key. For multi-user scenarios, each user must use the same key. + +**Q: What happens if I forget my API key?** +A: Check your server startup command or the file specified in `--api-key-file`. + +## Support + +For issues or questions: +1. Check the server logs +2. Review browser console for errors +3. Refer to the main authentication documentation: `API_AUTHENTICATION.md` +4. Check the quick start guide: `QUICK_START_AUTH.md` diff --git a/auth_inject.js b/auth_inject.js new file mode 100644 index 000000000..123b8e387 --- /dev/null +++ b/auth_inject.js @@ -0,0 +1,110 @@ +/** + * ComfyUI API Key Injection + * This script automatically adds the stored API key to all HTTP requests + */ + +(function() { + 'use strict'; + + // Get the stored API key + function getApiKey() { + return localStorage.getItem('comfyui_api_key') || sessionStorage.getItem('comfyui_api_key'); + } + + // Check if user is authenticated + function checkAuth() { + const apiKey = getApiKey(); + if (!apiKey && window.location.pathname !== '/auth_login.html') { + // Redirect to login page if no API key is found + window.location.href = '/auth_login.html'; + return false; + } + return true; + } + + // Intercept fetch requests + const originalFetch = window.fetch; + window.fetch = function(...args) { + const apiKey = getApiKey(); + + if (apiKey) { + // Clone or create the options object + let [url, options = {}] = args; + + // Initialize headers if not present + if (!options.headers) { + options.headers = {}; + } + + // Convert Headers object to plain object if needed + if (options.headers instanceof Headers) { + const headersObj = {}; + options.headers.forEach((value, key) => { + headersObj[key] = value; + }); + options.headers = headersObj; + } + + // Add Authorization header if not already present + if (!options.headers['Authorization'] && !options.headers['authorization']) { + options.headers['Authorization'] = `Bearer ${apiKey}`; + } + + // Update args + args = [url, options]; + } + + // Call original fetch and handle 401 errors + return originalFetch.apply(this, args).then(response => { + if (response.status === 401 && window.location.pathname !== '/auth_login.html') { + // Clear stored API key and redirect to login + localStorage.removeItem('comfyui_api_key'); + sessionStorage.removeItem('comfyui_api_key'); + window.location.href = '/auth_login.html'; + } + return response; + }); + }; + + // Intercept XMLHttpRequest + const originalOpen = XMLHttpRequest.prototype.open; + const originalSend = XMLHttpRequest.prototype.send; + + XMLHttpRequest.prototype.open = function(method, url, ...rest) { + this._url = url; + return originalOpen.apply(this, [method, url, ...rest]); + }; + + XMLHttpRequest.prototype.send = function(...args) { + const apiKey = getApiKey(); + + if (apiKey && !this.getRequestHeader('Authorization')) { + this.setRequestHeader('Authorization', `Bearer ${apiKey}`); + } + + // Handle 401 responses + this.addEventListener('load', function() { + if (this.status === 401 && window.location.pathname !== '/auth_login.html') { + localStorage.removeItem('comfyui_api_key'); + sessionStorage.removeItem('comfyui_api_key'); + window.location.href = '/auth_login.html'; + } + }); + + return originalSend.apply(this, args); + }; + + // Add logout function to window + window.comfyuiLogout = function() { + localStorage.removeItem('comfyui_api_key'); + sessionStorage.removeItem('comfyui_api_key'); + window.location.href = '/auth_login.html'; + }; + + // Check authentication on page load (except for login page) + if (window.location.pathname !== '/auth_login.html') { + checkAuth(); + } + + console.log('[ComfyUI Auth] API key injection enabled'); +})(); diff --git a/auth_login.html b/auth_login.html new file mode 100644 index 000000000..913aa362d --- /dev/null +++ b/auth_login.html @@ -0,0 +1,278 @@ + + + + + + ComfyUI - API Key Required + + + + + + + + diff --git a/facebook_audio.py b/facebook_audio.py new file mode 100644 index 000000000..fb605ea05 --- /dev/null +++ b/facebook_audio.py @@ -0,0 +1,8 @@ +from transformers import pipeline +import scipy + +synthesiser = pipeline("text-to-audio", "facebook/musicgen-large") + +music = synthesiser("lo-fi music with a soothing melody", forward_params={"do_sample": True}) + +scipy.io.wavfile.write("musicgen_out.wav", rate=music["sampling_rate"], data=music["audio"]) diff --git a/generate_audio_standalone.py b/generate_audio_standalone.py new file mode 100644 index 000000000..ca12d7bd4 --- /dev/null +++ b/generate_audio_standalone.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +""" +Standalone script to generate music from text using Stable Audio in ComfyUI. +Based on the workflow: user/default/workflows/audio_stable_audio_example.json + +This script replicates the workflow: +1. Load checkpoint model (stable-audio-open-1.0.safetensors) +2. Load CLIP text encoder (t5-base.safetensors) +3. Encode positive prompt (music description) +4. Encode negative prompt (empty) +5. Create empty latent audio (47.6 seconds) +6. Sample using KSampler +7. Decode audio from latent using VAE +8. Save as MP3 + +Requirements: +- stable-audio-open-1.0.safetensors in models/checkpoints/ +- t5-base.safetensors in models/text_encoders/ +""" + +import torch +import sys +import os +import random +import av +from io import BytesIO + +# Add ComfyUI to path +script_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, script_dir) + +import comfy.sd +import comfy.sample +import comfy.samplers +import comfy.model_management +import folder_paths +import latent_preview +import comfy.utils + + +def load_checkpoint(ckpt_name): + """Load checkpoint model - returns MODEL, CLIP, VAE""" + print(f"Loading checkpoint: {ckpt_name}") + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings") + ) + return out[:3] # MODEL, CLIP, VAE + + +def load_clip(clip_name, clip_type="stable_audio"): + """Load CLIP text encoder""" + print(f"Loading CLIP: {clip_name}") + clip_type_enum = getattr(comfy.sd.CLIPType, clip_type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) + + clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name) + clip = comfy.sd.load_clip( + ckpt_paths=[clip_path], + embedding_directory=folder_paths.get_folder_paths("embeddings"), + clip_type=clip_type_enum, + model_options={} + ) + return clip + + +def encode_text(clip, text): + """Encode text using CLIP - returns CONDITIONING""" + print(f"Encoding text: '{text}'") + if clip is None: + raise RuntimeError("ERROR: clip input is invalid: None") + tokens = clip.tokenize(text) + return clip.encode_from_tokens_scheduled(tokens) + + +def create_empty_latent_audio(seconds, batch_size=1): + """Create empty latent audio tensor""" + print(f"Creating empty latent audio: {seconds} seconds") + length = round((seconds * 44100 / 2048) / 2) * 2 + latent = torch.zeros( + [batch_size, 64, length], + device=comfy.model_management.intermediate_device() + ) + return {"samples": latent, "type": "audio"} + + +def sample_audio(model, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise=1.0): + """Run KSampler to generate audio latents""" + print(f"Sampling with seed={seed}, steps={steps}, cfg={cfg}, sampler={sampler_name}, scheduler={scheduler}") + + latent_samples = latent_image["samples"] + latent_samples = comfy.sample.fix_empty_latent_channels(model, latent_samples) + + # Prepare noise + batch_inds = latent_image["batch_index"] if "batch_index" in latent_image else None + noise = comfy.sample.prepare_noise(latent_samples, seed, batch_inds) + + # Check for noise mask + noise_mask = latent_image.get("noise_mask", None) + + # Prepare callback for progress + callback = latent_preview.prepare_callback(model, steps) + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + + # Sample + samples = comfy.sample.sample( + model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, latent_samples, + denoise=denoise, + disable_noise=False, + start_step=None, + last_step=None, + force_full_denoise=False, + noise_mask=noise_mask, + callback=callback, + disable_pbar=disable_pbar, + seed=seed + ) + + out = latent_image.copy() + out["samples"] = samples + return out + + +def decode_audio(vae, samples): + """Decode audio from latent samples using VAE""" + print("Decoding audio from latents") + audio = vae.decode(samples["samples"]).movedim(-1, 1) + + # Normalize audio + std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 + std[std < 1.0] = 1.0 + audio /= std + + return {"waveform": audio, "sample_rate": 44100} + + +def save_audio_mp3(audio, filename, quality="V0"): + """Save audio as MP3 file using PyAV (same as ComfyUI)""" + print(f"Saving audio to: {filename}") + + # Create output directory if needed + os.makedirs(os.path.dirname(filename), exist_ok=True) + + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + + # Ensure audio is in CPU + waveform = waveform.cpu() + + # Process each audio in batch (usually just 1) + for batch_number, waveform_item in enumerate(waveform): + if batch_number > 0: + # Add batch number to filename if multiple + base, ext = os.path.splitext(filename) + output_path = f"{base}_{batch_number}{ext}" + else: + output_path = filename + + # Create output buffer + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format="mp3") + + # Determine audio layout - waveform_item shape is [channels, samples] + num_channels = waveform_item.shape[0] if waveform_item.dim() > 1 else 1 + layout = "mono" if num_channels == 1 else "stereo" + + # Set up the MP3 output stream + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) + + # Set quality + if quality == "V0": + out_stream.codec_context.qscale = 1 # Highest VBR quality + elif quality == "128k": + out_stream.bit_rate = 128000 + elif quality == "320k": + out_stream.bit_rate = 320000 + + # Prepare waveform for PyAV: needs to be [samples, channels] + # Use detach() to avoid gradient tracking issues + if waveform_item.dim() == 1: + # Mono audio, add channel dimension + waveform_numpy = waveform_item.unsqueeze(1).float().detach().numpy() + else: + # Transpose from [channels, samples] to [samples, channels] + waveform_numpy = waveform_item.transpose(0, 1).float().detach().numpy() + + # Reshape to [1, samples * channels] for PyAV + waveform_numpy = waveform_numpy.reshape(1, -1) + + # Create audio frame + frame = av.AudioFrame.from_ndarray( + waveform_numpy, + format="flt", + layout=layout, + ) + frame.sample_rate = sample_rate + frame.pts = 0 + + # Encode + output_container.mux(out_stream.encode(frame)) + + # Flush encoder + output_container.mux(out_stream.encode(None)) + + # Close container + output_container.close() + + # Write to file + output_buffer.seek(0) + with open(output_path, "wb") as f: + f.write(output_buffer.getbuffer()) + + print(f"Audio saved successfully: {output_path}") + + +def main(): + # Configuration + checkpoint_name = "stable-audio-open-1.0.safetensors" + clip_name = "t5-base.safetensors" + positive_prompt = "A soft melodious acoustic guitar music" + negative_prompt = "" + audio_duration = 47.6 # seconds + seed = random.randint(0, 0xffffffffffffffff) # Random seed, or use specific value + steps = 50 + cfg = 4.98 + sampler_name = "dpmpp_3m_sde_gpu" + scheduler = "exponential" + denoise = 1.0 + output_filename = "output/audio/generated_music.mp3" + quality = "V0" + + print("=" * 60) + print("Stable Audio - Music Generation Script") + print("=" * 60) + print(f"Positive Prompt: {positive_prompt}") + print(f"Duration: {audio_duration} seconds") + print(f"Seed: {seed}") + print("=" * 60) + + # 1. Load checkpoint (MODEL, CLIP, VAE) + model, checkpoint_clip, vae = load_checkpoint(checkpoint_name) + + # 2. Load separate CLIP text encoder for stable audio + clip = load_clip(clip_name, "stable_audio") + + # 3. Encode positive and negative prompts + positive_conditioning = encode_text(clip, positive_prompt) + negative_conditioning = encode_text(clip, negative_prompt) + + # 4. Create empty latent audio + latent_audio = create_empty_latent_audio(audio_duration, batch_size=1) + + # 5. Sample using KSampler + sampled_latent = sample_audio( + model, seed, steps, cfg, sampler_name, scheduler, + positive_conditioning, negative_conditioning, latent_audio, denoise + ) + + # 6. Decode audio from latent using VAE + audio = decode_audio(vae, sampled_latent) + + # 7. Save as MP3 + save_audio_mp3(audio, output_filename, quality) + + print("=" * 60) + print("Generation complete!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/generate_vibevoice_standalone.py b/generate_vibevoice_standalone.py new file mode 100644 index 000000000..281b62e71 --- /dev/null +++ b/generate_vibevoice_standalone.py @@ -0,0 +1,476 @@ +#!/usr/bin/env python3 +""" +Standalone script to generate TTS audio using VibeVoice. +This script has NO ComfyUI dependencies and uses the models directly from HuggingFace. + +Based on Microsoft's VibeVoice: https://github.com/microsoft/VibeVoice + +Requirements: + pip install torch transformers numpy scipy soundfile librosa huggingface-hub + +Usage: + python generate_vibevoice_standalone.py +""" + +import torch +import numpy as np +import soundfile as sf +import os +import random +import re +import logging +from typing import Optional, List, Tuple +from huggingface_hub import snapshot_download + +logging.basicConfig(level=logging.INFO, format='[VibeVoice] %(message)s') +logger = logging.getLogger(__name__) + +try: + import librosa + LIBROSA_AVAILABLE = True +except ImportError: + logger.warning("librosa not available - resampling will not work") + LIBROSA_AVAILABLE = False + + +def set_seed(seed: int): + """Set random seeds for reproducibility""" + if seed == 0: + seed = random.randint(1, 0xffffffffffffffff) + + MAX_NUMPY_SEED = 2**32 - 1 + numpy_seed = seed % MAX_NUMPY_SEED + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(numpy_seed) + random.seed(seed) + + return seed + + +def parse_script(script: str) -> Tuple[List[Tuple[int, str]], List[int]]: + """ + Parse speaker script into (speaker_id, text) tuples. + + Supports formats: + [1] Some text... + Speaker 1: Some text... + + Returns: + parsed_lines: List of (0-based speaker_id, text) tuples + speaker_ids: List of unique 1-based speaker IDs in order of appearance + """ + parsed_lines = [] + speaker_ids_in_script = [] + + line_format_regex = re.compile(r'^(?:Speaker\s+(\d+)\s*:|\[(\d+)\])\s*(.*)$', re.IGNORECASE) + + for line in script.strip().split("\n"): + line = line.strip() + if not line: + continue + + match = line_format_regex.match(line) + if match: + speaker_id_str = match.group(1) or match.group(2) + speaker_id = int(speaker_id_str) + text_content = match.group(3) + + if match.group(1) is None and text_content.lstrip().startswith(':'): + colon_index = text_content.find(':') + text_content = text_content[colon_index + 1:] + + if speaker_id < 1: + logger.warning(f"Speaker ID must be 1 or greater. Skipping line: '{line}'") + continue + + text = text_content.strip() + internal_speaker_id = speaker_id - 1 + parsed_lines.append((internal_speaker_id, text)) + + if speaker_id not in speaker_ids_in_script: + speaker_ids_in_script.append(speaker_id) + else: + logger.warning(f"Could not parse speaker marker, ignoring: '{line}'") + + if not parsed_lines and script.strip(): + logger.info("No speaker markers found. Treating entire text as Speaker 1.") + parsed_lines.append((0, script.strip())) + speaker_ids_in_script.append(1) + + return parsed_lines, sorted(list(set(speaker_ids_in_script))) + + +def load_audio_file(audio_path: str, target_sr: int = 24000) -> Optional[np.ndarray]: + """Load audio file and convert to mono at target sample rate""" + if not os.path.exists(audio_path): + logger.error(f"Audio file not found: {audio_path}") + return None + + logger.info(f"Loading audio: {audio_path}") + + try: + # Load audio using soundfile + waveform, sr = sf.read(audio_path) + + # Convert to mono if stereo + if waveform.ndim > 1: + waveform = np.mean(waveform, axis=1) + + # Resample if needed + if sr != target_sr: + if not LIBROSA_AVAILABLE: + raise ImportError("librosa is required for resampling. Install with: pip install librosa") + logger.info(f"Resampling from {sr}Hz to {target_sr}Hz") + waveform = librosa.resample(y=waveform, orig_sr=sr, target_sr=target_sr) + + # Validate audio + if np.any(np.isnan(waveform)) or np.any(np.isinf(waveform)): + logger.error("Audio contains NaN or Inf values, replacing with zeros") + waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) + + if np.all(waveform == 0): + logger.warning("Audio waveform is completely silent") + + # Normalize extreme values + max_val = np.abs(waveform).max() + if max_val > 10.0: + logger.warning(f"Audio values are very large (max: {max_val}), normalizing") + waveform = waveform / max_val + + return waveform.astype(np.float32) + + except Exception as e: + logger.error(f"Error loading audio: {e}") + return None + + +def download_model(model_name: str = "VibeVoice-1.5B", cache_dir: str = "./models"): + """Download VibeVoice model from HuggingFace""" + + repo_mapping = { + "VibeVoice-1.5B": "microsoft/VibeVoice-1.5B", + "VibeVoice-Large": "aoi-ot/VibeVoice-Large" + } + + if model_name not in repo_mapping: + raise ValueError(f"Unknown model: {model_name}. Choose from: {list(repo_mapping.keys())}") + + repo_id = repo_mapping[model_name] + model_path = os.path.join(cache_dir, model_name) + + if os.path.exists(os.path.join(model_path, "config.json")): + logger.info(f"Model already downloaded: {model_path}") + return model_path + + logger.info(f"Downloading model from {repo_id}...") + os.makedirs(cache_dir, exist_ok=True) + + model_path = snapshot_download( + repo_id=repo_id, + local_dir=model_path, + local_dir_use_symlinks=False + ) + + logger.info(f"Model downloaded to: {model_path}") + return model_path + + +def generate_tts( + text: str, + model_name: str = "VibeVoice-Large", + speaker_audio_paths: Optional[dict] = None, + output_path: str = "output.wav", + cfg_scale: float = 1.3, + inference_steps: int = 10, + seed: int = 42, + temperature: float = 0.95, + top_p: float = 0.95, + top_k: int = 0, + cache_dir: str = "./models", + device: str = "auto" +): + """ + Generate TTS audio using VibeVoice + + Args: + text: Text script with speaker markers like "[1] text" or "Speaker 1: text" + model_name: Model to use ("VibeVoice-1.5B" or "VibeVoice-Large") + speaker_audio_paths: Dict mapping speaker IDs to audio file paths for voice cloning + e.g., {1: "voice1.wav", 2: "voice2.wav"} + output_path: Where to save the generated audio + cfg_scale: Classifier-Free Guidance scale (higher = more adherence to prompt) + inference_steps: Number of diffusion steps + seed: Random seed for reproducibility + temperature: Sampling temperature + top_p: Nucleus sampling parameter + top_k: Top-K sampling parameter + cache_dir: Directory to cache downloaded models + device: Device to use ("cuda", "mps", "cpu", or "auto" for automatic detection) + """ + + # Set seed + actual_seed = set_seed(seed) + logger.info(f"Using seed: {actual_seed}") + + # Determine device - with MPS support for Mac + if device == "auto": + if torch.cuda.is_available(): + device = "cuda" + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + device = "mps" + logger.info("MPS (Metal Performance Shaders) detected - using Mac GPU acceleration") + else: + device = "cpu" + logger.info(f"Using device: {device}") + + # Download model if needed + model_path = download_model(model_name, cache_dir) + + # Import VibeVoice components + logger.info("Loading VibeVoice model...") + try: + # Add the VibeVoice custom model code to path + import sys + vibevoice_custom_path = os.path.join(os.path.dirname(__file__), "custom_nodes", "ComfyUI-VibeVoice") + if vibevoice_custom_path not in sys.path: + sys.path.insert(0, vibevoice_custom_path) + + # Import custom VibeVoice model + from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference + from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor + from vibevoice.processor.vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor + from vibevoice.modular.modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizerFast + from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig + import json + + # Load config + config_path = os.path.join(model_path, "config.json") + config = VibeVoiceConfig.from_pretrained(config_path) + + # Load tokenizer - download if not present + tokenizer_file = os.path.join(model_path, "tokenizer.json") + if not os.path.exists(tokenizer_file): + logger.info(f"tokenizer.json not found, downloading from HuggingFace...") + from huggingface_hub import hf_hub_download + + # Determine which Qwen model to use based on model size + qwen_repo = "Qwen/Qwen2.5-1.5B" if "1.5B" in model_name else "Qwen/Qwen2.5-7B" + + try: + hf_hub_download( + repo_id=qwen_repo, + filename="tokenizer.json", + local_dir=model_path, + local_dir_use_symlinks=False + ) + logger.info("tokenizer.json downloaded successfully") + except Exception as e: + logger.error(f"Failed to download tokenizer.json: {e}") + raise FileNotFoundError(f"Could not download tokenizer.json from {qwen_repo}") + + tokenizer = VibeVoiceTextTokenizerFast(tokenizer_file=tokenizer_file) + + # Load processor config + preprocessor_config_path = os.path.join(model_path, "preprocessor_config.json") + processor_config_data = {} + if os.path.exists(preprocessor_config_path): + with open(preprocessor_config_path, 'r') as f: + processor_config_data = json.load(f) + + audio_processor = VibeVoiceTokenizerProcessor() + processor = VibeVoiceProcessor( + tokenizer=tokenizer, + audio_processor=audio_processor, + speech_tok_compress_ratio=processor_config_data.get("speech_tok_compress_ratio", 3200), + db_normalize=processor_config_data.get("db_normalize", True) + ) + + # Load model + # MPS doesn't support bfloat16 well, use float16 + if device == "mps": + dtype = torch.float16 + logger.info("Using float16 for MPS device") + elif torch.cuda.is_available() and torch.cuda.is_bf16_supported(): + dtype = torch.bfloat16 + else: + dtype = torch.float16 + + model = VibeVoiceForConditionalGenerationInference.from_pretrained( + model_path, + config=config, + torch_dtype=dtype, + device_map=device, + attn_implementation="sdpa" + ) + + model.eval() + logger.info("Model loaded successfully") + + except Exception as e: + logger.error(f"Failed to load model: {e}") + import traceback + traceback.print_exc() + raise + + # Parse script + parsed_lines, speaker_ids = parse_script(text) + if not parsed_lines: + raise ValueError("Script is empty or invalid") + + logger.info(f"Parsed {len(parsed_lines)} lines with speakers: {speaker_ids}") + + # Load speaker audio samples + voice_samples = [] + if speaker_audio_paths is None: + speaker_audio_paths = {} + + for speaker_id in speaker_ids: + audio_path = speaker_audio_paths.get(speaker_id) + if audio_path: + audio = load_audio_file(audio_path, target_sr=24000) + if audio is None: + logger.warning(f"Could not load audio for speaker {speaker_id}, using zero-shot TTS") + voice_samples.append(None) + else: + voice_samples.append(audio) + else: + logger.info(f"No reference audio for speaker {speaker_id}, using zero-shot TTS") + voice_samples.append(None) + + # Prepare inputs + logger.info("Processing inputs...") + try: + inputs = processor( + parsed_scripts=[parsed_lines], + voice_samples=[voice_samples], + speaker_ids_for_prompt=[speaker_ids], + padding=True, + return_tensors="pt", + return_attention_mask=True + ) + + # Move to device + inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} + + except Exception as e: + logger.error(f"Error processing inputs: {e}") + raise + + # Configure generation + model.set_ddpm_inference_steps(num_steps=inference_steps) + + generation_config = { + 'do_sample': True, + 'temperature': temperature, + 'top_p': top_p, + } + if top_k > 0: + generation_config['top_k'] = top_k + + # Generate + logger.info(f"Generating audio ({inference_steps} steps)...") + try: + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=None, + cfg_scale=cfg_scale, + tokenizer=processor.tokenizer, + generation_config=generation_config, + verbose=False + ) + + # Extract waveform + waveform = outputs.speech_outputs[0].cpu().numpy() + + # Ensure correct shape + if waveform.ndim == 1: + waveform = waveform.reshape(1, -1) + elif waveform.ndim == 2 and waveform.shape[0] > 1: + # If multiple channels, take first + waveform = waveform[0:1, :] + + # Convert to float32 for soundfile compatibility + waveform = waveform.astype(np.float32) + + # Save audio + os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True) + sf.write(output_path, waveform.T, 24000) + logger.info(f"Audio saved to: {output_path}") + + return waveform + + except Exception as e: + logger.error(f"Error during generation: {e}") + raise + + +def main(): + """Example usage""" + + # Configuration + model_name = "VibeVoice-Large" # or "VibeVoice-Large" + + # Text to generate - supports multiple speakers + text = """ + [1] Hello, this is speaker one. How are you today? + [2] Hi there! This is speaker two responding to you. It's great to meet you. + [1] Likewise! Let's generate some amazing speech together. + [2] Absolutely! VibeVoice makes it so easy to create diverse voices. + """ + + # Reference audio for voice cloning (optional) + # If not provided, will use zero-shot TTS + speaker_audio_paths = { + 1: "input/audio1.wav", # Path to reference audio for speaker 1 + 2: "input/laundry.mp3", # Uncomment to provide reference for speaker 2 + } + + # Generation parameters + output_path = "output/vibevoice_generated.wav" + cfg_scale = 1.3 + inference_steps = 10 + seed = 42 # or 0 for random + temperature = 0.95 + top_p = 0.95 + top_k = 0 + + print("=" * 60) + print("VibeVoice TTS - Standalone Script") + print("=" * 60) + print(f"Model: {model_name}") + print(f"Text: {text[:100]}...") + print("=" * 60) + + try: + generate_tts( + text=text, + model_name=model_name, + speaker_audio_paths=speaker_audio_paths, + output_path=output_path, + cfg_scale=cfg_scale, + inference_steps=inference_steps, + seed=seed, + temperature=temperature, + top_p=top_p, + top_k=top_k, + cache_dir="./models", + device="auto" + ) + + print("=" * 60) + print("Generation complete!") + print(f"Audio saved to: {output_path}") + print("=" * 60) + + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/requirements_vibevoice_standalone.txt b/requirements_vibevoice_standalone.txt new file mode 100644 index 000000000..2825490f1 --- /dev/null +++ b/requirements_vibevoice_standalone.txt @@ -0,0 +1,7 @@ +torch +transformers +numpy +scipy +soundfile +librosa +huggingface-hub diff --git a/server.py b/server.py index 0d9b80d7c..7d6d1f506 100644 --- a/server.py +++ b/server.py @@ -309,12 +309,76 @@ class PromptServer(): @routes.get("/") async def get_root(request): + # If API key is enabled and not provided, redirect to login + if args.api_key: + auth_header = request.headers.get("Authorization", "") + api_key_header = request.headers.get("X-API-Key", "") + api_key_query = request.query.get("api_key", "") + + has_valid_key = False + if auth_header.startswith("Bearer "): + provided_key = auth_header[7:] + has_valid_key = (provided_key == args.api_key) + elif api_key_header: + has_valid_key = (api_key_header == args.api_key) + elif api_key_query: + has_valid_key = (api_key_query == args.api_key) + + # If no valid key, check if coming from login page, otherwise show login + if not has_valid_key: + # Serve modified index.html with auth script injected + index_path = os.path.join(self.web_root, "index.html") + if os.path.exists(index_path): + with open(index_path, 'r', encoding='utf-8') as f: + html_content = f.read() + + # Inject auth script before closing tag + auth_script = '' + if '' in html_content: + html_content = html_content.replace('', f'{auth_script}\n') + else: + # Fallback: add at the beginning of body + html_content = html_content.replace('', f'\n{auth_script}') + + response = web.Response(text=html_content, content_type='text/html') + else: + response = web.FileResponse(index_path) + + response.headers['Cache-Control'] = 'no-cache' + response.headers["Pragma"] = "no-cache" + response.headers["Expires"] = "0" + return response + + # No auth required or valid key provided response = web.FileResponse(os.path.join(self.web_root, "index.html")) response.headers['Cache-Control'] = 'no-cache' response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" return response + @routes.get("/auth_login.html") + async def get_login_page(request): + """Serve the login page""" + login_page_path = os.path.join(os.path.dirname(__file__), "auth_login.html") + if os.path.exists(login_page_path): + response = web.FileResponse(login_page_path) + else: + # Fallback if file doesn't exist + response = web.Response(text="Login page not found", status=404) + response.headers['Cache-Control'] = 'no-cache' + return response + + @routes.get("/auth_inject.js") + async def get_auth_script(request): + """Serve the auth injection script""" + script_path = os.path.join(os.path.dirname(__file__), "auth_inject.js") + if os.path.exists(script_path): + response = web.FileResponse(script_path, headers={'Content-Type': 'application/javascript'}) + else: + response = web.Response(text="// Auth script not found", status=404, content_type='application/javascript') + response.headers['Cache-Control'] = 'no-cache' + return response + @routes.get("/health") async def get_health(request): """Health check endpoint that returns the status of the server"""