Better memory trimming and group_offloading logic

This commit is contained in:
doctorpangloss 2025-10-21 14:27:26 -07:00
parent 7ed9292532
commit d9269785d3
2 changed files with 60 additions and 6 deletions

View File

@ -398,7 +398,7 @@ try:
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
SUPPORT_FP8_OPS = True
@ -641,6 +641,42 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def trim_memory() -> bool:
"""
Trims memory usage, returning reserved memory to the system
Only supported on Windows and Linux
:return:
"""
try:
if sys.platform.startswith('linux'):
import ctypes
libc_path = ctypes.util.find_library('c')
if not libc_path:
return False
libc = ctypes.CDLL(libc_path)
if hasattr(libc, 'malloc_trim'):
return libc.malloc_trim(0) == 1
else:
return False
elif sys.platform == 'win32':
import ctypes.wintypes
kernel32 = ctypes.WinDLL("kernel32")
EmptyProcessWorkingSet = kernel32.EmptyProcessWorkingSet
EmptyProcessWorkingSet.argtypes = [ctypes.wintypes.HANDLE]
EmptyProcessWorkingSet.restype = ctypes.wintypes.BOOL
handle = -1
success = EmptyProcessWorkingSet(handle)
return bool(success)
else:
return False
except Exception as exc_info:
logger.warning("failed to trim", exc_info=exc_info)
return False
@tracer.start_as_current_span("Free Memory")
def free_memory(memory_required, device, keep_loaded=[]) -> List[LoadedModel]:
span = get_current_span()
@ -1593,6 +1629,7 @@ def _soft_empty_cache(force=False):
def unload_all_models():
with model_management_lock:
free_memory(1e30, get_torch_device())
trim_memory()
@_deprecate_method(version="*", message="The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")

View File

@ -1,9 +1,10 @@
import torch
import logging
from diffusers import HookRegistry
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_management import vram_state, VRAMState
from comfy.model_management import vram_state, VRAMState, unload_all_models, get_free_memory, get_torch_device
from comfy.model_management_types import HooksSupport, ModelManageable
from comfy.model_patcher import ModelPatcher
from comfy.node_helpers import export_custom_nodes
@ -14,6 +15,8 @@ from comfy.rmsnorm import RMSNorm
_DISABLE_COMFYUI_CASTING_HOOK = "disable_comfyui_casting_hook"
logger = logging.getLogger(__name__)
class DisableComfyWeightCast(ModelHook):
r"""
@ -75,6 +78,10 @@ def prepare_group_offloading_factory(load_device: torch.device, offload_device:
def wrapper(executor, model: ModelPatcher, *args, **kwargs):
# this model will now just be loaded to CPU, since diffusers will manage moving to gpu
model.load_device = offload_device
# we'll have to unload everything to use pinning better, this includes trimming
unload_all_models()
# loads the model, prepares everything
inner_model, conds, models = executor(model, *args, **kwargs)
@ -83,13 +90,23 @@ def prepare_group_offloading_factory(load_device: torch.device, offload_device:
raise ValueError("manual casting operations, where the model is loaded in different weights than inference will occur, is not supported")
# weights are patched, ready to go, inner model will be correctly deleted at the end of sampling
model_size = model.model_size()
model_too_large = model_size * 2 > get_free_memory(torch.cpu)
low_vram_state = vram_state in (VRAMState.LOW_VRAM,)
is_cuda_device = load_device.type == 'cuda'
if model_too_large or low_vram_state:
logger.error(f"group offloading did not use memory pinning because model_too_large={model_too_large} low_vram_state={low_vram_state}")
if not is_cuda_device:
logger.error(f"group offloading did not use stream because load_device.type={load_device.type} != \"cuda\"")
apply_group_offloading(
inner_model.diffusion_model,
load_device,
offload_device,
use_stream=True,
record_stream=True,
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,),
use_stream=is_cuda_device,
record_stream=is_cuda_device,
low_cpu_mem_usage=low_vram_state or model_too_large,
num_blocks_per_group=1
)
# then the inputs will be ready on the correct device due to the wrapper factory
@ -139,7 +156,7 @@ class GroupOffload(CustomNode):
num_blocks_per_group=1
)
elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable):
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device))
model.add_wrapper_with_key(WrappersMP.PREPARE_SAMPLING, "group_offload", prepare_group_offloading_factory(model.load_device, model.offload_device))
return model,