mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 16:50:17 +08:00
Fix disk weight device handling and cache accounting
This commit is contained in:
parent
82e70aa3c2
commit
91809e83ff
@ -939,7 +939,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
||||
if getattr(module, "comfy_patched_weights", False):
|
||||
target_device = input_device
|
||||
elif getattr(module, "comfy_cast_weights", False):
|
||||
target_device = torch.device("cpu")
|
||||
target_device = input_device
|
||||
else:
|
||||
target_device = input_device
|
||||
ensure_module_materialized(
|
||||
@ -1082,6 +1082,13 @@ def move_module_tensors(
|
||||
if tensor is None or tensor.device.type == "meta":
|
||||
return tensor
|
||||
target_dtype = dtype_override or tensor.dtype
|
||||
if (
|
||||
tensor.device.type == "cpu"
|
||||
and tensor.data_ptr() in model_management.PINNED_MEMORY
|
||||
and (device_to.type != "cpu" or target_dtype != tensor.dtype)
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(tensor)
|
||||
model_management.unpin_memory(tensor)
|
||||
if tensor.device == device_to and tensor.dtype == target_dtype:
|
||||
return tensor
|
||||
return model_management.cast_to(
|
||||
@ -1093,6 +1100,20 @@ def move_module_tensors(
|
||||
)
|
||||
|
||||
module._apply(apply_fn)
|
||||
if disk_weights_enabled():
|
||||
for submodule in module.modules():
|
||||
refs = REGISTRY.get(submodule)
|
||||
if not refs:
|
||||
continue
|
||||
for name, disk_ref in refs.items():
|
||||
if disk_ref.is_buffer:
|
||||
tensor = submodule._buffers.get(name)
|
||||
else:
|
||||
tensor = submodule._parameters.get(name)
|
||||
if tensor is None or tensor.device.type == "meta":
|
||||
CACHE.remove_entry(submodule, name)
|
||||
continue
|
||||
CACHE.record(submodule, name, tensor, is_buffer=disk_ref.is_buffer)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(device_to, offload_stream)
|
||||
return module
|
||||
@ -1140,12 +1161,11 @@ def module_to(
|
||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||
dtype_override = dtype or arg_dtype
|
||||
if target_device.type == "meta":
|
||||
cpu_device = torch.device("cpu")
|
||||
for submodule in module.modules():
|
||||
offload_module_weights(submodule)
|
||||
move_module_tensors(
|
||||
submodule,
|
||||
cpu_device,
|
||||
target_device,
|
||||
dtype_override=dtype_override,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
|
||||
@ -586,6 +586,7 @@ def minimum_inference_memory():
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
gpu_deficit = 0
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
free_before = get_free_memory(device)
|
||||
if is_device_cpu(device):
|
||||
@ -598,9 +599,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
headroom,
|
||||
device,
|
||||
)
|
||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
if freed_cache < memory_required:
|
||||
evict_ram_to_disk(memory_required - freed_cache, keep_loaded=keep_loaded)
|
||||
comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
elif is_device_cuda(device) or is_device_xpu(device):
|
||||
if free_before < memory_required:
|
||||
logging.debug(
|
||||
@ -610,6 +609,10 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
device,
|
||||
)
|
||||
comfy.disk_weights.evict_for_budget(device, memory_required)
|
||||
free_after_vram = get_free_memory(device)
|
||||
if free_after_vram < memory_required:
|
||||
gpu_deficit = memory_required - free_after_vram
|
||||
comfy.disk_weights.evict_ram_cache(gpu_deficit)
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -633,6 +636,13 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
offload_device = None
|
||||
if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device):
|
||||
offload_device = torch.device("meta")
|
||||
elif comfy.disk_weights.disk_weights_enabled() and gpu_deficit > 0 and (is_device_cuda(device) or is_device_xpu(device)):
|
||||
cpu = torch.device("cpu")
|
||||
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||
required_cpu = current_loaded_models[i].model_loaded_memory() + headroom
|
||||
free_cpu = get_free_memory(cpu)
|
||||
if free_cpu < required_cpu:
|
||||
offload_device = torch.device("meta")
|
||||
if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device):
|
||||
unloaded_model.append(i)
|
||||
|
||||
@ -658,42 +668,6 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
)
|
||||
return unloaded_models
|
||||
|
||||
|
||||
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
||||
if memory_to_free <= 0:
|
||||
return 0
|
||||
if not comfy.disk_weights.disk_weights_enabled():
|
||||
return 0
|
||||
|
||||
free_before = get_free_memory(torch.device("cpu"))
|
||||
freed = 0
|
||||
can_unload = []
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
loaded_memory = shift_model.model_loaded_memory()
|
||||
if loaded_memory > 0:
|
||||
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_needed = memory_to_free - freed
|
||||
if memory_needed <= 0:
|
||||
break
|
||||
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
|
||||
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
|
||||
|
||||
if freed > 0:
|
||||
free_after = get_free_memory(torch.device("cpu"))
|
||||
freed_total = max(0, free_after - free_before)
|
||||
logging.debug(
|
||||
"RAM evicted to disk: required=%d free=%d freed=%d",
|
||||
memory_to_free,
|
||||
free_before,
|
||||
freed_total if freed_total > 0 else freed,
|
||||
)
|
||||
return freed
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
Loading…
Reference in New Issue
Block a user