mm: make model offloading deffered with weakrefs

RAMPressure caching may ned to purge the same model that you are
currently trying to offload for VRAM freeing. In this case, RAMPressure
cache takes priority and needs to be able to pull the trigger on dumping
the whole model and freeing the ModelPatcher in question. To do this,
defer the actual tranfer of model weights from GPU to RAM to
model_management state and not as part of ModelPatcher. This is dones as
a list of weakrefs.

If RAM cache decides to free to model you are currently unloading, then
the ModelPatcher and refs simply dissappear in the middle of the
unloading process, and both RAM and VRAM will be freed.

The unpatcher now queues the individual leaf modules to be offloaded
one-by-one so that RAM levels can be monitored.

Note that the UnloadPartially that is potentially done as part of a
load will not be freeable this way, however it shouldn't be anyway as
that is the currently active model and RAM cache cannot save you if
you cant even fit the one model you are currently trying to use.
This commit is contained in:
Rattus 2025-11-18 09:42:56 +10:00
parent 07d7cd9618
commit 99bed5e19f
2 changed files with 30 additions and 10 deletions

View File

@ -535,12 +535,17 @@ class LoadedModel:
return False return False
def model_unload(self, memory_to_free=None, unpatch_weights=True): def model_unload(self, memory_to_free=None, unpatch_weights=True):
if self.model is None:
return True
if memory_to_free is not None: if memory_to_free is not None:
if memory_to_free < self.model.loaded_size(): if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free) freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free)
offload_modules(modules_to_offload, self.model.offload_device)
if freed >= memory_to_free: if freed >= memory_to_free:
return False return False
self.model.detach(unpatch_weights) if self.model is not None:
modules_to_offload = self.model.detach(unpatch_weights)
offload_modules(modules_to_offload, self.model.offload_device)
self.model_finalizer.detach() self.model_finalizer.detach()
self.model_finalizer = None self.model_finalizer = None
self.real_model = None self.real_model = None
@ -592,6 +597,13 @@ def extra_reserved_memory():
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def offload_modules(modules, offload_device):
for module in modules:
if module() is None:
continue
module().to(offload_device)
free_ram()
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc() cleanup_models_gc()
unloaded_model = [] unloaded_model = []

View File

@ -24,6 +24,7 @@ import inspect
import logging import logging
import math import math
import uuid import uuid
import weakref
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
@ -830,6 +831,7 @@ class ModelPatcher:
def unpatch_model(self, device_to=None, unpatch_weights=True): def unpatch_model(self, device_to=None, unpatch_weights=True):
self.eject_model() self.eject_model()
modules_to_move = []
if unpatch_weights: if unpatch_weights:
self.unpatch_hooks() self.unpatch_hooks()
self.unpin_all_weights() self.unpin_all_weights()
@ -854,7 +856,8 @@ class ModelPatcher:
self.backup.clear() self.backup.clear()
if device_to is not None: if device_to is not None:
self.model.to(device_to) modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ]
modules_to_move.append(weakref.ref(self.model))
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0 self.model.model_offload_buffer_memory = 0
@ -868,12 +871,14 @@ class ModelPatcher:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup.clear() self.object_patches_backup.clear()
return modules_to_move
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected(): with self.use_ejected():
hooks_unpatched = False hooks_unpatched = False
memory_freed = 0 memory_freed = 0
patch_counter = 0 patch_counter = 0
modules_to_move = []
unload_list = self._load_list() unload_list = self._load_list()
unload_list.sort() unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory offload_buffer = self.model.model_offload_buffer_memory
@ -910,7 +915,7 @@ class ModelPatcher:
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
if move_weight: if move_weight:
cast_weight = self.force_cast_weights cast_weight = self.force_cast_weights
m.to(device_to) modules_to_move.append(weakref.ref(m))
module_mem += move_weight_functions(m, device_to) module_mem += move_weight_functions(m, device_to)
if lowvram_possible: if lowvram_possible:
if weight_key in self.patches: if weight_key in self.patches:
@ -946,20 +951,22 @@ class ModelPatcher:
self.model.model_loaded_weight_memory -= memory_freed self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed return memory_freed, modules_to_move
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
with self.use_ejected(skip_and_inject_on_exit_only=True): with self.use_ejected(skip_and_inject_on_exit_only=True):
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model # TODO: force_patch_weights should not unload + reload full model
used = self.model.model_loaded_weight_memory used = self.model.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
if unpatch_weights: if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory) extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False) self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights: if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) _, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
return 0 return 0
full_load = False full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
@ -971,7 +978,7 @@ class ModelPatcher:
try: try:
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
except Exception as e: except Exception as e:
self.detach() comfy.model_management.offload_modules(self.detach(), self.offload_device)
raise e raise e
return self.model.model_loaded_weight_memory - current_used return self.model.model_loaded_weight_memory - current_used
@ -979,11 +986,12 @@ class ModelPatcher:
def detach(self, unpatch_all=True): def detach(self, unpatch_all=True):
self.eject_model() self.eject_model()
self.model_patches_to(self.offload_device) self.model_patches_to(self.offload_device)
modules_to_offload = []
if unpatch_all: if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
callback(self, unpatch_all) callback(self, unpatch_all)
return self.model return modules_to_offload
def current_loaded_device(self): def current_loaded_device(self):
return self.model.device return self.model.device