feat(gds): implement GPUDirect Storage initialization based on CLI arguments

This commit is contained in:
Maifee Ul Asad 2025-10-08 14:40:59 +06:00
parent 862e7784f4
commit 154b73835a

50
main.py
View File

@ -155,6 +155,56 @@ import comfyui_version
import app.logger import app.logger
import hook_breaker_ac10a0 import hook_breaker_ac10a0
# Initialize GPUDirect Storage if enabled
def init_gds():
"""Initialize GPUDirect Storage based on CLI arguments"""
if hasattr(args, 'disable_gds') and args.disable_gds:
logging.info("GDS explicitly disabled via --disable-gds")
return
if not hasattr(args, 'enable_gds') and not hasattr(args, 'gds_prefetch') and not hasattr(args, 'gds_stats'):
# GDS not explicitly requested, use auto-detection
return
try:
from comfy.gds_loader import GDSConfig, configure_gds, get_gds_stats
# Create GDS configuration from CLI args
config = GDSConfig(
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
max_concurrent_streams=getattr(args, 'gds_streams', 4),
prefetch_enabled=getattr(args, 'gds_prefetch', True),
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False)
)
# Configure GDS
configure_gds(config)
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
# Set up exit handler for stats if requested
if getattr(args, 'gds_stats', False):
import atexit
def print_gds_stats():
stats = get_gds_stats()
logging.info("=== GDS Statistics ===")
logging.info(f"Total loads: {stats['total_loads']}")
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
logging.info(f"Fallback loads: {stats['fallback_loads']}")
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
logging.info("===================")
atexit.register(print_gds_stats)
except ImportError as e:
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
except Exception as e:
logging.error(f"GDS initialization failed: {e}")
# Initialize GDS
init_gds()
def cuda_malloc_warning(): def cuda_malloc_warning():
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device) device_name = comfy.model_management.get_torch_device_name(device)