diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0f4445d33..4b0c5b9c5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -72,7 +72,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: # Load with mmap - this doesn't load all data into RAM mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) - # Register cleanup callback + # Register cleanup callback - will be called when tensor is garbage collected def _cleanup(): try: if os.path.exists(temp_file): @@ -83,34 +83,35 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: weakref.finalize(mmap_tensor, _cleanup) - # Save original 'to' method - original_to = mmap_tensor.to + # # Save original 'to' method + # original_to = mmap_tensor.to - # Create custom 'to' method that cleans up file when moving to CUDA - def custom_to(*args, **kwargs): - # Determine target device - target_device = None - if len(args) > 0: - if isinstance(args[0], torch.device): - target_device = args[0] - elif isinstance(args[0], str): - target_device = torch.device(args[0]) - if 'device' in kwargs: - target_device = kwargs['device'] - if isinstance(target_device, str): - target_device = torch.device(target_device) - - # Call original 'to' method first to move data - result = original_to(*args, **kwargs) - - # If moved to CUDA, cleanup the mmap file after the move - if target_device is not None and target_device.type == 'cuda': - _cleanup() - - return result + # # Create custom 'to' method that cleans up file when moving to CUDA + # def custom_to(*args, **kwargs): + # # Determine target device + # target_device = None + # if len(args) > 0: + # if isinstance(args[0], torch.device): + # target_device = args[0] + # elif isinstance(args[0], str): + # target_device = torch.device(args[0]) + # if 'device' in kwargs: + # target_device = kwargs['device'] + # if isinstance(target_device, str): + # target_device = torch.device(target_device) + # + # # Call original 'to' method first to move data + # result = original_to(*args, **kwargs) + # + # # NOTE: Cleanup disabled to avoid blocking model load performance + # # If moved to CUDA, cleanup the mmap file after the move + # if target_device is not None and target_device.type == 'cuda': + # _cleanup() + # + # return result - # Replace the 'to' method - mmap_tensor.to = custom_to + # # Replace the 'to' method + # mmap_tensor.to = custom_to return mmap_tensor diff --git a/tests/execution/test_model_mmap.py b/tests/execution/test_model_mmap.py index 65dbe01bd..7a608c931 100644 --- a/tests/execution/test_model_mmap.py +++ b/tests/execution/test_model_mmap.py @@ -170,7 +170,7 @@ def test_to_mmap_cuda_cycle(): 3. GPU memory is freed when converting to mmap 4. mmap tensor can be moved back to CUDA 5. Data remains consistent throughout the cycle - 6. mmap file is automatically cleaned up when moved to CUDA + 6. mmap file is automatically cleaned up via garbage collection """ # Check if CUDA is available @@ -251,24 +251,26 @@ def test_to_mmap_cuda_cycle(): print(f" Difference: {sum_diff:.6f}") assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" - # Step 5: Verify file cleanup + # Step 5: Verify file cleanup (delayed until garbage collection) print("\n5. Verifying file cleanup...") + # Delete the mmap tensor reference to trigger garbage collection + del mmap_tensor gc.collect() import time time.sleep(0.1) # Give OS time to clean up temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) - print(f" Temp mmap files after: {temp_files_after}") - # File should be cleaned up when moved to CUDA - assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after moving to CUDA" + print(f" Temp mmap files after GC: {temp_files_after}") + # File should be cleaned up after garbage collection + assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection" print("\n✓ Test passed!") print(" CUDA -> mmap -> CUDA cycle works correctly") print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") print(" Data consistency maintained") - print(" File cleanup successful") + print(" File cleanup successful (via garbage collection)") # Cleanup - del mmap_tensor, cuda_tensor + del cuda_tensor # mmap_tensor already deleted in Step 5 gc.collect() torch.cuda.empty_cache()