Merge branch 'improved_memory' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-11-19 11:06:27 -08:00
commit 9d20de6462
5 changed files with 148 additions and 134 deletions

View File

@ -573,7 +573,7 @@ class PromptExecutor:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
model_management.cleanup_models(keep_clone_weights_loaded=True)
model_management.cleanup_models_gc()
self.add_message("execution_cached",
{"nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)

View File

@ -22,9 +22,10 @@ import logging
import platform
import sys
import warnings
import weakref
from enum import Enum
from threading import RLock
from typing import Literal, List, Sequence, Final
from typing import List, Sequence, Final
import psutil
import torch
@ -338,11 +339,27 @@ def module_size(module):
class LoadedModel:
def __init__(self, model: ModelManageable):
self.model = model
self._set_model(model)
self.device = model.load_device
self.weights_loaded = False
self.real_model = None
self.currently_used = True
self.model_finalizer = None
self._patcher_finalizer = None
def _set_model(self, model):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
def _switch_parent(self):
model = self._parent_model()
if model is not None:
self._set_model(model)
@property
def model(self):
return self._model()
def model_memory(self):
return self.model.model_size()
@ -357,32 +374,23 @@ class LoadedModel:
return self.model_memory()
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
patch_model_to = self.device
self.model.model_patches_to(self.device)
self.model.model_patches_to(self.model.model_dtype())
load_weights = not self.weights_loaded
# if self.model.loaded_size() > 0:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model
if self.model.loaded_size() > 0:
use_more_vram = lowvram_model_memory
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram)
else:
try:
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.weights_loaded = True
return self.real_model
self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
return real_model
def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
@ -395,14 +403,14 @@ class LoadedModel:
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
def model_use_more_vram(self, extra_memory):
return self.model.partially_load(self.device, extra_memory)
def model_use_more_vram(self, extra_memory, force_patch_weights=False):
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
def __eq__(self, other):
return self.model is other.model
@ -413,6 +421,10 @@ class LoadedModel:
else:
return f"<LoadedModel>"
def __del__(self):
if self._patcher_finalizer is not None:
self._patcher_finalizer.detach()
def use_more_memory(extra_memory, loaded_models, device):
for m in loaded_models:
@ -449,43 +461,6 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
with model_management_lock:
return _unload_model_clones(model, unload_weights_only, force_unload)
def _unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
to_unload = []
for i in range(len(current_loaded_models)):
if model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return True
same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1
if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True
if not force_unload:
if unload_weights_only and unload_weight == False:
return None
else:
unload_weight = True
for i in to_unload:
logging.debug("unload clone {} {}".format(i, unload_weight))
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)
return unload_weight
@tracer.start_as_current_span("Free Memory")
def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
span = get_current_span()
@ -496,7 +471,8 @@ def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
return unloaded_models
def _free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
def _free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = []
can_unload = []
unloaded_models = []
@ -546,6 +522,7 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
cleanup_models_gc()
global vram_state
inference_memory = minimum_inference_memory()
@ -558,12 +535,9 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
models = set(models)
models_to_load = []
models_already_loaded = []
models_freed = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
try:
loaded_model_index = current_loaded_models.index(loaded_model)
except:
@ -571,46 +545,34 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
loaded.currently_used = True
models_already_loaded.append(loaded)
if loaded is None:
loaded.currently_used = True
models_to_load.append(loaded)
else:
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
free_mem = get_free_memory(d)
if free_mem < minimum_memory_required:
models_to_load = free_memory(minimum_memory_required, d)
models_freed += models_to_load
else:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
if len(models_to_load) == 0:
return
for loaded_model in models_to_load:
to_unload = []
for i in range(len(current_loaded_models)):
if loaded_model.model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload
for i in to_unload:
current_loaded_models.pop(i).model.detach(unpatch_all=False)
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) # unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_already_loaded:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) # unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded
for device in total_memory_required:
if device != torch.device("cpu"):
models_freed += free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
models_freed += free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
for device in total_memory_required:
if device != torch.device("cpu"):
free_mem = get_free_memory(device)
if free_mem < minimum_memory_required:
models_l = free_memory(minimum_memory_required, device)
models_freed += models_l
logging.debug("{} models unloaded.".format(len(models_l)))
for loaded_model in models_to_load:
model = loaded_model.model
@ -633,13 +595,6 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_mem = get_free_memory(d)
if free_mem > minimum_memory_required:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
span = get_current_span()
span.set_attribute("models_to_load", list(map(str, models_to_load)))
span.set_attribute("models_freed", list(map(str, models_freed)))
@ -662,23 +617,34 @@ def loaded_models(only_currently_used=False):
return output
def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock:
to_delete = []
for i in range(len(current_loaded_models)):
# TODO: very fragile function needs improvement
num_refs = sys.getrefcount(current_loaded_models[i].model)
if num_refs <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
# TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: # references from .real_model + the .model
to_delete = [i] + to_delete
def cleanup_models_gc():
do_gc = False
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.real_model() is not None and cur.model is None:
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
do_gc = True
break
for i in to_delete:
x = current_loaded_models.pop(i)
x.model_unload()
del x
if do_gc:
gc.collect()
soft_empty_cache()
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.real_model() is not None and cur.model is None:
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
if current_loaded_models[i].real_model() is None:
to_delete = [i] + to_delete
for i in to_delete:
x = current_loaded_models.pop(i)
del x
def dtype_size(dtype):
@ -747,7 +713,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, tor
pass
if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
if supports_fp8_compute(device): # if fp8 compute is supported the casting is most likely not expensive
return fp8_dtype
free_model_memory = maximum_vram_for_weights(device)
@ -960,6 +926,7 @@ def force_channels_last():
# TODO
return False
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
@ -971,12 +938,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
FLASH_ATTENTION_ENABLED = False
if not args.disable_flash_attn:
try:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import dataclasses
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable
from typing import Protocol, Optional, TypeVar, runtime_checkable, Callable, Any
import torch
import torch.nn
@ -71,11 +71,11 @@ class ModelManageable(Protocol):
def lowvram_patch_counter(self) -> int:
return 0
def partially_load(self, device_to: torch.device, extra_memory=0) -> int:
def partially_load(self, device_to: torch.device, extra_memory: int = 0, force_patch_weights: bool = False):
self.patch_model(device_to=device_to)
return self.model_size()
def partially_unload(self, device_to: torch.device, extra_memory=0) -> int:
def partially_unload(self, device_to: torch.device, memory_to_free: int = 0):
self.unpatch_model(device_to)
return self.model_size()
@ -113,6 +113,16 @@ class ModelManageable(Protocol):
def __del__(self):
del self.model
@property
def parent(self) -> ModelManageableT | None:
return None
def detach(self, unpatch_all: bool = True):
self.model_patches_to(self.offload_device)
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
return self.model
@dataclasses.dataclass
class MemoryMeasurements:
@ -120,6 +130,7 @@ class MemoryMeasurements:
model_loaded_weight_memory: int = 0
lowvram_patch_counter: int = 0
model_lowvram: bool = False
current_weight_patches_uuid: Any = None
_device: torch.device | None = None
@property

View File

@ -146,6 +146,7 @@ class ModelPatcher(ModelManageable):
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self._parent: ModelManageable | None = None
self.patches_uuid: uuid.UUID = uuid.uuid4()
self.ckpt_name = ckpt_name
self._memory_measurements = MemoryMeasurements(self.model)
@ -166,6 +167,18 @@ class ModelPatcher(ModelManageable):
def model_device(self, value: torch.device):
self._memory_measurements.device = value
@property
def current_weight_patches_uuid(self) -> Optional[uuid.UUID]:
return self._memory_measurements.current_weight_patches_uuid
@current_weight_patches_uuid.setter
def current_weight_patches_uuid(self, value):
self._memory_measurements.current_weight_patches_uuid = value
@property
def parent(self) -> Optional["ModelPatcher"]:
return self._parent
def lowvram_patch_counter(self):
return self._memory_measurements.lowvram_patch_counter
@ -191,6 +204,7 @@ class ModelPatcher(ModelManageable):
n._model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n._parent = self
return n
def is_clone(self, other):
@ -484,6 +498,7 @@ class ModelPatcher(ModelManageable):
self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = mem_counter
self._memory_measurements.current_weight_patches_uuid = self.patches_uuid
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches:
@ -518,6 +533,7 @@ class ModelPatcher(ModelManageable):
else:
utils.set_attr_param(self.model, k, bk.weight)
self._memory_measurements.current_weight_patches_uuid = None
self.backup.clear()
if device_to is not None:
@ -585,18 +601,35 @@ class ModelPatcher(ModelManageable):
self._memory_measurements.model_loaded_weight_memory -= memory_freed
return memory_freed
def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
unpatch_weights = self._memory_measurements.current_weight_patches_uuid is not None and (self._memory_measurements.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model
used = self._memory_measurements.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
if unpatch_weights:
extra_memory += (used - self._memory_measurements.model_loaded_weight_memory)
self.patch_model(load_weights=False)
full_load = False
if not self._memory_measurements.model_lowvram:
if not self._memory_measurements.model_lowvram and self._memory_measurements.model_loaded_weight_memory > 0:
return 0
if self._memory_measurements.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
current_used = self._memory_measurements.model_loaded_weight_memory
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
try:
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
except Exception as e:
self.detach()
raise e
return self._memory_measurements.model_loaded_weight_memory - current_used
def detach(self, unpatch_all=True):
self.model_patches_to(self.offload_device)
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
return self.model
def current_loaded_device(self):
return self.model_device
@ -618,3 +651,6 @@ class ModelPatcher(ModelManageable):
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def __del__(self):
self.detach(unpatch_all=False)

View File

@ -2,7 +2,7 @@
"3": {
"inputs": {
"seed": 309794859719915,
"steps": 30,
"steps": 1,
"cfg": 4.5,
"sampler_name": "euler",
"scheduler": "simple",