feat(gds): add GPUDirect Storage options for SSD-to-GPU model loading;

This commit is contained in:
Maifee Ul Asad 2025-10-08 14:39:22 +06:00
parent 5f24eb699c
commit 83b00df3f0

View File

@ -139,6 +139,17 @@ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
# GPUDirect Storage (GDS) arguments
gds_group = parser.add_argument_group('gds', 'GPUDirect Storage options for direct SSD-to-GPU model loading')
gds_group.add_argument("--enable-gds", action="store_true", help="Enable GPUDirect Storage for direct SSD-to-GPU model loading (requires CUDA 11.4+, cuFile).")
gds_group.add_argument("--disable-gds", action="store_true", help="Explicitly disable GPUDirect Storage.")
gds_group.add_argument("--gds-min-file-size", type=int, default=100, help="Minimum file size in MB to use GDS (default: 100MB).")
gds_group.add_argument("--gds-chunk-size", type=int, default=64, help="GDS transfer chunk size in MB (default: 64MB).")
gds_group.add_argument("--gds-streams", type=int, default=4, help="Number of CUDA streams for GDS operations (default: 4).")
gds_group.add_argument("--gds-prefetch", action="store_true", help="Enable GDS prefetching for better performance.")
gds_group.add_argument("--gds-no-fallback", action="store_true", help="Disable fallback to CPU loading if GDS fails.")
gds_group.add_argument("--gds-stats", action="store_true", help="Print GDS statistics on exit.")
class PerformanceFeature(enum.Enum): class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation" Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult" Fp8MatrixMultiplication = "fp8_matrix_mult"