From 99bed5e19fe831bb80b26420e96c646d26c0fc9b Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 18 Nov 2025 09:42:56 +1000 Subject: [PATCH] mm: make model offloading deffered with weakrefs RAMPressure caching may ned to purge the same model that you are currently trying to offload for VRAM freeing. In this case, RAMPressure cache takes priority and needs to be able to pull the trigger on dumping the whole model and freeing the ModelPatcher in question. To do this, defer the actual tranfer of model weights from GPU to RAM to model_management state and not as part of ModelPatcher. This is dones as a list of weakrefs. If RAM cache decides to free to model you are currently unloading, then the ModelPatcher and refs simply dissappear in the middle of the unloading process, and both RAM and VRAM will be freed. The unpatcher now queues the individual leaf modules to be offloaded one-by-one so that RAM levels can be monitored. Note that the UnloadPartially that is potentially done as part of a load will not be freeable this way, however it shouldn't be anyway as that is the currently active model and RAM cache cannot save you if you cant even fit the one model you are currently trying to use. --- comfy/model_management.py | 16 ++++++++++++++-- comfy/model_patcher.py | 24 ++++++++++++++++-------- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 18a700905..5a7f23e30 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -535,12 +535,17 @@ class LoadedModel: return False def model_unload(self, memory_to_free=None, unpatch_weights=True): + if self.model is None: + return True if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free) + offload_modules(modules_to_offload, self.model.offload_device) if freed >= memory_to_free: return False - self.model.detach(unpatch_weights) + if self.model is not None: + modules_to_offload = self.model.detach(unpatch_weights) + offload_modules(modules_to_offload, self.model.offload_device) self.model_finalizer.detach() self.model_finalizer = None self.real_model = None @@ -592,6 +597,13 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() +def offload_modules(modules, offload_device): + for module in modules: + if module() is None: + continue + module().to(offload_device) + free_ram() + def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() unloaded_model = [] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac77275..078c23019 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ import inspect import logging import math import uuid +import weakref from typing import Callable, Optional import torch @@ -830,6 +831,7 @@ class ModelPatcher: def unpatch_model(self, device_to=None, unpatch_weights=True): self.eject_model() + modules_to_move = [] if unpatch_weights: self.unpatch_hooks() self.unpin_all_weights() @@ -854,7 +856,8 @@ class ModelPatcher: self.backup.clear() if device_to is not None: - self.model.to(device_to) + modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ] + modules_to_move.append(weakref.ref(self.model)) self.model.device = device_to self.model.model_loaded_weight_memory = 0 self.model.model_offload_buffer_memory = 0 @@ -868,12 +871,14 @@ class ModelPatcher: comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) self.object_patches_backup.clear() + return modules_to_move def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 patch_counter = 0 + modules_to_move = [] unload_list = self._load_list() unload_list.sort() offload_buffer = self.model.model_offload_buffer_memory @@ -910,7 +915,7 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + modules_to_move.append(weakref.ref(m)) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: @@ -946,20 +951,22 @@ class ModelPatcher: self.model.model_loaded_weight_memory -= memory_freed self.model.model_offload_buffer_memory = offload_buffer logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) - return memory_freed + return memory_freed, modules_to_move def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): with self.use_ejected(skip_and_inject_on_exit_only=True): unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) # TODO: force_patch_weights should not unload + reload full model used = self.model.model_loaded_weight_memory - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + comfy.model_management.offload_modules(modules_to_offload, self.offload_device) if unpatch_weights: extra_memory += (used - self.model.model_loaded_weight_memory) self.patch_model(load_weights=False) if extra_memory < 0 and not unpatch_weights: - self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + _, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + comfy.model_management.offload_modules(modules_to_offload, self.offload_device) return 0 full_load = False if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: @@ -971,7 +978,7 @@ class ModelPatcher: try: self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) except Exception as e: - self.detach() + comfy.model_management.offload_modules(self.detach(), self.offload_device) raise e return self.model.model_loaded_weight_memory - current_used @@ -979,11 +986,12 @@ class ModelPatcher: def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) + modules_to_offload = [] if unpatch_all: - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): callback(self, unpatch_all) - return self.model + return modules_to_offload def current_loaded_device(self): return self.model.device