mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 00:30:21 +08:00
Compare commits
30 Commits
04a0cba2f5
...
9afb605b4e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9afb605b4e | ||
|
|
8ccc0c94fa | ||
|
|
4edb87aa50 | ||
|
|
0fc3b6e3a6 | ||
|
|
2108167f9f | ||
|
|
9d273d3ab1 | ||
|
|
70c91b8248 | ||
|
|
0da5a0fe58 | ||
|
|
e0eacb0688 | ||
|
|
7458e20465 | ||
|
|
b931b37e30 | ||
|
|
866a4619db | ||
|
|
7602203696 | ||
|
|
ffa7a369ba | ||
|
|
7ec1656735 | ||
|
|
cee75f301a | ||
|
|
1a59686ca8 | ||
|
|
6d96d26795 | ||
|
|
e07a32c9b8 | ||
|
|
a19f0a88e4 | ||
|
|
64811809a0 | ||
|
|
529109083a | ||
|
|
a7be9f6fc3 | ||
|
|
6075c44ec8 | ||
|
|
154b73835a | ||
|
|
862e7784f4 | ||
|
|
f6b6636bf3 | ||
|
|
83b00df3f0 | ||
|
|
5f24eb699c | ||
|
|
fab0954077 |
@ -404,6 +404,14 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||
|
||||
## How to run heavy workflow on mid range GPU (NVIDIA-Linux)?
|
||||
|
||||
Use the `--enable-gds` flag to activate NVIDIA [GPUDirect Storage](https://docs.nvidia.com/gpudirect-storage/) (GDS), which allows data to be transferred directly between SSDs and GPUs. This eliminates traditional CPU-mediated data paths, significantly reducing I/O latency and CPU overhead. System RAM will still be utilized for caching to further optimize performance, along with SSD.
|
||||
|
||||
This feature is tested on NVIDIA GPUs on Linux based system only.
|
||||
|
||||
Requires: `cupy-cuda12x>=12.0.0`, `pynvml>=11.4.1`, `cudf>=23.0.0`, `numba>=0.57.0`, `nvidia-ml-py>=12.0.0`.
|
||||
|
||||
## Support and dev channel
|
||||
|
||||
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
|
||||
|
||||
@ -154,6 +154,17 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
|
||||
# GPUDirect Storage (GDS) arguments
|
||||
gds_group = parser.add_argument_group('gds', 'GPUDirect Storage options for direct SSD-to-GPU model loading')
|
||||
gds_group.add_argument("--enable-gds", action="store_true", help="Enable GPUDirect Storage for direct SSD-to-GPU model loading (requires CUDA 11.4+, cuFile).")
|
||||
gds_group.add_argument("--disable-gds", action="store_true", help="Explicitly disable GPUDirect Storage.")
|
||||
gds_group.add_argument("--gds-min-file-size", type=int, default=100, help="Minimum file size in MB to use GDS (default: 100MB).")
|
||||
gds_group.add_argument("--gds-chunk-size", type=int, default=64, help="GDS transfer chunk size in MB (default: 64MB).")
|
||||
gds_group.add_argument("--gds-streams", type=int, default=4, help="Number of CUDA streams for GDS operations (default: 4).")
|
||||
gds_group.add_argument("--gds-prefetch", action="store_true", help="Enable GDS prefetching for better performance.")
|
||||
gds_group.add_argument("--gds-no-fallback", action="store_true", help="Disable fallback to CPU loading if GDS fails.")
|
||||
gds_group.add_argument("--gds-stats", action="store_true", help="Print GDS statistics on exit.")
|
||||
|
||||
class PerformanceFeature(enum.Enum):
|
||||
Fp16Accumulation = "fp16_accumulation"
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
|
||||
494
comfy/gds_loader.py
Normal file
494
comfy/gds_loader.py
Normal file
@ -0,0 +1,494 @@
|
||||
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
|
||||
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
|
||||
|
||||
"""
|
||||
GPUDirect Storage (GDS) Integration for ComfyUI
|
||||
Direct SSD-to-GPU model loading without RAM/CPU bottlenecks
|
||||
Still there will be some CPU/RAM usage, mostly for safetensors parsing and small buffers.
|
||||
|
||||
This module provides GPUDirect Storage functionality to load models directly
|
||||
from NVMe SSDs to GPU memory, bypassing system RAM and CPU.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import time
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from pathlib import Path
|
||||
import safetensors
|
||||
import gc
|
||||
import mmap
|
||||
from dataclasses import dataclass
|
||||
|
||||
try:
|
||||
import cupy
|
||||
import cupy.cuda.runtime as cuda_runtime
|
||||
CUPY_AVAILABLE = True
|
||||
except ImportError:
|
||||
CUPY_AVAILABLE = False
|
||||
logging.warning("CuPy not available. GDS will use fallback mode.")
|
||||
|
||||
try:
|
||||
import cudf # RAPIDS for GPU dataframes
|
||||
RAPIDS_AVAILABLE = True
|
||||
except ImportError:
|
||||
RAPIDS_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
NVML_AVAILABLE = True
|
||||
except ImportError:
|
||||
NVML_AVAILABLE = False
|
||||
logging.warning("NVIDIA-ML-Py not available. GPU monitoring disabled.")
|
||||
|
||||
@dataclass
|
||||
class GDSConfig:
|
||||
"""Configuration for GPUDirect Storage"""
|
||||
enabled: bool = True
|
||||
min_file_size_mb: int = 100 # Only use GDS for files larger than this
|
||||
chunk_size_mb: int = 64 # Size of chunks to transfer
|
||||
use_pinned_memory: bool = True
|
||||
prefetch_enabled: bool = True
|
||||
compression_aware: bool = True
|
||||
max_concurrent_streams: int = 4
|
||||
fallback_to_cpu: bool = True
|
||||
show_stats: bool = False # Whether to show stats on exit
|
||||
|
||||
|
||||
class GDSError(Exception):
|
||||
"""GDS-specific errors"""
|
||||
pass
|
||||
|
||||
|
||||
class GPUDirectStorage:
|
||||
"""
|
||||
GPUDirect Storage implementation for ComfyUI
|
||||
Enables direct SSD-to-GPU transfers for model loading
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[GDSConfig] = None):
|
||||
self.config = config or GDSConfig()
|
||||
self.device = torch.cuda.current_device() if torch.cuda.is_available() else None
|
||||
self.cuda_streams = []
|
||||
self.pinned_buffers = {}
|
||||
self.stats = {
|
||||
'gds_loads': 0,
|
||||
'fallback_loads': 0,
|
||||
'total_bytes_gds': 0,
|
||||
'total_time_gds': 0.0,
|
||||
'avg_bandwidth_gbps': 0.0
|
||||
}
|
||||
|
||||
# Initialize GDS if available
|
||||
self._gds_available = self._check_gds_availability()
|
||||
if self._gds_available:
|
||||
self._init_gds()
|
||||
else:
|
||||
logging.warning("GDS not available, using fallback methods")
|
||||
|
||||
def _check_gds_availability(self) -> bool:
|
||||
"""Check if GDS is available on the system"""
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
if not CUPY_AVAILABLE:
|
||||
return False
|
||||
|
||||
# Check for GPUDirect Storage support
|
||||
try:
|
||||
# Check CUDA version (GDS requires CUDA 11.4+)
|
||||
cuda_version = torch.version.cuda
|
||||
if cuda_version:
|
||||
major, minor = map(int, cuda_version.split('.')[:2])
|
||||
if major < 11 or (major == 11 and minor < 4):
|
||||
logging.warning(f"CUDA {cuda_version} detected. GDS requires CUDA 11.4+")
|
||||
return False
|
||||
|
||||
# Check if cuFile is available (part of CUDA toolkit)
|
||||
try:
|
||||
import cupy.cuda.cufile as cufile
|
||||
# Try to initialize cuFile
|
||||
cufile.initialize()
|
||||
return True
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logging.warning(f"cuFile not available: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"GDS availability check failed: {e}")
|
||||
return False
|
||||
|
||||
def _init_gds(self):
|
||||
"""Initialize GDS resources"""
|
||||
try:
|
||||
# Create CUDA streams for async operations
|
||||
for i in range(self.config.max_concurrent_streams):
|
||||
stream = torch.cuda.Stream()
|
||||
self.cuda_streams.append(stream)
|
||||
|
||||
# Pre-allocate pinned memory buffers
|
||||
if self.config.use_pinned_memory:
|
||||
self._allocate_pinned_buffers()
|
||||
|
||||
logging.info(f"GDS initialized with {len(self.cuda_streams)} streams")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize GDS: {e}")
|
||||
self._gds_available = False
|
||||
|
||||
def _allocate_pinned_buffers(self):
|
||||
"""Pre-allocate pinned memory buffers for staging"""
|
||||
try:
|
||||
# Allocate buffers of different sizes
|
||||
buffer_sizes = [16, 32, 64, 128, 256] # MB
|
||||
|
||||
for size_mb in buffer_sizes:
|
||||
size_bytes = size_mb * 1024 * 1024
|
||||
# Allocate pinned memory using CuPy
|
||||
if CUPY_AVAILABLE:
|
||||
buffer = cupy.cuda.alloc_pinned_memory(size_bytes)
|
||||
self.pinned_buffers[size_mb] = buffer
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to allocate pinned buffers: {e}")
|
||||
|
||||
def _get_file_size(self, file_path: str) -> int:
|
||||
"""Get file size in bytes"""
|
||||
return os.path.getsize(file_path)
|
||||
|
||||
def _should_use_gds(self, file_path: str) -> bool:
|
||||
"""Determine if GDS should be used for this file"""
|
||||
if not self._gds_available or not self.config.enabled:
|
||||
return False
|
||||
|
||||
file_size_mb = self._get_file_size(file_path) / (1024 * 1024)
|
||||
return file_size_mb >= self.config.min_file_size_mb
|
||||
|
||||
def _load_with_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Load model using GPUDirect Storage"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if file_path.lower().endswith(('.safetensors', '.sft')):
|
||||
return self._load_safetensors_gds(file_path)
|
||||
else:
|
||||
return self._load_pytorch_gds(file_path)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"GDS loading failed for {file_path}: {e}")
|
||||
if self.config.fallback_to_cpu:
|
||||
logging.info("Falling back to CPU loading")
|
||||
self.stats['fallback_loads'] += 1
|
||||
return self._load_fallback(file_path)
|
||||
else:
|
||||
raise GDSError(f"GDS loading failed: {e}")
|
||||
finally:
|
||||
load_time = time.time() - start_time
|
||||
self.stats['total_time_gds'] += load_time
|
||||
|
||||
def _load_safetensors_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Load safetensors file using GDS"""
|
||||
try:
|
||||
import cupy.cuda.cufile as cufile
|
||||
|
||||
# Open file with cuFile for direct GPU loading
|
||||
with cufile.CuFileManager() as manager:
|
||||
# Memory-map the file for efficient access
|
||||
with open(file_path, 'rb') as f:
|
||||
# Use mmap for large files
|
||||
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mmapped_file:
|
||||
|
||||
# Parse safetensors header
|
||||
header_size = int.from_bytes(mmapped_file[:8], 'little')
|
||||
header_bytes = mmapped_file[8:8+header_size]
|
||||
|
||||
import json
|
||||
header = json.loads(header_bytes.decode('utf-8'))
|
||||
|
||||
# Load tensors directly to GPU
|
||||
tensors = {}
|
||||
data_offset = 8 + header_size
|
||||
|
||||
for name, info in header.items():
|
||||
if name == "__metadata__":
|
||||
continue
|
||||
|
||||
dtype_map = {
|
||||
'F32': torch.float32,
|
||||
'F16': torch.float16,
|
||||
'BF16': torch.bfloat16,
|
||||
'I8': torch.int8,
|
||||
'I16': torch.int16,
|
||||
'I32': torch.int32,
|
||||
'I64': torch.int64,
|
||||
'U8': torch.uint8,
|
||||
}
|
||||
|
||||
dtype = dtype_map.get(info['dtype'], torch.float32)
|
||||
shape = info['shape']
|
||||
start_offset = data_offset + info['data_offsets'][0]
|
||||
end_offset = data_offset + info['data_offsets'][1]
|
||||
|
||||
# Direct GPU allocation
|
||||
tensor = torch.empty(shape, dtype=dtype, device=f'cuda:{self.device}')
|
||||
|
||||
# Use cuFile for direct transfer
|
||||
tensor_bytes = end_offset - start_offset
|
||||
|
||||
# Get GPU memory pointer
|
||||
gpu_ptr = tensor.data_ptr()
|
||||
|
||||
# Direct file-to-GPU transfer
|
||||
cufile.copy_from_file(
|
||||
gpu_ptr,
|
||||
mmapped_file[start_offset:end_offset],
|
||||
tensor_bytes
|
||||
)
|
||||
|
||||
tensors[name] = tensor
|
||||
|
||||
self.stats['gds_loads'] += 1
|
||||
self.stats['total_bytes_gds'] += self._get_file_size(file_path)
|
||||
|
||||
return tensors
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"GDS safetensors loading failed: {e}")
|
||||
raise
|
||||
|
||||
def _load_pytorch_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Load PyTorch file using GDS with staging"""
|
||||
try:
|
||||
# For PyTorch files, we need to use a staging approach
|
||||
# since torch.load doesn't support direct GPU loading
|
||||
|
||||
# Load to pinned memory first
|
||||
with open(file_path, 'rb') as f:
|
||||
file_size = self._get_file_size(file_path)
|
||||
|
||||
# Choose appropriate buffer or allocate new one
|
||||
buffer_size_mb = min(256, max(64, file_size // (1024 * 1024)))
|
||||
|
||||
if buffer_size_mb in self.pinned_buffers:
|
||||
pinned_buffer = self.pinned_buffers[buffer_size_mb]
|
||||
else:
|
||||
# Allocate temporary pinned buffer
|
||||
pinned_buffer = cupy.cuda.alloc_pinned_memory(file_size)
|
||||
|
||||
# Read file to pinned memory
|
||||
f.readinto(pinned_buffer)
|
||||
|
||||
# Use torch.load with map_location to specific GPU
|
||||
# This will be faster due to pinned memory
|
||||
state_dict = torch.load(
|
||||
f,
|
||||
map_location=f'cuda:{self.device}',
|
||||
weights_only=True
|
||||
)
|
||||
|
||||
self.stats['gds_loads'] += 1
|
||||
self.stats['total_bytes_gds'] += file_size
|
||||
|
||||
return state_dict
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"GDS PyTorch loading failed: {e}")
|
||||
raise
|
||||
|
||||
def _load_fallback(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||
"""Fallback loading method using standard approaches"""
|
||||
if file_path.lower().endswith(('.safetensors', '.sft')):
|
||||
# Use safetensors with device parameter
|
||||
with safetensors.safe_open(file_path, framework="pt", device=f'cuda:{self.device}') as f:
|
||||
return {k: f.get_tensor(k) for k in f.keys()}
|
||||
else:
|
||||
# Standard PyTorch loading
|
||||
return torch.load(file_path, map_location=f'cuda:{self.device}', weights_only=True)
|
||||
|
||||
def load_model(self, file_path: str, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Main entry point for loading models with GDS
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
device: Target device (if None, uses current CUDA device)
|
||||
|
||||
Returns:
|
||||
Dictionary of tensors loaded directly to GPU
|
||||
"""
|
||||
if device is not None and device.type == 'cuda':
|
||||
self.device = device.index or 0
|
||||
|
||||
if self._should_use_gds(file_path):
|
||||
logging.info(f"Loading {file_path} with GDS")
|
||||
return self._load_with_gds(file_path)
|
||||
else:
|
||||
logging.info(f"Loading {file_path} with standard method")
|
||||
self.stats['fallback_loads'] += 1
|
||||
return self._load_fallback(file_path)
|
||||
|
||||
def prefetch_model(self, file_path: str) -> bool:
|
||||
"""
|
||||
Prefetch model to GPU memory cache (if supported)
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
|
||||
Returns:
|
||||
True if prefetch was successful
|
||||
"""
|
||||
if not self.config.prefetch_enabled or not self._gds_available:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Basic prefetch implementation
|
||||
# This would ideally use NVIDIA's GPUDirect Storage API
|
||||
# to warm up the storage cache
|
||||
|
||||
file_size = self._get_file_size(file_path)
|
||||
logging.info(f"Prefetching {file_path} ({file_size // (1024*1024)} MB)")
|
||||
|
||||
# Read file metadata to warm caches
|
||||
with open(file_path, 'rb') as f:
|
||||
# Read first and last chunks to trigger prefetch
|
||||
f.read(1024 * 1024) # First 1MB
|
||||
f.seek(-min(1024 * 1024, file_size), 2) # Last 1MB
|
||||
f.read()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Prefetch failed for {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get loading statistics"""
|
||||
total_loads = self.stats['gds_loads'] + self.stats['fallback_loads']
|
||||
|
||||
if self.stats['total_time_gds'] > 0 and self.stats['total_bytes_gds'] > 0:
|
||||
bandwidth_gbps = (self.stats['total_bytes_gds'] / (1024**3)) / self.stats['total_time_gds']
|
||||
self.stats['avg_bandwidth_gbps'] = bandwidth_gbps
|
||||
|
||||
return {
|
||||
**self.stats,
|
||||
'total_loads': total_loads,
|
||||
'gds_usage_percent': (self.stats['gds_loads'] / max(1, total_loads)) * 100,
|
||||
'gds_available': self._gds_available,
|
||||
'config': self.config.__dict__
|
||||
}
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up GDS resources"""
|
||||
try:
|
||||
# Clear CUDA streams
|
||||
for stream in self.cuda_streams:
|
||||
stream.synchronize()
|
||||
self.cuda_streams.clear()
|
||||
|
||||
# Free pinned buffers
|
||||
for buffer in self.pinned_buffers.values():
|
||||
if CUPY_AVAILABLE:
|
||||
cupy.cuda.free_pinned_memory(buffer)
|
||||
self.pinned_buffers.clear()
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"GDS cleanup failed: {e}")
|
||||
|
||||
def __del__(self):
|
||||
"""Destructor to ensure cleanup"""
|
||||
self.cleanup()
|
||||
|
||||
|
||||
# Global GDS instance
|
||||
_gds_instance: Optional[GPUDirectStorage] = None
|
||||
|
||||
|
||||
def get_gds_instance(config: Optional[GDSConfig] = None) -> GPUDirectStorage:
|
||||
"""Get or create the global GDS instance"""
|
||||
global _gds_instance
|
||||
|
||||
if _gds_instance is None:
|
||||
_gds_instance = GPUDirectStorage(config)
|
||||
|
||||
return _gds_instance
|
||||
|
||||
|
||||
def load_torch_file_gds(ckpt: str, safe_load: bool = False, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
GDS-enabled replacement for comfy.utils.load_torch_file
|
||||
|
||||
Args:
|
||||
ckpt: Path to checkpoint file
|
||||
safe_load: Whether to use safe loading (for compatibility)
|
||||
device: Target device
|
||||
|
||||
Returns:
|
||||
Dictionary of loaded tensors
|
||||
"""
|
||||
gds = get_gds_instance()
|
||||
|
||||
try:
|
||||
# Load with GDS
|
||||
return gds.load_model(ckpt, device)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"GDS loading failed, falling back to standard method: {e}")
|
||||
# Fallback to original method
|
||||
import comfy.utils
|
||||
return comfy.utils.load_torch_file(ckpt, safe_load=safe_load, device=device)
|
||||
|
||||
|
||||
def prefetch_model_gds(file_path: str) -> bool:
|
||||
"""Prefetch model for faster loading"""
|
||||
gds = get_gds_instance()
|
||||
return gds.prefetch_model(file_path)
|
||||
|
||||
|
||||
def get_gds_stats() -> Dict[str, Any]:
|
||||
"""Get GDS statistics"""
|
||||
gds = get_gds_instance()
|
||||
return gds.get_stats()
|
||||
|
||||
|
||||
def configure_gds(config: GDSConfig):
|
||||
"""Configure GDS settings"""
|
||||
global _gds_instance
|
||||
_gds_instance = GPUDirectStorage(config)
|
||||
|
||||
|
||||
def init_gds(config: GDSConfig):
|
||||
"""
|
||||
Initialize GPUDirect Storage with the provided configuration
|
||||
|
||||
Args:
|
||||
config: GDSConfig object with initialization parameters
|
||||
"""
|
||||
try:
|
||||
# Configure GDS
|
||||
configure_gds(config)
|
||||
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
|
||||
|
||||
# Set up exit handler for stats if requested
|
||||
if hasattr(config, 'show_stats') and config.show_stats:
|
||||
import atexit
|
||||
def print_gds_stats():
|
||||
stats = get_gds_stats()
|
||||
logging.info("=== GDS Statistics ===")
|
||||
logging.info(f"Total loads: {stats['total_loads']}")
|
||||
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
|
||||
logging.info(f"Fallback loads: {stats['fallback_loads']}")
|
||||
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
|
||||
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
|
||||
logging.info("===================")
|
||||
atexit.register(print_gds_stats)
|
||||
|
||||
except ImportError as e:
|
||||
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"GDS initialization failed: {e}")
|
||||
@ -189,9 +189,12 @@ class AudioVAE(torch.nn.Module):
|
||||
waveform = self.device_manager.move_to_load_device(waveform)
|
||||
expected_channels = self.autoencoder.encoder.in_channels
|
||||
if waveform.shape[1] != expected_channels:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
if waveform.shape[1] == 1:
|
||||
waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
|
||||
mel_spec = self.preprocessor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||
|
||||
@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
def invert_slices(slices, length):
|
||||
sorted_slices = sorted(slices)
|
||||
result = []
|
||||
current = 0
|
||||
|
||||
for start, end in sorted_slices:
|
||||
if current < start:
|
||||
result.append((current, start))
|
||||
current = max(current, end)
|
||||
|
||||
if current < length:
|
||||
result.append((current, length))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def modulate(x, scale, timestep_zero_index=None):
|
||||
if timestep_zero_index is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
else:
|
||||
scale = (1 + scale.unsqueeze(1))
|
||||
actual_batch = scale.size(0) // 2
|
||||
slices = timestep_zero_index
|
||||
invert = invert_slices(timestep_zero_index, x.shape[1])
|
||||
for s in slices:
|
||||
x[:, s[0]:s[1]] *= scale[actual_batch:]
|
||||
for s in invert:
|
||||
x[:, s[0]:s[1]] *= scale[:actual_batch]
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(gate, x, timestep_zero_index=None):
|
||||
if timestep_zero_index is None:
|
||||
return gate * x
|
||||
else:
|
||||
actual_batch = gate.size(0) // 2
|
||||
|
||||
slices = timestep_zero_index
|
||||
invert = invert_slices(timestep_zero_index, x.shape[1])
|
||||
for s in slices:
|
||||
x[:, s[0]:s[1]] *= gate[actual_batch:]
|
||||
for s in invert:
|
||||
x[:, s[0]:s[1]] *= gate[:actual_batch]
|
||||
return x
|
||||
|
||||
#############################################################################
|
||||
# Core NextDiT Model #
|
||||
@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module):
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module):
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2(
|
||||
clamp_fp16(self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
))), timestep_zero_index=timestep_zero_index
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2(
|
||||
clamp_fp16(self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
))
|
||||
modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index),
|
||||
))), timestep_zero_index=timestep_zero_index
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
@ -345,13 +389,37 @@ class FinalLayer(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
def forward(self, x, c, timestep_zero_index=None):
|
||||
scale = self.adaLN_modulation(c)
|
||||
x = modulate(self.norm_final(x), scale)
|
||||
x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def pad_zimage(feats, pad_token, pad_tokens_multiple):
|
||||
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
|
||||
return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra
|
||||
|
||||
|
||||
def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}):
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = start_t
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
return x_pos_ids
|
||||
|
||||
|
||||
class NextDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
@ -378,6 +446,7 @@ class NextDiT(nn.Module):
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
siglip_feat_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -491,6 +560,41 @@ class NextDiT(nn.Module):
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if siglip_feat_dim is not None:
|
||||
self.siglip_embedder = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
siglip_feat_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.siglip_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
else:
|
||||
self.siglip_embedder = None
|
||||
self.siglip_refiner = None
|
||||
self.siglip_pad_token = None
|
||||
|
||||
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
@ -531,70 +635,168 @@ class NextDiT(nn.Module):
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
orig_x = x
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None):
|
||||
if cap_feats is not None:
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats_len = cap_feats.shape[1]
|
||||
if self.pad_tokens_multiple is not None:
|
||||
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
|
||||
else:
|
||||
cap_feats_len = 0
|
||||
cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
||||
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
|
||||
embeds = (cap_feats,)
|
||||
freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),)
|
||||
return embeds, freqs_cis, cap_feats_len
|
||||
|
||||
def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}):
|
||||
bsz = 1
|
||||
pH = pW = self.patch_size
|
||||
device = x.device
|
||||
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
|
||||
|
||||
if (not omni) or self.siglip_embedder is None:
|
||||
cap_feats_len = embeds[0].shape[1] + offset
|
||||
embeds += (None,)
|
||||
freqs_cis += (None,)
|
||||
else:
|
||||
cap_feats_len += offset
|
||||
if siglip_feats is not None:
|
||||
b, h, w, c = siglip_feats.shape
|
||||
siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c)
|
||||
siglip_feats = self.siglip_embedder(siglip_feats)
|
||||
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
||||
siglip_pos_ids[:, :, 0] = cap_feats_len + 2
|
||||
siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten()
|
||||
siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten()
|
||||
if self.siglip_pad_token is not None:
|
||||
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
|
||||
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
|
||||
else:
|
||||
if self.siglip_pad_token is not None:
|
||||
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
||||
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
||||
|
||||
if siglip_feats is None:
|
||||
embeds += (None,)
|
||||
freqs_cis += (None,)
|
||||
else:
|
||||
embeds += (siglip_feats,)
|
||||
freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options)
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
embeds += (x,)
|
||||
freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),)
|
||||
return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1
|
||||
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = x.shape[0]
|
||||
cap_mask = None # TODO?
|
||||
main_siglip = None
|
||||
orig_x = x
|
||||
|
||||
embeds = ([], [], [])
|
||||
freqs_cis = ([], [], [])
|
||||
leftover_cap = []
|
||||
|
||||
start_t = 0
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
for i, ref in enumerate(ref_latents):
|
||||
if i < len(ref_contexts):
|
||||
ref_con = ref_contexts[i]
|
||||
else:
|
||||
ref_con = None
|
||||
if i < len(siglip_feats):
|
||||
sig_feat = siglip_feats[i]
|
||||
else:
|
||||
sig_feat = None
|
||||
|
||||
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
|
||||
for i, e in enumerate(out[0]):
|
||||
if e is not None:
|
||||
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
|
||||
freqs_cis[i].append(out[1][i])
|
||||
start_t = out[2]
|
||||
leftover_cap = ref_contexts[len(ref_latents):]
|
||||
|
||||
H, W = x.shape[-2], x.shape[-1]
|
||||
img_sizes = [(H, W)] * bsz
|
||||
out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options)
|
||||
img_len = out[0][-1].shape[1]
|
||||
cap_len = out[0][0].shape[1]
|
||||
for i, e in enumerate(out[0]):
|
||||
if e is not None:
|
||||
e = comfy.utils.repeat_to_batch_size(e, bsz)
|
||||
embeds[i].append(e)
|
||||
freqs_cis[i].append(out[1][i])
|
||||
start_t = out[2]
|
||||
|
||||
for cap in leftover_cap:
|
||||
out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype)
|
||||
cap_len += out[0][0].shape[1]
|
||||
embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz))
|
||||
freqs_cis[0].append(out[1][0])
|
||||
start_t += out[2]
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
# refine context
|
||||
cap_feats = torch.cat(embeds[0], dim=1)
|
||||
cap_freqs_cis = torch.cat(freqs_cis[0], dim=1)
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
|
||||
feats = (cap_feats,)
|
||||
fc = (cap_freqs_cis,)
|
||||
|
||||
if omni and len(embeds[1]) > 0:
|
||||
siglip_mask = None
|
||||
siglip_feats_combined = torch.cat(embeds[1], dim=1)
|
||||
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
|
||||
if self.siglip_refiner is not None:
|
||||
for layer in self.siglip_refiner:
|
||||
siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options)
|
||||
feats += (siglip_feats_combined,)
|
||||
fc += (siglip_feats_freqs_cis,)
|
||||
|
||||
padded_img_mask = None
|
||||
x = torch.cat(embeds[-1], dim=1)
|
||||
fc_x = torch.cat(freqs_cis[-1], dim=1)
|
||||
if omni:
|
||||
timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])]
|
||||
else:
|
||||
timestep_zero_index = None
|
||||
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
padded_full_embed = torch.cat(feats + (x,), dim=1)
|
||||
if timestep_zero_index is not None:
|
||||
ind = padded_full_embed.shape[1] - x.shape[1]
|
||||
timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])]
|
||||
timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1]))
|
||||
|
||||
mask = None
|
||||
img_sizes = [(H, W)] * bsz
|
||||
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@ -604,7 +806,11 @@ class NextDiT(nn.Module):
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
|
||||
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
@ -619,8 +825,6 @@ class NextDiT(nn.Module):
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
@ -632,7 +836,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
@ -640,7 +844,7 @@ class NextDiT(nn.Module):
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
@ -649,8 +853,7 @@ class NextDiT(nn.Module):
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
img = self.final_layer(img, adaln_input)
|
||||
img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index)
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
|
||||
return -img
|
||||
|
||||
|
||||
@ -1150,6 +1150,7 @@ class CosmosPredict2(BaseModel):
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@ -1169,6 +1170,35 @@ class Lumina2(BaseModel):
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni
|
||||
if clip_vision_outputs is not None and len(clip_vision_outputs) > 0:
|
||||
sigfeats = []
|
||||
for clip_vision_output in clip_vision_outputs:
|
||||
if clip_vision_output is not None:
|
||||
image_size = clip_vision_output.image_sizes[0]
|
||||
shape = clip_vision_output.last_hidden_state.shape
|
||||
sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1]))
|
||||
if len(sigfeats) > 0:
|
||||
out['siglip_feats'] = comfy.conds.CONDList(sigfeats)
|
||||
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_contexts = kwargs.get("reference_latents_text_embeds", None)
|
||||
if ref_contexts is not None:
|
||||
out['ref_contexts'] = comfy.conds.CONDList(ref_contexts)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
|
||||
@ -446,6 +446,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["time_scale"] = 1000.0
|
||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
|
||||
if sig_weight is not None:
|
||||
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
||||
|
||||
return dit_config
|
||||
|
||||
|
||||
@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return OvisTEModel_
|
||||
|
||||
@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return ZImageTEModel_
|
||||
|
||||
@ -57,6 +57,18 @@ else:
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
# Try GDS loading first if available and device is GPU
|
||||
if device is not None and device.type == 'cuda':
|
||||
try:
|
||||
from . import gds_loader
|
||||
gds_result = gds_loader.load_torch_file_gds(ckpt, safe_load=safe_load, device=device)
|
||||
if return_metadata:
|
||||
# For GDS, we return empty metadata for now (can be enhanced)
|
||||
return (gds_result, {})
|
||||
return gds_result
|
||||
except Exception as e:
|
||||
logging.debug(f"GDS loading failed, using fallback: {e}")
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
metadata = None
|
||||
@ -639,6 +651,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
|
||||
@ -1000,20 +1000,38 @@ class Autogrow(ComfyTypeI):
|
||||
names = [f"{prefix}{i}" for i in range(max)]
|
||||
# need to create a new input based on the contents of input
|
||||
template_input = None
|
||||
for _, dict_input in input.items():
|
||||
# for now, get just the first value from dict_input
|
||||
template_required = True
|
||||
for _input_type, dict_input in input.items():
|
||||
# for now, get just the first value from dict_input; if not required, min can be ignored
|
||||
if len(dict_input) == 0:
|
||||
continue
|
||||
template_input = list(dict_input.values())[0]
|
||||
template_required = _input_type == "required"
|
||||
break
|
||||
if template_input is None:
|
||||
raise Exception("template_input could not be determined from required or optional; this should never happen.")
|
||||
new_dict = {}
|
||||
new_dict_added_to = False
|
||||
# first, add possible inputs into out_dict
|
||||
for i, name in enumerate(names):
|
||||
expected_id = finalize_prefix(curr_prefix, name)
|
||||
# required
|
||||
if i < min and template_required:
|
||||
out_dict["required"][expected_id] = template_input
|
||||
type_dict = new_dict.setdefault("required", {})
|
||||
# optional
|
||||
else:
|
||||
out_dict["optional"][expected_id] = template_input
|
||||
type_dict = new_dict.setdefault("optional", {})
|
||||
if expected_id in live_inputs:
|
||||
# required
|
||||
if i < min:
|
||||
type_dict = new_dict.setdefault("required", {})
|
||||
# optional
|
||||
else:
|
||||
type_dict = new_dict.setdefault("optional", {})
|
||||
# NOTE: prefix gets added in parse_class_inputs
|
||||
type_dict[name] = template_input
|
||||
new_dict_added_to = True
|
||||
# account for the edge case that all inputs are optional and no values are received
|
||||
if not new_dict_added_to:
|
||||
finalized_prefix = finalize_prefix(curr_prefix)
|
||||
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
|
||||
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_DICT
|
||||
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||
@ -1151,6 +1169,8 @@ class V3Data(TypedDict):
|
||||
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
|
||||
dynamic_paths: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
|
||||
dynamic_paths_default_value: dict[str, Any]
|
||||
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
|
||||
create_dynamic_tuple: bool
|
||||
'When True, the value of the dynamic input will be in the format (value, path_key).'
|
||||
|
||||
@ -1504,6 +1524,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
"required": {},
|
||||
"optional": {},
|
||||
"dynamic_paths": {},
|
||||
"dynamic_paths_default_value": {},
|
||||
}
|
||||
d = d.copy()
|
||||
# ignore hidden for parsing
|
||||
@ -1513,8 +1534,12 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
|
||||
out_dict["hidden"] = hidden
|
||||
v3_data = {}
|
||||
dynamic_paths = out_dict.pop("dynamic_paths", None)
|
||||
if dynamic_paths is not None:
|
||||
if dynamic_paths is not None and len(dynamic_paths) > 0:
|
||||
v3_data["dynamic_paths"] = dynamic_paths
|
||||
# this list is used for autogrow, in the case all inputs are optional and no values are passed
|
||||
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
|
||||
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
|
||||
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
|
||||
return out_dict, hidden, v3_data
|
||||
|
||||
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
|
||||
@ -1551,11 +1576,16 @@ def add_to_dict_v1(i: Input, d: dict):
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
|
||||
class DynamicPathsDefaultValue:
|
||||
EMPTY_DICT = "empty_dict"
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
|
||||
result = {}
|
||||
|
||||
create_tuple = v3_data.get("create_dynamic_tuple", False)
|
||||
@ -1569,6 +1599,11 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
|
||||
if is_last:
|
||||
value = values.pop(key, None)
|
||||
if value is None:
|
||||
# see if a default value was provided for this key
|
||||
default_option = default_value_dict.get(key, None)
|
||||
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
|
||||
value = {}
|
||||
if create_tuple:
|
||||
value = (value, key)
|
||||
current[p] = value
|
||||
|
||||
61
comfy_api_nodes/apis/bria.py
Normal file
61
comfy_api_nodes/apis/bria.py
Normal file
@ -0,0 +1,61 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InputModerationSettings(TypedDict):
|
||||
prompt_content_moderation: bool
|
||||
visual_input_moderation: bool
|
||||
visual_output_moderation: bool
|
||||
|
||||
|
||||
class BriaEditImageRequest(BaseModel):
|
||||
instruction: str | None = Field(...)
|
||||
structured_instruction: str | None = Field(
|
||||
...,
|
||||
description="Use this instead of instruction for precise, programmatic control.",
|
||||
)
|
||||
images: list[str] = Field(
|
||||
...,
|
||||
description="Required. Publicly available URL or Base64-encoded. Must contain exactly one item.",
|
||||
)
|
||||
mask: str | None = Field(
|
||||
None,
|
||||
description="Mask image (black and white). Black areas will be preserved, white areas will be edited. "
|
||||
"If omitted, the edit applies to the entire image. "
|
||||
"The input image and the the input mask must be of the same size.",
|
||||
)
|
||||
negative_prompt: str | None = Field(None)
|
||||
guidance_scale: float = Field(...)
|
||||
model_version: str = Field(...)
|
||||
steps_num: int = Field(...)
|
||||
seed: int = Field(...)
|
||||
ip_signal: bool = Field(
|
||||
False,
|
||||
description="If true, returns a warning for potential IP content in the instruction.",
|
||||
)
|
||||
prompt_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on instruction moderation failure."
|
||||
)
|
||||
visual_input_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on images or mask moderation failure."
|
||||
)
|
||||
visual_output_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on visual output moderation failure."
|
||||
)
|
||||
|
||||
|
||||
class BriaStatusResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
status_url: str = Field(...)
|
||||
warning: str | None = Field(None)
|
||||
|
||||
|
||||
class BriaResult(BaseModel):
|
||||
structured_prompt: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaResult | None = Field(None)
|
||||
198
comfy_api_nodes/nodes_bria.py
Normal file
198
comfy_api_nodes/nodes_bria.py
Normal file
@ -0,0 +1,198 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bria import (
|
||||
BriaEditImageRequest,
|
||||
BriaResponse,
|
||||
BriaStatusResponse,
|
||||
InputModerationSettings,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
convert_mask_to_image,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
|
||||
|
||||
class BriaImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaImageEditNode",
|
||||
display_name="Bria Image Edit",
|
||||
category="api node/image/Bria",
|
||||
description="Edit images using Bria latest model",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["FIBO"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Instruction to edit image",
|
||||
),
|
||||
IO.String.Input("negative_prompt", multiline=True, default=""),
|
||||
IO.String.Input(
|
||||
"structured_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="A string containing the structured edit prompt in JSON format. "
|
||||
"Use this instead of usual prompt for precise, programmatic control.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=1,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"guidance_scale",
|
||||
default=3,
|
||||
min=3,
|
||||
max=5,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Higher value makes the image follow the prompt more closely.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=50,
|
||||
min=20,
|
||||
max=50,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"moderation",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"prompt_content_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_input_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_output_moderation", default=True
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="Moderation settings",
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="If omitted, the edit applies to the entire image.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(display_name="structured_prompt"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.04}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
structured_prompt: str,
|
||||
seed: int,
|
||||
guidance_scale: float,
|
||||
steps: int,
|
||||
moderation: InputModerationSettings,
|
||||
mask: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if not prompt and not structured_prompt:
|
||||
raise ValueError(
|
||||
"One of prompt or structured_prompt is required to be non-empty."
|
||||
)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
mask_url = None
|
||||
if mask is not None:
|
||||
mask_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
cls,
|
||||
convert_mask_to_image(mask),
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading mask",
|
||||
)
|
||||
)[0]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
|
||||
data=BriaEditImageRequest(
|
||||
instruction=prompt if prompt else None,
|
||||
structured_instruction=structured_prompt if structured_prompt else None,
|
||||
images=await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading image",
|
||||
),
|
||||
mask=mask_url,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
model_version=model,
|
||||
steps_num=steps,
|
||||
prompt_content_moderation=moderation.get(
|
||||
"prompt_content_moderation", False
|
||||
),
|
||||
visual_input_content_moderation=moderation.get(
|
||||
"visual_input_moderation", False
|
||||
),
|
||||
visual_output_content_moderation=moderation.get(
|
||||
"visual_output_moderation", False
|
||||
),
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaResponse,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_image_tensor(response.result.image_url),
|
||||
response.result.structured_prompt,
|
||||
)
|
||||
|
||||
|
||||
class BriaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
BriaImageEditNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BriaExtension:
|
||||
return BriaExtension()
|
||||
@ -11,6 +11,7 @@ from .conversions import (
|
||||
audio_input_to_mp3,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
convert_mask_to_image,
|
||||
downscale_image_tensor,
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
@ -72,6 +73,7 @@ __all__ = [
|
||||
"audio_input_to_mp3",
|
||||
"audio_to_base64_string",
|
||||
"bytesio_to_image_tensor",
|
||||
"convert_mask_to_image",
|
||||
"downscale_image_tensor",
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
|
||||
@ -451,6 +451,12 @@ def resize_mask_to_image(
|
||||
return mask
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: Input.Image) -> torch.Tensor:
|
||||
"""Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image."""
|
||||
mask = mask.unsqueeze(-1)
|
||||
return torch.cat([mask] * 3, dim=-1)
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
|
||||
293
comfy_extras/nodes_gds.py
Normal file
293
comfy_extras/nodes_gds.py
Normal file
@ -0,0 +1,293 @@
|
||||
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
|
||||
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
|
||||
|
||||
"""
|
||||
Enhanced model loading nodes with GPUDirect Storage support
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
|
||||
class CheckpointLoaderGDS(ComfyNodeABC):
|
||||
"""
|
||||
Enhanced checkpoint loader with GPUDirect Storage support
|
||||
Provides direct SSD-to-GPU loading and prefetching capabilities
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {
|
||||
"tooltip": "The name of the checkpoint (model) to load with GDS optimization."
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"prefetch": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Prefetch model to GPU cache for faster loading."
|
||||
}),
|
||||
"use_gds": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Use GPUDirect Storage if available."
|
||||
}),
|
||||
"target_device": (["auto", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cpu"], {
|
||||
"default": "auto",
|
||||
"tooltip": "Target device for model loading."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING")
|
||||
RETURN_NAMES = ("model", "clip", "vae", "load_info")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.",
|
||||
"Loading information and statistics."
|
||||
)
|
||||
FUNCTION = "load_checkpoint_gds"
|
||||
CATEGORY = "loaders/advanced"
|
||||
DESCRIPTION = "Enhanced checkpoint loader with GPUDirect Storage support for direct SSD-to-GPU loading."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def load_checkpoint_gds(self, ckpt_name: str, prefetch: bool = False, use_gds: bool = True, target_device: str = "auto"):
|
||||
start_time = time.time()
|
||||
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
# Determine target device
|
||||
if target_device == "auto":
|
||||
device = None # Let the system decide
|
||||
elif target_device == "cpu":
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device(target_device)
|
||||
|
||||
load_info = {
|
||||
"file": ckpt_name,
|
||||
"path": ckpt_path,
|
||||
"target_device": str(device) if device else "auto",
|
||||
"gds_enabled": use_gds,
|
||||
"prefetch_used": prefetch
|
||||
}
|
||||
|
||||
try:
|
||||
# Prefetch if requested
|
||||
if prefetch and use_gds:
|
||||
try:
|
||||
from comfy.gds_loader import prefetch_model_gds
|
||||
prefetch_success = prefetch_model_gds(ckpt_path)
|
||||
load_info["prefetch_success"] = prefetch_success
|
||||
if prefetch_success:
|
||||
logging.info(f"Prefetched {ckpt_name} to GPU cache")
|
||||
except Exception as e:
|
||||
logging.warning(f"Prefetch failed for {ckpt_name}: {e}")
|
||||
load_info["prefetch_error"] = str(e)
|
||||
|
||||
# Load checkpoint with potential GDS optimization
|
||||
if use_gds and device and device.type == 'cuda':
|
||||
try:
|
||||
from comfy.gds_loader import get_gds_instance
|
||||
gds = get_gds_instance()
|
||||
|
||||
# Check if GDS should be used for this file
|
||||
if gds._should_use_gds(ckpt_path):
|
||||
load_info["loader_used"] = "GDS"
|
||||
logging.info(f"Loading {ckpt_name} with GDS")
|
||||
else:
|
||||
load_info["loader_used"] = "Standard"
|
||||
logging.info(f"Loading {ckpt_name} with standard method (file too small for GDS)")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"GDS check failed, using standard loading: {e}")
|
||||
load_info["loader_used"] = "Standard (GDS failed)"
|
||||
else:
|
||||
load_info["loader_used"] = "Standard"
|
||||
|
||||
# Load the actual checkpoint
|
||||
out = comfy.sd.load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings")
|
||||
)
|
||||
|
||||
load_time = time.time() - start_time
|
||||
load_info["load_time_seconds"] = round(load_time, 3)
|
||||
load_info["load_success"] = True
|
||||
|
||||
# Format load info as string
|
||||
info_str = f"Loaded: {ckpt_name}\n"
|
||||
info_str += f"Method: {load_info['loader_used']}\n"
|
||||
info_str += f"Time: {load_info['load_time_seconds']}s\n"
|
||||
info_str += f"Device: {load_info['target_device']}"
|
||||
|
||||
if "prefetch_success" in load_info:
|
||||
info_str += f"\nPrefetch: {'✓' if load_info['prefetch_success'] else '✗'}"
|
||||
|
||||
logging.info(f"Checkpoint loaded: {ckpt_name} in {load_time:.3f}s using {load_info['loader_used']}")
|
||||
|
||||
return (*out[:3], info_str)
|
||||
|
||||
except Exception as e:
|
||||
load_info["load_success"] = False
|
||||
load_info["error"] = str(e)
|
||||
error_str = f"Failed to load: {ckpt_name}\nError: {str(e)}"
|
||||
logging.error(f"Checkpoint loading failed: {e}")
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
|
||||
class ModelPrefetcher(ComfyNodeABC):
|
||||
"""
|
||||
Node for prefetching models to GPU cache
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"checkpoint_names": ("STRING", {
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "List of checkpoint names to prefetch (one per line)."
|
||||
}),
|
||||
"prefetch_enabled": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Enable/disable prefetching."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("prefetch_report",)
|
||||
OUTPUT_TOOLTIPS = ("Report of prefetch operations.",)
|
||||
FUNCTION = "prefetch_models"
|
||||
CATEGORY = "loaders/advanced"
|
||||
DESCRIPTION = "Prefetch multiple models to GPU cache for faster loading."
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def prefetch_models(self, checkpoint_names: str, prefetch_enabled: bool = True):
|
||||
if not prefetch_enabled:
|
||||
return ("Prefetching disabled",)
|
||||
|
||||
# Parse checkpoint names
|
||||
names = [name.strip() for name in checkpoint_names.split('\n') if name.strip()]
|
||||
|
||||
if not names:
|
||||
return ("No checkpoints specified for prefetching",)
|
||||
|
||||
try:
|
||||
from comfy.gds_loader import prefetch_model_gds
|
||||
except ImportError:
|
||||
return ("GDS not available for prefetching",)
|
||||
|
||||
results = []
|
||||
successful_prefetches = 0
|
||||
|
||||
for name in names:
|
||||
try:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", name)
|
||||
success = prefetch_model_gds(ckpt_path)
|
||||
|
||||
if success:
|
||||
results.append(f"✓ {name}")
|
||||
successful_prefetches += 1
|
||||
else:
|
||||
results.append(f"✗ {name} (prefetch failed)")
|
||||
|
||||
except Exception as e:
|
||||
results.append(f"✗ {name} (error: {str(e)[:50]})")
|
||||
|
||||
report = f"Prefetch Report ({successful_prefetches}/{len(names)} successful):\n"
|
||||
report += "\n".join(results)
|
||||
|
||||
return (report,)
|
||||
|
||||
|
||||
class GDSStats(ComfyNodeABC):
|
||||
"""
|
||||
Node for displaying GDS statistics
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"refresh": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Refresh statistics."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("stats_report",)
|
||||
OUTPUT_TOOLTIPS = ("GDS statistics and performance report.",)
|
||||
FUNCTION = "get_stats"
|
||||
CATEGORY = "utils/advanced"
|
||||
DESCRIPTION = "Display GPUDirect Storage statistics and performance metrics."
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def get_stats(self, refresh: bool = False):
|
||||
try:
|
||||
from comfy.gds_loader import get_gds_stats
|
||||
stats = get_gds_stats()
|
||||
|
||||
report = "=== GPUDirect Storage Statistics ===\n\n"
|
||||
|
||||
# Availability
|
||||
report += f"GDS Available: {'✓' if stats['gds_available'] else '✗'}\n"
|
||||
|
||||
# Usage statistics
|
||||
report += f"Total Loads: {stats['total_loads']}\n"
|
||||
report += f"GDS Loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)\n"
|
||||
report += f"Fallback Loads: {stats['fallback_loads']}\n\n"
|
||||
|
||||
# Performance metrics
|
||||
if stats['total_bytes_gds'] > 0:
|
||||
gb_transferred = stats['total_bytes_gds'] / (1024**3)
|
||||
report += f"Data Transferred: {gb_transferred:.2f} GB\n"
|
||||
report += f"Average Bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s\n"
|
||||
report += f"Total GDS Time: {stats['total_time_gds']:.2f}s\n\n"
|
||||
|
||||
# Configuration
|
||||
config = stats.get('config', {})
|
||||
if config:
|
||||
report += "Configuration:\n"
|
||||
report += f"- Enabled: {config.get('enabled', 'Unknown')}\n"
|
||||
report += f"- Min File Size: {config.get('min_file_size_mb', 'Unknown')} MB\n"
|
||||
report += f"- Chunk Size: {config.get('chunk_size_mb', 'Unknown')} MB\n"
|
||||
report += f"- Max Streams: {config.get('max_concurrent_streams', 'Unknown')}\n"
|
||||
report += f"- Prefetch: {config.get('prefetch_enabled', 'Unknown')}\n"
|
||||
report += f"- Fallback: {config.get('fallback_to_cpu', 'Unknown')}\n"
|
||||
|
||||
return (report,)
|
||||
|
||||
except ImportError:
|
||||
return ("GDS module not available",)
|
||||
except Exception as e:
|
||||
return (f"Error retrieving GDS stats: {str(e)}",)
|
||||
|
||||
|
||||
# Node mappings
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CheckpointLoaderGDS": CheckpointLoaderGDS,
|
||||
"ModelPrefetcher": ModelPrefetcher,
|
||||
"GDSStats": GDSStats,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointLoaderGDS": "Load Checkpoint (GDS)",
|
||||
"ModelPrefetcher": "Model Prefetcher",
|
||||
"GDSStats": "GDS Statistics",
|
||||
}
|
||||
88
comfy_extras/nodes_zimage.py
Normal file
88
comfy_extras/nodes_zimage.py
Normal file
@ -0,0 +1,88 @@
|
||||
import node_helpers
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class TextEncodeZImageOmni(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeZImageOmni",
|
||||
category="advanced/conditioning",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.ClipVision.Input("image_encoder", optional=True),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Boolean.Input("auto_resize_images", default=True),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Image.Input("image1", optional=True),
|
||||
io.Image.Input("image2", optional=True),
|
||||
io.Image.Input("image3", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
|
||||
ref_latents = []
|
||||
images = list(filter(lambda a: a is not None, [image1, image2, image3]))
|
||||
|
||||
prompt_list = []
|
||||
template = None
|
||||
if len(images) > 0:
|
||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1)
|
||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
||||
template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"
|
||||
|
||||
encoded_images = []
|
||||
|
||||
for i, image in enumerate(images):
|
||||
if image_encoder is not None:
|
||||
encoded_images.append(image_encoder.encode_image(image))
|
||||
|
||||
if vae is not None:
|
||||
if auto_resize_images:
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(1024 * 1024)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by / 8.0) * 8
|
||||
height = round(samples.shape[2] * scale_by / 8.0) * 8
|
||||
|
||||
image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
|
||||
ref_latents.append(vae.encode(image))
|
||||
|
||||
tokens = clip.tokenize(prompt, llama_template=template)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
extra_text_embeds = []
|
||||
for p in prompt_list:
|
||||
tokens = clip.tokenize(p, llama_template="{}")
|
||||
text_embeds = clip.encode_from_tokens_scheduled(tokens)
|
||||
extra_text_embeds.append(text_embeds[0][0])
|
||||
|
||||
if len(ref_latents) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
||||
if len(encoded_images) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True)
|
||||
if len(extra_text_embeds) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True)
|
||||
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class ZImageExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeZImageOmni,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ZImageExtension:
|
||||
return ZImageExtension()
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.9.2"
|
||||
__version__ = "0.10.0"
|
||||
|
||||
29
main.py
29
main.py
@ -184,6 +184,35 @@ import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
# Initialize GPUDirect Storage if enabled
|
||||
def init_gds():
|
||||
"""Initialize GPUDirect Storage based on CLI arguments"""
|
||||
if hasattr(args, 'disable_gds') and args.disable_gds:
|
||||
logging.info("GDS explicitly disabled via --disable-gds")
|
||||
return
|
||||
|
||||
if not hasattr(args, 'enable_gds') and not hasattr(args, 'gds_prefetch') and not hasattr(args, 'gds_stats'):
|
||||
# GDS not explicitly requested, use auto-detection
|
||||
return
|
||||
|
||||
if hasattr(args, 'enable_gds') and args.enable_gds:
|
||||
from comfy.gds_loader import GDSConfig, init_gds as gds_init
|
||||
|
||||
config = GDSConfig(
|
||||
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
|
||||
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
|
||||
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
|
||||
max_concurrent_streams=getattr(args, 'gds_streams', 4),
|
||||
prefetch_enabled=getattr(args, 'gds_prefetch', True),
|
||||
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False),
|
||||
show_stats=getattr(args, 'gds_stats', False)
|
||||
)
|
||||
|
||||
gds_init(config)
|
||||
|
||||
# Initialize GDS
|
||||
init_gds()
|
||||
|
||||
def cuda_malloc_warning():
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -2367,12 +2367,14 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_model_patch.py",
|
||||
"nodes_easycache.py",
|
||||
"nodes_audio_encoder.py",
|
||||
"nodes_gds.py",
|
||||
"nodes_rope.py",
|
||||
"nodes_logic.py",
|
||||
"nodes_nop.py",
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.9.2"
|
||||
version = "0.10.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.36.14
|
||||
comfyui-workflow-templates==0.8.11
|
||||
comfyui-frontend-package==1.37.11
|
||||
comfyui-workflow-templates==0.8.15
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
@ -27,4 +27,4 @@ comfy-kitchen>=0.2.7
|
||||
kornia>=0.7.1
|
||||
spandrel
|
||||
pydantic~=2.0
|
||||
pydantic-settings~=2.0
|
||||
pydantic-settings~=2.0
|
||||
Loading…
Reference in New Issue
Block a user