Better memory trimming and group_offloading logic

This commit is contained in:
doctorpangloss 2025-10-21 14:27:26 -07:00
parent 7ed9292532
commit d9269785d3
2 changed files with 60 additions and 6 deletions

View File

@ -398,7 +398,7 @@ try:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0): if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]): if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
SUPPORT_FP8_OPS = True SUPPORT_FP8_OPS = True
@ -641,6 +641,42 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def trim_memory() -> bool:
"""
Trims memory usage, returning reserved memory to the system
Only supported on Windows and Linux
:return:
"""
try:
if sys.platform.startswith('linux'):
import ctypes
libc_path = ctypes.util.find_library('c')
if not libc_path:
return False
libc = ctypes.CDLL(libc_path)
if hasattr(libc, 'malloc_trim'):
return libc.malloc_trim(0) == 1
else:
return False
elif sys.platform == 'win32':
import ctypes.wintypes
kernel32 = ctypes.WinDLL("kernel32")
EmptyProcessWorkingSet = kernel32.EmptyProcessWorkingSet
EmptyProcessWorkingSet.argtypes = [ctypes.wintypes.HANDLE]
EmptyProcessWorkingSet.restype = ctypes.wintypes.BOOL
handle = -1
success = EmptyProcessWorkingSet(handle)
return bool(success)
else:
return False
except Exception as exc_info:
logger.warning("failed to trim", exc_info=exc_info)
return False
@tracer.start_as_current_span("Free Memory") @tracer.start_as_current_span("Free Memory")
def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]: def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
span = get_current_span() span = get_current_span()
@ -1593,6 +1629,7 @@ def _soft_empty_cache(force=False):
def unload_all_models(): def unload_all_models():
with model_management_lock: with model_management_lock:
free_memory(1e30, get_torch_device()) free_memory(1e30, get_torch_device())
trim_memory()
@_deprecate_method(version="*", message="The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.") @_deprecate_method(version="*", message="The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")

View File

@ -1,9 +1,10 @@
import torch import torch
import logging
from diffusers import HookRegistry from diffusers import HookRegistry
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
from comfy.language.transformers_model_management import TransformersManagedModel from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_management import vram_state, VRAMState from comfy.model_management import vram_state, VRAMState, unload_all_models, get_free_memory, get_torch_device
from comfy.model_management_types import HooksSupport, ModelManageable from comfy.model_management_types import HooksSupport, ModelManageable
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.node_helpers import export_custom_nodes from comfy.node_helpers import export_custom_nodes
@ -14,6 +15,8 @@ from comfy.rmsnorm import RMSNorm
_DISABLE_COMFYUI_CASTING_HOOK = "disable_comfyui_casting_hook" _DISABLE_COMFYUI_CASTING_HOOK = "disable_comfyui_casting_hook"
logger = logging.getLogger(__name__)
class DisableComfyWeightCast(ModelHook): class DisableComfyWeightCast(ModelHook):
r""" r"""
@ -75,6 +78,10 @@ def prepare_group_offloading_factory(load_device: torch.device, offload_device:
def wrapper(executor, model: ModelPatcher, *args, **kwargs): def wrapper(executor, model: ModelPatcher, *args, **kwargs):
# this model will now just be loaded to CPU, since diffusers will manage moving to gpu # this model will now just be loaded to CPU, since diffusers will manage moving to gpu
model.load_device = offload_device model.load_device = offload_device
# we'll have to unload everything to use pinning better, this includes trimming
unload_all_models()
# loads the model, prepares everything # loads the model, prepares everything
inner_model, conds, models = executor(model, *args, **kwargs) inner_model, conds, models = executor(model, *args, **kwargs)
@ -83,13 +90,23 @@ def prepare_group_offloading_factory(load_device: torch.device, offload_device:
raise ValueError("manual casting operations, where the model is loaded in different weights than inference will occur, is not supported") raise ValueError("manual casting operations, where the model is loaded in different weights than inference will occur, is not supported")
# weights are patched, ready to go, inner model will be correctly deleted at the end of sampling # weights are patched, ready to go, inner model will be correctly deleted at the end of sampling
model_size = model.model_size()
model_too_large = model_size * 2 > get_free_memory(torch.cpu)
low_vram_state = vram_state in (VRAMState.LOW_VRAM,)
is_cuda_device = load_device.type == 'cuda'
if model_too_large or low_vram_state:
logger.error(f"group offloading did not use memory pinning because model_too_large={model_too_large} low_vram_state={low_vram_state}")
if not is_cuda_device:
logger.error(f"group offloading did not use stream because load_device.type={load_device.type} != \"cuda\"")
apply_group_offloading( apply_group_offloading(
inner_model.diffusion_model, inner_model.diffusion_model,
load_device, load_device,
offload_device, offload_device,
use_stream=True, use_stream=is_cuda_device,
record_stream=True, record_stream=is_cuda_device,
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,), low_cpu_mem_usage=low_vram_state or model_too_large,
num_blocks_per_group=1 num_blocks_per_group=1
) )
# then the inputs will be ready on the correct device due to the wrapper factory # then the inputs will be ready on the correct device due to the wrapper factory
@ -139,7 +156,7 @@ class GroupOffload(CustomNode):
num_blocks_per_group=1 num_blocks_per_group=1
) )
elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable): elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable):
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device)) model.add_wrapper_with_key(WrappersMP.PREPARE_SAMPLING, "group_offload", prepare_group_offloading_factory(model.load_device, model.offload_device))
return model, return model,