mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 11:03:00 +08:00
mm: make model offloading deffered with weakrefs
RAMPressure caching may ned to purge the same model that you are currently trying to offload for VRAM freeing. In this case, RAMPressure cache takes priority and needs to be able to pull the trigger on dumping the whole model and freeing the ModelPatcher in question. To do this, defer the actual tranfer of model weights from GPU to RAM to model_management state and not as part of ModelPatcher. This is dones as a list of weakrefs. If RAM cache decides to free to model you are currently unloading, then the ModelPatcher and refs simply dissappear in the middle of the unloading process, and both RAM and VRAM will be freed. The unpatcher now queues the individual leaf modules to be offloaded one-by-one so that RAM levels can be monitored. Note that the UnloadPartially that is potentially done as part of a load will not be freeable this way, however it shouldn't be anyway as that is the currently active model and RAM cache cannot save you if you cant even fit the one model you are currently trying to use.
This commit is contained in:
parent
07d7cd9618
commit
99bed5e19f
@ -535,12 +535,17 @@ class LoadedModel:
|
||||
return False
|
||||
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
if self.model is None:
|
||||
return True
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
offload_modules(modules_to_offload, self.model.offload_device)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
if self.model is not None:
|
||||
modules_to_offload = self.model.detach(unpatch_weights)
|
||||
offload_modules(modules_to_offload, self.model.offload_device)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
@ -592,6 +597,13 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def offload_modules(modules, offload_device):
|
||||
for module in modules:
|
||||
if module() is None:
|
||||
continue
|
||||
module().to(offload_device)
|
||||
free_ram()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
|
||||
@ -24,6 +24,7 @@ import inspect
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
import weakref
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -830,6 +831,7 @@ class ModelPatcher:
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
self.eject_model()
|
||||
modules_to_move = []
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
self.unpin_all_weights()
|
||||
@ -854,7 +856,8 @@ class ModelPatcher:
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ]
|
||||
modules_to_move.append(weakref.ref(self.model))
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
@ -868,12 +871,14 @@ class ModelPatcher:
|
||||
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
return modules_to_move
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
modules_to_move = []
|
||||
unload_list = self._load_list()
|
||||
unload_list.sort()
|
||||
offload_buffer = self.model.model_offload_buffer_memory
|
||||
@ -910,7 +915,7 @@ class ModelPatcher:
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
cast_weight = self.force_cast_weights
|
||||
m.to(device_to)
|
||||
modules_to_move.append(weakref.ref(m))
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
@ -946,20 +951,22 @@ class ModelPatcher:
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
return memory_freed, modules_to_move
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
||||
# TODO: force_patch_weights should not unload + reload full model
|
||||
used = self.model.model_loaded_weight_memory
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
||||
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
||||
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
|
||||
if unpatch_weights:
|
||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||
|
||||
self.patch_model(load_weights=False)
|
||||
if extra_memory < 0 and not unpatch_weights:
|
||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
_, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
|
||||
return 0
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
@ -971,7 +978,7 @@ class ModelPatcher:
|
||||
try:
|
||||
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||
except Exception as e:
|
||||
self.detach()
|
||||
comfy.model_management.offload_modules(self.detach(), self.offload_device)
|
||||
raise e
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
@ -979,11 +986,12 @@ class ModelPatcher:
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
self.model_patches_to(self.offload_device)
|
||||
modules_to_offload = []
|
||||
if unpatch_all:
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
||||
callback(self, unpatch_all)
|
||||
return self.model
|
||||
return modules_to_offload
|
||||
|
||||
def current_loaded_device(self):
|
||||
return self.model.device
|
||||
|
||||
Loading…
Reference in New Issue
Block a user