ComfyUI/comfy_extras/nodes_gds.py
2025-10-08 14:38:38 +06:00

293 lines
11 KiB
Python

# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad
# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007
"""
Enhanced model loading nodes with GPUDirect Storage support
"""
import logging
import time
import asyncio
from typing import Optional, Dict, Any
import torch
import folder_paths
import comfy.sd
import comfy.utils
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
class CheckpointLoaderGDS(ComfyNodeABC):
"""
Enhanced checkpoint loader with GPUDirect Storage support
Provides direct SSD-to-GPU loading and prefetching capabilities
"""
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), {
"tooltip": "The name of the checkpoint (model) to load with GDS optimization."
}),
},
"optional": {
"prefetch": ("BOOLEAN", {
"default": False,
"tooltip": "Prefetch model to GPU cache for faster loading."
}),
"use_gds": ("BOOLEAN", {
"default": True,
"tooltip": "Use GPUDirect Storage if available."
}),
"target_device": (["auto", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cpu"], {
"default": "auto",
"tooltip": "Target device for model loading."
})
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING")
RETURN_NAMES = ("model", "clip", "vae", "load_info")
OUTPUT_TOOLTIPS = (
"The model used for denoising latents.",
"The CLIP model used for encoding text prompts.",
"The VAE model used for encoding and decoding images to and from latent space.",
"Loading information and statistics."
)
FUNCTION = "load_checkpoint_gds"
CATEGORY = "loaders/advanced"
DESCRIPTION = "Enhanced checkpoint loader with GPUDirect Storage support for direct SSD-to-GPU loading."
EXPERIMENTAL = True
def load_checkpoint_gds(self, ckpt_name: str, prefetch: bool = False, use_gds: bool = True, target_device: str = "auto"):
start_time = time.time()
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
# Determine target device
if target_device == "auto":
device = None # Let the system decide
elif target_device == "cpu":
device = torch.device("cpu")
else:
device = torch.device(target_device)
load_info = {
"file": ckpt_name,
"path": ckpt_path,
"target_device": str(device) if device else "auto",
"gds_enabled": use_gds,
"prefetch_used": prefetch
}
try:
# Prefetch if requested
if prefetch and use_gds:
try:
from comfy.gds_loader import prefetch_model_gds
prefetch_success = prefetch_model_gds(ckpt_path)
load_info["prefetch_success"] = prefetch_success
if prefetch_success:
logging.info(f"Prefetched {ckpt_name} to GPU cache")
except Exception as e:
logging.warning(f"Prefetch failed for {ckpt_name}: {e}")
load_info["prefetch_error"] = str(e)
# Load checkpoint with potential GDS optimization
if use_gds and device and device.type == 'cuda':
try:
from comfy.gds_loader import get_gds_instance
gds = get_gds_instance()
# Check if GDS should be used for this file
if gds._should_use_gds(ckpt_path):
load_info["loader_used"] = "GDS"
logging.info(f"Loading {ckpt_name} with GDS")
else:
load_info["loader_used"] = "Standard"
logging.info(f"Loading {ckpt_name} with standard method (file too small for GDS)")
except Exception as e:
logging.warning(f"GDS check failed, using standard loading: {e}")
load_info["loader_used"] = "Standard (GDS failed)"
else:
load_info["loader_used"] = "Standard"
# Load the actual checkpoint
out = comfy.sd.load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings")
)
load_time = time.time() - start_time
load_info["load_time_seconds"] = round(load_time, 3)
load_info["load_success"] = True
# Format load info as string
info_str = f"Loaded: {ckpt_name}\n"
info_str += f"Method: {load_info['loader_used']}\n"
info_str += f"Time: {load_info['load_time_seconds']}s\n"
info_str += f"Device: {load_info['target_device']}"
if "prefetch_success" in load_info:
info_str += f"\nPrefetch: {'' if load_info['prefetch_success'] else ''}"
logging.info(f"Checkpoint loaded: {ckpt_name} in {load_time:.3f}s using {load_info['loader_used']}")
return (*out[:3], info_str)
except Exception as e:
load_info["load_success"] = False
load_info["error"] = str(e)
error_str = f"Failed to load: {ckpt_name}\nError: {str(e)}"
logging.error(f"Checkpoint loading failed: {e}")
raise RuntimeError(error_str)
class ModelPrefetcher(ComfyNodeABC):
"""
Node for prefetching models to GPU cache
"""
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"checkpoint_names": ("STRING", {
"multiline": True,
"default": "",
"tooltip": "List of checkpoint names to prefetch (one per line)."
}),
"prefetch_enabled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable/disable prefetching."
})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("prefetch_report",)
OUTPUT_TOOLTIPS = ("Report of prefetch operations.",)
FUNCTION = "prefetch_models"
CATEGORY = "loaders/advanced"
DESCRIPTION = "Prefetch multiple models to GPU cache for faster loading."
OUTPUT_NODE = True
def prefetch_models(self, checkpoint_names: str, prefetch_enabled: bool = True):
if not prefetch_enabled:
return ("Prefetching disabled",)
# Parse checkpoint names
names = [name.strip() for name in checkpoint_names.split('\n') if name.strip()]
if not names:
return ("No checkpoints specified for prefetching",)
try:
from comfy.gds_loader import prefetch_model_gds
except ImportError:
return ("GDS not available for prefetching",)
results = []
successful_prefetches = 0
for name in names:
try:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", name)
success = prefetch_model_gds(ckpt_path)
if success:
results.append(f"{name}")
successful_prefetches += 1
else:
results.append(f"{name} (prefetch failed)")
except Exception as e:
results.append(f"{name} (error: {str(e)[:50]})")
report = f"Prefetch Report ({successful_prefetches}/{len(names)} successful):\n"
report += "\n".join(results)
return (report,)
class GDSStats(ComfyNodeABC):
"""
Node for displaying GDS statistics
"""
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"refresh": ("BOOLEAN", {
"default": False,
"tooltip": "Refresh statistics."
})
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("stats_report",)
OUTPUT_TOOLTIPS = ("GDS statistics and performance report.",)
FUNCTION = "get_stats"
CATEGORY = "utils/advanced"
DESCRIPTION = "Display GPUDirect Storage statistics and performance metrics."
OUTPUT_NODE = True
def get_stats(self, refresh: bool = False):
try:
from comfy.gds_loader import get_gds_stats
stats = get_gds_stats()
report = "=== GPUDirect Storage Statistics ===\n\n"
# Availability
report += f"GDS Available: {'' if stats['gds_available'] else ''}\n"
# Usage statistics
report += f"Total Loads: {stats['total_loads']}\n"
report += f"GDS Loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)\n"
report += f"Fallback Loads: {stats['fallback_loads']}\n\n"
# Performance metrics
if stats['total_bytes_gds'] > 0:
gb_transferred = stats['total_bytes_gds'] / (1024**3)
report += f"Data Transferred: {gb_transferred:.2f} GB\n"
report += f"Average Bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s\n"
report += f"Total GDS Time: {stats['total_time_gds']:.2f}s\n\n"
# Configuration
config = stats.get('config', {})
if config:
report += "Configuration:\n"
report += f"- Enabled: {config.get('enabled', 'Unknown')}\n"
report += f"- Min File Size: {config.get('min_file_size_mb', 'Unknown')} MB\n"
report += f"- Chunk Size: {config.get('chunk_size_mb', 'Unknown')} MB\n"
report += f"- Max Streams: {config.get('max_concurrent_streams', 'Unknown')}\n"
report += f"- Prefetch: {config.get('prefetch_enabled', 'Unknown')}\n"
report += f"- Fallback: {config.get('fallback_to_cpu', 'Unknown')}\n"
return (report,)
except ImportError:
return ("GDS module not available",)
except Exception as e:
return (f"Error retrieving GDS stats: {str(e)}",)
# Node mappings
NODE_CLASS_MAPPINGS = {
"CheckpointLoaderGDS": CheckpointLoaderGDS,
"ModelPrefetcher": ModelPrefetcher,
"GDSStats": GDSStats,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoaderGDS": "Load Checkpoint (GDS)",
"ModelPrefetcher": "Model Prefetcher",
"GDSStats": "GDS Statistics",
}