diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dae9a895d..efc28c160 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -143,6 +143,7 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn' vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") +parser.add_argument("--offload-reserve-ram-gb", type=float, default=None, help="Set the amount of ram in GB you want to reserve for other use. When the limit is reached, model on vram will be offloaded to mmap to save ram.") parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 7f5a8aee9..c6c5151ff 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -28,6 +28,29 @@ import weakref import gc import os + +from functools import lru_cache + +@lru_cache(maxsize=1) +def get_offload_reserve_ram_gb(): + offload_reserve_ram_gb = 0 + try: + val = getattr(args, 'offload-reserve-ram-gb', None) + except Exception: + val = None + + if val is not None: + try: + offload_reserve_ram_gb = int(val) + except Exception: + logging.warning(f"Invalid args.offload-reserve-ram-gb value: {val}, defaulting to 0") + offload_reserve_ram_gb= 0 + return offload_reserve_ram_gb + +def get_free_disk(): + return psutil.disk_usage("/").free + + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -524,16 +547,33 @@ class LoadedModel: return False def model_unload(self, memory_to_free=None, unpatch_weights=True): - if memory_to_free is not None: - if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - if freed >= memory_to_free: - return False - self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None - self.real_model = None - return True + if memory_to_free is None: + # free the full model + memory_to_free = self.model.loaded_size() + + available_memory = get_free_memory(self.model.offload_device) + + mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage + if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size(): + partially_unload = True + else: + partially_unload = False + + if partially_unload: + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + if freed < memory_to_free: + logging.debug(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB") + else: + self.model.detach(unpatch_weights) + self.model_finalizer.detach() + self.model_finalizer = None + self.real_model = None + + if partially_unload: + return False + else: + return True + def model_use_more_vram(self, extra_memory, force_patch_weights=False): return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) @@ -587,7 +627,7 @@ def free_memory(memory_required, device, keep_loaded=[]): can_unload = [] unloaded_models = [] - for i in range(len(current_loaded_models) -1, -1, -1): + for i in range(len(current_loaded_models) -1, -1): shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded and not shift_model.is_dead(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d26c690..84793617d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,6 +27,10 @@ import uuid from typing import Callable, Optional import torch +import os +import tempfile +import weakref +import gc import comfy.float import comfy.hooks @@ -37,6 +41,80 @@ import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +from comfy.model_management import get_free_memory, get_offload_reserve_ram_gb, get_free_disk +from comfy.quant_ops import QuantizedTensor + +def enable_offload_to_mmap() -> bool: + if comfy.utils.DISABLE_MMAP: + return False + + free_cpu_mem = get_free_memory(torch.device("cpu")) + offload_reserve_ram_gb = get_offload_reserve_ram_gb() + if free_cpu_mem <= offload_reserve_ram_gb * 1024 * 1024 * 1024: + logging.debug(f"Enabling offload to mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {offload_reserve_ram_gb} GB") + return True + + return False + +def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: + """ + Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support. + """ + # Create temporary file + if filename is None: + temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1] + else: + temp_file = filename + + # Save tensor to file + cpu_tensor = t.cpu() + torch.save(cpu_tensor, temp_file) + + # If we created a CPU copy from other device, delete it to free memory + if not t.device.type == 'cpu': + del cpu_tensor + gc.collect() + + # Load with mmap - this doesn't load all data into RAM + mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) + + # Register cleanup callback - will be called when tensor is garbage collected + def _cleanup(): + try: + if os.path.exists(temp_file): + os.remove(temp_file) + logging.debug(f"Cleaned up mmap file: {temp_file}") + except Exception: + pass + + weakref.finalize(mmap_tensor, _cleanup) + + return mmap_tensor + +def model_to_mmap(model: torch.nn.Module): + """Convert all parameters and buffers to memory-mapped tensors + + Args: + model: PyTorch module to convert + + Returns: + The same model with all tensors converted to memory-mapped format + """ + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + + def convert_fn(t): + if isinstance(t, torch.nn.Parameter): + new_tensor = to_mmap(t.detach()) + return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad) + elif isinstance(t, torch.Tensor): + return to_mmap(t) + return t + + new_model = model._apply(convert_fn) + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + return new_model def string_to_seed(data): @@ -506,6 +584,7 @@ class ModelPatcher: return comfy.utils.get_attr(self.model, name) def model_patches_to(self, device): + # TODO(sf): to mmap to = self.model_options["transformer_options"] if "patches" in to: patches = to["patches"] @@ -853,9 +932,15 @@ class ModelPatcher: self.model.current_weight_patches_uuid = None self.backup.clear() + if device_to is not None: - self.model.to(device_to) + if enable_offload_to_mmap(): + # offload to mmap + model_to_mmap(self.model) + else: + self.model.to(device_to) self.model.device = device_to + self.model.model_loaded_weight_memory = 0 self.model.model_offload_buffer_memory = 0 @@ -914,7 +999,14 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + if enable_offload_to_mmap(): + if get_free_disk() < module_mem: + logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") + break + # offload to mmap + model_to_mmap(m) + else: + m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd96541d7..b88a2661f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -130,7 +130,19 @@ class QuantizedTensor(torch.Tensor): layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + # Use as_subclass so the QuantizedTensor instance shares the same + # storage and metadata as the underlying qdata tensor. This ensures + # torch.save/torch.load and the torch serialization storage scanning + # see a valid underlying storage (fixes data_ptr errors). + if not isinstance(qdata, torch.Tensor): + raise TypeError("qdata must be a torch.Tensor") + obj = qdata.as_subclass(cls) + # Ensure grad flag is consistent for quantized tensors + try: + obj.requires_grad_(False) + except Exception: + pass + return obj def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata @@ -578,3 +590,34 @@ def fp8_func(func, args, kwargs): ar[0] = plain_input return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) return func(*args, **kwargs) + +def _rebuild_quantized_tensor(qdata, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling when qdata is already a tensor.""" + return QuantizedTensor(qdata, layout_type, layout_params) + + +def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple. + + qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original + inner tensor. We call the provided rebuild function with its args to recreate the + inner tensor, then wrap it in QuantizedTensor. + """ + rebuild_fn, rebuild_args = qdata_reduce + qdata = rebuild_fn(*rebuild_args) + return QuantizedTensor(qdata, layout_type, layout_params) + + +# Register custom globals with torch.serialization so torch.load(..., weights_only=True) +# accepts these during unpickling. Wrapped in try/except for older PyTorch versions. +try: + import torch as _torch_serial + if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"): + _torch_serial.serialization.add_safe_globals([ + QuantizedTensor, + _rebuild_quantized_tensor, + _rebuild_quantized_tensor_from_base, + ]) +except Exception: + # If add_safe_globals doesn't exist or registration fails, we silently continue. + pass diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 9cb54ede8..51d27dd26 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -47,6 +47,29 @@ class TestQuantizedTensor(unittest.TestCase): self.assertEqual(dequantized.dtype, torch.float32) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + def test_save_load(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") + + torch.save(qt, "test.pt") + loaded_qt = torch.load("test.pt", weights_only=False) + # loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False) + + self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout") + self.assertEqual(loaded_qt._layout_params['scale'], scale) + self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16) + def test_from_float(self): """Test creating QuantizedTensor from float tensor""" float_tensor = torch.randn(64, 32, dtype=torch.float32) diff --git a/tests/inference/test_model_mmap.py b/tests/inference/test_model_mmap.py new file mode 100644 index 000000000..a7bff3bfc --- /dev/null +++ b/tests/inference/test_model_mmap.py @@ -0,0 +1,287 @@ +import pytest +import torch +import torch.nn as nn +import psutil +import os +import gc +import tempfile +import sys + +# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +from comfy.model_patcher import model_to_mmap, to_mmap + + +class LargeModel(nn.Module): + """A simple model with large parameters for testing memory mapping""" + + def __init__(self, size_gb=10): + super().__init__() + # Calculate number of float32 elements needed for target size + # 1 GB = 1024^3 bytes, float32 = 4 bytes + bytes_per_gb = 1024 * 1024 * 1024 + elements_per_gb = bytes_per_gb // 4 # float32 is 4 bytes + total_elements = int(size_gb * elements_per_gb) + + # Create a large linear layer + # Split into multiple layers to avoid single tensor size limits + self.layers = nn.ModuleList() + elements_per_layer = 500 * 1024 * 1024 # 500M elements per layer (~2GB) + num_layers = (total_elements + elements_per_layer - 1) // elements_per_layer + + for i in range(num_layers): + if i == num_layers - 1: + # Last layer gets the remaining elements + remaining = total_elements - (i * elements_per_layer) + in_features = int(remaining ** 0.5) + out_features = (remaining + in_features - 1) // in_features + else: + in_features = int(elements_per_layer ** 0.5) + out_features = (elements_per_layer + in_features - 1) // in_features + + # Create layer without bias to control size precisely + self.layers.append(nn.Linear(in_features, out_features, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_process_memory_gb(): + """Get current process memory usage in GB""" + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + return mem_info.rss / (1024 ** 3) # Convert to GB + + +def get_model_size_gb(model): + """Calculate model size in GB""" + total_size = 0 + for param in model.parameters(): + total_size += param.nelement() * param.element_size() + for buffer in model.buffers(): + total_size += buffer.nelement() * buffer.element_size() + return total_size / (1024 ** 3) + + +def test_model_to_mmap_memory_efficiency(): + """Test that model_to_mmap reduces memory usage for a 10GB model to less than 1GB + + The typical use case is: + 1. Load a large model on CUDA + 2. Convert to mmap to offload from GPU to disk-backed memory + 3. This frees GPU memory and reduces CPU RAM usage + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection before starting + gc.collect() + torch.cuda.empty_cache() + + # Record initial memory + initial_cpu_memory = get_process_memory_gb() + initial_gpu_memory = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"\nInitial CPU memory: {initial_cpu_memory:.2f} GB") + print(f"Initial GPU memory: {initial_gpu_memory:.2f} GB") + + # Create a 10GB model + print("Creating 10GB model...") + model = LargeModel(size_gb=10) + + # Verify model size + model_size = get_model_size_gb(model) + print(f"Model size: {model_size:.2f} GB") + assert model_size >= 9.5, f"Model size {model_size:.2f} GB is less than expected 10 GB" + + # Move model to CUDA + print("Moving model to CUDA...") + model = model.cuda() + torch.cuda.synchronize() + + # Memory after moving to CUDA + cpu_after_cuda = get_process_memory_gb() + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after moving to CUDA: {cpu_after_cuda:.2f} GB") + print(f"GPU memory after moving to CUDA: {gpu_after_cuda:.2f} GB") + + # Convert to mmap (this should move model from GPU to disk-backed memory) + # Note: model_to_mmap modifies the model in-place via _apply() + # so model and model_mmap will be the same object + print("Converting model to mmap...") + model_mmap = model_to_mmap(model) + + # Verify that model and model_mmap are the same object (in-place modification) + assert model is model_mmap, "model_to_mmap should modify the model in-place" + + # Force garbage collection and clear CUDA cache + # The original CUDA tensors should be automatically freed when replaced + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Memory after mmap conversion + cpu_after_mmap = get_process_memory_gb() + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after mmap: {cpu_after_mmap:.2f} GB") + print(f"GPU memory after mmap: {gpu_after_mmap:.2f} GB") + + # Calculate memory changes from CUDA state (the baseline we're converting from) + cpu_increase = cpu_after_mmap - cpu_after_cuda + gpu_decrease = gpu_after_cuda - gpu_after_mmap # Should be positive (freed) + print(f"\nCPU memory increase from CUDA: {cpu_increase:.2f} GB") + print(f"GPU memory freed: {gpu_decrease:.2f} GB") + + # Verify that CPU memory usage increase is less than 1GB + # The mmap should use disk-backed storage, keeping CPU RAM usage low + # We use 1.5 GB threshold to account for overhead + assert cpu_increase < 1.5, ( + f"CPU memory increase after mmap ({cpu_increase:.2f} GB) should be less than 1.5 GB. " + f"CUDA state: {cpu_after_cuda:.2f} GB, After mmap: {cpu_after_mmap:.2f} GB" + ) + + # Verify that GPU memory has been freed + # We expect at least 9 GB to be freed (original 10GB model with some tolerance) + assert gpu_decrease > 9.0, ( + f"GPU memory should be freed after mmap. " + f"Freed: {gpu_decrease:.2f} GB (from {gpu_after_cuda:.2f} to {gpu_after_mmap:.2f} GB), expected > 9 GB" + ) + + # Verify the model is still functional (basic sanity check) + assert model_mmap is not None + assert len(list(model_mmap.parameters())) > 0 + + print(f"\n✓ Test passed!") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 1.5 GB") + print(f" GPU memory freed: {gpu_decrease:.2f} GB > 9.0 GB") + print(f" Model successfully offloaded from GPU to disk-backed memory") + + # Cleanup (model and model_mmap are the same object) + del model, model_mmap + gc.collect() + torch.cuda.empty_cache() + + +def test_to_mmap_cuda_cycle(): + """Test CUDA -> mmap -> CUDA cycle + + This test verifies: + 1. CUDA tensor can be converted to mmap tensor + 2. CPU memory increase is minimal when using mmap (< 0.1 GB) + 3. GPU memory is freed when converting to mmap + 4. mmap tensor can be moved back to CUDA + 5. Data remains consistent throughout the cycle + 6. mmap file is automatically cleaned up via garbage collection + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection + gc.collect() + torch.cuda.empty_cache() + + print("\nTest: CUDA -> mmap -> CUDA cycle") + + # Record initial CPU memory + initial_cpu_memory = get_process_memory_gb() + print(f"Initial CPU memory: {initial_cpu_memory:.2f} GB") + + # Step 1: Create a CUDA tensor + print("\n1. Creating CUDA tensor...") + original_data = torch.randn(5000, 5000).cuda() + original_sum = original_data.sum().item() + print(f" Shape: {original_data.shape}") + print(f" Device: {original_data.device}") + print(f" Sum: {original_sum:.2f}") + + # Record GPU and CPU memory after CUDA allocation + cpu_after_cuda = get_process_memory_gb() + gpu_before_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_before_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_cuda:.2f} GB") + + # Step 2: Convert to mmap tensor + print("\n2. Converting to mmap tensor...") + mmap_tensor = to_mmap(original_data) + del original_data + gc.collect() + torch.cuda.empty_cache() + + print(f" Device: {mmap_tensor.device}") + print(f" Sum: {mmap_tensor.sum().item():.2f}") + + # Verify GPU memory is freed + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + cpu_after_mmap = get_process_memory_gb() + print(f" GPU memory freed: {gpu_before_mmap - gpu_after_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_mmap:.2f} GB") + + # Verify GPU memory is freed + assert gpu_after_mmap < 0.1, f"GPU memory should be freed, but {gpu_after_mmap:.2f} GB still allocated" + + # Verify CPU memory increase is minimal (should be close to 0 due to mmap) + cpu_increase = cpu_after_mmap - cpu_after_cuda + print(f" CPU memory increase: {cpu_increase:.2f} GB") + assert cpu_increase < 0.1, f"CPU memory should increase minimally, but increased by {cpu_increase:.2f} GB" + + # Get the temp file path (we'll check if it gets cleaned up) + # The file should exist at this point + temp_files_before = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files exist: {temp_files_before}") + + # Step 3: Move back to CUDA + print("\n3. Moving back to CUDA...") + cuda_tensor = mmap_tensor.to('cuda') + torch.cuda.synchronize() + + print(f" Device: {cuda_tensor.device}") + final_sum = cuda_tensor.sum().item() + print(f" Sum: {final_sum:.2f}") + + # Verify GPU memory is used again + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_after_cuda:.2f} GB") + + # Step 4: Verify data consistency + print("\n4. Verifying data consistency...") + sum_diff = abs(original_sum - final_sum) + print(f" Original sum: {original_sum:.2f}") + print(f" Final sum: {final_sum:.2f}") + print(f" Difference: {sum_diff:.6f}") + assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" + + # Step 5: Verify file cleanup (delayed until garbage collection) + print("\n5. Verifying file cleanup...") + # Delete the mmap tensor reference to trigger garbage collection + del mmap_tensor + gc.collect() + import time + time.sleep(0.1) # Give OS time to clean up + temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files after GC: {temp_files_after}") + # File should be cleaned up after garbage collection + assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection" + + print("\n✓ Test passed!") + print(" CUDA -> mmap -> CUDA cycle works correctly") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") + print(" Data consistency maintained") + print(" File cleanup successful (via garbage collection)") + + # Cleanup + del cuda_tensor # mmap_tensor already deleted in Step 5 + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Run the tests directly + test_model_to_mmap_memory_efficiency() + test_to_mmap_cuda_cycle() +