Fix disk weight device handling and cache accounting

This commit is contained in:
ifilipis 2026-01-20 15:04:33 +02:00
parent 82e70aa3c2
commit 91809e83ff
2 changed files with 36 additions and 42 deletions

View File

@ -939,7 +939,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
if getattr(module, "comfy_patched_weights", False):
target_device = input_device
elif getattr(module, "comfy_cast_weights", False):
target_device = torch.device("cpu")
target_device = input_device
else:
target_device = input_device
ensure_module_materialized(
@ -1082,6 +1082,13 @@ def move_module_tensors(
if tensor is None or tensor.device.type == "meta":
return tensor
target_dtype = dtype_override or tensor.dtype
if (
tensor.device.type == "cpu"
and tensor.data_ptr() in model_management.PINNED_MEMORY
and (device_to.type != "cpu" or target_dtype != tensor.dtype)
):
model_management.wait_for_pinned_tensor(tensor)
model_management.unpin_memory(tensor)
if tensor.device == device_to and tensor.dtype == target_dtype:
return tensor
return model_management.cast_to(
@ -1093,6 +1100,20 @@ def move_module_tensors(
)
module._apply(apply_fn)
if disk_weights_enabled():
for submodule in module.modules():
refs = REGISTRY.get(submodule)
if not refs:
continue
for name, disk_ref in refs.items():
if disk_ref.is_buffer:
tensor = submodule._buffers.get(name)
else:
tensor = submodule._parameters.get(name)
if tensor is None or tensor.device.type == "meta":
CACHE.remove_entry(submodule, name)
continue
CACHE.record(submodule, name, tensor, is_buffer=disk_ref.is_buffer)
if non_blocking and offload_stream is not None:
model_management.sync_stream(device_to, offload_stream)
return module
@ -1140,12 +1161,11 @@ def module_to(
target_device = _find_existing_device(module) or torch.device("cpu")
dtype_override = dtype or arg_dtype
if target_device.type == "meta":
cpu_device = torch.device("cpu")
for submodule in module.modules():
offload_module_weights(submodule)
move_module_tensors(
submodule,
cpu_device,
target_device,
dtype_override=dtype_override,
non_blocking=non_blocking,
)

View File

@ -586,6 +586,7 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
gpu_deficit = 0
if comfy.disk_weights.disk_weights_enabled():
free_before = get_free_memory(device)
if is_device_cpu(device):
@ -598,9 +599,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
headroom,
device,
)
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
if freed_cache < memory_required:
evict_ram_to_disk(memory_required - freed_cache, keep_loaded=keep_loaded)
comfy.disk_weights.evict_ram_cache(memory_required)
elif is_device_cuda(device) or is_device_xpu(device):
if free_before < memory_required:
logging.debug(
@ -610,6 +609,10 @@ def free_memory(memory_required, device, keep_loaded=[]):
device,
)
comfy.disk_weights.evict_for_budget(device, memory_required)
free_after_vram = get_free_memory(device)
if free_after_vram < memory_required:
gpu_deficit = memory_required - free_after_vram
comfy.disk_weights.evict_ram_cache(gpu_deficit)
unloaded_model = []
can_unload = []
unloaded_models = []
@ -633,6 +636,13 @@ def free_memory(memory_required, device, keep_loaded=[]):
offload_device = None
if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device):
offload_device = torch.device("meta")
elif comfy.disk_weights.disk_weights_enabled() and gpu_deficit > 0 and (is_device_cuda(device) or is_device_xpu(device)):
cpu = torch.device("cpu")
headroom = comfy.disk_weights.ram_headroom_bytes()
required_cpu = current_loaded_models[i].model_loaded_memory() + headroom
free_cpu = get_free_memory(cpu)
if free_cpu < required_cpu:
offload_device = torch.device("meta")
if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device):
unloaded_model.append(i)
@ -658,42 +668,6 @@ def free_memory(memory_required, device, keep_loaded=[]):
)
return unloaded_models
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
if memory_to_free <= 0:
return 0
if not comfy.disk_weights.disk_weights_enabled():
return 0
free_before = get_free_memory(torch.device("cpu"))
freed = 0
can_unload = []
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model not in keep_loaded and not shift_model.is_dead():
loaded_memory = shift_model.model_loaded_memory()
if loaded_memory > 0:
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
for x in sorted(can_unload):
i = x[-1]
memory_needed = memory_to_free - freed
if memory_needed <= 0:
break
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
if freed > 0:
free_after = get_free_memory(torch.device("cpu"))
freed_total = max(0, free_after - free_before)
logging.debug(
"RAM evicted to disk: required=%d free=%d freed=%d",
memory_to_free,
free_before,
freed_total if freed_total > 0 else freed,
)
return freed
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state