mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Better memory trimming and group_offloading logic
This commit is contained in:
parent
7ed9292532
commit
d9269785d3
@ -398,7 +398,7 @@ try:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
||||
SUPPORT_FP8_OPS = True
|
||||
@ -641,6 +641,42 @@ def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
|
||||
def trim_memory() -> bool:
|
||||
"""
|
||||
Trims memory usage, returning reserved memory to the system
|
||||
|
||||
Only supported on Windows and Linux
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if sys.platform.startswith('linux'):
|
||||
import ctypes
|
||||
libc_path = ctypes.util.find_library('c')
|
||||
if not libc_path:
|
||||
return False
|
||||
|
||||
libc = ctypes.CDLL(libc_path)
|
||||
|
||||
if hasattr(libc, 'malloc_trim'):
|
||||
return libc.malloc_trim(0) == 1
|
||||
else:
|
||||
return False
|
||||
elif sys.platform == 'win32':
|
||||
import ctypes.wintypes
|
||||
kernel32 = ctypes.WinDLL("kernel32")
|
||||
EmptyProcessWorkingSet = kernel32.EmptyProcessWorkingSet
|
||||
EmptyProcessWorkingSet.argtypes = [ctypes.wintypes.HANDLE]
|
||||
EmptyProcessWorkingSet.restype = ctypes.wintypes.BOOL
|
||||
handle = -1
|
||||
success = EmptyProcessWorkingSet(handle)
|
||||
return bool(success)
|
||||
else:
|
||||
return False
|
||||
except Exception as exc_info:
|
||||
logger.warning("failed to trim", exc_info=exc_info)
|
||||
return False
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Free Memory")
|
||||
def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
|
||||
span = get_current_span()
|
||||
@ -1593,6 +1629,7 @@ def _soft_empty_cache(force=False):
|
||||
def unload_all_models():
|
||||
with model_management_lock:
|
||||
free_memory(1e30, get_torch_device())
|
||||
trim_memory()
|
||||
|
||||
|
||||
@_deprecate_method(version="*", message="The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import logging
|
||||
from diffusers import HookRegistry
|
||||
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
|
||||
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_management import vram_state, VRAMState
|
||||
from comfy.model_management import vram_state, VRAMState, unload_all_models, get_free_memory, get_torch_device
|
||||
from comfy.model_management_types import HooksSupport, ModelManageable
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
@ -14,6 +15,8 @@ from comfy.rmsnorm import RMSNorm
|
||||
|
||||
_DISABLE_COMFYUI_CASTING_HOOK = "disable_comfyui_casting_hook"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DisableComfyWeightCast(ModelHook):
|
||||
r"""
|
||||
@ -75,6 +78,10 @@ def prepare_group_offloading_factory(load_device: torch.device, offload_device:
|
||||
def wrapper(executor, model: ModelPatcher, *args, **kwargs):
|
||||
# this model will now just be loaded to CPU, since diffusers will manage moving to gpu
|
||||
model.load_device = offload_device
|
||||
|
||||
# we'll have to unload everything to use pinning better, this includes trimming
|
||||
unload_all_models()
|
||||
|
||||
# loads the model, prepares everything
|
||||
inner_model, conds, models = executor(model, *args, **kwargs)
|
||||
|
||||
@ -83,13 +90,23 @@ def prepare_group_offloading_factory(load_device: torch.device, offload_device:
|
||||
raise ValueError("manual casting operations, where the model is loaded in different weights than inference will occur, is not supported")
|
||||
|
||||
# weights are patched, ready to go, inner model will be correctly deleted at the end of sampling
|
||||
model_size = model.model_size()
|
||||
|
||||
model_too_large = model_size * 2 > get_free_memory(torch.cpu)
|
||||
low_vram_state = vram_state in (VRAMState.LOW_VRAM,)
|
||||
is_cuda_device = load_device.type == 'cuda'
|
||||
|
||||
if model_too_large or low_vram_state:
|
||||
logger.error(f"group offloading did not use memory pinning because model_too_large={model_too_large} low_vram_state={low_vram_state}")
|
||||
if not is_cuda_device:
|
||||
logger.error(f"group offloading did not use stream because load_device.type={load_device.type} != \"cuda\"")
|
||||
apply_group_offloading(
|
||||
inner_model.diffusion_model,
|
||||
load_device,
|
||||
offload_device,
|
||||
use_stream=True,
|
||||
record_stream=True,
|
||||
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,),
|
||||
use_stream=is_cuda_device,
|
||||
record_stream=is_cuda_device,
|
||||
low_cpu_mem_usage=low_vram_state or model_too_large,
|
||||
num_blocks_per_group=1
|
||||
)
|
||||
# then the inputs will be ready on the correct device due to the wrapper factory
|
||||
@ -139,7 +156,7 @@ class GroupOffload(CustomNode):
|
||||
num_blocks_per_group=1
|
||||
)
|
||||
elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable):
|
||||
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device))
|
||||
model.add_wrapper_with_key(WrappersMP.PREPARE_SAMPLING, "group_offload", prepare_group_offloading_factory(model.load_device, model.offload_device))
|
||||
return model,
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user