mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 11:03:00 +08:00
feat(gds): implement GPUDirect Storage initialization based on CLI arguments
This commit is contained in:
parent
862e7784f4
commit
154b73835a
50
main.py
50
main.py
@ -155,6 +155,56 @@ import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
# Initialize GPUDirect Storage if enabled
|
||||
def init_gds():
|
||||
"""Initialize GPUDirect Storage based on CLI arguments"""
|
||||
if hasattr(args, 'disable_gds') and args.disable_gds:
|
||||
logging.info("GDS explicitly disabled via --disable-gds")
|
||||
return
|
||||
|
||||
if not hasattr(args, 'enable_gds') and not hasattr(args, 'gds_prefetch') and not hasattr(args, 'gds_stats'):
|
||||
# GDS not explicitly requested, use auto-detection
|
||||
return
|
||||
|
||||
try:
|
||||
from comfy.gds_loader import GDSConfig, configure_gds, get_gds_stats
|
||||
|
||||
# Create GDS configuration from CLI args
|
||||
config = GDSConfig(
|
||||
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
|
||||
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
|
||||
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
|
||||
max_concurrent_streams=getattr(args, 'gds_streams', 4),
|
||||
prefetch_enabled=getattr(args, 'gds_prefetch', True),
|
||||
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False)
|
||||
)
|
||||
|
||||
# Configure GDS
|
||||
configure_gds(config)
|
||||
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
|
||||
|
||||
# Set up exit handler for stats if requested
|
||||
if getattr(args, 'gds_stats', False):
|
||||
import atexit
|
||||
def print_gds_stats():
|
||||
stats = get_gds_stats()
|
||||
logging.info("=== GDS Statistics ===")
|
||||
logging.info(f"Total loads: {stats['total_loads']}")
|
||||
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
|
||||
logging.info(f"Fallback loads: {stats['fallback_loads']}")
|
||||
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
|
||||
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
|
||||
logging.info("===================")
|
||||
atexit.register(print_gds_stats)
|
||||
|
||||
except ImportError as e:
|
||||
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"GDS initialization failed: {e}")
|
||||
|
||||
# Initialize GDS
|
||||
init_gds()
|
||||
|
||||
def cuda_malloc_warning():
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user