mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 11:32:31 +08:00
Compare commits
24 Commits
3d9bc9ea5f
...
97c63d6aab
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97c63d6aab | ||
|
|
72f6be1690 | ||
|
|
16b9aabd52 | ||
|
|
245f6139b6 | ||
|
|
3365ad18a5 | ||
|
|
f09904720d | ||
|
|
7602203696 | ||
|
|
ffa7a369ba | ||
|
|
7ec1656735 | ||
|
|
cee75f301a | ||
|
|
1a59686ca8 | ||
|
|
6d96d26795 | ||
|
|
e07a32c9b8 | ||
|
|
a19f0a88e4 | ||
|
|
64811809a0 | ||
|
|
529109083a | ||
|
|
a7be9f6fc3 | ||
|
|
6075c44ec8 | ||
|
|
154b73835a | ||
|
|
862e7784f4 | ||
|
|
f6b6636bf3 | ||
|
|
83b00df3f0 | ||
|
|
5f24eb699c | ||
|
|
fab0954077 |
@ -404,6 +404,14 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||
|
||||
## How to run heavy workflow on mid range GPU (NVIDIA-Linux)?
|
||||
|
||||
Use the `--enable-gds` flag to activate NVIDIA [GPUDirect Storage](https://docs.nvidia.com/gpudirect-storage/) (GDS), which allows data to be transferred directly between SSDs and GPUs. This eliminates traditional CPU-mediated data paths, significantly reducing I/O latency and CPU overhead. System RAM will still be utilized for caching to further optimize performance, along with SSD.
|
||||
|
||||
This feature is tested on NVIDIA GPUs on Linux based system only.
|
||||
|
||||
Requires: `cupy-cuda12x>=12.0.0`, `pynvml>=11.4.1`, `cudf>=23.0.0`, `numba>=0.57.0`, `nvidia-ml-py>=12.0.0`.
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
|
||||
|
||||
@ -154,6 +154,17 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
|
||||
# GPUDirect Storage (GDS) arguments
|
||||
gds_group = parser.add_argument_group('gds', 'GPUDirect Storage options for direct SSD-to-GPU model loading')
|
||||
gds_group.add_argument("--enable-gds", action="store_true", help="Enable GPUDirect Storage for direct SSD-to-GPU model loading (requires CUDA 11.4+, cuFile).")
|
||||
gds_group.add_argument("--disable-gds", action="store_true", help="Explicitly disable GPUDirect Storage.")
|
||||
gds_group.add_argument("--gds-min-file-size", type=int, default=100, help="Minimum file size in MB to use GDS (default: 100MB).")
|
||||
gds_group.add_argument("--gds-chunk-size", type=int, default=64, help="GDS transfer chunk size in MB (default: 64MB).")
|
||||
gds_group.add_argument("--gds-streams", type=int, default=4, help="Number of CUDA streams for GDS operations (default: 4).")
|
||||
gds_group.add_argument("--gds-prefetch", action="store_true", help="Enable GDS prefetching for better performance.")
|
||||
gds_group.add_argument("--gds-no-fallback", action="store_true", help="Disable fallback to CPU loading if GDS fails.")
|
||||
gds_group.add_argument("--gds-stats", action="store_true", help="Print GDS statistics on exit.")
|
||||
|
||||
class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
|
||||
494
comfy/gds_loader.py
Normal file
494
comfy/gds_loader.py
Normal file
@ -0,0 +1,494 @@
|
||||
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
|
||||
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
|
||||
|
||||
"""
|
||||
GPUDirect Storage (GDS) Integration for ComfyUI
|
||||
Direct SSD-to-GPU model loading without RAM/CPU bottlenecks
|
||||
Still there will be some CPU/RAM usage, mostly for safetensors parsing and small buffers.
|
||||
|
||||
This module provides GPUDirect Storage functionality to load models directly
|
||||
from NVMe SSDs to GPU memory, bypassing system RAM and CPU.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import time
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from pathlib import Path
|
||||
import safetensors
|
||||
import gc
|
||||
import mmap
|
||||
from dataclasses import dataclass
|
||||
|
||||
try:
|
||||
import cupy
|
||||
import cupy.cuda.runtime as cuda_runtime
|
||||
CUPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
CUPY_AVAILABLE = False
|
||||
logging.warning("CuPy not available. GDS will use fallback mode.")
|
||||
|
||||
try:
|
||||
import cudf # RAPIDS for GPU dataframes
|
||||
RAPIDS_AVAILABLE = True
|
||||
except ImportError:
|
||||
RAPIDS_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
NVML_AVAILABLE = True
|
||||
except ImportError:
|
||||
NVML_AVAILABLE = False
|
||||
logging.warning("NVIDIA-ML-Py not available. GPU monitoring disabled.")
|
||||
|
||||
@dataclass
|
||||
class GDSConfig:
|
||||
"""Configuration for GPUDirect Storage"""
|
||||
enabled: bool = True
|
||||
min_file_size_mb: int = 100 # Only use GDS for files larger than this
|
||||
chunk_size_mb: int = 64 # Size of chunks to transfer
|
||||
use_pinned_memory: bool = True
|
||||
prefetch_enabled: bool = True
|
||||
compression_aware: bool = True
|
||||
max_concurrent_streams: int = 4
|
||||
fallback_to_cpu: bool = True
|
||||
show_stats: bool = False # Whether to show stats on exit
|
||||
|
||||
|
||||
class GDSError(Exception):
|
||||
"""GDS-specific errors"""
|
||||
pass
|
||||
|
||||
|
||||
class GPUDirectStorage:
|
||||
"""
|
||||
GPUDirect Storage implementation for ComfyUI
|
||||
Enables direct SSD-to-GPU transfers for model loading
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[GDSConfig] = None):
|
||||
self.config = config or GDSConfig()
|
||||
self.device = torch.cuda.current_device() if torch.cuda.is_available() else None
|
||||
self.cuda_streams = []
|
||||
self.pinned_buffers = {}
|
||||
self.stats = {
|
||||
'gds_loads': 0,
|
||||
'fallback_loads': 0,
|
||||
'total_bytes_gds': 0,
|
||||
'total_time_gds': 0.0,
|
||||
'avg_bandwidth_gbps': 0.0
|
||||
}
|
||||
|
||||
# Initialize GDS if available
|
||||
self._gds_available = self._check_gds_availability()
|
||||
if self._gds_available:
|
||||
self._init_gds()
|
||||
else:
|
||||
logging.warning("GDS not available, using fallback methods")
|
||||
|
||||
def _check_gds_availability(self) -> bool:
|
||||
"""Check if GDS is available on the system"""
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
if not CUPY_AVAILABLE:
|
||||
return False
|
||||
|
||||
# Check for GPUDirect Storage support
|
||||
try:
|
||||
# Check CUDA version (GDS requires CUDA 11.4+)
|
||||
cuda_version = torch.version.cuda
|
||||
if cuda_version:
|
||||
major, minor = map(int, cuda_version.split('.')[:2])
|
||||
if major < 11 or (major == 11 and minor < 4):
|
||||
logging.warning(f"CUDA {cuda_version} detected. GDS requires CUDA 11.4+")
|
||||
return False
|
||||
|
||||
# Check if cuFile is available (part of CUDA toolkit)
|
||||
try:
|
||||
import cupy.cuda.cufile as cufile
|
||||
# Try to initialize cuFile
|
||||
cufile.initialize()
|
||||
return True
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logging.warning(f"cuFile not available: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"GDS availability check failed: {e}")
|
||||
return False
|
||||
|
||||
def _init_gds(self):
|
||||
"""Initialize GDS resources"""
|
||||
try:
|
||||
# Create CUDA streams for async operations
|
||||
for i in range(self.config.max_concurrent_streams):
|
||||
stream = torch.cuda.Stream()
|
||||
self.cuda_streams.append(stream)
|
||||
|
||||
# Pre-allocate pinned memory buffers
|
||||
if self.config.use_pinned_memory:
|
||||
self._allocate_pinned_buffers()
|
||||
|
||||
logging.info(f"GDS initialized with {len(self.cuda_streams)} streams")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize GDS: {e}")
|
||||
self._gds_available = False
|
||||
|
||||
def _allocate_pinned_buffers(self):
|
||||
"""Pre-allocate pinned memory buffers for staging"""
|
||||
try:
|
||||
# Allocate buffers of different sizes
|
||||
buffer_sizes = [16, 32, 64, 128, 256] # MB
|
||||
|
||||
for size_mb in buffer_sizes:
|
||||
size_bytes = size_mb * 1024 * 1024
|
||||
# Allocate pinned memory using CuPy
|
||||
if CUPY_AVAILABLE:
|
||||
buffer = cupy.cuda.alloc_pinned_memory(size_bytes)
|
||||
self.pinned_buffers[size_mb] = buffer
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to allocate pinned buffers: {e}")
|
||||
|
||||
def _get_file_size(self, file_path: str) -> int:
|
||||
"""Get file size in bytes"""
|
||||
return os.path.getsize(file_path)
|
||||
|
||||
def _should_use_gds(self, file_path: str) -> bool:
|
||||
"""Determine if GDS should be used for this file"""
|
||||
if not self._gds_available or not self.config.enabled:
|
||||
return False
|
||||
|
||||
file_size_mb = self._get_file_size(file_path) / (1024 * 1024)
|
||||
return file_size_mb >= self.config.min_file_size_mb
|
||||
|
||||
def _load_with_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Load model using GPUDirect Storage"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if file_path.lower().endswith(('.safetensors', '.sft')):
|
||||
return self._load_safetensors_gds(file_path)
|
||||
else:
|
||||
return self._load_pytorch_gds(file_path)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"GDS loading failed for {file_path}: {e}")
|
||||
if self.config.fallback_to_cpu:
|
||||
logging.info("Falling back to CPU loading")
|
||||
self.stats['fallback_loads'] += 1
|
||||
return self._load_fallback(file_path)
|
||||
else:
|
||||
raise GDSError(f"GDS loading failed: {e}")
|
||||
finally:
|
||||
load_time = time.time() - start_time
|
||||
self.stats['total_time_gds'] += load_time
|
||||
|
||||
def _load_safetensors_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Load safetensors file using GDS"""
|
||||
try:
|
||||
import cupy.cuda.cufile as cufile
|
||||
|
||||
# Open file with cuFile for direct GPU loading
|
||||
with cufile.CuFileManager() as manager:
|
||||
# Memory-map the file for efficient access
|
||||
with open(file_path, 'rb') as f:
|
||||
# Use mmap for large files
|
||||
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mmapped_file:
|
||||
|
||||
# Parse safetensors header
|
||||
header_size = int.from_bytes(mmapped_file[:8], 'little')
|
||||
header_bytes = mmapped_file[8:8+header_size]
|
||||
|
||||
import json
|
||||
header = json.loads(header_bytes.decode('utf-8'))
|
||||
|
||||
# Load tensors directly to GPU
|
||||
tensors = {}
|
||||
data_offset = 8 + header_size
|
||||
|
||||
for name, info in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
|
||||
dtype_map = {
|
||||
'F32': torch.float32,
|
||||
'F16': torch.float16,
|
||||
'BF16': torch.bfloat16,
|
||||
'I8': torch.int8,
|
||||
'I16': torch.int16,
|
||||
'I32': torch.int32,
|
||||
'I64': torch.int64,
|
||||
'U8': torch.uint8,
|
||||
}
|
||||
|
||||
dtype = dtype_map.get(info['dtype'], torch.float32)
|
||||
shape = info['shape']
|
||||
start_offset = data_offset + info['data_offsets'][0]
|
||||
end_offset = data_offset + info['data_offsets'][1]
|
||||
|
||||
# Direct GPU allocation
|
||||
tensor = torch.empty(shape, dtype=dtype, device=f'cuda:{self.device}')
|
||||
|
||||
# Use cuFile for direct transfer
|
||||
tensor_bytes = end_offset - start_offset
|
||||
|
||||
# Get GPU memory pointer
|
||||
gpu_ptr = tensor.data_ptr()
|
||||
|
||||
# Direct file-to-GPU transfer
|
||||
cufile.copy_from_file(
|
||||
gpu_ptr,
|
||||
mmapped_file[start_offset:end_offset],
|
||||
tensor_bytes
|
||||
)
|
||||
|
||||
tensors[name] = tensor
|
||||
|
||||
self.stats['gds_loads'] += 1
|
||||
self.stats['total_bytes_gds'] += self._get_file_size(file_path)
|
||||
|
||||
return tensors
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"GDS safetensors loading failed: {e}")
|
||||
raise
|
||||
|
||||
def _load_pytorch_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Load PyTorch file using GDS with staging"""
|
||||
try:
|
||||
# For PyTorch files, we need to use a staging approach
|
||||
# since torch.load doesn't support direct GPU loading
|
||||
|
||||
# Load to pinned memory first
|
||||
with open(file_path, 'rb') as f:
|
||||
file_size = self._get_file_size(file_path)
|
||||
|
||||
# Choose appropriate buffer or allocate new one
|
||||
buffer_size_mb = min(256, max(64, file_size // (1024 * 1024)))
|
||||
|
||||
if buffer_size_mb in self.pinned_buffers:
|
||||
pinned_buffer = self.pinned_buffers[buffer_size_mb]
|
||||
else:
|
||||
# Allocate temporary pinned buffer
|
||||
pinned_buffer = cupy.cuda.alloc_pinned_memory(file_size)
|
||||
|
||||
# Read file to pinned memory
|
||||
f.readinto(pinned_buffer)
|
||||
|
||||
# Use torch.load with map_location to specific GPU
|
||||
# This will be faster due to pinned memory
|
||||
state_dict = torch.load(
|
||||
f,
|
||||
map_location=f'cuda:{self.device}',
|
||||
weights_only=True
|
||||
)
|
||||
|
||||
self.stats['gds_loads'] += 1
|
||||
self.stats['total_bytes_gds'] += file_size
|
||||
|
||||
return state_dict
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"GDS PyTorch loading failed: {e}")
|
||||
raise
|
||||
|
||||
def _load_fallback(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Fallback loading method using standard approaches"""
|
||||
if file_path.lower().endswith(('.safetensors', '.sft')):
|
||||
# Use safetensors with device parameter
|
||||
with safetensors.safe_open(file_path, framework="pt", device=f'cuda:{self.device}') as f:
|
||||
return {k: f.get_tensor(k) for k in f.keys()}
|
||||
else:
|
||||
# Standard PyTorch loading
|
||||
return torch.load(file_path, map_location=f'cuda:{self.device}', weights_only=True)
|
||||
|
||||
def load_model(self, file_path: str, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Main entry point for loading models with GDS
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
device: Target device (if None, uses current CUDA device)
|
||||
|
||||
Returns:
|
||||
Dictionary of tensors loaded directly to GPU
|
||||
"""
|
||||
if device is not None and device.type == 'cuda':
|
||||
self.device = device.index or 0
|
||||
|
||||
if self._should_use_gds(file_path):
|
||||
logging.info(f"Loading {file_path} with GDS")
|
||||
return self._load_with_gds(file_path)
|
||||
else:
|
||||
logging.info(f"Loading {file_path} with standard method")
|
||||
self.stats['fallback_loads'] += 1
|
||||
return self._load_fallback(file_path)
|
||||
|
||||
def prefetch_model(self, file_path: str) -> bool:
|
||||
"""
|
||||
Prefetch model to GPU memory cache (if supported)
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
|
||||
Returns:
|
||||
True if prefetch was successful
|
||||
"""
|
||||
if not self.config.prefetch_enabled or not self._gds_available:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Basic prefetch implementation
|
||||
# This would ideally use NVIDIA's GPUDirect Storage API
|
||||
# to warm up the storage cache
|
||||
|
||||
file_size = self._get_file_size(file_path)
|
||||
logging.info(f"Prefetching {file_path} ({file_size // (1024*1024)} MB)")
|
||||
|
||||
# Read file metadata to warm caches
|
||||
with open(file_path, 'rb') as f:
|
||||
# Read first and last chunks to trigger prefetch
|
||||
f.read(1024 * 1024) # First 1MB
|
||||
f.seek(-min(1024 * 1024, file_size), 2) # Last 1MB
|
||||
f.read()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Prefetch failed for {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get loading statistics"""
|
||||
total_loads = self.stats['gds_loads'] + self.stats['fallback_loads']
|
||||
|
||||
if self.stats['total_time_gds'] > 0 and self.stats['total_bytes_gds'] > 0:
|
||||
bandwidth_gbps = (self.stats['total_bytes_gds'] / (1024**3)) / self.stats['total_time_gds']
|
||||
self.stats['avg_bandwidth_gbps'] = bandwidth_gbps
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
'total_loads': total_loads,
|
||||
'gds_usage_percent': (self.stats['gds_loads'] / max(1, total_loads)) * 100,
|
||||
'gds_available': self._gds_available,
|
||||
'config': self.config.__dict__
|
||||
}
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up GDS resources"""
|
||||
try:
|
||||
# Clear CUDA streams
|
||||
for stream in self.cuda_streams:
|
||||
stream.synchronize()
|
||||
self.cuda_streams.clear()
|
||||
|
||||
# Free pinned buffers
|
||||
for buffer in self.pinned_buffers.values():
|
||||
if CUPY_AVAILABLE:
|
||||
cupy.cuda.free_pinned_memory(buffer)
|
||||
self.pinned_buffers.clear()
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"GDS cleanup failed: {e}")
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure cleanup"""
|
||||
self.cleanup()
|
||||
|
||||
|
||||
# Global GDS instance
|
||||
_gds_instance: Optional[GPUDirectStorage] = None
|
||||
|
||||
|
||||
def get_gds_instance(config: Optional[GDSConfig] = None) -> GPUDirectStorage:
|
||||
"""Get or create the global GDS instance"""
|
||||
global _gds_instance
|
||||
|
||||
if _gds_instance is None:
|
||||
_gds_instance = GPUDirectStorage(config)
|
||||
|
||||
return _gds_instance
|
||||
|
||||
|
||||
def load_torch_file_gds(ckpt: str, safe_load: bool = False, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
GDS-enabled replacement for comfy.utils.load_torch_file
|
||||
|
||||
Args:
|
||||
ckpt: Path to checkpoint file
|
||||
safe_load: Whether to use safe loading (for compatibility)
|
||||
device: Target device
|
||||
|
||||
Returns:
|
||||
Dictionary of loaded tensors
|
||||
"""
|
||||
gds = get_gds_instance()
|
||||
|
||||
try:
|
||||
# Load with GDS
|
||||
return gds.load_model(ckpt, device)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"GDS loading failed, falling back to standard method: {e}")
|
||||
# Fallback to original method
|
||||
import comfy.utils
|
||||
return comfy.utils.load_torch_file(ckpt, safe_load=safe_load, device=device)
|
||||
|
||||
|
||||
def prefetch_model_gds(file_path: str) -> bool:
|
||||
"""Prefetch model for faster loading"""
|
||||
gds = get_gds_instance()
|
||||
return gds.prefetch_model(file_path)
|
||||
|
||||
|
||||
def get_gds_stats() -> Dict[str, Any]:
|
||||
"""Get GDS statistics"""
|
||||
gds = get_gds_instance()
|
||||
return gds.get_stats()
|
||||
|
||||
|
||||
def configure_gds(config: GDSConfig):
|
||||
"""Configure GDS settings"""
|
||||
global _gds_instance
|
||||
_gds_instance = GPUDirectStorage(config)
|
||||
|
||||
|
||||
def init_gds(config: GDSConfig):
|
||||
"""
|
||||
Initialize GPUDirect Storage with the provided configuration
|
||||
|
||||
Args:
|
||||
config: GDSConfig object with initialization parameters
|
||||
"""
|
||||
try:
|
||||
# Configure GDS
|
||||
configure_gds(config)
|
||||
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
|
||||
|
||||
# Set up exit handler for stats if requested
|
||||
if hasattr(config, 'show_stats') and config.show_stats:
|
||||
import atexit
|
||||
def print_gds_stats():
|
||||
stats = get_gds_stats()
|
||||
logging.info("=== GDS Statistics ===")
|
||||
logging.info(f"Total loads: {stats['total_loads']}")
|
||||
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
|
||||
logging.info(f"Fallback loads: {stats['fallback_loads']}")
|
||||
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
|
||||
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
|
||||
logging.info("===================")
|
||||
atexit.register(print_gds_stats)
|
||||
|
||||
except ImportError as e:
|
||||
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"GDS initialization failed: {e}")
|
||||
@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module):
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
def qkv_fn_q(x):
|
||||
@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module):
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "attn1_patch" in patches:
|
||||
for p in patches["attn1_patch"]:
|
||||
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module):
|
||||
"""
|
||||
# assert e.dtype == torch.float32
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# cross-attention & ffn
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
|
||||
if "attn2_patch" in patches:
|
||||
for p in patches["attn2_patch"]:
|
||||
x = p({"x": x, "transformer_options": transformer_options})
|
||||
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
@ -488,7 +501,7 @@ class WanModel(torch.nn.Module):
|
||||
self.blocks = nn.ModuleList([
|
||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||
for _ in range(num_layers)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
# head
|
||||
@ -541,6 +554,7 @@ class WanModel(torch.nn.Module):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
@ -738,6 +752,7 @@ class VaceWanModel(WanModel):
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# time embeddings
|
||||
|
||||
500
comfy/ldm/wan/model_multitalk.py
Normal file
500
comfy/ldm/wan/model_multitalk.py
Normal file
@ -0,0 +1,500 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
import comfy
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
|
||||
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||||
visual_q = visual_q.transpose(1, 2) * scale
|
||||
|
||||
B, H, x_seqlens, K = visual_q.shape
|
||||
|
||||
x_ref_attn_maps = []
|
||||
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||||
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
|
||||
|
||||
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
|
||||
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
|
||||
|
||||
for i in range(0, x_seqlens, chunk_size):
|
||||
end_i = min(i + chunk_size, x_seqlens)
|
||||
|
||||
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
|
||||
|
||||
# Apply softmax
|
||||
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
|
||||
attn_chunk = (attn_chunk - attn_max).exp()
|
||||
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
|
||||
attn_chunk = attn_chunk / (attn_sum + 1e-8)
|
||||
|
||||
# Apply mask and sum
|
||||
masked_attn = attn_chunk * ref_target_mask
|
||||
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
|
||||
|
||||
del attn_chunk, masked_attn
|
||||
|
||||
# Average across heads
|
||||
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
|
||||
x_ref_attn_maps.append(x_ref_attnmap)
|
||||
|
||||
del visual_q, ref_k
|
||||
|
||||
return torch.cat(x_ref_attn_maps, dim=0)
|
||||
|
||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
||||
"""Args:
|
||||
query (torch.tensor): B M H K
|
||||
key (torch.tensor): B M H K
|
||||
shape (tuple): (N_t, N_h, N_w)
|
||||
ref_target_masks: [B, N_h * N_w]
|
||||
"""
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_seqlens = N_h * N_w
|
||||
ref_k = ref_k[:, :x_seqlens]
|
||||
_, seq_lens, heads, _ = visual_q.shape
|
||||
class_num, _ = ref_target_masks.shape
|
||||
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q)
|
||||
|
||||
split_chunk = heads // split_num
|
||||
|
||||
for i in range(split_num):
|
||||
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(
|
||||
visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||
ref_target_masks
|
||||
)
|
||||
x_ref_attn_maps += x_ref_attn_maps_perhead
|
||||
|
||||
return x_ref_attn_maps / split_num
|
||||
|
||||
|
||||
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
|
||||
source_min, source_max = source_range
|
||||
new_min, new_max = target_range
|
||||
normalized = (column - source_min) / (source_max - source_min + epsilon)
|
||||
scaled = normalized * (new_max - new_min) + new_min
|
||||
return scaled
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
def get_audio_embeds(encoded_audio, audio_start, audio_end):
|
||||
audio_embs = []
|
||||
human_num = len(encoded_audio)
|
||||
audio_frames = encoded_audio[0].shape[0]
|
||||
|
||||
indices = (torch.arange(4 + 1) - 2) * 1
|
||||
|
||||
for human_idx in range(human_num):
|
||||
if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence
|
||||
pad_len = audio_end - audio_frames
|
||||
pad_shape = list(encoded_audio[human_idx].shape)
|
||||
pad_shape[0] = pad_len
|
||||
pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1)))
|
||||
encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0)
|
||||
else:
|
||||
encoded_audio_in = encoded_audio[human_idx]
|
||||
center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0)
|
||||
center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1)
|
||||
audio_emb = encoded_audio_in[center_indices].unsqueeze(0)
|
||||
audio_embs.append(audio_emb)
|
||||
|
||||
return torch.cat(audio_embs, dim=0)
|
||||
|
||||
|
||||
def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end):
|
||||
audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end)
|
||||
|
||||
first_frame_audio_emb_s = audio_embs[:, :1, ...]
|
||||
latter_frame_audio_emb = audio_embs[:, 1:, ...]
|
||||
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4)
|
||||
|
||||
middle_index = audio_proj.seq_len // 2
|
||||
|
||||
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||
latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||
|
||||
audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
||||
audio_emb = torch.cat(audio_emb.split(1), dim=2)
|
||||
|
||||
return audio_emb
|
||||
|
||||
|
||||
class RotaryPositionalEmbedding1D(torch.nn.Module):
|
||||
def __init__(self,
|
||||
head_dim,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.base = 10000
|
||||
|
||||
def precompute_freqs_cis_1d(self, pos_indices):
|
||||
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
|
||||
freqs = freqs.to(pos_indices.device)
|
||||
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
|
||||
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, pos_indices):
|
||||
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
|
||||
|
||||
x_ = x.float()
|
||||
|
||||
freqs_cis = freqs_cis.float().to(x.device)
|
||||
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||
x_ = (x_ * cos) + (rotate_half(x_) * sin)
|
||||
|
||||
return x_.type_as(x)
|
||||
|
||||
class SingleStreamAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
expected_tokens = N_t * N_h * N_w
|
||||
actual_tokens = x.shape[1]
|
||||
x_extra = None
|
||||
|
||||
if actual_tokens != expected_tokens:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
|
||||
B = x.shape[0]
|
||||
S = N_h * N_w
|
||||
x = x.view(B * N_t, S, self.dim)
|
||||
|
||||
# get q for hidden_state
|
||||
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
|
||||
|
||||
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
|
||||
kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
|
||||
|
||||
#print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128])
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# linear transform
|
||||
x = self.proj(x.reshape(B * N_t, S, self.dim))
|
||||
x = x.view(B, N_t * S, self.dim)
|
||||
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
class SingleStreamMultiAttention(SingleStreamAttention):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
encoder_hidden_states_dim: int,
|
||||
num_heads: int,
|
||||
qkv_bias: bool,
|
||||
class_range: int = 24,
|
||||
class_interval: int = 4,
|
||||
device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
|
||||
# Rotary-embedding layout parameters
|
||||
self.class_interval = class_interval
|
||||
self.class_range = class_range
|
||||
self.max_humans = self.class_range // self.class_interval
|
||||
|
||||
# Constant bucket used for background tokens
|
||||
self.rope_bak = int(self.class_range // 2)
|
||||
|
||||
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
shape=None,
|
||||
x_ref_attn_map=None
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device)
|
||||
human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1
|
||||
# Single-speaker fall-through
|
||||
if human_num <= 1:
|
||||
return super().forward(x, encoder_hidden_states, shape)
|
||||
|
||||
N_t, N_h, N_w = shape
|
||||
|
||||
x_extra = None
|
||||
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
|
||||
x_extra = x[:, -N_h * N_w:, :]
|
||||
x = x[:, :-N_h * N_w, :]
|
||||
N_t = N_t - 1
|
||||
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||
|
||||
# Query projection
|
||||
B, N, C = x.shape
|
||||
q = self.q_linear(x)
|
||||
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
|
||||
# Use `class_range` logic for 2 speakers
|
||||
rope_h1 = (0, self.class_interval)
|
||||
rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
||||
rope_bak = int(self.class_range // 2)
|
||||
|
||||
# Normalize and scale attention maps for each speaker
|
||||
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
||||
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
||||
max_min_values = torch.cat([max_values, min_values], dim=2)
|
||||
|
||||
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
||||
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
||||
|
||||
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
|
||||
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
|
||||
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
|
||||
|
||||
# Token-wise speaker dominance
|
||||
max_indices = x_ref_attn_map.argmax(dim=0)
|
||||
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||||
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
|
||||
|
||||
# Apply rotary to Q
|
||||
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
q = self.rope_1d(q, normalized_pos)
|
||||
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Keys / Values
|
||||
_, N_a, _ = encoder_hidden_states.shape
|
||||
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||
|
||||
# Rotary for keys – assign centre of each speaker bucket to its context tokens
|
||||
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
|
||||
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
|
||||
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
|
||||
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
|
||||
|
||||
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
||||
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||
|
||||
# Final attention
|
||||
q = rearrange(q, "B H M K -> B M H K")
|
||||
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||
|
||||
x = optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
encoder_k.transpose(1, 2),
|
||||
encoder_v.transpose(1, 2),
|
||||
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||
|
||||
# Linear projection
|
||||
x = x.reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
|
||||
# Restore original layout
|
||||
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||
if x_extra is not None:
|
||||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkAudioProjModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
seq_len: int = 5,
|
||||
seq_len_vf: int = 12,
|
||||
blocks: int = 12,
|
||||
channels: int = 768,
|
||||
intermediate_dim: int = 512,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.blocks = blocks
|
||||
self.channels = channels
|
||||
self.input_dim = seq_len * blocks * channels
|
||||
self.input_dim_vf = seq_len_vf * blocks * channels
|
||||
self.intermediate_dim = intermediate_dim
|
||||
self.context_tokens = context_tokens
|
||||
self.out_dim = out_dim
|
||||
|
||||
# define multiple linear layers
|
||||
self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype)
|
||||
self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, audio_embeds, audio_embeds_vf):
|
||||
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||||
B, _, _, S, C = audio_embeds.shape
|
||||
|
||||
# process audio of first frame
|
||||
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||
|
||||
# process audio of latter frame
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||||
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||||
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||||
|
||||
# first projection
|
||||
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||||
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||||
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||||
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||||
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||||
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||||
|
||||
# second projection
|
||||
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||||
|
||||
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim)
|
||||
|
||||
# normalization and reshape
|
||||
context_tokens = self.norm(context_tokens)
|
||||
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||
|
||||
return context_tokens
|
||||
|
||||
|
||||
class WanMultiTalkAttentionBlock(torch.nn.Module):
|
||||
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations)
|
||||
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
|
||||
|
||||
|
||||
class MultiTalkGetAttnMapPatch:
|
||||
def __init__(self, ref_target_masks=None):
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x = kwargs["x"]
|
||||
|
||||
if self.ref_target_masks is not None:
|
||||
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
||||
transformer_options["x_ref_attn_map"] = x_ref_attn_map
|
||||
return x
|
||||
|
||||
|
||||
class MultiTalkCrossAttnPatch:
|
||||
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
|
||||
self.model_patch = model_patch
|
||||
self.audio_scale = audio_scale
|
||||
self.ref_target_masks = ref_target_masks
|
||||
|
||||
def __call__(self, kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
block_idx = transformer_options.get("block_index", None)
|
||||
x = kwargs["x"]
|
||||
if block_idx is None:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
audio_embeds = transformer_options.get("audio_embeds")
|
||||
x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None)
|
||||
|
||||
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
|
||||
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
|
||||
norm_x, audio_embeds.to(x.dtype),
|
||||
shape=transformer_options["grid_sizes"],
|
||||
x_ref_attn_map=x_ref_attn_map
|
||||
)
|
||||
x = x + x_audio * self.audio_scale
|
||||
return x
|
||||
|
||||
def models(self):
|
||||
return [self.model_patch]
|
||||
|
||||
class MultiTalkApplyModelWrapper:
|
||||
def __init__(self, init_latents):
|
||||
self.init_latents = init_latents
|
||||
|
||||
def __call__(self, executor, x, *args, **kwargs):
|
||||
x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x)
|
||||
samples = executor(x, *args, **kwargs)
|
||||
return samples
|
||||
|
||||
|
||||
class InfiniteTalkOuterSampleWrapper:
|
||||
def __init__(self, motion_frames_latent, model_patch, is_extend=False):
|
||||
self.motion_frames_latent = motion_frames_latent
|
||||
self.model_patch = model_patch
|
||||
self.is_extend = is_extend
|
||||
|
||||
def __call__(self, executor, *args, **kwargs):
|
||||
model_patcher = executor.class_obj.model_patcher
|
||||
model_options = executor.class_obj.model_options
|
||||
process_latent_in = model_patcher.model.process_latent_in
|
||||
|
||||
# for InfiniteTalk, model input first latent(s) need to always be replaced on every step
|
||||
if self.motion_frames_latent is not None:
|
||||
wrappers = model_options["transformer_options"]["wrappers"]
|
||||
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
|
||||
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))]
|
||||
|
||||
# run the sampling process
|
||||
result = executor(*args, **kwargs)
|
||||
|
||||
# insert motion frames before decoding
|
||||
if self.is_extend:
|
||||
overlap = self.motion_frames_latent.shape[2]
|
||||
result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
|
||||
|
||||
return result
|
||||
|
||||
def to(self, device_or_dtype):
|
||||
if isinstance(device_or_dtype, torch.device):
|
||||
if self.motion_frames_latent is not None:
|
||||
self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype)
|
||||
return self
|
||||
@ -636,14 +636,13 @@ class VAE:
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
if self.latent_channels == 48: # Wan 2.2
|
||||
if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.process_input = self.process_output = lambda image: image
|
||||
self.process_output = lambda image: image
|
||||
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
|
||||
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
|
||||
@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
|
||||
def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True),
|
||||
latent_format=None, show_progress_bar=False):
|
||||
super().__init__()
|
||||
self.image_channels = 3
|
||||
self.patch_size = 1
|
||||
@ -124,6 +125,9 @@ class TAEHV(nn.Module):
|
||||
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||
self.patch_size = 2
|
||||
elif self.latent_channels == 128: # LTX2
|
||||
self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True)
|
||||
|
||||
if self.latent_channels == 32: # HunyuanVideo1.5
|
||||
act_func = nn.LeakyReLU(0.2, inplace=True)
|
||||
else: # HunyuanVideo, Wan 2.1
|
||||
@ -131,41 +135,52 @@ class TAEHV(nn.Module):
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
conv(self.image_channels*self.patch_size**2, 64), act_func,
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
conv(64, self.latent_channels),
|
||||
)
|
||||
n_f = [256, 128, 64, 64]
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
||||
)
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, value):
|
||||
self._show_progress_bar = value
|
||||
self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool))
|
||||
self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow))
|
||||
self.frames_to_trim = self.t_upscale - 1
|
||||
self._show_progress_bar = show_progress_bar
|
||||
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, value):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
n_pad = 4 - x.shape[1] % 4
|
||||
if self.patch_size > 1:
|
||||
B, T, C, H, W = x.shape
|
||||
x = x.reshape(B * T, C, H, W)
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
|
||||
if x.shape[1] % self.t_downscale != 0:
|
||||
# pad at end to multiple of t_downscale
|
||||
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
|
||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||
x = torch.cat([x, padding], 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
||||
return self.process_out(x)
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
|
||||
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1:
|
||||
|
||||
@ -118,9 +118,18 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
missing, unexpected = self.load_state_dict(sdo, strict=False)
|
||||
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
|
||||
return (missing, unexpected)
|
||||
|
||||
missing_all = []
|
||||
unexpected_all = []
|
||||
|
||||
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
|
||||
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
|
||||
if component_sd:
|
||||
missing, unexpected = component.load_state_dict(component_sd, strict=False)
|
||||
missing_all.extend([f"{prefix}{k}" for k in missing])
|
||||
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
|
||||
|
||||
return (missing_all, unexpected_all)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
constant = 6.0
|
||||
|
||||
@ -57,6 +57,18 @@ else:
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
# Try GDS loading first if available and device is GPU
|
||||
if device is not None and device.type == 'cuda':
|
||||
try:
|
||||
from . import gds_loader
|
||||
gds_result = gds_loader.load_torch_file_gds(ckpt, safe_load=safe_load, device=device)
|
||||
if return_metadata:
|
||||
# For GDS, we return empty metadata for now (can be enhanced)
|
||||
return (gds_result, {})
|
||||
return gds_result
|
||||
except Exception as e:
|
||||
logging.debug(f"GDS loading failed, using fallback: {e}")
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
metadata = None
|
||||
|
||||
@ -754,7 +754,7 @@ class AnyType(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="MODEL_PATCH")
|
||||
class MODEL_PATCH(ComfyTypeIO):
|
||||
class ModelPatch(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO_ENCODER")
|
||||
@ -2038,6 +2038,7 @@ __all__ = [
|
||||
"ControlNet",
|
||||
"Vae",
|
||||
"Model",
|
||||
"ModelPatch",
|
||||
"ClipVision",
|
||||
"ClipVisionOutput",
|
||||
"AudioEncoder",
|
||||
|
||||
@ -24,7 +24,7 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaImageEditNode",
|
||||
display_name="Bria Image Edit",
|
||||
display_name="Bria FIBO Image Edit",
|
||||
category="api node/image/Bria",
|
||||
description="Edit images using Bria latest model",
|
||||
inputs=[
|
||||
|
||||
@ -364,9 +364,9 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 1",
|
||||
display_name="OpenAI GPT Image 1.5",
|
||||
category="api node/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -429,6 +429,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gpt-image-1", "gpt-image-1.5"],
|
||||
default="gpt-image-1.5",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
|
||||
@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
do_easycache = easycache.should_do_easycache(sigmas)
|
||||
if do_easycache:
|
||||
easycache.check_metadata(x)
|
||||
# if there isn't a cache diff for current conds, we cannot skip this step
|
||||
can_apply_cache_diff = easycache.can_apply_cache_diff(uuids)
|
||||
# if first cond marked this step for skipping, skip it and use appropriate cached values
|
||||
if easycache.skip_current_step:
|
||||
if easycache.skip_current_step and can_apply_cache_diff:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}")
|
||||
return easycache.apply_cache_diff(x, uuids)
|
||||
@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate():
|
||||
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm
|
||||
easycache.cumulative_change_rate += approx_output_change_rate
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold:
|
||||
if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff:
|
||||
if easycache.verbose:
|
||||
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}")
|
||||
# other conds should also skip this step, and instead use their cached values
|
||||
@ -240,6 +242,9 @@ class EasyCacheHolder:
|
||||
return to_return.clone()
|
||||
return to_return
|
||||
|
||||
def can_apply_cache_diff(self, uuids: list[UUID]) -> bool:
|
||||
return all(uuid in self.uuid_cache_diffs for uuid in uuids)
|
||||
|
||||
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]):
|
||||
if self.first_cond_uuid in uuids:
|
||||
self.total_steps_skipped += 1
|
||||
|
||||
293
comfy_extras/nodes_gds.py
Normal file
293
comfy_extras/nodes_gds.py
Normal file
@ -0,0 +1,293 @@
|
||||
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
|
||||
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
|
||||
|
||||
"""
|
||||
Enhanced model loading nodes with GPUDirect Storage support
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
|
||||
class CheckpointLoaderGDS(ComfyNodeABC):
|
||||
"""
|
||||
Enhanced checkpoint loader with GPUDirect Storage support
|
||||
Provides direct SSD-to-GPU loading and prefetching capabilities
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {
|
||||
"tooltip": "The name of the checkpoint (model) to load with GDS optimization."
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"prefetch": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Prefetch model to GPU cache for faster loading."
|
||||
}),
|
||||
"use_gds": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Use GPUDirect Storage if available."
|
||||
}),
|
||||
"target_device": (["auto", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cpu"], {
|
||||
"default": "auto",
|
||||
"tooltip": "Target device for model loading."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING")
|
||||
RETURN_NAMES = ("model", "clip", "vae", "load_info")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.",
|
||||
"Loading information and statistics."
|
||||
)
|
||||
FUNCTION = "load_checkpoint_gds"
|
||||
CATEGORY = "loaders/advanced"
|
||||
DESCRIPTION = "Enhanced checkpoint loader with GPUDirect Storage support for direct SSD-to-GPU loading."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def load_checkpoint_gds(self, ckpt_name: str, prefetch: bool = False, use_gds: bool = True, target_device: str = "auto"):
|
||||
start_time = time.time()
|
||||
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
# Determine target device
|
||||
if target_device == "auto":
|
||||
device = None # Let the system decide
|
||||
elif target_device == "cpu":
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device(target_device)
|
||||
|
||||
load_info = {
|
||||
"file": ckpt_name,
|
||||
"path": ckpt_path,
|
||||
"target_device": str(device) if device else "auto",
|
||||
"gds_enabled": use_gds,
|
||||
"prefetch_used": prefetch
|
||||
}
|
||||
|
||||
try:
|
||||
# Prefetch if requested
|
||||
if prefetch and use_gds:
|
||||
try:
|
||||
from comfy.gds_loader import prefetch_model_gds
|
||||
prefetch_success = prefetch_model_gds(ckpt_path)
|
||||
load_info["prefetch_success"] = prefetch_success
|
||||
if prefetch_success:
|
||||
logging.info(f"Prefetched {ckpt_name} to GPU cache")
|
||||
except Exception as e:
|
||||
logging.warning(f"Prefetch failed for {ckpt_name}: {e}")
|
||||
load_info["prefetch_error"] = str(e)
|
||||
|
||||
# Load checkpoint with potential GDS optimization
|
||||
if use_gds and device and device.type == 'cuda':
|
||||
try:
|
||||
from comfy.gds_loader import get_gds_instance
|
||||
gds = get_gds_instance()
|
||||
|
||||
# Check if GDS should be used for this file
|
||||
if gds._should_use_gds(ckpt_path):
|
||||
load_info["loader_used"] = "GDS"
|
||||
logging.info(f"Loading {ckpt_name} with GDS")
|
||||
else:
|
||||
load_info["loader_used"] = "Standard"
|
||||
logging.info(f"Loading {ckpt_name} with standard method (file too small for GDS)")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"GDS check failed, using standard loading: {e}")
|
||||
load_info["loader_used"] = "Standard (GDS failed)"
|
||||
else:
|
||||
load_info["loader_used"] = "Standard"
|
||||
|
||||
# Load the actual checkpoint
|
||||
out = comfy.sd.load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings")
|
||||
)
|
||||
|
||||
load_time = time.time() - start_time
|
||||
load_info["load_time_seconds"] = round(load_time, 3)
|
||||
load_info["load_success"] = True
|
||||
|
||||
# Format load info as string
|
||||
info_str = f"Loaded: {ckpt_name}\n"
|
||||
info_str += f"Method: {load_info['loader_used']}\n"
|
||||
info_str += f"Time: {load_info['load_time_seconds']}s\n"
|
||||
info_str += f"Device: {load_info['target_device']}"
|
||||
|
||||
if "prefetch_success" in load_info:
|
||||
info_str += f"\nPrefetch: {'✓' if load_info['prefetch_success'] else '✗'}"
|
||||
|
||||
logging.info(f"Checkpoint loaded: {ckpt_name} in {load_time:.3f}s using {load_info['loader_used']}")
|
||||
|
||||
return (*out[:3], info_str)
|
||||
|
||||
except Exception as e:
|
||||
load_info["load_success"] = False
|
||||
load_info["error"] = str(e)
|
||||
error_str = f"Failed to load: {ckpt_name}\nError: {str(e)}"
|
||||
logging.error(f"Checkpoint loading failed: {e}")
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
|
||||
class ModelPrefetcher(ComfyNodeABC):
|
||||
"""
|
||||
Node for prefetching models to GPU cache
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"checkpoint_names": ("STRING", {
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "List of checkpoint names to prefetch (one per line)."
|
||||
}),
|
||||
"prefetch_enabled": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Enable/disable prefetching."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("prefetch_report",)
|
||||
OUTPUT_TOOLTIPS = ("Report of prefetch operations.",)
|
||||
FUNCTION = "prefetch_models"
|
||||
CATEGORY = "loaders/advanced"
|
||||
DESCRIPTION = "Prefetch multiple models to GPU cache for faster loading."
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def prefetch_models(self, checkpoint_names: str, prefetch_enabled: bool = True):
|
||||
if not prefetch_enabled:
|
||||
return ("Prefetching disabled",)
|
||||
|
||||
# Parse checkpoint names
|
||||
names = [name.strip() for name in checkpoint_names.split('\n') if name.strip()]
|
||||
|
||||
if not names:
|
||||
return ("No checkpoints specified for prefetching",)
|
||||
|
||||
try:
|
||||
from comfy.gds_loader import prefetch_model_gds
|
||||
except ImportError:
|
||||
return ("GDS not available for prefetching",)
|
||||
|
||||
results = []
|
||||
successful_prefetches = 0
|
||||
|
||||
for name in names:
|
||||
try:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", name)
|
||||
success = prefetch_model_gds(ckpt_path)
|
||||
|
||||
if success:
|
||||
results.append(f"✓ {name}")
|
||||
successful_prefetches += 1
|
||||
else:
|
||||
results.append(f"✗ {name} (prefetch failed)")
|
||||
|
||||
except Exception as e:
|
||||
results.append(f"✗ {name} (error: {str(e)[:50]})")
|
||||
|
||||
report = f"Prefetch Report ({successful_prefetches}/{len(names)} successful):\n"
|
||||
report += "\n".join(results)
|
||||
|
||||
return (report,)
|
||||
|
||||
|
||||
class GDSStats(ComfyNodeABC):
|
||||
"""
|
||||
Node for displaying GDS statistics
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"refresh": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Refresh statistics."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("stats_report",)
|
||||
OUTPUT_TOOLTIPS = ("GDS statistics and performance report.",)
|
||||
FUNCTION = "get_stats"
|
||||
CATEGORY = "utils/advanced"
|
||||
DESCRIPTION = "Display GPUDirect Storage statistics and performance metrics."
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def get_stats(self, refresh: bool = False):
|
||||
try:
|
||||
from comfy.gds_loader import get_gds_stats
|
||||
stats = get_gds_stats()
|
||||
|
||||
report = "=== GPUDirect Storage Statistics ===\n\n"
|
||||
|
||||
# Availability
|
||||
report += f"GDS Available: {'✓' if stats['gds_available'] else '✗'}\n"
|
||||
|
||||
# Usage statistics
|
||||
report += f"Total Loads: {stats['total_loads']}\n"
|
||||
report += f"GDS Loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)\n"
|
||||
report += f"Fallback Loads: {stats['fallback_loads']}\n\n"
|
||||
|
||||
# Performance metrics
|
||||
if stats['total_bytes_gds'] > 0:
|
||||
gb_transferred = stats['total_bytes_gds'] / (1024**3)
|
||||
report += f"Data Transferred: {gb_transferred:.2f} GB\n"
|
||||
report += f"Average Bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s\n"
|
||||
report += f"Total GDS Time: {stats['total_time_gds']:.2f}s\n\n"
|
||||
|
||||
# Configuration
|
||||
config = stats.get('config', {})
|
||||
if config:
|
||||
report += "Configuration:\n"
|
||||
report += f"- Enabled: {config.get('enabled', 'Unknown')}\n"
|
||||
report += f"- Min File Size: {config.get('min_file_size_mb', 'Unknown')} MB\n"
|
||||
report += f"- Chunk Size: {config.get('chunk_size_mb', 'Unknown')} MB\n"
|
||||
report += f"- Max Streams: {config.get('max_concurrent_streams', 'Unknown')}\n"
|
||||
report += f"- Prefetch: {config.get('prefetch_enabled', 'Unknown')}\n"
|
||||
report += f"- Fallback: {config.get('fallback_to_cpu', 'Unknown')}\n"
|
||||
|
||||
return (report,)
|
||||
|
||||
except ImportError:
|
||||
return ("GDS module not available",)
|
||||
except Exception as e:
|
||||
return (f"Error retrieving GDS stats: {str(e)}",)
|
||||
|
||||
|
||||
# Node mappings
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CheckpointLoaderGDS": CheckpointLoaderGDS,
|
||||
"ModelPrefetcher": ModelPrefetcher,
|
||||
"GDSStats": GDSStats,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointLoaderGDS": "Load Checkpoint (GDS)",
|
||||
"ModelPrefetcher": "Model Prefetcher",
|
||||
"GDSStats": "GDS Statistics",
|
||||
}
|
||||
@ -7,6 +7,7 @@ import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.latent_formats
|
||||
import comfy.ldm.lumina.controlnet
|
||||
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
@ -257,6 +258,14 @@ class ModelPatchLoader:
|
||||
if torch.count_nonzero(ref_weight) == 0:
|
||||
config['broken'] = True
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||
elif "audio_proj.proj1.weight" in sd:
|
||||
model = MultiTalkModelPatch(
|
||||
audio_window=5, context_tokens=32, vae_scale=4,
|
||||
in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0],
|
||||
intermediate_dim=sd["audio_proj.proj1.weight"].shape[0],
|
||||
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
operations=comfy.ops.manual_cast)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
@ -524,6 +533,38 @@ class USOStyleReference:
|
||||
return (model_patched,)
|
||||
|
||||
|
||||
class MultiTalkModelPatch(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
audio_window: int = 5,
|
||||
intermediate_dim: int = 512,
|
||||
in_dim: int = 5120,
|
||||
out_dim: int = 768,
|
||||
context_tokens: int = 32,
|
||||
vae_scale: int = 4,
|
||||
num_layers: int = 40,
|
||||
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.audio_proj = MultiTalkAudioProjModel(
|
||||
seq_len=audio_window,
|
||||
seq_len_vf=audio_window+vae_scale-1,
|
||||
intermediate_dim=intermediate_dim,
|
||||
out_dim=out_dim,
|
||||
context_tokens=context_tokens,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
operations=operations
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[
|
||||
WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelPatchLoader": ModelPatchLoader,
|
||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||
|
||||
@ -8,9 +8,10 @@ import comfy.latent_formats
|
||||
import comfy.clip_vision
|
||||
import json
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
from typing import Tuple, TypedDict
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import logging
|
||||
|
||||
class WanImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -1288,6 +1289,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||
return io.NodeOutput(out_latent)
|
||||
|
||||
|
||||
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features
|
||||
class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
class DCValues(TypedDict):
|
||||
mode: str
|
||||
audio_encoder_output_2: io.AudioEncoderOutput.Type
|
||||
mask: io.Mask.Type
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanInfiniteTalkToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.DynamicCombo.Input("mode", options=[
|
||||
io.DynamicCombo.Option("single_speaker", []),
|
||||
io.DynamicCombo.Option("two_speakers", [
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
|
||||
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
|
||||
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
|
||||
]),
|
||||
]),
|
||||
io.Model.Input("model"),
|
||||
io.ModelPatch.Input("model_patch"),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
|
||||
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
|
||||
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
io.Image.Input("previous_frames", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="model"),
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Int.Output(display_name="trim_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count,
|
||||
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
|
||||
|
||||
if previous_frames is not None and previous_frames.shape[0] < motion_frame_count:
|
||||
raise ValueError("Not enough previous frames provided.")
|
||||
|
||||
if mode["mode"] == "two_speakers":
|
||||
audio_encoder_output_2 = mode["audio_encoder_output_2"]
|
||||
mask_1 = mode["mask_1"]
|
||||
mask_2 = mode["mask_2"]
|
||||
|
||||
if audio_encoder_output_2 is not None:
|
||||
if mask_1 is None or mask_2 is None:
|
||||
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
|
||||
|
||||
ref_masks = None
|
||||
if mask_1 is not None and mask_2 is not None:
|
||||
if audio_encoder_output_2 is None:
|
||||
raise ValueError("Second audio encoder output must be provided if two masks are used.")
|
||||
ref_masks = torch.cat([mask_1, mask_2])
|
||||
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
model_patched = model.clone()
|
||||
|
||||
encoded_audio_list = []
|
||||
seq_lengths = []
|
||||
|
||||
for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]:
|
||||
if audio_encoder_output is None:
|
||||
continue
|
||||
all_layers = audio_encoder_output["encoded_audio_all_layers"]
|
||||
encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512]
|
||||
encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512]
|
||||
encoded_audio_list.append(encoded_audio)
|
||||
seq_lengths.append(encoded_audio.shape[0])
|
||||
|
||||
# Pad / combine depending on multi_audio_type
|
||||
multi_audio_type = "add"
|
||||
if len(encoded_audio_list) > 1:
|
||||
if multi_audio_type == "para":
|
||||
max_len = max(seq_lengths)
|
||||
padded = []
|
||||
for emb in encoded_audio_list:
|
||||
if emb.shape[0] < max_len:
|
||||
pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype)
|
||||
emb = torch.cat([emb, pad], dim=0)
|
||||
padded.append(emb)
|
||||
encoded_audio_list = padded
|
||||
elif multi_audio_type == "add":
|
||||
total_len = sum(seq_lengths)
|
||||
full_list = []
|
||||
offset = 0
|
||||
for emb, seq_len in zip(encoded_audio_list, seq_lengths):
|
||||
full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype)
|
||||
full[offset:offset+seq_len] = emb
|
||||
full_list.append(full)
|
||||
offset += seq_len
|
||||
encoded_audio_list = full_list
|
||||
|
||||
token_ref_target_masks = None
|
||||
if ref_masks is not None:
|
||||
token_ref_target_masks = torch.nn.functional.interpolate(
|
||||
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
|
||||
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
|
||||
|
||||
# when extending from previous frames
|
||||
if previous_frames is not None:
|
||||
motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
frame_offset = previous_frames.shape[0] - motion_frame_count
|
||||
|
||||
audio_start = frame_offset
|
||||
audio_end = audio_start + length
|
||||
logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}")
|
||||
|
||||
motion_frames_latent = vae.encode(motion_frames[:, :, :, :3])
|
||||
trim_image = motion_frame_count
|
||||
else:
|
||||
audio_start = trim_image = 0
|
||||
audio_end = length
|
||||
motion_frames_latent = concat_latent_image[:, :, :1]
|
||||
|
||||
audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype())
|
||||
model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed
|
||||
|
||||
# add outer sample wrapper
|
||||
model_patched.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
|
||||
"infinite_talk_outer_sample",
|
||||
InfiniteTalkOuterSampleWrapper(
|
||||
motion_frames_latent,
|
||||
model_patch,
|
||||
is_extend=previous_frames is not None,
|
||||
))
|
||||
# add cross-attention patch
|
||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch")
|
||||
if token_ref_target_masks is not None:
|
||||
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch")
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -1307,6 +1473,7 @@ class WanExtension(ComfyExtension):
|
||||
WanHuMoImageToVideo,
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
WanInfiniteTalkToVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
||||
@ -11,7 +11,7 @@ import logging
|
||||
default_preview_method = args.preview_method
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||
|
||||
def preview_to_image(latent_image, do_scale=True):
|
||||
if do_scale:
|
||||
|
||||
29
main.py
29
main.py
@ -184,6 +184,35 @@ import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
# Initialize GPUDirect Storage if enabled
|
||||
def init_gds():
|
||||
"""Initialize GPUDirect Storage based on CLI arguments"""
|
||||
if hasattr(args, 'disable_gds') and args.disable_gds:
|
||||
logging.info("GDS explicitly disabled via --disable-gds")
|
||||
return
|
||||
|
||||
if not hasattr(args, 'enable_gds') and not hasattr(args, 'gds_prefetch') and not hasattr(args, 'gds_stats'):
|
||||
# GDS not explicitly requested, use auto-detection
|
||||
return
|
||||
|
||||
if hasattr(args, 'enable_gds') and args.enable_gds:
|
||||
from comfy.gds_loader import GDSConfig, init_gds as gds_init
|
||||
|
||||
config = GDSConfig(
|
||||
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
|
||||
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
|
||||
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
|
||||
max_concurrent_streams=getattr(args, 'gds_streams', 4),
|
||||
prefetch_enabled=getattr(args, 'gds_prefetch', True),
|
||||
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False),
|
||||
show_stats=getattr(args, 'gds_stats', False)
|
||||
)
|
||||
|
||||
gds_init(config)
|
||||
|
||||
# Initialize GDS
|
||||
init_gds()
|
||||
|
||||
def cuda_malloc_warning():
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -707,7 +707,7 @@ class LoraLoaderModelOnly(LoraLoader):
|
||||
return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
|
||||
|
||||
class VAELoader:
|
||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
||||
@staticmethod
|
||||
def vae_list(s):
|
||||
@ -2382,6 +2382,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_model_patch.py",
|
||||
"nodes_easycache.py",
|
||||
"nodes_audio_encoder.py",
|
||||
"nodes_gds.py",
|
||||
"nodes_rope.py",
|
||||
"nodes_logic.py",
|
||||
"nodes_nop.py",
|
||||
|
||||
@ -27,4 +27,4 @@ comfy-kitchen>=0.2.7
|
||||
kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
pydantic-settings~=2.0
|
||||
Loading…
Reference in New Issue
Block a user