mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 02:53:05 +08:00
Merge 7ec1656735 into 5ac3b26a7d
This commit is contained in:
commit
1b00baf940
@ -399,6 +399,14 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
|
|||||||
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.
|
||||||
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
<br/><br/>If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal.
|
||||||
|
|
||||||
|
## How to run heavy workflow on mid range GPU (NVIDIA-Linux)?
|
||||||
|
|
||||||
|
Use the `--enable-gds` flag to activate NVIDIA [GPUDirect Storage](https://docs.nvidia.com/gpudirect-storage/) (GDS), which allows data to be transferred directly between SSDs and GPUs. This eliminates traditional CPU-mediated data paths, significantly reducing I/O latency and CPU overhead. System RAM will still be utilized for caching to further optimize performance, along with SSD.
|
||||||
|
|
||||||
|
This feature is tested on NVIDIA GPUs on Linux based system only.
|
||||||
|
|
||||||
|
Requires: `cupy-cuda12x>=12.0.0`, `pynvml>=11.4.1`, `cudf>=23.0.0`, `numba>=0.57.0`, `nvidia-ml-py>=12.0.0`.
|
||||||
|
|
||||||
## Support and dev channel
|
## Support and dev channel
|
||||||
|
|
||||||
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
|
[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
|
||||||
|
|||||||
@ -147,6 +147,17 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
|
|||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||||
|
|
||||||
|
# GPUDirect Storage (GDS) arguments
|
||||||
|
gds_group = parser.add_argument_group('gds', 'GPUDirect Storage options for direct SSD-to-GPU model loading')
|
||||||
|
gds_group.add_argument("--enable-gds", action="store_true", help="Enable GPUDirect Storage for direct SSD-to-GPU model loading (requires CUDA 11.4+, cuFile).")
|
||||||
|
gds_group.add_argument("--disable-gds", action="store_true", help="Explicitly disable GPUDirect Storage.")
|
||||||
|
gds_group.add_argument("--gds-min-file-size", type=int, default=100, help="Minimum file size in MB to use GDS (default: 100MB).")
|
||||||
|
gds_group.add_argument("--gds-chunk-size", type=int, default=64, help="GDS transfer chunk size in MB (default: 64MB).")
|
||||||
|
gds_group.add_argument("--gds-streams", type=int, default=4, help="Number of CUDA streams for GDS operations (default: 4).")
|
||||||
|
gds_group.add_argument("--gds-prefetch", action="store_true", help="Enable GDS prefetching for better performance.")
|
||||||
|
gds_group.add_argument("--gds-no-fallback", action="store_true", help="Disable fallback to CPU loading if GDS fails.")
|
||||||
|
gds_group.add_argument("--gds-stats", action="store_true", help="Print GDS statistics on exit.")
|
||||||
|
|
||||||
class PerformanceFeature(enum.Enum):
|
class PerformanceFeature(enum.Enum):
|
||||||
Fp16Accumulation = "fp16_accumulation"
|
Fp16Accumulation = "fp16_accumulation"
|
||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
|
|||||||
494
comfy/gds_loader.py
Normal file
494
comfy/gds_loader.py
Normal file
@ -0,0 +1,494 @@
|
|||||||
|
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
|
||||||
|
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
|
||||||
|
|
||||||
|
"""
|
||||||
|
GPUDirect Storage (GDS) Integration for ComfyUI
|
||||||
|
Direct SSD-to-GPU model loading without RAM/CPU bottlenecks
|
||||||
|
Still there will be some CPU/RAM usage, mostly for safetensors parsing and small buffers.
|
||||||
|
|
||||||
|
This module provides GPUDirect Storage functionality to load models directly
|
||||||
|
from NVMe SSDs to GPU memory, bypassing system RAM and CPU.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
from typing import Optional, Dict, Any, Union
|
||||||
|
from pathlib import Path
|
||||||
|
import safetensors
|
||||||
|
import gc
|
||||||
|
import mmap
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cupy
|
||||||
|
import cupy.cuda.runtime as cuda_runtime
|
||||||
|
CUPY_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
CUPY_AVAILABLE = False
|
||||||
|
logging.warning("CuPy not available. GDS will use fallback mode.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cudf # RAPIDS for GPU dataframes
|
||||||
|
RAPIDS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
RAPIDS_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
NVML_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
NVML_AVAILABLE = False
|
||||||
|
logging.warning("NVIDIA-ML-Py not available. GPU monitoring disabled.")
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GDSConfig:
|
||||||
|
"""Configuration for GPUDirect Storage"""
|
||||||
|
enabled: bool = True
|
||||||
|
min_file_size_mb: int = 100 # Only use GDS for files larger than this
|
||||||
|
chunk_size_mb: int = 64 # Size of chunks to transfer
|
||||||
|
use_pinned_memory: bool = True
|
||||||
|
prefetch_enabled: bool = True
|
||||||
|
compression_aware: bool = True
|
||||||
|
max_concurrent_streams: int = 4
|
||||||
|
fallback_to_cpu: bool = True
|
||||||
|
show_stats: bool = False # Whether to show stats on exit
|
||||||
|
|
||||||
|
|
||||||
|
class GDSError(Exception):
|
||||||
|
"""GDS-specific errors"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GPUDirectStorage:
|
||||||
|
"""
|
||||||
|
GPUDirect Storage implementation for ComfyUI
|
||||||
|
Enables direct SSD-to-GPU transfers for model loading
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[GDSConfig] = None):
|
||||||
|
self.config = config or GDSConfig()
|
||||||
|
self.device = torch.cuda.current_device() if torch.cuda.is_available() else None
|
||||||
|
self.cuda_streams = []
|
||||||
|
self.pinned_buffers = {}
|
||||||
|
self.stats = {
|
||||||
|
'gds_loads': 0,
|
||||||
|
'fallback_loads': 0,
|
||||||
|
'total_bytes_gds': 0,
|
||||||
|
'total_time_gds': 0.0,
|
||||||
|
'avg_bandwidth_gbps': 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize GDS if available
|
||||||
|
self._gds_available = self._check_gds_availability()
|
||||||
|
if self._gds_available:
|
||||||
|
self._init_gds()
|
||||||
|
else:
|
||||||
|
logging.warning("GDS not available, using fallback methods")
|
||||||
|
|
||||||
|
def _check_gds_availability(self) -> bool:
|
||||||
|
"""Check if GDS is available on the system"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not CUPY_AVAILABLE:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for GPUDirect Storage support
|
||||||
|
try:
|
||||||
|
# Check CUDA version (GDS requires CUDA 11.4+)
|
||||||
|
cuda_version = torch.version.cuda
|
||||||
|
if cuda_version:
|
||||||
|
major, minor = map(int, cuda_version.split('.')[:2])
|
||||||
|
if major < 11 or (major == 11 and minor < 4):
|
||||||
|
logging.warning(f"CUDA {cuda_version} detected. GDS requires CUDA 11.4+")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if cuFile is available (part of CUDA toolkit)
|
||||||
|
try:
|
||||||
|
import cupy.cuda.cufile as cufile
|
||||||
|
# Try to initialize cuFile
|
||||||
|
cufile.initialize()
|
||||||
|
return True
|
||||||
|
except (ImportError, RuntimeError) as e:
|
||||||
|
logging.warning(f"cuFile not available: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"GDS availability check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _init_gds(self):
|
||||||
|
"""Initialize GDS resources"""
|
||||||
|
try:
|
||||||
|
# Create CUDA streams for async operations
|
||||||
|
for i in range(self.config.max_concurrent_streams):
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
self.cuda_streams.append(stream)
|
||||||
|
|
||||||
|
# Pre-allocate pinned memory buffers
|
||||||
|
if self.config.use_pinned_memory:
|
||||||
|
self._allocate_pinned_buffers()
|
||||||
|
|
||||||
|
logging.info(f"GDS initialized with {len(self.cuda_streams)} streams")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to initialize GDS: {e}")
|
||||||
|
self._gds_available = False
|
||||||
|
|
||||||
|
def _allocate_pinned_buffers(self):
|
||||||
|
"""Pre-allocate pinned memory buffers for staging"""
|
||||||
|
try:
|
||||||
|
# Allocate buffers of different sizes
|
||||||
|
buffer_sizes = [16, 32, 64, 128, 256] # MB
|
||||||
|
|
||||||
|
for size_mb in buffer_sizes:
|
||||||
|
size_bytes = size_mb * 1024 * 1024
|
||||||
|
# Allocate pinned memory using CuPy
|
||||||
|
if CUPY_AVAILABLE:
|
||||||
|
buffer = cupy.cuda.alloc_pinned_memory(size_bytes)
|
||||||
|
self.pinned_buffers[size_mb] = buffer
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to allocate pinned buffers: {e}")
|
||||||
|
|
||||||
|
def _get_file_size(self, file_path: str) -> int:
|
||||||
|
"""Get file size in bytes"""
|
||||||
|
return os.path.getsize(file_path)
|
||||||
|
|
||||||
|
def _should_use_gds(self, file_path: str) -> bool:
|
||||||
|
"""Determine if GDS should be used for this file"""
|
||||||
|
if not self._gds_available or not self.config.enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
file_size_mb = self._get_file_size(file_path) / (1024 * 1024)
|
||||||
|
return file_size_mb >= self.config.min_file_size_mb
|
||||||
|
|
||||||
|
def _load_with_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Load model using GPUDirect Storage"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if file_path.lower().endswith(('.safetensors', '.sft')):
|
||||||
|
return self._load_safetensors_gds(file_path)
|
||||||
|
else:
|
||||||
|
return self._load_pytorch_gds(file_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS loading failed for {file_path}: {e}")
|
||||||
|
if self.config.fallback_to_cpu:
|
||||||
|
logging.info("Falling back to CPU loading")
|
||||||
|
self.stats['fallback_loads'] += 1
|
||||||
|
return self._load_fallback(file_path)
|
||||||
|
else:
|
||||||
|
raise GDSError(f"GDS loading failed: {e}")
|
||||||
|
finally:
|
||||||
|
load_time = time.time() - start_time
|
||||||
|
self.stats['total_time_gds'] += load_time
|
||||||
|
|
||||||
|
def _load_safetensors_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Load safetensors file using GDS"""
|
||||||
|
try:
|
||||||
|
import cupy.cuda.cufile as cufile
|
||||||
|
|
||||||
|
# Open file with cuFile for direct GPU loading
|
||||||
|
with cufile.CuFileManager() as manager:
|
||||||
|
# Memory-map the file for efficient access
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
# Use mmap for large files
|
||||||
|
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mmapped_file:
|
||||||
|
|
||||||
|
# Parse safetensors header
|
||||||
|
header_size = int.from_bytes(mmapped_file[:8], 'little')
|
||||||
|
header_bytes = mmapped_file[8:8+header_size]
|
||||||
|
|
||||||
|
import json
|
||||||
|
header = json.loads(header_bytes.decode('utf-8'))
|
||||||
|
|
||||||
|
# Load tensors directly to GPU
|
||||||
|
tensors = {}
|
||||||
|
data_offset = 8 + header_size
|
||||||
|
|
||||||
|
for name, info in header.items():
|
||||||
|
if name == "__metadata__":
|
||||||
|
continue
|
||||||
|
|
||||||
|
dtype_map = {
|
||||||
|
'F32': torch.float32,
|
||||||
|
'F16': torch.float16,
|
||||||
|
'BF16': torch.bfloat16,
|
||||||
|
'I8': torch.int8,
|
||||||
|
'I16': torch.int16,
|
||||||
|
'I32': torch.int32,
|
||||||
|
'I64': torch.int64,
|
||||||
|
'U8': torch.uint8,
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype = dtype_map.get(info['dtype'], torch.float32)
|
||||||
|
shape = info['shape']
|
||||||
|
start_offset = data_offset + info['data_offsets'][0]
|
||||||
|
end_offset = data_offset + info['data_offsets'][1]
|
||||||
|
|
||||||
|
# Direct GPU allocation
|
||||||
|
tensor = torch.empty(shape, dtype=dtype, device=f'cuda:{self.device}')
|
||||||
|
|
||||||
|
# Use cuFile for direct transfer
|
||||||
|
tensor_bytes = end_offset - start_offset
|
||||||
|
|
||||||
|
# Get GPU memory pointer
|
||||||
|
gpu_ptr = tensor.data_ptr()
|
||||||
|
|
||||||
|
# Direct file-to-GPU transfer
|
||||||
|
cufile.copy_from_file(
|
||||||
|
gpu_ptr,
|
||||||
|
mmapped_file[start_offset:end_offset],
|
||||||
|
tensor_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
tensors[name] = tensor
|
||||||
|
|
||||||
|
self.stats['gds_loads'] += 1
|
||||||
|
self.stats['total_bytes_gds'] += self._get_file_size(file_path)
|
||||||
|
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS safetensors loading failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _load_pytorch_gds(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Load PyTorch file using GDS with staging"""
|
||||||
|
try:
|
||||||
|
# For PyTorch files, we need to use a staging approach
|
||||||
|
# since torch.load doesn't support direct GPU loading
|
||||||
|
|
||||||
|
# Load to pinned memory first
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
file_size = self._get_file_size(file_path)
|
||||||
|
|
||||||
|
# Choose appropriate buffer or allocate new one
|
||||||
|
buffer_size_mb = min(256, max(64, file_size // (1024 * 1024)))
|
||||||
|
|
||||||
|
if buffer_size_mb in self.pinned_buffers:
|
||||||
|
pinned_buffer = self.pinned_buffers[buffer_size_mb]
|
||||||
|
else:
|
||||||
|
# Allocate temporary pinned buffer
|
||||||
|
pinned_buffer = cupy.cuda.alloc_pinned_memory(file_size)
|
||||||
|
|
||||||
|
# Read file to pinned memory
|
||||||
|
f.readinto(pinned_buffer)
|
||||||
|
|
||||||
|
# Use torch.load with map_location to specific GPU
|
||||||
|
# This will be faster due to pinned memory
|
||||||
|
state_dict = torch.load(
|
||||||
|
f,
|
||||||
|
map_location=f'cuda:{self.device}',
|
||||||
|
weights_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stats['gds_loads'] += 1
|
||||||
|
self.stats['total_bytes_gds'] += file_size
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS PyTorch loading failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _load_fallback(self, file_path: str) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Fallback loading method using standard approaches"""
|
||||||
|
if file_path.lower().endswith(('.safetensors', '.sft')):
|
||||||
|
# Use safetensors with device parameter
|
||||||
|
with safetensors.safe_open(file_path, framework="pt", device=f'cuda:{self.device}') as f:
|
||||||
|
return {k: f.get_tensor(k) for k in f.keys()}
|
||||||
|
else:
|
||||||
|
# Standard PyTorch loading
|
||||||
|
return torch.load(file_path, map_location=f'cuda:{self.device}', weights_only=True)
|
||||||
|
|
||||||
|
def load_model(self, file_path: str, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Main entry point for loading models with GDS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the model file
|
||||||
|
device: Target device (if None, uses current CUDA device)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of tensors loaded directly to GPU
|
||||||
|
"""
|
||||||
|
if device is not None and device.type == 'cuda':
|
||||||
|
self.device = device.index or 0
|
||||||
|
|
||||||
|
if self._should_use_gds(file_path):
|
||||||
|
logging.info(f"Loading {file_path} with GDS")
|
||||||
|
return self._load_with_gds(file_path)
|
||||||
|
else:
|
||||||
|
logging.info(f"Loading {file_path} with standard method")
|
||||||
|
self.stats['fallback_loads'] += 1
|
||||||
|
return self._load_fallback(file_path)
|
||||||
|
|
||||||
|
def prefetch_model(self, file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Prefetch model to GPU memory cache (if supported)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the model file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if prefetch was successful
|
||||||
|
"""
|
||||||
|
if not self.config.prefetch_enabled or not self._gds_available:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Basic prefetch implementation
|
||||||
|
# This would ideally use NVIDIA's GPUDirect Storage API
|
||||||
|
# to warm up the storage cache
|
||||||
|
|
||||||
|
file_size = self._get_file_size(file_path)
|
||||||
|
logging.info(f"Prefetching {file_path} ({file_size // (1024*1024)} MB)")
|
||||||
|
|
||||||
|
# Read file metadata to warm caches
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
# Read first and last chunks to trigger prefetch
|
||||||
|
f.read(1024 * 1024) # First 1MB
|
||||||
|
f.seek(-min(1024 * 1024, file_size), 2) # Last 1MB
|
||||||
|
f.read()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Prefetch failed for {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get loading statistics"""
|
||||||
|
total_loads = self.stats['gds_loads'] + self.stats['fallback_loads']
|
||||||
|
|
||||||
|
if self.stats['total_time_gds'] > 0 and self.stats['total_bytes_gds'] > 0:
|
||||||
|
bandwidth_gbps = (self.stats['total_bytes_gds'] / (1024**3)) / self.stats['total_time_gds']
|
||||||
|
self.stats['avg_bandwidth_gbps'] = bandwidth_gbps
|
||||||
|
|
||||||
|
return {
|
||||||
|
**self.stats,
|
||||||
|
'total_loads': total_loads,
|
||||||
|
'gds_usage_percent': (self.stats['gds_loads'] / max(1, total_loads)) * 100,
|
||||||
|
'gds_available': self._gds_available,
|
||||||
|
'config': self.config.__dict__
|
||||||
|
}
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up GDS resources"""
|
||||||
|
try:
|
||||||
|
# Clear CUDA streams
|
||||||
|
for stream in self.cuda_streams:
|
||||||
|
stream.synchronize()
|
||||||
|
self.cuda_streams.clear()
|
||||||
|
|
||||||
|
# Free pinned buffers
|
||||||
|
for buffer in self.pinned_buffers.values():
|
||||||
|
if CUPY_AVAILABLE:
|
||||||
|
cupy.cuda.free_pinned_memory(buffer)
|
||||||
|
self.pinned_buffers.clear()
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"GDS cleanup failed: {e}")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Destructor to ensure cleanup"""
|
||||||
|
self.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
# Global GDS instance
|
||||||
|
_gds_instance: Optional[GPUDirectStorage] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_gds_instance(config: Optional[GDSConfig] = None) -> GPUDirectStorage:
|
||||||
|
"""Get or create the global GDS instance"""
|
||||||
|
global _gds_instance
|
||||||
|
|
||||||
|
if _gds_instance is None:
|
||||||
|
_gds_instance = GPUDirectStorage(config)
|
||||||
|
|
||||||
|
return _gds_instance
|
||||||
|
|
||||||
|
|
||||||
|
def load_torch_file_gds(ckpt: str, safe_load: bool = False, device: Optional[torch.device] = None) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
GDS-enabled replacement for comfy.utils.load_torch_file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt: Path to checkpoint file
|
||||||
|
safe_load: Whether to use safe loading (for compatibility)
|
||||||
|
device: Target device
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of loaded tensors
|
||||||
|
"""
|
||||||
|
gds = get_gds_instance()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load with GDS
|
||||||
|
return gds.load_model(ckpt, device)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS loading failed, falling back to standard method: {e}")
|
||||||
|
# Fallback to original method
|
||||||
|
import comfy.utils
|
||||||
|
return comfy.utils.load_torch_file(ckpt, safe_load=safe_load, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def prefetch_model_gds(file_path: str) -> bool:
|
||||||
|
"""Prefetch model for faster loading"""
|
||||||
|
gds = get_gds_instance()
|
||||||
|
return gds.prefetch_model(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gds_stats() -> Dict[str, Any]:
|
||||||
|
"""Get GDS statistics"""
|
||||||
|
gds = get_gds_instance()
|
||||||
|
return gds.get_stats()
|
||||||
|
|
||||||
|
|
||||||
|
def configure_gds(config: GDSConfig):
|
||||||
|
"""Configure GDS settings"""
|
||||||
|
global _gds_instance
|
||||||
|
_gds_instance = GPUDirectStorage(config)
|
||||||
|
|
||||||
|
|
||||||
|
def init_gds(config: GDSConfig):
|
||||||
|
"""
|
||||||
|
Initialize GPUDirect Storage with the provided configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: GDSConfig object with initialization parameters
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Configure GDS
|
||||||
|
configure_gds(config)
|
||||||
|
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
|
||||||
|
|
||||||
|
# Set up exit handler for stats if requested
|
||||||
|
if hasattr(config, 'show_stats') and config.show_stats:
|
||||||
|
import atexit
|
||||||
|
def print_gds_stats():
|
||||||
|
stats = get_gds_stats()
|
||||||
|
logging.info("=== GDS Statistics ===")
|
||||||
|
logging.info(f"Total loads: {stats['total_loads']}")
|
||||||
|
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
|
||||||
|
logging.info(f"Fallback loads: {stats['fallback_loads']}")
|
||||||
|
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
|
||||||
|
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
|
||||||
|
logging.info("===================")
|
||||||
|
atexit.register(print_gds_stats)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS initialization failed: {e}")
|
||||||
@ -56,6 +56,18 @@ else:
|
|||||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||||
|
# Try GDS loading first if available and device is GPU
|
||||||
|
if device is not None and device.type == 'cuda':
|
||||||
|
try:
|
||||||
|
from . import gds_loader
|
||||||
|
gds_result = gds_loader.load_torch_file_gds(ckpt, safe_load=safe_load, device=device)
|
||||||
|
if return_metadata:
|
||||||
|
# For GDS, we return empty metadata for now (can be enhanced)
|
||||||
|
return (gds_result, {})
|
||||||
|
return gds_result
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"GDS loading failed, using fallback: {e}")
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|||||||
293
comfy_extras/nodes_gds.py
Normal file
293
comfy_extras/nodes_gds.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
|
||||||
|
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
|
||||||
|
|
||||||
|
"""
|
||||||
|
Enhanced model loading nodes with GPUDirect Storage support
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.utils
|
||||||
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointLoaderGDS(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Enhanced checkpoint loader with GPUDirect Storage support
|
||||||
|
Provides direct SSD-to-GPU loading and prefetching capabilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {
|
||||||
|
"tooltip": "The name of the checkpoint (model) to load with GDS optimization."
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"prefetch": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Prefetch model to GPU cache for faster loading."
|
||||||
|
}),
|
||||||
|
"use_gds": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Use GPUDirect Storage if available."
|
||||||
|
}),
|
||||||
|
"target_device": (["auto", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cpu"], {
|
||||||
|
"default": "auto",
|
||||||
|
"tooltip": "Target device for model loading."
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING")
|
||||||
|
RETURN_NAMES = ("model", "clip", "vae", "load_info")
|
||||||
|
OUTPUT_TOOLTIPS = (
|
||||||
|
"The model used for denoising latents.",
|
||||||
|
"The CLIP model used for encoding text prompts.",
|
||||||
|
"The VAE model used for encoding and decoding images to and from latent space.",
|
||||||
|
"Loading information and statistics."
|
||||||
|
)
|
||||||
|
FUNCTION = "load_checkpoint_gds"
|
||||||
|
CATEGORY = "loaders/advanced"
|
||||||
|
DESCRIPTION = "Enhanced checkpoint loader with GPUDirect Storage support for direct SSD-to-GPU loading."
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
def load_checkpoint_gds(self, ckpt_name: str, prefetch: bool = False, use_gds: bool = True, target_device: str = "auto"):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
|
||||||
|
# Determine target device
|
||||||
|
if target_device == "auto":
|
||||||
|
device = None # Let the system decide
|
||||||
|
elif target_device == "cpu":
|
||||||
|
device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
device = torch.device(target_device)
|
||||||
|
|
||||||
|
load_info = {
|
||||||
|
"file": ckpt_name,
|
||||||
|
"path": ckpt_path,
|
||||||
|
"target_device": str(device) if device else "auto",
|
||||||
|
"gds_enabled": use_gds,
|
||||||
|
"prefetch_used": prefetch
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prefetch if requested
|
||||||
|
if prefetch and use_gds:
|
||||||
|
try:
|
||||||
|
from comfy.gds_loader import prefetch_model_gds
|
||||||
|
prefetch_success = prefetch_model_gds(ckpt_path)
|
||||||
|
load_info["prefetch_success"] = prefetch_success
|
||||||
|
if prefetch_success:
|
||||||
|
logging.info(f"Prefetched {ckpt_name} to GPU cache")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Prefetch failed for {ckpt_name}: {e}")
|
||||||
|
load_info["prefetch_error"] = str(e)
|
||||||
|
|
||||||
|
# Load checkpoint with potential GDS optimization
|
||||||
|
if use_gds and device and device.type == 'cuda':
|
||||||
|
try:
|
||||||
|
from comfy.gds_loader import get_gds_instance
|
||||||
|
gds = get_gds_instance()
|
||||||
|
|
||||||
|
# Check if GDS should be used for this file
|
||||||
|
if gds._should_use_gds(ckpt_path):
|
||||||
|
load_info["loader_used"] = "GDS"
|
||||||
|
logging.info(f"Loading {ckpt_name} with GDS")
|
||||||
|
else:
|
||||||
|
load_info["loader_used"] = "Standard"
|
||||||
|
logging.info(f"Loading {ckpt_name} with standard method (file too small for GDS)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"GDS check failed, using standard loading: {e}")
|
||||||
|
load_info["loader_used"] = "Standard (GDS failed)"
|
||||||
|
else:
|
||||||
|
load_info["loader_used"] = "Standard"
|
||||||
|
|
||||||
|
# Load the actual checkpoint
|
||||||
|
out = comfy.sd.load_checkpoint_guess_config(
|
||||||
|
ckpt_path,
|
||||||
|
output_vae=True,
|
||||||
|
output_clip=True,
|
||||||
|
embedding_directory=folder_paths.get_folder_paths("embeddings")
|
||||||
|
)
|
||||||
|
|
||||||
|
load_time = time.time() - start_time
|
||||||
|
load_info["load_time_seconds"] = round(load_time, 3)
|
||||||
|
load_info["load_success"] = True
|
||||||
|
|
||||||
|
# Format load info as string
|
||||||
|
info_str = f"Loaded: {ckpt_name}\n"
|
||||||
|
info_str += f"Method: {load_info['loader_used']}\n"
|
||||||
|
info_str += f"Time: {load_info['load_time_seconds']}s\n"
|
||||||
|
info_str += f"Device: {load_info['target_device']}"
|
||||||
|
|
||||||
|
if "prefetch_success" in load_info:
|
||||||
|
info_str += f"\nPrefetch: {'✓' if load_info['prefetch_success'] else '✗'}"
|
||||||
|
|
||||||
|
logging.info(f"Checkpoint loaded: {ckpt_name} in {load_time:.3f}s using {load_info['loader_used']}")
|
||||||
|
|
||||||
|
return (*out[:3], info_str)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
load_info["load_success"] = False
|
||||||
|
load_info["error"] = str(e)
|
||||||
|
error_str = f"Failed to load: {ckpt_name}\nError: {str(e)}"
|
||||||
|
logging.error(f"Checkpoint loading failed: {e}")
|
||||||
|
raise RuntimeError(error_str)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPrefetcher(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Node for prefetching models to GPU cache
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"checkpoint_names": ("STRING", {
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "List of checkpoint names to prefetch (one per line)."
|
||||||
|
}),
|
||||||
|
"prefetch_enabled": ("BOOLEAN", {
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Enable/disable prefetching."
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("prefetch_report",)
|
||||||
|
OUTPUT_TOOLTIPS = ("Report of prefetch operations.",)
|
||||||
|
FUNCTION = "prefetch_models"
|
||||||
|
CATEGORY = "loaders/advanced"
|
||||||
|
DESCRIPTION = "Prefetch multiple models to GPU cache for faster loading."
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def prefetch_models(self, checkpoint_names: str, prefetch_enabled: bool = True):
|
||||||
|
if not prefetch_enabled:
|
||||||
|
return ("Prefetching disabled",)
|
||||||
|
|
||||||
|
# Parse checkpoint names
|
||||||
|
names = [name.strip() for name in checkpoint_names.split('\n') if name.strip()]
|
||||||
|
|
||||||
|
if not names:
|
||||||
|
return ("No checkpoints specified for prefetching",)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from comfy.gds_loader import prefetch_model_gds
|
||||||
|
except ImportError:
|
||||||
|
return ("GDS not available for prefetching",)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
successful_prefetches = 0
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
try:
|
||||||
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", name)
|
||||||
|
success = prefetch_model_gds(ckpt_path)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
results.append(f"✓ {name}")
|
||||||
|
successful_prefetches += 1
|
||||||
|
else:
|
||||||
|
results.append(f"✗ {name} (prefetch failed)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
results.append(f"✗ {name} (error: {str(e)[:50]})")
|
||||||
|
|
||||||
|
report = f"Prefetch Report ({successful_prefetches}/{len(names)} successful):\n"
|
||||||
|
report += "\n".join(results)
|
||||||
|
|
||||||
|
return (report,)
|
||||||
|
|
||||||
|
|
||||||
|
class GDSStats(ComfyNodeABC):
|
||||||
|
"""
|
||||||
|
Node for displaying GDS statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s) -> InputTypeDict:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"refresh": ("BOOLEAN", {
|
||||||
|
"default": False,
|
||||||
|
"tooltip": "Refresh statistics."
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("STRING",)
|
||||||
|
RETURN_NAMES = ("stats_report",)
|
||||||
|
OUTPUT_TOOLTIPS = ("GDS statistics and performance report.",)
|
||||||
|
FUNCTION = "get_stats"
|
||||||
|
CATEGORY = "utils/advanced"
|
||||||
|
DESCRIPTION = "Display GPUDirect Storage statistics and performance metrics."
|
||||||
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
|
def get_stats(self, refresh: bool = False):
|
||||||
|
try:
|
||||||
|
from comfy.gds_loader import get_gds_stats
|
||||||
|
stats = get_gds_stats()
|
||||||
|
|
||||||
|
report = "=== GPUDirect Storage Statistics ===\n\n"
|
||||||
|
|
||||||
|
# Availability
|
||||||
|
report += f"GDS Available: {'✓' if stats['gds_available'] else '✗'}\n"
|
||||||
|
|
||||||
|
# Usage statistics
|
||||||
|
report += f"Total Loads: {stats['total_loads']}\n"
|
||||||
|
report += f"GDS Loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)\n"
|
||||||
|
report += f"Fallback Loads: {stats['fallback_loads']}\n\n"
|
||||||
|
|
||||||
|
# Performance metrics
|
||||||
|
if stats['total_bytes_gds'] > 0:
|
||||||
|
gb_transferred = stats['total_bytes_gds'] / (1024**3)
|
||||||
|
report += f"Data Transferred: {gb_transferred:.2f} GB\n"
|
||||||
|
report += f"Average Bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s\n"
|
||||||
|
report += f"Total GDS Time: {stats['total_time_gds']:.2f}s\n\n"
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
config = stats.get('config', {})
|
||||||
|
if config:
|
||||||
|
report += "Configuration:\n"
|
||||||
|
report += f"- Enabled: {config.get('enabled', 'Unknown')}\n"
|
||||||
|
report += f"- Min File Size: {config.get('min_file_size_mb', 'Unknown')} MB\n"
|
||||||
|
report += f"- Chunk Size: {config.get('chunk_size_mb', 'Unknown')} MB\n"
|
||||||
|
report += f"- Max Streams: {config.get('max_concurrent_streams', 'Unknown')}\n"
|
||||||
|
report += f"- Prefetch: {config.get('prefetch_enabled', 'Unknown')}\n"
|
||||||
|
report += f"- Fallback: {config.get('fallback_to_cpu', 'Unknown')}\n"
|
||||||
|
|
||||||
|
return (report,)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
return ("GDS module not available",)
|
||||||
|
except Exception as e:
|
||||||
|
return (f"Error retrieving GDS stats: {str(e)}",)
|
||||||
|
|
||||||
|
|
||||||
|
# Node mappings
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"CheckpointLoaderGDS": CheckpointLoaderGDS,
|
||||||
|
"ModelPrefetcher": ModelPrefetcher,
|
||||||
|
"GDSStats": GDSStats,
|
||||||
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"CheckpointLoaderGDS": "Load Checkpoint (GDS)",
|
||||||
|
"ModelPrefetcher": "Model Prefetcher",
|
||||||
|
"GDSStats": "GDS Statistics",
|
||||||
|
}
|
||||||
29
main.py
29
main.py
@ -185,6 +185,35 @@ import comfyui_version
|
|||||||
import app.logger
|
import app.logger
|
||||||
import hook_breaker_ac10a0
|
import hook_breaker_ac10a0
|
||||||
|
|
||||||
|
# Initialize GPUDirect Storage if enabled
|
||||||
|
def init_gds():
|
||||||
|
"""Initialize GPUDirect Storage based on CLI arguments"""
|
||||||
|
if hasattr(args, 'disable_gds') and args.disable_gds:
|
||||||
|
logging.info("GDS explicitly disabled via --disable-gds")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not hasattr(args, 'enable_gds') and not hasattr(args, 'gds_prefetch') and not hasattr(args, 'gds_stats'):
|
||||||
|
# GDS not explicitly requested, use auto-detection
|
||||||
|
return
|
||||||
|
|
||||||
|
if hasattr(args, 'enable_gds') and args.enable_gds:
|
||||||
|
from comfy.gds_loader import GDSConfig, init_gds as gds_init
|
||||||
|
|
||||||
|
config = GDSConfig(
|
||||||
|
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
|
||||||
|
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
|
||||||
|
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
|
||||||
|
max_concurrent_streams=getattr(args, 'gds_streams', 4),
|
||||||
|
prefetch_enabled=getattr(args, 'gds_prefetch', True),
|
||||||
|
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False),
|
||||||
|
show_stats=getattr(args, 'gds_stats', False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gds_init(config)
|
||||||
|
|
||||||
|
# Initialize GDS
|
||||||
|
init_gds()
|
||||||
|
|
||||||
def cuda_malloc_warning():
|
def cuda_malloc_warning():
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -2354,6 +2354,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_model_patch.py",
|
"nodes_model_patch.py",
|
||||||
"nodes_easycache.py",
|
"nodes_easycache.py",
|
||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
|
"nodes_gds.py",
|
||||||
"nodes_rope.py",
|
"nodes_rope.py",
|
||||||
"nodes_logic.py",
|
"nodes_logic.py",
|
||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
|
|||||||
@ -26,4 +26,4 @@ av>=14.2.0
|
|||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
spandrel
|
spandrel
|
||||||
pydantic~=2.0
|
pydantic~=2.0
|
||||||
pydantic-settings~=2.0
|
pydantic-settings~=2.0
|
||||||
Loading…
Reference in New Issue
Block a user