diff --git a/comfy/model_management.py b/comfy/model_management.py index 40717b1e4..5105111c6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,18 @@ import importlib import platform import weakref import gc +import os + +from functools import lru_cache + +@lru_cache(maxsize=1) +def get_mmap_mem_threshold_gb(): + mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) + logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}") + return mmap_mem_threshold_gb + +def get_free_disk(): + return psutil.disk_usage("/").free class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -521,16 +533,33 @@ class LoadedModel: return False def model_unload(self, memory_to_free=None, unpatch_weights=True): - if memory_to_free is not None: - if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - if freed >= memory_to_free: - return False - self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None - self.real_model = None - return True + if memory_to_free is None: + # free the full model + memory_to_free = self.model.loaded_size() + + available_memory = get_free_memory(self.model.offload_device) + + mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage + if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size(): + partially_unload = True + else: + partially_unload = False + + if partially_unload: + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + if freed < memory_to_free: + logging.debug(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB") + else: + self.model.detach(unpatch_weights) + self.model_finalizer.detach() + self.model_finalizer = None + self.real_model = None + + if partially_unload: + return False + else: + return True + def model_use_more_vram(self, extra_memory, force_patch_weights=False): return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d26c690..11b658c7e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,6 +27,10 @@ import uuid from typing import Callable, Optional import torch +import os +import tempfile +import weakref +import gc import comfy.float import comfy.hooks @@ -37,6 +41,76 @@ import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk +from comfy.quant_ops import QuantizedTensor + +def need_mmap() -> bool: + free_cpu_mem = get_free_memory(torch.device("cpu")) + mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() + if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: + logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") + return True + return False + +def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: + """ + Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support. + """ + # Create temporary file + if filename is None: + temp_file = tempfile.mkstemp(suffix='.pt', prefix='comfy_mmap_')[1] + else: + temp_file = filename + + # Save tensor to file + cpu_tensor = t.cpu() + torch.save(cpu_tensor, temp_file) + + # If we created a CPU copy from other device, delete it to free memory + if not t.device.type == 'cpu': + del cpu_tensor + gc.collect() + + # Load with mmap - this doesn't load all data into RAM + mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) + + # Register cleanup callback - will be called when tensor is garbage collected + def _cleanup(): + try: + if os.path.exists(temp_file): + os.remove(temp_file) + logging.debug(f"Cleaned up mmap file: {temp_file}") + except Exception: + pass + + weakref.finalize(mmap_tensor, _cleanup) + + return mmap_tensor + +def model_to_mmap(model: torch.nn.Module): + """Convert all parameters and buffers to memory-mapped tensors + + Args: + model: PyTorch module to convert + + Returns: + The same model with all tensors converted to memory-mapped format + """ + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + + def convert_fn(t): + if isinstance(t, torch.nn.Parameter): + new_tensor = to_mmap(t.detach()) + return torch.nn.Parameter(new_tensor, requires_grad=t.requires_grad) + elif isinstance(t, torch.Tensor): + return to_mmap(t) + return t + + new_model = model._apply(convert_fn) + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + return new_model def string_to_seed(data): @@ -506,6 +580,7 @@ class ModelPatcher: return comfy.utils.get_attr(self.model, name) def model_patches_to(self, device): + # TODO(sf): to mmap to = self.model_options["transformer_options"] if "patches" in to: patches = to["patches"] @@ -853,9 +928,15 @@ class ModelPatcher: self.model.current_weight_patches_uuid = None self.backup.clear() + if device_to is not None: - self.model.to(device_to) + if need_mmap(): + # offload to mmap + model_to_mmap(self.model) + else: + self.model.to(device_to) self.model.device = device_to + self.model.model_loaded_weight_memory = 0 self.model.model_offload_buffer_memory = 0 @@ -914,7 +995,14 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + if need_mmap(): + if get_free_disk() < module_mem: + logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") + break + # offload to mmap + model_to_mmap(m) + else: + m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd96541d7..b88a2661f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -130,7 +130,19 @@ class QuantizedTensor(torch.Tensor): layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + # Use as_subclass so the QuantizedTensor instance shares the same + # storage and metadata as the underlying qdata tensor. This ensures + # torch.save/torch.load and the torch serialization storage scanning + # see a valid underlying storage (fixes data_ptr errors). + if not isinstance(qdata, torch.Tensor): + raise TypeError("qdata must be a torch.Tensor") + obj = qdata.as_subclass(cls) + # Ensure grad flag is consistent for quantized tensors + try: + obj.requires_grad_(False) + except Exception: + pass + return obj def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata @@ -578,3 +590,34 @@ def fp8_func(func, args, kwargs): ar[0] = plain_input return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) return func(*args, **kwargs) + +def _rebuild_quantized_tensor(qdata, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling when qdata is already a tensor.""" + return QuantizedTensor(qdata, layout_type, layout_params) + + +def _rebuild_quantized_tensor_from_base(qdata_reduce, layout_type, layout_params): + """Rebuild QuantizedTensor during unpickling given the base tensor's reduce tuple. + + qdata_reduce is the tuple returned by qdata.__reduce_ex__(protocol) on the original + inner tensor. We call the provided rebuild function with its args to recreate the + inner tensor, then wrap it in QuantizedTensor. + """ + rebuild_fn, rebuild_args = qdata_reduce + qdata = rebuild_fn(*rebuild_args) + return QuantizedTensor(qdata, layout_type, layout_params) + + +# Register custom globals with torch.serialization so torch.load(..., weights_only=True) +# accepts these during unpickling. Wrapped in try/except for older PyTorch versions. +try: + import torch as _torch_serial + if hasattr(_torch_serial, "serialization") and hasattr(_torch_serial.serialization, "add_safe_globals"): + _torch_serial.serialization.add_safe_globals([ + QuantizedTensor, + _rebuild_quantized_tensor, + _rebuild_quantized_tensor_from_base, + ]) +except Exception: + # If add_safe_globals doesn't exist or registration fails, we silently continue. + pass diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 9cb54ede8..51d27dd26 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -47,6 +47,29 @@ class TestQuantizedTensor(unittest.TestCase): self.assertEqual(dequantized.dtype, torch.float32) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + def test_save_load(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") + + torch.save(qt, "test.pt") + loaded_qt = torch.load("test.pt", weights_only=False) + # loaded_qt = torch.load("test.pt", map_location='cpu', mmap=True, weights_only=False) + + self.assertEqual(loaded_qt._layout_type, "TensorCoreFP8Layout") + self.assertEqual(loaded_qt._layout_params['scale'], scale) + self.assertEqual(loaded_qt._layout_params['orig_dtype'], torch.bfloat16) + def test_from_float(self): """Test creating QuantizedTensor from float tensor""" float_tensor = torch.randn(64, 32, dtype=torch.float32) diff --git a/tests/inference/test_model_mmap.py b/tests/inference/test_model_mmap.py new file mode 100644 index 000000000..a7bff3bfc --- /dev/null +++ b/tests/inference/test_model_mmap.py @@ -0,0 +1,287 @@ +import pytest +import torch +import torch.nn as nn +import psutil +import os +import gc +import tempfile +import sys + +# Ensure the project root is on the Python path (so `import comfy` works when running tests from this folder) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + +from comfy.model_patcher import model_to_mmap, to_mmap + + +class LargeModel(nn.Module): + """A simple model with large parameters for testing memory mapping""" + + def __init__(self, size_gb=10): + super().__init__() + # Calculate number of float32 elements needed for target size + # 1 GB = 1024^3 bytes, float32 = 4 bytes + bytes_per_gb = 1024 * 1024 * 1024 + elements_per_gb = bytes_per_gb // 4 # float32 is 4 bytes + total_elements = int(size_gb * elements_per_gb) + + # Create a large linear layer + # Split into multiple layers to avoid single tensor size limits + self.layers = nn.ModuleList() + elements_per_layer = 500 * 1024 * 1024 # 500M elements per layer (~2GB) + num_layers = (total_elements + elements_per_layer - 1) // elements_per_layer + + for i in range(num_layers): + if i == num_layers - 1: + # Last layer gets the remaining elements + remaining = total_elements - (i * elements_per_layer) + in_features = int(remaining ** 0.5) + out_features = (remaining + in_features - 1) // in_features + else: + in_features = int(elements_per_layer ** 0.5) + out_features = (elements_per_layer + in_features - 1) // in_features + + # Create layer without bias to control size precisely + self.layers.append(nn.Linear(in_features, out_features, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_process_memory_gb(): + """Get current process memory usage in GB""" + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + return mem_info.rss / (1024 ** 3) # Convert to GB + + +def get_model_size_gb(model): + """Calculate model size in GB""" + total_size = 0 + for param in model.parameters(): + total_size += param.nelement() * param.element_size() + for buffer in model.buffers(): + total_size += buffer.nelement() * buffer.element_size() + return total_size / (1024 ** 3) + + +def test_model_to_mmap_memory_efficiency(): + """Test that model_to_mmap reduces memory usage for a 10GB model to less than 1GB + + The typical use case is: + 1. Load a large model on CUDA + 2. Convert to mmap to offload from GPU to disk-backed memory + 3. This frees GPU memory and reduces CPU RAM usage + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection before starting + gc.collect() + torch.cuda.empty_cache() + + # Record initial memory + initial_cpu_memory = get_process_memory_gb() + initial_gpu_memory = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"\nInitial CPU memory: {initial_cpu_memory:.2f} GB") + print(f"Initial GPU memory: {initial_gpu_memory:.2f} GB") + + # Create a 10GB model + print("Creating 10GB model...") + model = LargeModel(size_gb=10) + + # Verify model size + model_size = get_model_size_gb(model) + print(f"Model size: {model_size:.2f} GB") + assert model_size >= 9.5, f"Model size {model_size:.2f} GB is less than expected 10 GB" + + # Move model to CUDA + print("Moving model to CUDA...") + model = model.cuda() + torch.cuda.synchronize() + + # Memory after moving to CUDA + cpu_after_cuda = get_process_memory_gb() + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after moving to CUDA: {cpu_after_cuda:.2f} GB") + print(f"GPU memory after moving to CUDA: {gpu_after_cuda:.2f} GB") + + # Convert to mmap (this should move model from GPU to disk-backed memory) + # Note: model_to_mmap modifies the model in-place via _apply() + # so model and model_mmap will be the same object + print("Converting model to mmap...") + model_mmap = model_to_mmap(model) + + # Verify that model and model_mmap are the same object (in-place modification) + assert model is model_mmap, "model_to_mmap should modify the model in-place" + + # Force garbage collection and clear CUDA cache + # The original CUDA tensors should be automatically freed when replaced + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Memory after mmap conversion + cpu_after_mmap = get_process_memory_gb() + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after mmap: {cpu_after_mmap:.2f} GB") + print(f"GPU memory after mmap: {gpu_after_mmap:.2f} GB") + + # Calculate memory changes from CUDA state (the baseline we're converting from) + cpu_increase = cpu_after_mmap - cpu_after_cuda + gpu_decrease = gpu_after_cuda - gpu_after_mmap # Should be positive (freed) + print(f"\nCPU memory increase from CUDA: {cpu_increase:.2f} GB") + print(f"GPU memory freed: {gpu_decrease:.2f} GB") + + # Verify that CPU memory usage increase is less than 1GB + # The mmap should use disk-backed storage, keeping CPU RAM usage low + # We use 1.5 GB threshold to account for overhead + assert cpu_increase < 1.5, ( + f"CPU memory increase after mmap ({cpu_increase:.2f} GB) should be less than 1.5 GB. " + f"CUDA state: {cpu_after_cuda:.2f} GB, After mmap: {cpu_after_mmap:.2f} GB" + ) + + # Verify that GPU memory has been freed + # We expect at least 9 GB to be freed (original 10GB model with some tolerance) + assert gpu_decrease > 9.0, ( + f"GPU memory should be freed after mmap. " + f"Freed: {gpu_decrease:.2f} GB (from {gpu_after_cuda:.2f} to {gpu_after_mmap:.2f} GB), expected > 9 GB" + ) + + # Verify the model is still functional (basic sanity check) + assert model_mmap is not None + assert len(list(model_mmap.parameters())) > 0 + + print(f"\n✓ Test passed!") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 1.5 GB") + print(f" GPU memory freed: {gpu_decrease:.2f} GB > 9.0 GB") + print(f" Model successfully offloaded from GPU to disk-backed memory") + + # Cleanup (model and model_mmap are the same object) + del model, model_mmap + gc.collect() + torch.cuda.empty_cache() + + +def test_to_mmap_cuda_cycle(): + """Test CUDA -> mmap -> CUDA cycle + + This test verifies: + 1. CUDA tensor can be converted to mmap tensor + 2. CPU memory increase is minimal when using mmap (< 0.1 GB) + 3. GPU memory is freed when converting to mmap + 4. mmap tensor can be moved back to CUDA + 5. Data remains consistent throughout the cycle + 6. mmap file is automatically cleaned up via garbage collection + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection + gc.collect() + torch.cuda.empty_cache() + + print("\nTest: CUDA -> mmap -> CUDA cycle") + + # Record initial CPU memory + initial_cpu_memory = get_process_memory_gb() + print(f"Initial CPU memory: {initial_cpu_memory:.2f} GB") + + # Step 1: Create a CUDA tensor + print("\n1. Creating CUDA tensor...") + original_data = torch.randn(5000, 5000).cuda() + original_sum = original_data.sum().item() + print(f" Shape: {original_data.shape}") + print(f" Device: {original_data.device}") + print(f" Sum: {original_sum:.2f}") + + # Record GPU and CPU memory after CUDA allocation + cpu_after_cuda = get_process_memory_gb() + gpu_before_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_before_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_cuda:.2f} GB") + + # Step 2: Convert to mmap tensor + print("\n2. Converting to mmap tensor...") + mmap_tensor = to_mmap(original_data) + del original_data + gc.collect() + torch.cuda.empty_cache() + + print(f" Device: {mmap_tensor.device}") + print(f" Sum: {mmap_tensor.sum().item():.2f}") + + # Verify GPU memory is freed + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + cpu_after_mmap = get_process_memory_gb() + print(f" GPU memory freed: {gpu_before_mmap - gpu_after_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_mmap:.2f} GB") + + # Verify GPU memory is freed + assert gpu_after_mmap < 0.1, f"GPU memory should be freed, but {gpu_after_mmap:.2f} GB still allocated" + + # Verify CPU memory increase is minimal (should be close to 0 due to mmap) + cpu_increase = cpu_after_mmap - cpu_after_cuda + print(f" CPU memory increase: {cpu_increase:.2f} GB") + assert cpu_increase < 0.1, f"CPU memory should increase minimally, but increased by {cpu_increase:.2f} GB" + + # Get the temp file path (we'll check if it gets cleaned up) + # The file should exist at this point + temp_files_before = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files exist: {temp_files_before}") + + # Step 3: Move back to CUDA + print("\n3. Moving back to CUDA...") + cuda_tensor = mmap_tensor.to('cuda') + torch.cuda.synchronize() + + print(f" Device: {cuda_tensor.device}") + final_sum = cuda_tensor.sum().item() + print(f" Sum: {final_sum:.2f}") + + # Verify GPU memory is used again + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_after_cuda:.2f} GB") + + # Step 4: Verify data consistency + print("\n4. Verifying data consistency...") + sum_diff = abs(original_sum - final_sum) + print(f" Original sum: {original_sum:.2f}") + print(f" Final sum: {final_sum:.2f}") + print(f" Difference: {sum_diff:.6f}") + assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" + + # Step 5: Verify file cleanup (delayed until garbage collection) + print("\n5. Verifying file cleanup...") + # Delete the mmap tensor reference to trigger garbage collection + del mmap_tensor + gc.collect() + import time + time.sleep(0.1) # Give OS time to clean up + temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files after GC: {temp_files_after}") + # File should be cleaned up after garbage collection + assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection" + + print("\n✓ Test passed!") + print(" CUDA -> mmap -> CUDA cycle works correctly") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") + print(" Data consistency maintained") + print(" File cleanup successful (via garbage collection)") + + # Cleanup + del cuda_tensor # mmap_tensor already deleted in Step 5 + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Run the tests directly + test_model_to_mmap_memory_efficiency() + test_to_mmap_cuda_cycle() +