From 97189bf6bb7a095accb62be4dff137a8b71cb9e0 Mon Sep 17 00:00:00 2001 From: ifilipis <40601736+ifilipis@users.noreply.github.com> Date: Fri, 9 Jan 2026 00:38:37 +0200 Subject: [PATCH] Integrate disk offload into memory management --- comfy/disk_weights.py | 75 +++++++++++++++++++++++++++++++++++++-- comfy/model_base.py | 10 +++++- comfy/model_management.py | 44 +++++++++++++++++++++-- comfy/model_patcher.py | 23 ++++++++---- 4 files changed, 139 insertions(+), 13 deletions(-) diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index bceef4c30..fb9186db4 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -372,7 +372,11 @@ def _device_free_memory(device: torch.device) -> int: def _evict_ram_for_budget(required_bytes: int) -> int: if required_bytes <= 0: return 0 - return evict_ram_cache(required_bytes) + freed = evict_ram_cache(required_bytes) + if freed < required_bytes: + from . import model_management + freed += model_management.evict_ram_to_disk(required_bytes - freed) + return freed def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int: @@ -654,6 +658,16 @@ def _find_tensor_dtype(args, kwargs) -> Optional[torch.dtype]: return check(kwargs) +def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype: Optional[torch.dtype]) -> Optional[torch.dtype]: + if manual_cast_dtype is not None: + return manual_cast_dtype + if input_dtype is None: + return None + if torch.is_floating_point(torch.empty((), dtype=input_dtype)): + return input_dtype + return None + + def ensure_module_materialized( module: torch.nn.Module, target_device: torch.device, @@ -744,7 +758,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs): return input_dtype = _find_tensor_dtype(args, kwargs) manual_cast_dtype = getattr(module, "manual_cast_dtype", None) - dtype_override = manual_cast_dtype or input_dtype + dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype) if getattr(module, "comfy_cast_weights", False): target_device = torch.device("cpu") fallback_device = _find_tensor_device(args, kwargs) @@ -793,6 +807,15 @@ def _extract_to_device(args, kwargs) -> Optional[torch.device]: return None +def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]: + if "dtype" in kwargs and kwargs["dtype"] is not None: + return kwargs["dtype"] + for arg in args: + if isinstance(arg, torch.dtype): + return arg + return None + + def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]: for param in module.parameters(recurse=True): if param is not None and param.device.type != "meta": @@ -803,12 +826,58 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]: return None +def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None): + def _move(tensor): + if tensor is None: + return None + if tensor.device.type == "meta": + return tensor + if dtype_override is not None and tensor.dtype != dtype_override: + return tensor.to(device=device_to, dtype=dtype_override) + return tensor.to(device=device_to) + + module._apply(_move) + return module + + +def offload_module_weights(module: torch.nn.Module) -> int: + if not disk_weights_enabled(): + return 0 + refs = REGISTRY.get(module) + if not refs: + return 0 + offloaded_bytes = 0 + if module in LAZY_MODULE_STATE: + ref_name = next(iter(refs.keys()), None) + if ref_name is not None: + _evict_module_weight(module, ref_name, False) + for disk_ref in refs.values(): + nbytes = _meta_nbytes(disk_ref.meta) + if nbytes is not None: + offloaded_bytes += nbytes + return offloaded_bytes + for name, disk_ref in refs.items(): + _evict_module_weight(module, name, disk_ref.is_buffer) + nbytes = _meta_nbytes(disk_ref.meta) + if nbytes is not None: + offloaded_bytes += nbytes + return offloaded_bytes + + def module_to(module: torch.nn.Module, *args, **kwargs): + allow_materialize = kwargs.pop("allow_materialize", True) if disk_weights_enabled(): target_device = _extract_to_device(args, kwargs) if target_device is None: target_device = _find_existing_device(module) or torch.device("cpu") - materialize_module_tree(module, target_device) + if target_device.type == "meta": + offload_module_weights(module) + return module + if allow_materialize: + materialize_module_tree(module, target_device) + return module.to(*args, **kwargs) + dtype_override = _extract_to_dtype(args, kwargs) + return move_module_tensors(module, target_device, dtype_override=dtype_override) return module.to(*args, **kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3bb155f2c..84766591b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -56,6 +56,7 @@ import comfy.conds import comfy.ops from enum import Enum from . import utils +from . import safetensors_stream import comfy.latent_formats import comfy.model_sampling import math @@ -299,7 +300,14 @@ class BaseModel(torch.nn.Module): return out def load_model_weights(self, sd, unet_prefix=""): - to_load = utils.state_dict_prefix_replace(sd, {unet_prefix: ""}, filter_keys=True) + replace_prefix = {unet_prefix: ""} if unet_prefix else {} + if replace_prefix: + if utils.is_stream_state_dict(sd): + to_load = utils.state_dict_prefix_replace(sd, replace_prefix, filter_keys=True) + else: + to_load = safetensors_stream.RenameViewStateDict(sd, replace_prefix, filter_keys=True, mutate_base=False) + else: + to_load = sd to_load = self.model_config.process_unet_state_dict(to_load) m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False) if len(m) > 0: diff --git a/comfy/model_management.py b/comfy/model_management.py index f389ad857..f32c77583 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -530,7 +530,12 @@ class LoadedModel: freed = self.model.partially_unload(self.model.offload_device, memory_to_free) if freed >= memory_to_free: return False - self.model.detach(unpatch_weights) + offload_device = None + if comfy.disk_weights.disk_weights_enabled(): + offload_device = torch.device("meta") + self.model.detach(unpatch_weights, offload_device=offload_device) + if offload_device is not None and offload_device.type == "meta": + logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk") self.model_finalizer.detach() self.model_finalizer = None self.real_model = None @@ -585,7 +590,9 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled(): - comfy.disk_weights.evict_ram_cache(memory_required) + freed_cache = comfy.disk_weights.evict_ram_cache(memory_required) + if freed_cache < memory_required: + evict_ram_to_disk(memory_required - freed_cache) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -621,6 +628,34 @@ def free_memory(memory_required, device, keep_loaded=[]): soft_empty_cache() return unloaded_models + +def evict_ram_to_disk(memory_to_free, keep_loaded=[]): + if memory_to_free <= 0: + return 0 + if not comfy.disk_weights.disk_weights_enabled(): + return 0 + + freed = 0 + can_unload = [] + for i in range(len(current_loaded_models) - 1, -1, -1): + shift_model = current_loaded_models[i] + if shift_model not in keep_loaded and not shift_model.is_dead(): + loaded_memory = shift_model.model_loaded_memory() + if loaded_memory > 0: + can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) + + for x in sorted(can_unload): + i = x[-1] + memory_needed = memory_to_free - freed + if memory_needed <= 0: + break + logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk") + freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed) + + if freed > 0: + logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024))) + return freed + def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): cleanup_models_gc() global vram_state @@ -1293,7 +1328,10 @@ def get_free_memory(dev=None, torch_free_too=False): if dev is None: dev = get_torch_device() - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + if hasattr(dev, 'type') and dev.type == "meta": + mem_free_total = sys.maxsize + mem_free_torch = mem_free_total + elif hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c2427490d..d709b466e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -857,7 +857,7 @@ class ModelPatcher: self.backup.clear() if device_to is not None: - comfy.disk_weights.module_to(self.model, device_to) + comfy.disk_weights.module_to(self.model, device_to, allow_materialize=False) self.model.device = device_to self.model.model_loaded_weight_memory = 0 self.model.model_offload_buffer_memory = 0 @@ -917,7 +917,16 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + freed_bytes = module_mem + if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled(): + freed_bytes = comfy.disk_weights.offload_module_weights(m) + if freed_bytes == 0: + freed_bytes = module_mem + else: + if comfy.disk_weights.disk_weights_enabled(): + comfy.disk_weights.move_module_tensors(m, device_to) + else: + m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: @@ -940,7 +949,7 @@ class ModelPatcher: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False - memory_freed += module_mem + memory_freed += freed_bytes offload_buffer = max(offload_buffer, potential_offload) offload_weight_factor.append(module_mem) offload_weight_factor.pop(0) @@ -954,7 +963,8 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed self.model.model_offload_buffer_memory = offload_buffer - logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) + target_label = "disk" if device_to is not None and device_to.type == "meta" else device_to + logging.info("Unloaded partially to {}: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(target_label, memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): @@ -985,11 +995,12 @@ class ModelPatcher: return self.model.model_loaded_weight_memory - current_used - def detach(self, unpatch_all=True): + def detach(self, unpatch_all=True, offload_device=None): self.eject_model() self.model_patches_to(self.offload_device) + target_device = self.offload_device if offload_device is None else offload_device if unpatch_all: - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + self.unpatch_model(target_device, unpatch_weights=unpatch_all) for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): callback(self, unpatch_all) return self.model