mm: make model offloading deffered with weakrefs

RAMPressure caching may ned to purge the same model that you are
currently trying to offload for VRAM freeing. In this case, RAMPressure
cache takes priority and needs to be able to pull the trigger on dumping
the whole model and freeing the ModelPatcher in question. To do this,
defer the actual tranfer of model weights from GPU to RAM to
model_management state and not as part of ModelPatcher. This is dones as
a list of weakrefs.

If RAM cache decides to free to model you are currently unloading, then
the ModelPatcher and refs simply dissappear in the middle of the
unloading process, and both RAM and VRAM will be freed.

The unpatcher now queues the individual leaf modules to be offloaded
one-by-one so that RAM levels can be monitored.

Note that the UnloadPartially that is potentially done as part of a
load will not be freeable this way, however it shouldn't be anyway as
that is the currently active model and RAM cache cannot save you if
you cant even fit the one model you are currently trying to use.
This commit is contained in:
Rattus 2025-11-18 09:42:56 +10:00
parent 07d7cd9618
commit 99bed5e19f
2 changed files with 30 additions and 10 deletions

View File

@ -535,12 +535,17 @@ class LoadedModel:
return False
def model_unload(self, memory_to_free=None, unpatch_weights=True):
if self.model is None:
return True
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free)
offload_modules(modules_to_offload, self.model.offload_device)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
if self.model is not None:
modules_to_offload = self.model.detach(unpatch_weights)
offload_modules(modules_to_offload, self.model.offload_device)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
@ -592,6 +597,13 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def offload_modules(modules, offload_device):
for module in modules:
if module() is None:
continue
module().to(offload_device)
free_ram()
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = []

View File

@ -24,6 +24,7 @@ import inspect
import logging
import math
import uuid
import weakref
from typing import Callable, Optional
import torch
@ -830,6 +831,7 @@ class ModelPatcher:
def unpatch_model(self, device_to=None, unpatch_weights=True):
self.eject_model()
modules_to_move = []
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
@ -854,7 +856,8 @@ class ModelPatcher:
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ]
modules_to_move.append(weakref.ref(self.model))
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
@ -868,12 +871,14 @@ class ModelPatcher:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup.clear()
return modules_to_move
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
modules_to_move = []
unload_list = self._load_list()
unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory
@ -910,7 +915,7 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
modules_to_move.append(weakref.ref(m))
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
@ -946,20 +951,22 @@ class ModelPatcher:
self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
return memory_freed, modules_to_move
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
with self.use_ejected(skip_and_inject_on_exit_only=True):
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model
used = self.model.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
_, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
@ -971,7 +978,7 @@ class ModelPatcher:
try:
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
except Exception as e:
self.detach()
comfy.model_management.offload_modules(self.detach(), self.offload_device)
raise e
return self.model.model_loaded_weight_memory - current_used
@ -979,11 +986,12 @@ class ModelPatcher:
def detach(self, unpatch_all=True):
self.eject_model()
self.model_patches_to(self.offload_device)
modules_to_offload = []
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
callback(self, unpatch_all)
return self.model
return modules_to_offload
def current_loaded_device(self):
return self.model.device