mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 02:53:05 +08:00
293 lines
11 KiB
Python
293 lines
11 KiB
Python
# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
|
|
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
|
|
|
|
"""
|
|
Enhanced model loading nodes with GPUDirect Storage support
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import asyncio
|
|
from typing import Optional, Dict, Any
|
|
|
|
import torch
|
|
import folder_paths
|
|
import comfy.sd
|
|
import comfy.utils
|
|
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
|
|
|
|
|
|
class CheckpointLoaderGDS(ComfyNodeABC):
|
|
"""
|
|
Enhanced checkpoint loader with GPUDirect Storage support
|
|
Provides direct SSD-to-GPU loading and prefetching capabilities
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s) -> InputTypeDict:
|
|
return {
|
|
"required": {
|
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {
|
|
"tooltip": "The name of the checkpoint (model) to load with GDS optimization."
|
|
}),
|
|
},
|
|
"optional": {
|
|
"prefetch": ("BOOLEAN", {
|
|
"default": False,
|
|
"tooltip": "Prefetch model to GPU cache for faster loading."
|
|
}),
|
|
"use_gds": ("BOOLEAN", {
|
|
"default": True,
|
|
"tooltip": "Use GPUDirect Storage if available."
|
|
}),
|
|
"target_device": (["auto", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cpu"], {
|
|
"default": "auto",
|
|
"tooltip": "Target device for model loading."
|
|
})
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING")
|
|
RETURN_NAMES = ("model", "clip", "vae", "load_info")
|
|
OUTPUT_TOOLTIPS = (
|
|
"The model used for denoising latents.",
|
|
"The CLIP model used for encoding text prompts.",
|
|
"The VAE model used for encoding and decoding images to and from latent space.",
|
|
"Loading information and statistics."
|
|
)
|
|
FUNCTION = "load_checkpoint_gds"
|
|
CATEGORY = "loaders/advanced"
|
|
DESCRIPTION = "Enhanced checkpoint loader with GPUDirect Storage support for direct SSD-to-GPU loading."
|
|
EXPERIMENTAL = True
|
|
|
|
def load_checkpoint_gds(self, ckpt_name: str, prefetch: bool = False, use_gds: bool = True, target_device: str = "auto"):
|
|
start_time = time.time()
|
|
|
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
|
|
|
# Determine target device
|
|
if target_device == "auto":
|
|
device = None # Let the system decide
|
|
elif target_device == "cpu":
|
|
device = torch.device("cpu")
|
|
else:
|
|
device = torch.device(target_device)
|
|
|
|
load_info = {
|
|
"file": ckpt_name,
|
|
"path": ckpt_path,
|
|
"target_device": str(device) if device else "auto",
|
|
"gds_enabled": use_gds,
|
|
"prefetch_used": prefetch
|
|
}
|
|
|
|
try:
|
|
# Prefetch if requested
|
|
if prefetch and use_gds:
|
|
try:
|
|
from comfy.gds_loader import prefetch_model_gds
|
|
prefetch_success = prefetch_model_gds(ckpt_path)
|
|
load_info["prefetch_success"] = prefetch_success
|
|
if prefetch_success:
|
|
logging.info(f"Prefetched {ckpt_name} to GPU cache")
|
|
except Exception as e:
|
|
logging.warning(f"Prefetch failed for {ckpt_name}: {e}")
|
|
load_info["prefetch_error"] = str(e)
|
|
|
|
# Load checkpoint with potential GDS optimization
|
|
if use_gds and device and device.type == 'cuda':
|
|
try:
|
|
from comfy.gds_loader import get_gds_instance
|
|
gds = get_gds_instance()
|
|
|
|
# Check if GDS should be used for this file
|
|
if gds._should_use_gds(ckpt_path):
|
|
load_info["loader_used"] = "GDS"
|
|
logging.info(f"Loading {ckpt_name} with GDS")
|
|
else:
|
|
load_info["loader_used"] = "Standard"
|
|
logging.info(f"Loading {ckpt_name} with standard method (file too small for GDS)")
|
|
|
|
except Exception as e:
|
|
logging.warning(f"GDS check failed, using standard loading: {e}")
|
|
load_info["loader_used"] = "Standard (GDS failed)"
|
|
else:
|
|
load_info["loader_used"] = "Standard"
|
|
|
|
# Load the actual checkpoint
|
|
out = comfy.sd.load_checkpoint_guess_config(
|
|
ckpt_path,
|
|
output_vae=True,
|
|
output_clip=True,
|
|
embedding_directory=folder_paths.get_folder_paths("embeddings")
|
|
)
|
|
|
|
load_time = time.time() - start_time
|
|
load_info["load_time_seconds"] = round(load_time, 3)
|
|
load_info["load_success"] = True
|
|
|
|
# Format load info as string
|
|
info_str = f"Loaded: {ckpt_name}\n"
|
|
info_str += f"Method: {load_info['loader_used']}\n"
|
|
info_str += f"Time: {load_info['load_time_seconds']}s\n"
|
|
info_str += f"Device: {load_info['target_device']}"
|
|
|
|
if "prefetch_success" in load_info:
|
|
info_str += f"\nPrefetch: {'✓' if load_info['prefetch_success'] else '✗'}"
|
|
|
|
logging.info(f"Checkpoint loaded: {ckpt_name} in {load_time:.3f}s using {load_info['loader_used']}")
|
|
|
|
return (*out[:3], info_str)
|
|
|
|
except Exception as e:
|
|
load_info["load_success"] = False
|
|
load_info["error"] = str(e)
|
|
error_str = f"Failed to load: {ckpt_name}\nError: {str(e)}"
|
|
logging.error(f"Checkpoint loading failed: {e}")
|
|
raise RuntimeError(error_str)
|
|
|
|
|
|
class ModelPrefetcher(ComfyNodeABC):
|
|
"""
|
|
Node for prefetching models to GPU cache
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s) -> InputTypeDict:
|
|
return {
|
|
"required": {
|
|
"checkpoint_names": ("STRING", {
|
|
"multiline": True,
|
|
"default": "",
|
|
"tooltip": "List of checkpoint names to prefetch (one per line)."
|
|
}),
|
|
"prefetch_enabled": ("BOOLEAN", {
|
|
"default": True,
|
|
"tooltip": "Enable/disable prefetching."
|
|
})
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("prefetch_report",)
|
|
OUTPUT_TOOLTIPS = ("Report of prefetch operations.",)
|
|
FUNCTION = "prefetch_models"
|
|
CATEGORY = "loaders/advanced"
|
|
DESCRIPTION = "Prefetch multiple models to GPU cache for faster loading."
|
|
OUTPUT_NODE = True
|
|
|
|
def prefetch_models(self, checkpoint_names: str, prefetch_enabled: bool = True):
|
|
if not prefetch_enabled:
|
|
return ("Prefetching disabled",)
|
|
|
|
# Parse checkpoint names
|
|
names = [name.strip() for name in checkpoint_names.split('\n') if name.strip()]
|
|
|
|
if not names:
|
|
return ("No checkpoints specified for prefetching",)
|
|
|
|
try:
|
|
from comfy.gds_loader import prefetch_model_gds
|
|
except ImportError:
|
|
return ("GDS not available for prefetching",)
|
|
|
|
results = []
|
|
successful_prefetches = 0
|
|
|
|
for name in names:
|
|
try:
|
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", name)
|
|
success = prefetch_model_gds(ckpt_path)
|
|
|
|
if success:
|
|
results.append(f"✓ {name}")
|
|
successful_prefetches += 1
|
|
else:
|
|
results.append(f"✗ {name} (prefetch failed)")
|
|
|
|
except Exception as e:
|
|
results.append(f"✗ {name} (error: {str(e)[:50]})")
|
|
|
|
report = f"Prefetch Report ({successful_prefetches}/{len(names)} successful):\n"
|
|
report += "\n".join(results)
|
|
|
|
return (report,)
|
|
|
|
|
|
class GDSStats(ComfyNodeABC):
|
|
"""
|
|
Node for displaying GDS statistics
|
|
"""
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(s) -> InputTypeDict:
|
|
return {
|
|
"required": {
|
|
"refresh": ("BOOLEAN", {
|
|
"default": False,
|
|
"tooltip": "Refresh statistics."
|
|
})
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("STRING",)
|
|
RETURN_NAMES = ("stats_report",)
|
|
OUTPUT_TOOLTIPS = ("GDS statistics and performance report.",)
|
|
FUNCTION = "get_stats"
|
|
CATEGORY = "utils/advanced"
|
|
DESCRIPTION = "Display GPUDirect Storage statistics and performance metrics."
|
|
OUTPUT_NODE = True
|
|
|
|
def get_stats(self, refresh: bool = False):
|
|
try:
|
|
from comfy.gds_loader import get_gds_stats
|
|
stats = get_gds_stats()
|
|
|
|
report = "=== GPUDirect Storage Statistics ===\n\n"
|
|
|
|
# Availability
|
|
report += f"GDS Available: {'✓' if stats['gds_available'] else '✗'}\n"
|
|
|
|
# Usage statistics
|
|
report += f"Total Loads: {stats['total_loads']}\n"
|
|
report += f"GDS Loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)\n"
|
|
report += f"Fallback Loads: {stats['fallback_loads']}\n\n"
|
|
|
|
# Performance metrics
|
|
if stats['total_bytes_gds'] > 0:
|
|
gb_transferred = stats['total_bytes_gds'] / (1024**3)
|
|
report += f"Data Transferred: {gb_transferred:.2f} GB\n"
|
|
report += f"Average Bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s\n"
|
|
report += f"Total GDS Time: {stats['total_time_gds']:.2f}s\n\n"
|
|
|
|
# Configuration
|
|
config = stats.get('config', {})
|
|
if config:
|
|
report += "Configuration:\n"
|
|
report += f"- Enabled: {config.get('enabled', 'Unknown')}\n"
|
|
report += f"- Min File Size: {config.get('min_file_size_mb', 'Unknown')} MB\n"
|
|
report += f"- Chunk Size: {config.get('chunk_size_mb', 'Unknown')} MB\n"
|
|
report += f"- Max Streams: {config.get('max_concurrent_streams', 'Unknown')}\n"
|
|
report += f"- Prefetch: {config.get('prefetch_enabled', 'Unknown')}\n"
|
|
report += f"- Fallback: {config.get('fallback_to_cpu', 'Unknown')}\n"
|
|
|
|
return (report,)
|
|
|
|
except ImportError:
|
|
return ("GDS module not available",)
|
|
except Exception as e:
|
|
return (f"Error retrieving GDS stats: {str(e)}",)
|
|
|
|
|
|
# Node mappings
|
|
NODE_CLASS_MAPPINGS = {
|
|
"CheckpointLoaderGDS": CheckpointLoaderGDS,
|
|
"ModelPrefetcher": ModelPrefetcher,
|
|
"GDSStats": GDSStats,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"CheckpointLoaderGDS": "Load Checkpoint (GDS)",
|
|
"ModelPrefetcher": "Model Prefetcher",
|
|
"GDSStats": "GDS Statistics",
|
|
} |