Compare commits

...

30 Commits

Author SHA1 Message Date
Maifee Ul Asad
9afb605b4e
Merge 7602203696 into 8ccc0c94fa 2026-01-20 16:12:28 +07:00
comfyanonymous
8ccc0c94fa
Make omni stuff work on regular z image for easier testing. (#11985)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-01-20 00:32:00 -05:00
Comfy Org PR Bot
4edb87aa50
Bump comfyui-frontend-package to 1.37.11 (#11976) 2026-01-19 23:57:50 -05:00
ComfyUI Wiki
0fc3b6e3a6
chore: update workflow templates to v0.8.15 (#11984) 2026-01-19 23:17:56 -05:00
comfyanonymous
2108167f9f
Support zimage omni base model. (#11979) 2026-01-19 23:17:38 -05:00
comfyanonymous
9d273d3ab1 ComfyUI v0.10.0 2026-01-19 22:40:18 -05:00
comfyanonymous
70c91b8248
Fix #11963 (#11982) 2026-01-19 22:32:40 -05:00
rkfg
0da5a0fe58
Convert mono audio to fake stereo for LTXV VAE encoding (#11965)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-19 22:12:02 -05:00
comfyanonymous
e0eacb0688
Simpler way to implement the #11980 loras. (#11981) 2026-01-19 22:00:36 -05:00
Jedrzej Kosinski
7458e20465
Make Autogrow validation work properly (#11977)
* In-progress autogrow validation fixes - properly looks at required/optional inputs, now working on the edge case that all inputs are optional and nothing is plugged in (should just be an empty dictionary passed into node)

* Allow autogrow to work with all inputs being optional

* Revert accidentally pushed changes to nodes_logic.py
2026-01-19 16:58:30 -08:00
Jedrzej Kosinski
b931b37e30
feat(api-nodes): add Bria Edit node (#11978)
Co-authored-by: Alexander Piskun <bigcat88@icloud.com>
2026-01-19 16:47:14 -08:00
ComfyUI Wiki
866a4619db
chore: update workflow templates to v0.8.14 (#11974) 2026-01-19 14:21:35 -08:00
Maifee Ul Asad
7602203696
Merge branch 'Comfy-Org:master' into offloader-maifee 2026-01-11 12:59:53 +06:00
Maifee Ul Asad
ffa7a369ba
Merge branch 'Comfy-Org:master' into offloader-maifee 2026-01-07 21:26:04 +06:00
Maifee Ul Asad
7ec1656735
Merge branch 'comfyanonymous:master' into offloader-maifee 2025-12-07 15:32:22 +06:00
Maifee Ul Asad
cee75f301a
Merge branch 'comfyanonymous:master' into offloader-maifee 2025-11-27 08:47:41 +06:00
Maifee Ul Asad
1a59686ca8
Merge branch 'comfyanonymous:master' into offloader-maifee 2025-11-25 22:09:53 +06:00
Maifee Ul Asad
6d96d26795
Merge branch 'comfyanonymous:master' into offloader-maifee 2025-11-25 22:08:51 +06:00
Maifee Ul Asad
e07a32c9b8
Merge branch 'master' into offloader-maifee 2025-11-09 17:25:46 +06:00
Maifee Ul Asad
a19f0a88e4 refactor(gds): add show_stats option; moved GDS initialization to dedicated file; 2025-10-12 00:54:04 +06:00
Maifee Ul Asad
64811809a0 docs: updated GDS environment limitations;
- works only on linux and nvidia
2025-10-11 21:51:24 +06:00
Maifee Ul Asad
529109083a docs: added GDS docs and deps documentation; 2025-10-10 22:23:13 +06:00
Maifee Ul Asad
a7be9f6fc3 review: remove GDS-related dependencies from requirements.txt 2025-10-10 22:08:58 +06:00
Maifee Ul Asad
6075c44ec8 feat(gds): add GDS-related dependencies to requirements.txt 2025-10-08 14:43:55 +06:00
Maifee Ul Asad
154b73835a feat(gds): implement GPUDirect Storage initialization based on CLI arguments 2025-10-08 14:40:59 +06:00
Maifee Ul Asad
862e7784f4 feat(gds): add nodes_gds.py to built-in extra nodes initialization 2025-10-08 14:40:41 +06:00
Maifee Ul Asad
f6b6636bf3 feat(gds): implement GDS loading fallback in load_torch_file function;
- need to work with tensorflow and other formats
 - afaik, almost all models shared now is in torch format
 - converting types should not be that big of a deal
2025-10-08 14:40:29 +06:00
Maifee Ul Asad
83b00df3f0 feat(gds): add GPUDirect Storage options for SSD-to-GPU model loading; 2025-10-08 14:39:22 +06:00
Maifee Ul Asad
5f24eb699c feat(gds): implement GPUDirect Storage for efficient model loading 2025-10-08 14:38:49 +06:00
Maifee Ul Asad
fab0954077 feat(gds): add GPUDirect Storage support for model loading and prefetching
- limited to NVIDIA GPUs only
2025-10-08 14:38:38 +06:00
22 changed files with 1557 additions and 75 deletions

View File

@ -404,6 +404,14 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
## How to run heavy workflow on mid range GPU (NVIDIA-Linux)?
Use the `--enable-gds` flag to activate NVIDIA [GPUDirect Storage](https://docs.nvidia.com/gpudirect-storage/) (GDS), which allows data to be transferred directly between SSDs and GPUs. This eliminates traditional CPU-mediated data paths, significantly reducing I/O latency and CPU overhead. System RAM will still be utilized for caching to further optimize performance, along with SSD.
This feature is tested on NVIDIA GPUs on Linux based system only.
Requires: `cupy-cuda12x>=12.0.0`, `pynvml>=11.4.1`, `cudf>=23.0.0`, `numba>=0.57.0`, `nvidia-ml-py>=12.0.0`.
## Support and dev channel
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.

View File

@ -154,6 +154,17 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
# GPUDirect Storage (GDS) arguments
gds_group = parser.add_argument_group('gds', 'GPUDirect Storage options for direct SSD-to-GPU model loading')
gds_group.add_argument("--enable-gds", action="store_true", help="Enable GPUDirect Storage for direct SSD-to-GPU model loading (requires CUDA 11.4+, cuFile).")
gds_group.add_argument("--disable-gds", action="store_true", help="Explicitly disable GPUDirect Storage.")
gds_group.add_argument("--gds-min-file-size", type=int, default=100, help="Minimum file size in MB to use GDS (default: 100MB).")
gds_group.add_argument("--gds-chunk-size", type=int, default=64, help="GDS transfer chunk size in MB (default: 64MB).")
gds_group.add_argument("--gds-streams", type=int, default=4, help="Number of CUDA streams for GDS operations (default: 4).")
gds_group.add_argument("--gds-prefetch", action="store_true", help="Enable GDS prefetching for better performance.")
gds_group.add_argument("--gds-no-fallback", action="store_true", help="Disable fallback to CPU loading if GDS fails.")
gds_group.add_argument("--gds-stats", action="store_true", help="Print GDS statistics on exit.")
class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"

494
comfy/gds_loader.py Normal file
View File

@ -0,0 +1,494 @@
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
"""
GPUDirect Storage (GDS) Integration for ComfyUI
Direct SSD-to-GPU model loading without RAM/CPU bottlenecks
Still there will be some CPU/RAM usage, mostly for safetensors parsing and small buffers.
This module provides GPUDirect Storage functionality to load models directly
from NVMe SSDs to GPU memory, bypassing system RAM and CPU.
"""
import os
import logging
import torch
import time
from typing import Optional, Dict, Any, Union
from pathlib import Path
import safetensors
import gc
import mmap
from dataclasses import dataclass
try:
import cupy
import cupy.cuda.runtime as cuda_runtime
CUPY_AVAILABLE = True
except ImportError:
CUPY_AVAILABLE = False
logging.warning("CuPy not available. GDS will use fallback mode.")
try:
import cudf # RAPIDS for GPU dataframes
RAPIDS_AVAILABLE = True
except ImportError:
RAPIDS_AVAILABLE = False
try:
import pynvml
pynvml.nvmlInit()
NVML_AVAILABLE = True
except ImportError:
NVML_AVAILABLE = False
logging.warning("NVIDIA-ML-Py not available. GPU monitoring disabled.")
@dataclass
class GDSConfig:
"""Configuration for GPUDirect Storage"""
enabled: bool = True
min_file_size_mb: int = 100 # Only use GDS for files larger than this
chunk_size_mb: int = 64 # Size of chunks to transfer
use_pinned_memory: bool = True
prefetch_enabled: bool = True
compression_aware: bool = True
max_concurrent_streams: int = 4
fallback_to_cpu: bool = True
show_stats: bool = False # Whether to show stats on exit
class GDSError(Exception):
"""GDS-specific errors"""
pass
class GPUDirectStorage:
"""
GPUDirect Storage implementation for ComfyUI
Enables direct SSD-to-GPU transfers for model loading
"""
def __init__(self, config: Optional[GDSConfig] = None):
self.config = config or GDSConfig()
self.device = torch.cuda.current_device() if torch.cuda.is_available() else None
self.cuda_streams = []
self.pinned_buffers = {}
self.stats = {
'gds_loads': 0,
'fallback_loads': 0,
'total_bytes_gds': 0,
'total_time_gds': 0.0,
'avg_bandwidth_gbps': 0.0
}
# Initialize GDS if available
self._gds_available = self._check_gds_availability()
if self._gds_available:
self._init_gds()
else:
logging.warning("GDS not available, using fallback methods")
def _check_gds_availability(self) -> bool:
"""Check if GDS is available on the system"""
if not torch.cuda.is_available():
return False
if not CUPY_AVAILABLE:
return False
# Check for GPUDirect Storage support
try:
# Check CUDA version (GDS requires CUDA 11.4+)
cuda_version = torch.version.cuda
if cuda_version:
major, minor = map(int, cuda_version.split('.')[:2])
if major < 11 or (major == 11 and minor < 4):
logging.warning(f"CUDA {cuda_version} detected. GDS requires CUDA 11.4+")
return False
# Check if cuFile is available (part of CUDA toolkit)
try:
import cupy.cuda.cufile as cufile
# Try to initialize cuFile
cufile.initialize()
return True
except (ImportError, RuntimeError) as e:
logging.warning(f"cuFile not available: {e}")
return False
except Exception as e:
logging.warning(f"GDS availability check failed: {e}")
return False
def _init_gds(self):
"""Initialize GDS resources"""
try:
# Create CUDA streams for async operations
for i in range(self.config.max_concurrent_streams):
stream = torch.cuda.Stream()
self.cuda_streams.append(stream)
# Pre-allocate pinned memory buffers
if self.config.use_pinned_memory:
self._allocate_pinned_buffers()
logging.info(f"GDS initialized with {len(self.cuda_streams)} streams")
except Exception as e:
logging.error(f"Failed to initialize GDS: {e}")
self._gds_available = False
def _allocate_pinned_buffers(self):
"""Pre-allocate pinned memory buffers for staging"""
try:
# Allocate buffers of different sizes
buffer_sizes = [16, 32, 64, 128, 256] # MB
for size_mb in buffer_sizes:
size_bytes = size_mb * 1024 * 1024
# Allocate pinned memory using CuPy
if CUPY_AVAILABLE:
buffer = cupy.cuda.alloc_pinned_memory(size_bytes)
self.pinned_buffers[size_mb] = buffer
except Exception as e:
logging.warning(f"Failed to allocate pinned buffers: {e}")
def _get_file_size(self, file_path: str) -> int:
"""Get file size in bytes"""
return os.path.getsize(file_path)
def _should_use_gds(self, file_path: str) -> bool:
"""Determine if GDS should be used for this file"""
if not self._gds_available or not self.config.enabled:
return False
file_size_mb = self._get_file_size(file_path) / (1024 * 1024)
return file_size_mb >= self.config.min_file_size_mb
def _load_with_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
"""Load model using GPUDirect Storage"""
start_time = time.time()
try:
if file_path.lower().endswith(('.safetensors', '.sft')):
return self._load_safetensors_gds(file_path)
else:
return self._load_pytorch_gds(file_path)
except Exception as e:
logging.error(f"GDS loading failed for {file_path}: {e}")
if self.config.fallback_to_cpu:
logging.info("Falling back to CPU loading")
self.stats['fallback_loads'] += 1
return self._load_fallback(file_path)
else:
raise GDSError(f"GDS loading failed: {e}")
finally:
load_time = time.time() - start_time
self.stats['total_time_gds'] += load_time
def _load_safetensors_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
"""Load safetensors file using GDS"""
try:
import cupy.cuda.cufile as cufile
# Open file with cuFile for direct GPU loading
with cufile.CuFileManager() as manager:
# Memory-map the file for efficient access
with open(file_path, 'rb') as f:
# Use mmap for large files
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mmapped_file:
# Parse safetensors header
header_size = int.from_bytes(mmapped_file[:8], 'little')
header_bytes = mmapped_file[8:8+header_size]
import json
header = json.loads(header_bytes.decode('utf-8'))
# Load tensors directly to GPU
tensors = {}
data_offset = 8 + header_size
for name, info in header.items():
if name == "__metadata__":
continue
dtype_map = {
'F32': torch.float32,
'F16': torch.float16,
'BF16': torch.bfloat16,
'I8': torch.int8,
'I16': torch.int16,
'I32': torch.int32,
'I64': torch.int64,
'U8': torch.uint8,
}
dtype = dtype_map.get(info['dtype'], torch.float32)
shape = info['shape']
start_offset = data_offset + info['data_offsets'][0]
end_offset = data_offset + info['data_offsets'][1]
# Direct GPU allocation
tensor = torch.empty(shape, dtype=dtype, device=f'cuda:{self.device}')
# Use cuFile for direct transfer
tensor_bytes = end_offset - start_offset
# Get GPU memory pointer
gpu_ptr = tensor.data_ptr()
# Direct file-to-GPU transfer
cufile.copy_from_file(
gpu_ptr,
mmapped_file[start_offset:end_offset],
tensor_bytes
)
tensors[name] = tensor
self.stats['gds_loads'] += 1
self.stats['total_bytes_gds'] += self._get_file_size(file_path)
return tensors
except Exception as e:
logging.error(f"GDS safetensors loading failed: {e}")
raise
def _load_pytorch_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
"""Load PyTorch file using GDS with staging"""
try:
# For PyTorch files, we need to use a staging approach
# since torch.load doesn't support direct GPU loading
# Load to pinned memory first
with open(file_path, 'rb') as f:
file_size = self._get_file_size(file_path)
# Choose appropriate buffer or allocate new one
buffer_size_mb = min(256, max(64, file_size // (1024 * 1024)))
if buffer_size_mb in self.pinned_buffers:
pinned_buffer = self.pinned_buffers[buffer_size_mb]
else:
# Allocate temporary pinned buffer
pinned_buffer = cupy.cuda.alloc_pinned_memory(file_size)
# Read file to pinned memory
f.readinto(pinned_buffer)
# Use torch.load with map_location to specific GPU
# This will be faster due to pinned memory
state_dict = torch.load(
f,
map_location=f'cuda:{self.device}',
weights_only=True
)
self.stats['gds_loads'] += 1
self.stats['total_bytes_gds'] += file_size
return state_dict
except Exception as e:
logging.error(f"GDS PyTorch loading failed: {e}")
raise
def _load_fallback(self, file_path: str) -> Dict[str, torch.Tensor]:
"""Fallback loading method using standard approaches"""
if file_path.lower().endswith(('.safetensors', '.sft')):
# Use safetensors with device parameter
with safetensors.safe_open(file_path, framework="pt", device=f'cuda:{self.device}') as f:
return {k: f.get_tensor(k) for k in f.keys()}
else:
# Standard PyTorch loading
return torch.load(file_path, map_location=f'cuda:{self.device}', weights_only=True)
def load_model(self, file_path: str, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
"""
Main entry point for loading models with GDS
Args:
file_path: Path to the model file
device: Target device (if None, uses current CUDA device)
Returns:
Dictionary of tensors loaded directly to GPU
"""
if device is not None and device.type == 'cuda':
self.device = device.index or 0
if self._should_use_gds(file_path):
logging.info(f"Loading {file_path} with GDS")
return self._load_with_gds(file_path)
else:
logging.info(f"Loading {file_path} with standard method")
self.stats['fallback_loads'] += 1
return self._load_fallback(file_path)
def prefetch_model(self, file_path: str) -> bool:
"""
Prefetch model to GPU memory cache (if supported)
Args:
file_path: Path to the model file
Returns:
True if prefetch was successful
"""
if not self.config.prefetch_enabled or not self._gds_available:
return False
try:
# Basic prefetch implementation
# This would ideally use NVIDIA's GPUDirect Storage API
# to warm up the storage cache
file_size = self._get_file_size(file_path)
logging.info(f"Prefetching {file_path} ({file_size // (1024*1024)} MB)")
# Read file metadata to warm caches
with open(file_path, 'rb') as f:
# Read first and last chunks to trigger prefetch
f.read(1024 * 1024) # First 1MB
f.seek(-min(1024 * 1024, file_size), 2) # Last 1MB
f.read()
return True
except Exception as e:
logging.warning(f"Prefetch failed for {file_path}: {e}")
return False
def get_stats(self) -> Dict[str, Any]:
"""Get loading statistics"""
total_loads = self.stats['gds_loads'] + self.stats['fallback_loads']
if self.stats['total_time_gds'] > 0 and self.stats['total_bytes_gds'] > 0:
bandwidth_gbps = (self.stats['total_bytes_gds'] / (1024**3)) / self.stats['total_time_gds']
self.stats['avg_bandwidth_gbps'] = bandwidth_gbps
return {
**self.stats,
'total_loads': total_loads,
'gds_usage_percent': (self.stats['gds_loads'] / max(1, total_loads)) * 100,
'gds_available': self._gds_available,
'config': self.config.__dict__
}
def cleanup(self):
"""Clean up GDS resources"""
try:
# Clear CUDA streams
for stream in self.cuda_streams:
stream.synchronize()
self.cuda_streams.clear()
# Free pinned buffers
for buffer in self.pinned_buffers.values():
if CUPY_AVAILABLE:
cupy.cuda.free_pinned_memory(buffer)
self.pinned_buffers.clear()
# Force garbage collection
gc.collect()
torch.cuda.empty_cache()
except Exception as e:
logging.warning(f"GDS cleanup failed: {e}")
def __del__(self):
"""Destructor to ensure cleanup"""
self.cleanup()
# Global GDS instance
_gds_instance: Optional[GPUDirectStorage] = None
def get_gds_instance(config: Optional[GDSConfig] = None) -> GPUDirectStorage:
"""Get or create the global GDS instance"""
global _gds_instance
if _gds_instance is None:
_gds_instance = GPUDirectStorage(config)
return _gds_instance
def load_torch_file_gds(ckpt: str, safe_load: bool = False, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
"""
GDS-enabled replacement for comfy.utils.load_torch_file
Args:
ckpt: Path to checkpoint file
safe_load: Whether to use safe loading (for compatibility)
device: Target device
Returns:
Dictionary of loaded tensors
"""
gds = get_gds_instance()
try:
# Load with GDS
return gds.load_model(ckpt, device)
except Exception as e:
logging.error(f"GDS loading failed, falling back to standard method: {e}")
# Fallback to original method
import comfy.utils
return comfy.utils.load_torch_file(ckpt, safe_load=safe_load, device=device)
def prefetch_model_gds(file_path: str) -> bool:
"""Prefetch model for faster loading"""
gds = get_gds_instance()
return gds.prefetch_model(file_path)
def get_gds_stats() -> Dict[str, Any]:
"""Get GDS statistics"""
gds = get_gds_instance()
return gds.get_stats()
def configure_gds(config: GDSConfig):
"""Configure GDS settings"""
global _gds_instance
_gds_instance = GPUDirectStorage(config)
def init_gds(config: GDSConfig):
"""
Initialize GPUDirect Storage with the provided configuration
Args:
config: GDSConfig object with initialization parameters
"""
try:
# Configure GDS
configure_gds(config)
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
# Set up exit handler for stats if requested
if hasattr(config, 'show_stats') and config.show_stats:
import atexit
def print_gds_stats():
stats = get_gds_stats()
logging.info("=== GDS Statistics ===")
logging.info(f"Total loads: {stats['total_loads']}")
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
logging.info(f"Fallback loads: {stats['fallback_loads']}")
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
logging.info("===================")
atexit.register(print_gds_stats)
except ImportError as e:
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
except Exception as e:
logging.error(f"GDS initialization failed: {e}")

View File

@ -189,9 +189,12 @@ class AudioVAE(torch.nn.Module):
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels:
raise ValueError(
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
)
if waveform.shape[1] == 1:
waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:])
else:
raise ValueError(
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
)
mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device

View File

@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
import comfy.utils
def modulate(x, scale):
return x * (1 + scale.unsqueeze(1))
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
def modulate(x, scale, timestep_zero_index=None):
if timestep_zero_index is None:
return x * (1 + scale.unsqueeze(1))
else:
scale = (1 + scale.unsqueeze(1))
actual_batch = scale.size(0) // 2
slices = timestep_zero_index
invert = invert_slices(timestep_zero_index, x.shape[1])
for s in slices:
x[:, s[0]:s[1]] *= scale[actual_batch:]
for s in invert:
x[:, s[0]:s[1]] *= scale[:actual_batch]
return x
def apply_gate(gate, x, timestep_zero_index=None):
if timestep_zero_index is None:
return gate * x
else:
actual_batch = gate.size(0) // 2
slices = timestep_zero_index
invert = invert_slices(timestep_zero_index, x.shape[1])
for s in slices:
x[:, s[0]:s[1]] *= gate[actual_batch:]
for s in invert:
x[:, s[0]:s[1]] *= gate[:actual_batch]
return x
#############################################################################
# Core NextDiT Model #
@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
timestep_zero_index=None,
transformer_options={},
):
"""
@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module):
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2(
clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa),
modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index),
x_mask,
freqs_cis,
transformer_options=transformer_options,
))
))), timestep_zero_index=timestep_zero_index
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2(
clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
))
modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index),
))), timestep_zero_index=timestep_zero_index
)
else:
assert adaln_input is None
@ -345,13 +389,37 @@ class FinalLayer(nn.Module):
),
)
def forward(self, x, c):
def forward(self, x, c, timestep_zero_index=None):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index)
x = self.linear(x)
return x
def pad_zimage(feats, pad_token, pad_tokens_multiple):
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra
def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}):
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = start_t
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
return x_pos_ids
class NextDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
@ -378,6 +446,7 @@ class NextDiT(nn.Module):
time_scale=1.0,
pad_tokens_multiple=None,
clip_text_dim=None,
siglip_feat_dim=None,
image_model=None,
device=None,
dtype=None,
@ -491,6 +560,41 @@ class NextDiT(nn.Module):
for layer_id in range(n_layers)
]
)
if siglip_feat_dim is not None:
self.siglip_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
siglip_feat_dim,
dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.siglip_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
else:
self.siglip_embedder = None
self.siglip_refiner = None
self.siglip_pad_token = None
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
@ -531,70 +635,168 @@ class NextDiT(nn.Module):
imgs = torch.stack(imgs, dim=0)
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None):
if cap_feats is not None:
cap_feats = self.cap_embedder(cap_feats)
cap_feats_len = cap_feats.shape[1]
if self.pad_tokens_multiple is not None:
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
else:
cap_feats_len = 0
cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
embeds = (cap_feats,)
freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),)
return embeds, freqs_cis, cap_feats_len
def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}):
bsz = 1
pH = pW = self.patch_size
device = x.device
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
if (not omni) or self.siglip_embedder is None:
cap_feats_len = embeds[0].shape[1] + offset
embeds += (None,)
freqs_cis += (None,)
else:
cap_feats_len += offset
if siglip_feats is not None:
b, h, w, c = siglip_feats.shape
siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c)
siglip_feats = self.siglip_embedder(siglip_feats)
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
siglip_pos_ids[:, :, 0] = cap_feats_len + 2
siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten()
siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten()
if self.siglip_pad_token is not None:
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
else:
if self.siglip_pad_token is not None:
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
if siglip_feats is None:
embeds += (None,)
freqs_cis += (None,)
else:
embeds += (siglip_feats,)
freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),)
B, C, H, W = x.shape
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
H_tokens, W_tokens = H // pH, W // pW
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options)
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
embeds += (x,)
freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),)
return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1
def patchify_and_embed(
self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = x.shape[0]
cap_mask = None # TODO?
main_siglip = None
orig_x = x
embeds = ([], [], [])
freqs_cis = ([], [], [])
leftover_cap = []
start_t = 0
omni = len(ref_latents) > 0
if omni:
for i, ref in enumerate(ref_latents):
if i < len(ref_contexts):
ref_con = ref_contexts[i]
else:
ref_con = None
if i < len(siglip_feats):
sig_feat = siglip_feats[i]
else:
sig_feat = None
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
for i, e in enumerate(out[0]):
if e is not None:
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
freqs_cis[i].append(out[1][i])
start_t = out[2]
leftover_cap = ref_contexts[len(ref_latents):]
H, W = x.shape[-2], x.shape[-1]
img_sizes = [(H, W)] * bsz
out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options)
img_len = out[0][-1].shape[1]
cap_len = out[0][0].shape[1]
for i, e in enumerate(out[0]):
if e is not None:
e = comfy.utils.repeat_to_batch_size(e, bsz)
embeds[i].append(e)
freqs_cis[i].append(out[1][i])
start_t = out[2]
for cap in leftover_cap:
out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype)
cap_len += out[0][0].shape[1]
embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz))
freqs_cis[0].append(out[1][0])
start_t += out[2]
patches = transformer_options.get("patches", {})
# refine context
cap_feats = torch.cat(embeds[0], dim=1)
cap_freqs_cis = torch.cat(freqs_cis[0], dim=1)
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
feats = (cap_feats,)
fc = (cap_freqs_cis,)
if omni and len(embeds[1]) > 0:
siglip_mask = None
siglip_feats_combined = torch.cat(embeds[1], dim=1)
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
if self.siglip_refiner is not None:
for layer in self.siglip_refiner:
siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options)
feats += (siglip_feats_combined,)
fc += (siglip_feats_freqs_cis,)
padded_img_mask = None
x = torch.cat(embeds[-1], dim=1)
fc_x = torch.cat(freqs_cis[-1], dim=1)
if omni:
timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])]
else:
timestep_zero_index = None
x_input = x
for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
padded_full_embed = torch.cat(feats + (x,), dim=1)
if timestep_zero_index is not None:
ind = padded_full_embed.shape[1] - x.shape[1]
timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])]
timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1]))
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@ -604,7 +806,11 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
omni = len(ref_latents) > 0
if omni:
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@ -619,8 +825,6 @@ class NextDiT(nn.Module):
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
@ -632,7 +836,7 @@ class NextDiT(nn.Module):
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
@ -640,7 +844,7 @@ class NextDiT(nn.Module):
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
@ -649,8 +853,7 @@ class NextDiT(nn.Module):
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
img = self.final_layer(img, adaln_input)
img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index)
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -img

View File

@ -1150,6 +1150,7 @@ class CosmosPredict2(BaseModel):
class Lumina2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@ -1169,6 +1170,35 @@ class Lumina2(BaseModel):
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni
if clip_vision_outputs is not None and len(clip_vision_outputs) > 0:
sigfeats = []
for clip_vision_output in clip_vision_outputs:
if clip_vision_output is not None:
image_size = clip_vision_output.image_sizes[0]
shape = clip_vision_output.last_hidden_state.shape
sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1]))
if len(sigfeats) > 0:
out['siglip_feats'] = comfy.conds.CONDList(sigfeats)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_contexts = kwargs.get("reference_latents_text_embeds", None)
if ref_contexts is not None:
out['ref_contexts'] = comfy.conds.CONDList(ref_contexts)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class WAN21(BaseModel):

View File

@ -446,6 +446,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["time_scale"] = 1000.0
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
if sig_weight is not None:
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
return dit_config

View File

@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return OvisTEModel_

View File

@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ZImageTEModel_

View File

@ -57,6 +57,18 @@ else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
# Try GDS loading first if available and device is GPU
if device is not None and device.type == 'cuda':
try:
from . import gds_loader
gds_result = gds_loader.load_torch_file_gds(ckpt, safe_load=safe_load, device=device)
if return_metadata:
# For GDS, we return empty metadata for now (can be enhanced)
return (gds_result, {})
return gds_result
except Exception as e:
logging.debug(f"GDS loading failed, using fallback: {e}")
if device is None:
device = torch.device("cpu")
metadata = None
@ -639,6 +651,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2
}
for k in block_map:

View File

@ -1000,20 +1000,38 @@ class Autogrow(ComfyTypeI):
names = [f"{prefix}{i}" for i in range(max)]
# need to create a new input based on the contents of input
template_input = None
for _, dict_input in input.items():
# for now, get just the first value from dict_input
template_required = True
for _input_type, dict_input in input.items():
# for now, get just the first value from dict_input; if not required, min can be ignored
if len(dict_input) == 0:
continue
template_input = list(dict_input.values())[0]
template_required = _input_type == "required"
break
if template_input is None:
raise Exception("template_input could not be determined from required or optional; this should never happen.")
new_dict = {}
new_dict_added_to = False
# first, add possible inputs into out_dict
for i, name in enumerate(names):
expected_id = finalize_prefix(curr_prefix, name)
# required
if i < min and template_required:
out_dict["required"][expected_id] = template_input
type_dict = new_dict.setdefault("required", {})
# optional
else:
out_dict["optional"][expected_id] = template_input
type_dict = new_dict.setdefault("optional", {})
if expected_id in live_inputs:
# required
if i < min:
type_dict = new_dict.setdefault("required", {})
# optional
else:
type_dict = new_dict.setdefault("optional", {})
# NOTE: prefix gets added in parse_class_inputs
type_dict[name] = template_input
new_dict_added_to = True
# account for the edge case that all inputs are optional and no values are received
if not new_dict_added_to:
finalized_prefix = finalize_prefix(curr_prefix)
out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix
out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_DICT
parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
@ -1151,6 +1169,8 @@ class V3Data(TypedDict):
'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.'
dynamic_paths: dict[str, Any]
'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.'
dynamic_paths_default_value: dict[str, Any]
'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.'
create_dynamic_tuple: bool
'When True, the value of the dynamic input will be in the format (value, path_key).'
@ -1504,6 +1524,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
"required": {},
"optional": {},
"dynamic_paths": {},
"dynamic_paths_default_value": {},
}
d = d.copy()
# ignore hidden for parsing
@ -1513,8 +1534,12 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i
out_dict["hidden"] = hidden
v3_data = {}
dynamic_paths = out_dict.pop("dynamic_paths", None)
if dynamic_paths is not None:
if dynamic_paths is not None and len(dynamic_paths) > 0:
v3_data["dynamic_paths"] = dynamic_paths
# this list is used for autogrow, in the case all inputs are optional and no values are passed
dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None)
if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0:
v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value
return out_dict, hidden, v3_data
def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None:
@ -1551,11 +1576,16 @@ def add_to_dict_v1(i: Input, d: dict):
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
class DynamicPathsDefaultValue:
EMPTY_DICT = "empty_dict"
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None)
default_value_dict = v3_data.get("dynamic_paths_default_value", {})
if paths is None:
return values
values = values.copy()
result = {}
create_tuple = v3_data.get("create_dynamic_tuple", False)
@ -1569,6 +1599,11 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
if is_last:
value = values.pop(key, None)
if value is None:
# see if a default value was provided for this key
default_option = default_value_dict.get(key, None)
if default_option == DynamicPathsDefaultValue.EMPTY_DICT:
value = {}
if create_tuple:
value = (value, key)
current[p] = value

View File

@ -0,0 +1,61 @@
from typing import TypedDict
from pydantic import BaseModel, Field
class InputModerationSettings(TypedDict):
prompt_content_moderation: bool
visual_input_moderation: bool
visual_output_moderation: bool
class BriaEditImageRequest(BaseModel):
instruction: str | None = Field(...)
structured_instruction: str | None = Field(
...,
description="Use this instead of instruction for precise, programmatic control.",
)
images: list[str] = Field(
...,
description="Required. Publicly available URL or Base64-encoded. Must contain exactly one item.",
)
mask: str | None = Field(
None,
description="Mask image (black and white). Black areas will be preserved, white areas will be edited. "
"If omitted, the edit applies to the entire image. "
"The input image and the the input mask must be of the same size.",
)
negative_prompt: str | None = Field(None)
guidance_scale: float = Field(...)
model_version: str = Field(...)
steps_num: int = Field(...)
seed: int = Field(...)
ip_signal: bool = Field(
False,
description="If true, returns a warning for potential IP content in the instruction.",
)
prompt_content_moderation: bool = Field(
False, description="If true, returns 422 on instruction moderation failure."
)
visual_input_content_moderation: bool = Field(
False, description="If true, returns 422 on images or mask moderation failure."
)
visual_output_content_moderation: bool = Field(
False, description="If true, returns 422 on visual output moderation failure."
)
class BriaStatusResponse(BaseModel):
request_id: str = Field(...)
status_url: str = Field(...)
warning: str | None = Field(None)
class BriaResult(BaseModel):
structured_prompt: str = Field(...)
image_url: str = Field(...)
class BriaResponse(BaseModel):
status: str = Field(...)
result: BriaResult | None = Field(None)

View File

@ -0,0 +1,198 @@
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import (
BriaEditImageRequest,
BriaResponse,
BriaStatusResponse,
InputModerationSettings,
)
from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
get_number_of_images,
poll_op,
sync_op,
upload_images_to_comfyapi,
)
class BriaImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaImageEditNode",
display_name="Bria Image Edit",
category="api node/image/Bria",
description="Edit images using Bria latest model",
inputs=[
IO.Combo.Input("model", options=["FIBO"]),
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Instruction to edit image",
),
IO.String.Input("negative_prompt", multiline=True, default=""),
IO.String.Input(
"structured_prompt",
multiline=True,
default="",
tooltip="A string containing the structured edit prompt in JSON format. "
"Use this instead of usual prompt for precise, programmatic control.",
),
IO.Int.Input(
"seed",
default=1,
min=1,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Float.Input(
"guidance_scale",
default=3,
min=3,
max=5,
step=0.01,
display_mode=IO.NumberDisplay.number,
tooltip="Higher value makes the image follow the prompt more closely.",
),
IO.Int.Input(
"steps",
default=50,
min=20,
max=50,
step=1,
display_mode=IO.NumberDisplay.number,
),
IO.DynamicCombo.Input(
"moderation",
options=[
IO.DynamicCombo.Option(
"true",
[
IO.Boolean.Input(
"prompt_content_moderation", default=False
),
IO.Boolean.Input(
"visual_input_moderation", default=False
),
IO.Boolean.Input(
"visual_output_moderation", default=True
),
],
),
IO.DynamicCombo.Option("false", []),
],
tooltip="Moderation settings",
),
IO.Mask.Input(
"mask",
tooltip="If omitted, the edit applies to the entire image.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
IO.String.Output(display_name="structured_prompt"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.04}""",
),
)
@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
prompt: str,
negative_prompt: str,
structured_prompt: str,
seed: int,
guidance_scale: float,
steps: int,
moderation: InputModerationSettings,
mask: Input.Image | None = None,
) -> IO.NodeOutput:
if not prompt and not structured_prompt:
raise ValueError(
"One of prompt or structured_prompt is required to be non-empty."
)
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
mask_url = None
if mask is not None:
mask_url = (
await upload_images_to_comfyapi(
cls,
convert_mask_to_image(mask),
max_images=1,
mime_type="image/png",
wait_label="Uploading mask",
)
)[0]
response = await sync_op(
cls,
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
data=BriaEditImageRequest(
instruction=prompt if prompt else None,
structured_instruction=structured_prompt if structured_prompt else None,
images=await upload_images_to_comfyapi(
cls,
image,
max_images=1,
mime_type="image/png",
wait_label="Uploading image",
),
mask=mask_url,
negative_prompt=negative_prompt if negative_prompt else None,
guidance_scale=guidance_scale,
seed=seed,
model_version=model,
steps_num=steps,
prompt_content_moderation=moderation.get(
"prompt_content_moderation", False
),
visual_input_content_moderation=moderation.get(
"visual_input_moderation", False
),
visual_output_content_moderation=moderation.get(
"visual_output_moderation", False
),
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaResponse,
)
return IO.NodeOutput(
await download_url_to_image_tensor(response.result.image_url),
response.result.structured_prompt,
)
class BriaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BriaImageEditNode,
]
async def comfy_entrypoint() -> BriaExtension:
return BriaExtension()

View File

@ -11,6 +11,7 @@ from .conversions import (
audio_input_to_mp3,
audio_to_base64_string,
bytesio_to_image_tensor,
convert_mask_to_image,
downscale_image_tensor,
image_tensor_pair_to_batch,
pil_to_bytesio,
@ -72,6 +73,7 @@ __all__ = [
"audio_input_to_mp3",
"audio_to_base64_string",
"bytesio_to_image_tensor",
"convert_mask_to_image",
"downscale_image_tensor",
"image_tensor_pair_to_batch",
"pil_to_bytesio",

View File

@ -451,6 +451,12 @@ def resize_mask_to_image(
return mask
def convert_mask_to_image(mask: Input.Image) -> torch.Tensor:
"""Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image."""
mask = mask.unsqueeze(-1)
return torch.cat([mask] * 3, dim=-1)
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:

293
comfy_extras/nodes_gds.py Normal file
View File

@ -0,0 +1,293 @@
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
"""
Enhanced model loading nodes with GPUDirect Storage support
"""
import logging
import time
import asyncio
from typing import Optional, Dict, Any
import torch
import folder_paths
import comfy.sd
import comfy.utils
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
class CheckpointLoaderGDS(ComfyNodeABC):
"""
Enhanced checkpoint loader with GPUDirect Storage support
Provides direct SSD-to-GPU loading and prefetching capabilities
"""
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {
"tooltip": "The name of the checkpoint (model) to load with GDS optimization."
}),
},
"optional": {
"prefetch": ("BOOLEAN", {
"default": False,
"tooltip": "Prefetch model to GPU cache for faster loading."
}),
"use_gds": ("BOOLEAN", {
"default": True,
"tooltip": "Use GPUDirect Storage if available."
}),
"target_device": (["auto", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cpu"], {
"default": "auto",
"tooltip": "Target device for model loading."
})
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING")
RETURN_NAMES = ("model", "clip", "vae", "load_info")
OUTPUT_TOOLTIPS = (
"The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.",
"Loading information and statistics."
)
FUNCTION = "load_checkpoint_gds"
CATEGORY = "loaders/advanced"
DESCRIPTION = "Enhanced checkpoint loader with GPUDirect Storage support for direct SSD-to-GPU loading."
EXPERIMENTAL = True
def load_checkpoint_gds(self, ckpt_name: str, prefetch: bool = False, use_gds: bool = True, target_device: str = "auto"):
start_time = time.time()
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
# Determine target device
if target_device == "auto":
device = None # Let the system decide
elif target_device == "cpu":
device = torch.device("cpu")
else:
device = torch.device(target_device)
load_info = {
"file": ckpt_name,
"path": ckpt_path,
"target_device": str(device) if device else "auto",
"gds_enabled": use_gds,
"prefetch_used": prefetch
}
try:
# Prefetch if requested
if prefetch and use_gds:
try:
from comfy.gds_loader import prefetch_model_gds
prefetch_success = prefetch_model_gds(ckpt_path)
load_info["prefetch_success"] = prefetch_success
if prefetch_success:
logging.info(f"Prefetched {ckpt_name} to GPU cache")
except Exception as e:
logging.warning(f"Prefetch failed for {ckpt_name}: {e}")
load_info["prefetch_error"] = str(e)
# Load checkpoint with potential GDS optimization
if use_gds and device and device.type == 'cuda':
try:
from comfy.gds_loader import get_gds_instance
gds = get_gds_instance()
# Check if GDS should be used for this file
if gds._should_use_gds(ckpt_path):
load_info["loader_used"] = "GDS"
logging.info(f"Loading {ckpt_name} with GDS")
else:
load_info["loader_used"] = "Standard"
logging.info(f"Loading {ckpt_name} with standard method (file too small for GDS)")
except Exception as e:
logging.warning(f"GDS check failed, using standard loading: {e}")
load_info["loader_used"] = "Standard (GDS failed)"
else:
load_info["loader_used"] = "Standard"
# Load the actual checkpoint
out = comfy.sd.load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings")
)
load_time = time.time() - start_time
load_info["load_time_seconds"] = round(load_time, 3)
load_info["load_success"] = True
# Format load info as string
info_str = f"Loaded: {ckpt_name}\n"
info_str += f"Method: {load_info['loader_used']}\n"
info_str += f"Time: {load_info['load_time_seconds']}s\n"
info_str += f"Device: {load_info['target_device']}"
if "prefetch_success" in load_info:
info_str += f"\nPrefetch: {'' if load_info['prefetch_success'] else ''}"
logging.info(f"Checkpoint loaded: {ckpt_name} in {load_time:.3f}s using {load_info['loader_used']}")
return (*out[:3], info_str)
except Exception as e:
load_info["load_success"] = False
load_info["error"] = str(e)
error_str = f"Failed to load: {ckpt_name}\nError: {str(e)}"
logging.error(f"Checkpoint loading failed: {e}")
raise RuntimeError(error_str)
class ModelPrefetcher(ComfyNodeABC):
"""
Node for prefetching models to GPU cache
"""
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"checkpoint_names": ("STRING", {
"multiline": True,
"default": "",
"tooltip": "List of checkpoint names to prefetch (one per line)."
}),
"prefetch_enabled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable/disable prefetching."
})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("prefetch_report",)
OUTPUT_TOOLTIPS = ("Report of prefetch operations.",)
FUNCTION = "prefetch_models"
CATEGORY = "loaders/advanced"
DESCRIPTION = "Prefetch multiple models to GPU cache for faster loading."
OUTPUT_NODE = True
def prefetch_models(self, checkpoint_names: str, prefetch_enabled: bool = True):
if not prefetch_enabled:
return ("Prefetching disabled",)
# Parse checkpoint names
names = [name.strip() for name in checkpoint_names.split('\n') if name.strip()]
if not names:
return ("No checkpoints specified for prefetching",)
try:
from comfy.gds_loader import prefetch_model_gds
except ImportError:
return ("GDS not available for prefetching",)
results = []
successful_prefetches = 0
for name in names:
try:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", name)
success = prefetch_model_gds(ckpt_path)
if success:
results.append(f"{name}")
successful_prefetches += 1
else:
results.append(f"{name} (prefetch failed)")
except Exception as e:
results.append(f"{name} (error: {str(e)[:50]})")
report = f"Prefetch Report ({successful_prefetches}/{len(names)} successful):\n"
report += "\n".join(results)
return (report,)
class GDSStats(ComfyNodeABC):
"""
Node for displaying GDS statistics
"""
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"refresh": ("BOOLEAN", {
"default": False,
"tooltip": "Refresh statistics."
})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("stats_report",)
OUTPUT_TOOLTIPS = ("GDS statistics and performance report.",)
FUNCTION = "get_stats"
CATEGORY = "utils/advanced"
DESCRIPTION = "Display GPUDirect Storage statistics and performance metrics."
OUTPUT_NODE = True
def get_stats(self, refresh: bool = False):
try:
from comfy.gds_loader import get_gds_stats
stats = get_gds_stats()
report = "=== GPUDirect Storage Statistics ===\n\n"
# Availability
report += f"GDS Available: {'' if stats['gds_available'] else ''}\n"
# Usage statistics
report += f"Total Loads: {stats['total_loads']}\n"
report += f"GDS Loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)\n"
report += f"Fallback Loads: {stats['fallback_loads']}\n\n"
# Performance metrics
if stats['total_bytes_gds'] > 0:
gb_transferred = stats['total_bytes_gds'] / (1024**3)
report += f"Data Transferred: {gb_transferred:.2f} GB\n"
report += f"Average Bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s\n"
report += f"Total GDS Time: {stats['total_time_gds']:.2f}s\n\n"
# Configuration
config = stats.get('config', {})
if config:
report += "Configuration:\n"
report += f"- Enabled: {config.get('enabled', 'Unknown')}\n"
report += f"- Min File Size: {config.get('min_file_size_mb', 'Unknown')} MB\n"
report += f"- Chunk Size: {config.get('chunk_size_mb', 'Unknown')} MB\n"
report += f"- Max Streams: {config.get('max_concurrent_streams', 'Unknown')}\n"
report += f"- Prefetch: {config.get('prefetch_enabled', 'Unknown')}\n"
report += f"- Fallback: {config.get('fallback_to_cpu', 'Unknown')}\n"
return (report,)
except ImportError:
return ("GDS module not available",)
except Exception as e:
return (f"Error retrieving GDS stats: {str(e)}",)
# Node mappings
NODE_CLASS_MAPPINGS = {
"CheckpointLoaderGDS": CheckpointLoaderGDS,
"ModelPrefetcher": ModelPrefetcher,
"GDSStats": GDSStats,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoaderGDS": "Load Checkpoint (GDS)",
"ModelPrefetcher": "Model Prefetcher",
"GDSStats": "GDS Statistics",
}

View File

@ -0,0 +1,88 @@
import node_helpers
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import math
import comfy.utils
class TextEncodeZImageOmni(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeZImageOmni",
category="advanced/conditioning",
is_experimental=True,
inputs=[
io.Clip.Input("clip"),
io.ClipVision.Input("image_encoder", optional=True),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Boolean.Input("auto_resize_images", default=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),
io.Image.Input("image3", optional=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
ref_latents = []
images = list(filter(lambda a: a is not None, [image1, image2, image3]))
prompt_list = []
template = None
if len(images) > 0:
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1)
prompt_list += ["<|vision_end|><|im_end|>"]
template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"
encoded_images = []
for i, image in enumerate(images):
if image_encoder is not None:
encoded_images.append(image_encoder.encode_image(image))
if vae is not None:
if auto_resize_images:
samples = image.movedim(-1, 1)
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by / 8.0) * 8
height = round(samples.shape[2] * scale_by / 8.0) * 8
image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
ref_latents.append(vae.encode(image))
tokens = clip.tokenize(prompt, llama_template=template)
conditioning = clip.encode_from_tokens_scheduled(tokens)
extra_text_embeds = []
for p in prompt_list:
tokens = clip.tokenize(p, llama_template="{}")
text_embeds = clip.encode_from_tokens_scheduled(tokens)
extra_text_embeds.append(text_embeds[0][0])
if len(ref_latents) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
if len(encoded_images) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True)
if len(extra_text_embeds) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True)
return io.NodeOutput(conditioning)
class ZImageExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeZImageOmni,
]
async def comfy_entrypoint() -> ZImageExtension:
return ZImageExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.9.2"
__version__ = "0.10.0"

29
main.py
View File

@ -184,6 +184,35 @@ import comfyui_version
import app.logger
import hook_breaker_ac10a0
# Initialize GPUDirect Storage if enabled
def init_gds():
"""Initialize GPUDirect Storage based on CLI arguments"""
if hasattr(args, 'disable_gds') and args.disable_gds:
logging.info("GDS explicitly disabled via --disable-gds")
return
if not hasattr(args, 'enable_gds') and not hasattr(args, 'gds_prefetch') and not hasattr(args, 'gds_stats'):
# GDS not explicitly requested, use auto-detection
return
if hasattr(args, 'enable_gds') and args.enable_gds:
from comfy.gds_loader import GDSConfig, init_gds as gds_init
config = GDSConfig(
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
max_concurrent_streams=getattr(args, 'gds_streams', 4),
prefetch_enabled=getattr(args, 'gds_prefetch', True),
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False),
show_stats=getattr(args, 'gds_stats', False)
)
gds_init(config)
# Initialize GDS
init_gds()
def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)

View File

@ -2367,12 +2367,14 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_gds.py",
"nodes_rope.py",
"nodes_logic.py",
"nodes_nop.py",
"nodes_kandinsky5.py",
"nodes_wanmove.py",
"nodes_image_compare.py",
"nodes_zimage.py",
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.9.2"
version = "0.10.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.36.14
comfyui-workflow-templates==0.8.11
comfyui-frontend-package==1.37.11
comfyui-workflow-templates==0.8.15
comfyui-embedded-docs==0.4.0
torch
torchsde
@ -27,4 +27,4 @@ comfy-kitchen>=0.2.7
kornia>=0.7.1
spandrel
pydantic~=2.0
pydantic-settings~=2.0
pydantic-settings~=2.0