From 154b73835a14567a7df1b5e3934dcf39759e9eec Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Wed, 8 Oct 2025 14:40:59 +0600 Subject: [PATCH] feat(gds): implement GPUDirect Storage initialization based on CLI arguments --- main.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/main.py b/main.py index 35857dba8..9c4414a17 100644 --- a/main.py +++ b/main.py @@ -155,6 +155,56 @@ import comfyui_version import app.logger import hook_breaker_ac10a0 +# Initialize GPUDirect Storage if enabled +def init_gds(): + """Initialize GPUDirect Storage based on CLI arguments""" + if hasattr(args, 'disable_gds') and args.disable_gds: + logging.info("GDS explicitly disabled via --disable-gds") + return + + if not hasattr(args, 'enable_gds') and not hasattr(args, 'gds_prefetch') and not hasattr(args, 'gds_stats'): + # GDS not explicitly requested, use auto-detection + return + + try: + from comfy.gds_loader import GDSConfig, configure_gds, get_gds_stats + + # Create GDS configuration from CLI args + config = GDSConfig( + enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False), + min_file_size_mb=getattr(args, 'gds_min_file_size', 100), + chunk_size_mb=getattr(args, 'gds_chunk_size', 64), + max_concurrent_streams=getattr(args, 'gds_streams', 4), + prefetch_enabled=getattr(args, 'gds_prefetch', True), + fallback_to_cpu=not getattr(args, 'gds_no_fallback', False) + ) + + # Configure GDS + configure_gds(config) + logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}") + + # Set up exit handler for stats if requested + if getattr(args, 'gds_stats', False): + import atexit + def print_gds_stats(): + stats = get_gds_stats() + logging.info("=== GDS Statistics ===") + logging.info(f"Total loads: {stats['total_loads']}") + logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)") + logging.info(f"Fallback loads: {stats['fallback_loads']}") + logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB") + logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s") + logging.info("===================") + atexit.register(print_gds_stats) + + except ImportError as e: + logging.warning(f"GDS initialization failed - missing dependencies: {e}") + except Exception as e: + logging.error(f"GDS initialization failed: {e}") + +# Initialize GDS +init_gds() + def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() device_name = comfy.model_management.get_torch_device_name(device)