mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 19:13:02 +08:00
feat(gds): implement GPUDirect Storage initialization based on CLI arguments
This commit is contained in:
parent
862e7784f4
commit
154b73835a
50
main.py
50
main.py
@ -155,6 +155,56 @@ import comfyui_version
|
|||||||
import app.logger
|
import app.logger
|
||||||
import hook_breaker_ac10a0
|
import hook_breaker_ac10a0
|
||||||
|
|
||||||
|
# Initialize GPUDirect Storage if enabled
|
||||||
|
def init_gds():
|
||||||
|
"""Initialize GPUDirect Storage based on CLI arguments"""
|
||||||
|
if hasattr(args, 'disable_gds') and args.disable_gds:
|
||||||
|
logging.info("GDS explicitly disabled via --disable-gds")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not hasattr(args, 'enable_gds') and not hasattr(args, 'gds_prefetch') and not hasattr(args, 'gds_stats'):
|
||||||
|
# GDS not explicitly requested, use auto-detection
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from comfy.gds_loader import GDSConfig, configure_gds, get_gds_stats
|
||||||
|
|
||||||
|
# Create GDS configuration from CLI args
|
||||||
|
config = GDSConfig(
|
||||||
|
enabled=getattr(args, 'enable_gds', False) or getattr(args, 'gds_prefetch', False),
|
||||||
|
min_file_size_mb=getattr(args, 'gds_min_file_size', 100),
|
||||||
|
chunk_size_mb=getattr(args, 'gds_chunk_size', 64),
|
||||||
|
max_concurrent_streams=getattr(args, 'gds_streams', 4),
|
||||||
|
prefetch_enabled=getattr(args, 'gds_prefetch', True),
|
||||||
|
fallback_to_cpu=not getattr(args, 'gds_no_fallback', False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure GDS
|
||||||
|
configure_gds(config)
|
||||||
|
logging.info(f"GDS initialized: enabled={config.enabled}, min_size={config.min_file_size_mb}MB, streams={config.max_concurrent_streams}")
|
||||||
|
|
||||||
|
# Set up exit handler for stats if requested
|
||||||
|
if getattr(args, 'gds_stats', False):
|
||||||
|
import atexit
|
||||||
|
def print_gds_stats():
|
||||||
|
stats = get_gds_stats()
|
||||||
|
logging.info("=== GDS Statistics ===")
|
||||||
|
logging.info(f"Total loads: {stats['total_loads']}")
|
||||||
|
logging.info(f"GDS loads: {stats['gds_loads']} ({stats['gds_usage_percent']:.1f}%)")
|
||||||
|
logging.info(f"Fallback loads: {stats['fallback_loads']}")
|
||||||
|
logging.info(f"Total bytes via GDS: {stats['total_bytes_gds'] / (1024**3):.2f} GB")
|
||||||
|
logging.info(f"Average bandwidth: {stats['avg_bandwidth_gbps']:.2f} GB/s")
|
||||||
|
logging.info("===================")
|
||||||
|
atexit.register(print_gds_stats)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
logging.warning(f"GDS initialization failed - missing dependencies: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"GDS initialization failed: {e}")
|
||||||
|
|
||||||
|
# Initialize GDS
|
||||||
|
init_gds()
|
||||||
|
|
||||||
def cuda_malloc_warning():
|
def cuda_malloc_warning():
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device_name = comfy.model_management.get_torch_device_name(device)
|
device_name = comfy.model_management.get_torch_device_name(device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user