mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 19:13:02 +08:00
feat(gds): add GPUDirect Storage support for model loading and prefetching
- limited to NVIDIA GPUs only
This commit is contained in:
parent
637221995f
commit
fab0954077
293
comfy_extras/nodes_gds.py
Normal file
293
comfy_extras/nodes_gds.py
Normal file
@ -0,0 +1,293 @@
|
||||
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
|
||||
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
|
||||
|
||||
"""
|
||||
Enhanced model loading nodes with GPUDirect Storage support
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import torch
|
||||
import folder_paths
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
|
||||
|
||||
class CheckpointLoaderGDS(ComfyNodeABC):
|
||||
"""
|
||||
Enhanced checkpoint loader with GPUDirect Storage support
|
||||
Provides direct SSD-to-GPU loading and prefetching capabilities
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {
|
||||
"tooltip": "The name of the checkpoint (model) to load with GDS optimization."
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"prefetch": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Prefetch model to GPU cache for faster loading."
|
||||
}),
|
||||
"use_gds": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Use GPUDirect Storage if available."
|
||||
}),
|
||||
"target_device": (["auto", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cpu"], {
|
||||
"default": "auto",
|
||||
"tooltip": "Target device for model loading."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING")
|
||||
RETURN_NAMES = ("model", "clip", "vae", "load_info")
|
||||
OUTPUT_TOOLTIPS = (
|
||||
"The model used for denoising latents.",
|
||||
"The CLIP model used for encoding text prompts.",
|
||||
"The VAE model used for encoding and decoding images to and from latent space.",
|
||||
"Loading information and statistics."
|
||||
)
|
||||
FUNCTION = "load_checkpoint_gds"
|
||||
CATEGORY = "loaders/advanced"
|
||||
DESCRIPTION = "Enhanced checkpoint loader with GPUDirect Storage support for direct SSD-to-GPU loading."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def load_checkpoint_gds(self, ckpt_name: str, prefetch: bool = False, use_gds: bool = True, target_device: str = "auto"):
|
||||
start_time = time.time()
|
||||
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||
|
||||
# Determine target device
|
||||
if target_device == "auto":
|
||||
device = None # Let the system decide
|
||||
elif target_device == "cpu":
|
||||
device = torch.device("cpu")
|
||||
else:
|
||||
device = torch.device(target_device)
|
||||
|
||||
load_info = {
|
||||
"file": ckpt_name,
|
||||
"path": ckpt_path,
|
||||
"target_device": str(device) if device else "auto",
|
||||
"gds_enabled": use_gds,
|
||||
"prefetch_used": prefetch
|
||||
}
|
||||
|
||||
try:
|
||||
# Prefetch if requested
|
||||
if prefetch and use_gds:
|
||||
try:
|
||||
from comfy.gds_loader import prefetch_model_gds
|
||||
prefetch_success = prefetch_model_gds(ckpt_path)
|
||||
load_info["prefetch_success"] = prefetch_success
|
||||
if prefetch_success:
|
||||
logging.info(f"Prefetched {ckpt_name} to GPU cache")
|
||||
except Exception as e:
|
||||
logging.warning(f"Prefetch failed for {ckpt_name}: {e}")
|
||||
load_info["prefetch_error"] = str(e)
|
||||
|
||||
# Load checkpoint with potential GDS optimization
|
||||
if use_gds and device and device.type == 'cuda':
|
||||
try:
|
||||
from comfy.gds_loader import get_gds_instance
|
||||
gds = get_gds_instance()
|
||||
|
||||
# Check if GDS should be used for this file
|
||||
if gds._should_use_gds(ckpt_path):
|
||||
load_info["loader_used"] = "GDS"
|
||||
logging.info(f"Loading {ckpt_name} with GDS")
|
||||
else:
|
||||
load_info["loader_used"] = "Standard"
|
||||
logging.info(f"Loading {ckpt_name} with standard method (file too small for GDS)")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"GDS check failed, using standard loading: {e}")
|
||||
load_info["loader_used"] = "Standard (GDS failed)"
|
||||
else:
|
||||
load_info["loader_used"] = "Standard"
|
||||
|
||||
# Load the actual checkpoint
|
||||
out = comfy.sd.load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings")
|
||||
)
|
||||
|
||||
load_time = time.time() - start_time
|
||||
load_info["load_time_seconds"] = round(load_time, 3)
|
||||
load_info["load_success"] = True
|
||||
|
||||
# Format load info as string
|
||||
info_str = f"Loaded: {ckpt_name}\n"
|
||||
info_str += f"Method: {load_info['loader_used']}\n"
|
||||
info_str += f"Time: {load_info['load_time_seconds']}s\n"
|
||||
info_str += f"Device: {load_info['target_device']}"
|
||||
|
||||
if "prefetch_success" in load_info:
|
||||
info_str += f"\nPrefetch: {'✓' if load_info['prefetch_success'] else '✗'}"
|
||||
|
||||
logging.info(f"Checkpoint loaded: {ckpt_name} in {load_time:.3f}s using {load_info['loader_used']}")
|
||||
|
||||
return (*out[:3], info_str)
|
||||
|
||||
except Exception as e:
|
||||
load_info["load_success"] = False
|
||||
load_info["error"] = str(e)
|
||||
error_str = f"Failed to load: {ckpt_name}\nError: {str(e)}"
|
||||
logging.error(f"Checkpoint loading failed: {e}")
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
|
||||
class ModelPrefetcher(ComfyNodeABC):
|
||||
"""
|
||||
Node for prefetching models to GPU cache
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"checkpoint_names": ("STRING", {
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "List of checkpoint names to prefetch (one per line)."
|
||||
}),
|
||||
"prefetch_enabled": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"tooltip": "Enable/disable prefetching."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("prefetch_report",)
|
||||
OUTPUT_TOOLTIPS = ("Report of prefetch operations.",)
|
||||
FUNCTION = "prefetch_models"
|
||||
CATEGORY = "loaders/advanced"
|
||||
DESCRIPTION = "Prefetch multiple models to GPU cache for faster loading."
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def prefetch_models(self, checkpoint_names: str, prefetch_enabled: bool = True):
|
||||
if not prefetch_enabled:
|
||||
return ("Prefetching disabled",)
|
||||
|
||||
# Parse checkpoint names
|
||||
names = [name.strip() for name in checkpoint_names.split('\n') if name.strip()]
|
||||
|
||||
if not names:
|
||||
return ("No checkpoints specified for prefetching",)
|
||||
|
||||
try:
|
||||
from comfy.gds_loader import prefetch_model_gds
|
||||
except ImportError:
|
||||
return ("GDS not available for prefetching",)
|
||||
|
||||
results = []
|
||||
successful_prefetches = 0
|
||||
|
||||
for name in names:
|
||||
try:
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", name)
|
||||
success = prefetch_model_gds(ckpt_path)
|
||||
|
||||
if success:
|
||||
results.append(f"✓ {name}")
|
||||
successful_prefetches += 1
|
||||
else:
|
||||
results.append(f"✗ {name} (prefetch failed)")
|
||||
|
||||
except Exception as e:
|
||||
results.append(f"✗ {name} (error: {str(e)[:50]})")
|
||||
|
||||
report = f"Prefetch Report ({successful_prefetches}/{len(names)} successful):\n"
|
||||
report += "\n".join(results)
|
||||
|
||||
return (report,)
|
||||
|
||||
|
||||
class GDSStats(ComfyNodeABC):
|
||||
"""
|
||||
Node for displaying GDS statistics
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"refresh": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"tooltip": "Refresh statistics."
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("stats_report",)
|
||||
OUTPUT_TOOLTIPS = ("GDS statistics and performance report.",)
|
||||
FUNCTION = "get_stats"
|
||||
CATEGORY = "utils/advanced"
|
||||
DESCRIPTION = "Display GPUDirect Storage statistics and performance metrics."
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def get_stats(self, refresh: bool = False):
|
||||
try:
|
||||
from comfy.gds_loader import get_gds_stats
|
||||
stats = get_gds_stats()
|
||||
|
||||
report = "=== GPUDirect Storage Statistics ===\n\n"
|
||||
|
||||
# Availability
|
||||
report += f"GDS Available: {'✓' if stats['gds_available'] else '✗'}\n"
|
||||
|
||||
# Usage statistics
|
||||
report += f"Total Loads: {stats['total_loads']}\n"
|
||||
report += f"GDS Loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)\n"
|
||||
report += f"Fallback Loads: {stats['fallback_loads']}\n\n"
|
||||
|
||||
# Performance metrics
|
||||
if stats['total_bytes_gds'] > 0:
|
||||
gb_transferred = stats['total_bytes_gds'] / (1024**3)
|
||||
report += f"Data Transferred: {gb_transferred:.2f} GB\n"
|
||||
report += f"Average Bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s\n"
|
||||
report += f"Total GDS Time: {stats['total_time_gds']:.2f}s\n\n"
|
||||
|
||||
# Configuration
|
||||
config = stats.get('config', {})
|
||||
if config:
|
||||
report += "Configuration:\n"
|
||||
report += f"- Enabled: {config.get('enabled', 'Unknown')}\n"
|
||||
report += f"- Min File Size: {config.get('min_file_size_mb', 'Unknown')} MB\n"
|
||||
report += f"- Chunk Size: {config.get('chunk_size_mb', 'Unknown')} MB\n"
|
||||
report += f"- Max Streams: {config.get('max_concurrent_streams', 'Unknown')}\n"
|
||||
report += f"- Prefetch: {config.get('prefetch_enabled', 'Unknown')}\n"
|
||||
report += f"- Fallback: {config.get('fallback_to_cpu', 'Unknown')}\n"
|
||||
|
||||
return (report,)
|
||||
|
||||
except ImportError:
|
||||
return ("GDS module not available",)
|
||||
except Exception as e:
|
||||
return (f"Error retrieving GDS stats: {str(e)}",)
|
||||
|
||||
|
||||
# Node mappings
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CheckpointLoaderGDS": CheckpointLoaderGDS,
|
||||
"ModelPrefetcher": ModelPrefetcher,
|
||||
"GDSStats": GDSStats,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointLoaderGDS": "Load Checkpoint (GDS)",
|
||||
"ModelPrefetcher": "Model Prefetcher",
|
||||
"GDSStats": "GDS Statistics",
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user