From fab09540779d7cb6597ec74e88e2e0b97365e380 Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Wed, 8 Oct 2025 14:38:38 +0600 Subject: [PATCH] feat(gds): add GPUDirect Storage support for model loading and prefetching - limited to NVIDIA GPUs only --- comfy_extras/nodes_gds.py | 293 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 comfy_extras/nodes_gds.py diff --git a/comfy_extras/nodes_gds.py b/comfy_extras/nodes_gds.py new file mode 100644 index 000000000..fc3aa8a2f --- /dev/null +++ b/comfy_extras/nodes_gds.py @@ -0,0 +1,293 @@ +# copyright 2025 Maifee Ul Asad @ github.com/maifeeulasad +# copyright under GNU GENERAL PUBLIC LICENSE, Version 3, 29 June 2007 + +""" +Enhanced model loading nodes with GPUDirect Storage support +""" + +import logging +import time +import asyncio +from typing import Optional, Dict, Any + +import torch +import folder_paths +import comfy.sd +import comfy.utils +from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict + + +class CheckpointLoaderGDS(ComfyNodeABC): + """ + Enhanced checkpoint loader with GPUDirect Storage support + Provides direct SSD-to-GPU loading and prefetching capabilities + """ + + @classmethod + def INPUT_TYPES(s) -> InputTypeDict: + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), { + "tooltip": "The name of the checkpoint (model) to load with GDS optimization." + }), + }, + "optional": { + "prefetch": ("BOOLEAN", { + "default": False, + "tooltip": "Prefetch model to GPU cache for faster loading." + }), + "use_gds": ("BOOLEAN", { + "default": True, + "tooltip": "Use GPUDirect Storage if available." + }), + "target_device": (["auto", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cpu"], { + "default": "auto", + "tooltip": "Target device for model loading." + }) + } + } + + RETURN_TYPES = ("MODEL", "CLIP", "VAE", "STRING") + RETURN_NAMES = ("model", "clip", "vae", "load_info") + OUTPUT_TOOLTIPS = ( + "The model used for denoising latents.", + "The CLIP model used for encoding text prompts.", + "The VAE model used for encoding and decoding images to and from latent space.", + "Loading information and statistics." + ) + FUNCTION = "load_checkpoint_gds" + CATEGORY = "loaders/advanced" + DESCRIPTION = "Enhanced checkpoint loader with GPUDirect Storage support for direct SSD-to-GPU loading." + EXPERIMENTAL = True + + def load_checkpoint_gds(self, ckpt_name: str, prefetch: bool = False, use_gds: bool = True, target_device: str = "auto"): + start_time = time.time() + + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + # Determine target device + if target_device == "auto": + device = None # Let the system decide + elif target_device == "cpu": + device = torch.device("cpu") + else: + device = torch.device(target_device) + + load_info = { + "file": ckpt_name, + "path": ckpt_path, + "target_device": str(device) if device else "auto", + "gds_enabled": use_gds, + "prefetch_used": prefetch + } + + try: + # Prefetch if requested + if prefetch and use_gds: + try: + from comfy.gds_loader import prefetch_model_gds + prefetch_success = prefetch_model_gds(ckpt_path) + load_info["prefetch_success"] = prefetch_success + if prefetch_success: + logging.info(f"Prefetched {ckpt_name} to GPU cache") + except Exception as e: + logging.warning(f"Prefetch failed for {ckpt_name}: {e}") + load_info["prefetch_error"] = str(e) + + # Load checkpoint with potential GDS optimization + if use_gds and device and device.type == 'cuda': + try: + from comfy.gds_loader import get_gds_instance + gds = get_gds_instance() + + # Check if GDS should be used for this file + if gds._should_use_gds(ckpt_path): + load_info["loader_used"] = "GDS" + logging.info(f"Loading {ckpt_name} with GDS") + else: + load_info["loader_used"] = "Standard" + logging.info(f"Loading {ckpt_name} with standard method (file too small for GDS)") + + except Exception as e: + logging.warning(f"GDS check failed, using standard loading: {e}") + load_info["loader_used"] = "Standard (GDS failed)" + else: + load_info["loader_used"] = "Standard" + + # Load the actual checkpoint + out = comfy.sd.load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings") + ) + + load_time = time.time() - start_time + load_info["load_time_seconds"] = round(load_time, 3) + load_info["load_success"] = True + + # Format load info as string + info_str = f"Loaded: {ckpt_name}\n" + info_str += f"Method: {load_info['loader_used']}\n" + info_str += f"Time: {load_info['load_time_seconds']}s\n" + info_str += f"Device: {load_info['target_device']}" + + if "prefetch_success" in load_info: + info_str += f"\nPrefetch: {'✓' if load_info['prefetch_success'] else '✗'}" + + logging.info(f"Checkpoint loaded: {ckpt_name} in {load_time:.3f}s using {load_info['loader_used']}") + + return (*out[:3], info_str) + + except Exception as e: + load_info["load_success"] = False + load_info["error"] = str(e) + error_str = f"Failed to load: {ckpt_name}\nError: {str(e)}" + logging.error(f"Checkpoint loading failed: {e}") + raise RuntimeError(error_str) + + +class ModelPrefetcher(ComfyNodeABC): + """ + Node for prefetching models to GPU cache + """ + + @classmethod + def INPUT_TYPES(s) -> InputTypeDict: + return { + "required": { + "checkpoint_names": ("STRING", { + "multiline": True, + "default": "", + "tooltip": "List of checkpoint names to prefetch (one per line)." + }), + "prefetch_enabled": ("BOOLEAN", { + "default": True, + "tooltip": "Enable/disable prefetching." + }) + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("prefetch_report",) + OUTPUT_TOOLTIPS = ("Report of prefetch operations.",) + FUNCTION = "prefetch_models" + CATEGORY = "loaders/advanced" + DESCRIPTION = "Prefetch multiple models to GPU cache for faster loading." + OUTPUT_NODE = True + + def prefetch_models(self, checkpoint_names: str, prefetch_enabled: bool = True): + if not prefetch_enabled: + return ("Prefetching disabled",) + + # Parse checkpoint names + names = [name.strip() for name in checkpoint_names.split('\n') if name.strip()] + + if not names: + return ("No checkpoints specified for prefetching",) + + try: + from comfy.gds_loader import prefetch_model_gds + except ImportError: + return ("GDS not available for prefetching",) + + results = [] + successful_prefetches = 0 + + for name in names: + try: + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", name) + success = prefetch_model_gds(ckpt_path) + + if success: + results.append(f"✓ {name}") + successful_prefetches += 1 + else: + results.append(f"✗ {name} (prefetch failed)") + + except Exception as e: + results.append(f"✗ {name} (error: {str(e)[:50]})") + + report = f"Prefetch Report ({successful_prefetches}/{len(names)} successful):\n" + report += "\n".join(results) + + return (report,) + + +class GDSStats(ComfyNodeABC): + """ + Node for displaying GDS statistics + """ + + @classmethod + def INPUT_TYPES(s) -> InputTypeDict: + return { + "required": { + "refresh": ("BOOLEAN", { + "default": False, + "tooltip": "Refresh statistics." + }) + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("stats_report",) + OUTPUT_TOOLTIPS = ("GDS statistics and performance report.",) + FUNCTION = "get_stats" + CATEGORY = "utils/advanced" + DESCRIPTION = "Display GPUDirect Storage statistics and performance metrics." + OUTPUT_NODE = True + + def get_stats(self, refresh: bool = False): + try: + from comfy.gds_loader import get_gds_stats + stats = get_gds_stats() + + report = "=== GPUDirect Storage Statistics ===\n\n" + + # Availability + report += f"GDS Available: {'✓' if stats['gds_available'] else '✗'}\n" + + # Usage statistics + report += f"Total Loads: {stats['total_loads']}\n" + report += f"GDS Loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)\n" + report += f"Fallback Loads: {stats['fallback_loads']}\n\n" + + # Performance metrics + if stats['total_bytes_gds'] > 0: + gb_transferred = stats['total_bytes_gds'] / (1024**3) + report += f"Data Transferred: {gb_transferred:.2f} GB\n" + report += f"Average Bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s\n" + report += f"Total GDS Time: {stats['total_time_gds']:.2f}s\n\n" + + # Configuration + config = stats.get('config', {}) + if config: + report += "Configuration:\n" + report += f"- Enabled: {config.get('enabled', 'Unknown')}\n" + report += f"- Min File Size: {config.get('min_file_size_mb', 'Unknown')} MB\n" + report += f"- Chunk Size: {config.get('chunk_size_mb', 'Unknown')} MB\n" + report += f"- Max Streams: {config.get('max_concurrent_streams', 'Unknown')}\n" + report += f"- Prefetch: {config.get('prefetch_enabled', 'Unknown')}\n" + report += f"- Fallback: {config.get('fallback_to_cpu', 'Unknown')}\n" + + return (report,) + + except ImportError: + return ("GDS module not available",) + except Exception as e: + return (f"Error retrieving GDS stats: {str(e)}",) + + +# Node mappings +NODE_CLASS_MAPPINGS = { + "CheckpointLoaderGDS": CheckpointLoaderGDS, + "ModelPrefetcher": ModelPrefetcher, + "GDSStats": GDSStats, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CheckpointLoaderGDS": "Load Checkpoint (GDS)", + "ModelPrefetcher": "Model Prefetcher", + "GDSStats": "GDS Statistics", +} \ No newline at end of file