mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Fix disk weight device handling and cache accounting
This commit is contained in:
parent
82e70aa3c2
commit
91809e83ff
@ -939,7 +939,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
|||||||
if getattr(module, "comfy_patched_weights", False):
|
if getattr(module, "comfy_patched_weights", False):
|
||||||
target_device = input_device
|
target_device = input_device
|
||||||
elif getattr(module, "comfy_cast_weights", False):
|
elif getattr(module, "comfy_cast_weights", False):
|
||||||
target_device = torch.device("cpu")
|
target_device = input_device
|
||||||
else:
|
else:
|
||||||
target_device = input_device
|
target_device = input_device
|
||||||
ensure_module_materialized(
|
ensure_module_materialized(
|
||||||
@ -1082,6 +1082,13 @@ def move_module_tensors(
|
|||||||
if tensor is None or tensor.device.type == "meta":
|
if tensor is None or tensor.device.type == "meta":
|
||||||
return tensor
|
return tensor
|
||||||
target_dtype = dtype_override or tensor.dtype
|
target_dtype = dtype_override or tensor.dtype
|
||||||
|
if (
|
||||||
|
tensor.device.type == "cpu"
|
||||||
|
and tensor.data_ptr() in model_management.PINNED_MEMORY
|
||||||
|
and (device_to.type != "cpu" or target_dtype != tensor.dtype)
|
||||||
|
):
|
||||||
|
model_management.wait_for_pinned_tensor(tensor)
|
||||||
|
model_management.unpin_memory(tensor)
|
||||||
if tensor.device == device_to and tensor.dtype == target_dtype:
|
if tensor.device == device_to and tensor.dtype == target_dtype:
|
||||||
return tensor
|
return tensor
|
||||||
return model_management.cast_to(
|
return model_management.cast_to(
|
||||||
@ -1093,6 +1100,20 @@ def move_module_tensors(
|
|||||||
)
|
)
|
||||||
|
|
||||||
module._apply(apply_fn)
|
module._apply(apply_fn)
|
||||||
|
if disk_weights_enabled():
|
||||||
|
for submodule in module.modules():
|
||||||
|
refs = REGISTRY.get(submodule)
|
||||||
|
if not refs:
|
||||||
|
continue
|
||||||
|
for name, disk_ref in refs.items():
|
||||||
|
if disk_ref.is_buffer:
|
||||||
|
tensor = submodule._buffers.get(name)
|
||||||
|
else:
|
||||||
|
tensor = submodule._parameters.get(name)
|
||||||
|
if tensor is None or tensor.device.type == "meta":
|
||||||
|
CACHE.remove_entry(submodule, name)
|
||||||
|
continue
|
||||||
|
CACHE.record(submodule, name, tensor, is_buffer=disk_ref.is_buffer)
|
||||||
if non_blocking and offload_stream is not None:
|
if non_blocking and offload_stream is not None:
|
||||||
model_management.sync_stream(device_to, offload_stream)
|
model_management.sync_stream(device_to, offload_stream)
|
||||||
return module
|
return module
|
||||||
@ -1140,12 +1161,11 @@ def module_to(
|
|||||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||||
dtype_override = dtype or arg_dtype
|
dtype_override = dtype or arg_dtype
|
||||||
if target_device.type == "meta":
|
if target_device.type == "meta":
|
||||||
cpu_device = torch.device("cpu")
|
|
||||||
for submodule in module.modules():
|
for submodule in module.modules():
|
||||||
offload_module_weights(submodule)
|
offload_module_weights(submodule)
|
||||||
move_module_tensors(
|
move_module_tensors(
|
||||||
submodule,
|
submodule,
|
||||||
cpu_device,
|
target_device,
|
||||||
dtype_override=dtype_override,
|
dtype_override=dtype_override,
|
||||||
non_blocking=non_blocking,
|
non_blocking=non_blocking,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -586,6 +586,7 @@ def minimum_inference_memory():
|
|||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
|
gpu_deficit = 0
|
||||||
if comfy.disk_weights.disk_weights_enabled():
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
free_before = get_free_memory(device)
|
free_before = get_free_memory(device)
|
||||||
if is_device_cpu(device):
|
if is_device_cpu(device):
|
||||||
@ -598,9 +599,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
headroom,
|
headroom,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
comfy.disk_weights.evict_ram_cache(memory_required)
|
||||||
if freed_cache < memory_required:
|
|
||||||
evict_ram_to_disk(memory_required - freed_cache, keep_loaded=keep_loaded)
|
|
||||||
elif is_device_cuda(device) or is_device_xpu(device):
|
elif is_device_cuda(device) or is_device_xpu(device):
|
||||||
if free_before < memory_required:
|
if free_before < memory_required:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
@ -610,6 +609,10 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
comfy.disk_weights.evict_for_budget(device, memory_required)
|
comfy.disk_weights.evict_for_budget(device, memory_required)
|
||||||
|
free_after_vram = get_free_memory(device)
|
||||||
|
if free_after_vram < memory_required:
|
||||||
|
gpu_deficit = memory_required - free_after_vram
|
||||||
|
comfy.disk_weights.evict_ram_cache(gpu_deficit)
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
unloaded_models = []
|
unloaded_models = []
|
||||||
@ -633,6 +636,13 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
offload_device = None
|
offload_device = None
|
||||||
if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device):
|
if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device):
|
||||||
offload_device = torch.device("meta")
|
offload_device = torch.device("meta")
|
||||||
|
elif comfy.disk_weights.disk_weights_enabled() and gpu_deficit > 0 and (is_device_cuda(device) or is_device_xpu(device)):
|
||||||
|
cpu = torch.device("cpu")
|
||||||
|
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||||
|
required_cpu = current_loaded_models[i].model_loaded_memory() + headroom
|
||||||
|
free_cpu = get_free_memory(cpu)
|
||||||
|
if free_cpu < required_cpu:
|
||||||
|
offload_device = torch.device("meta")
|
||||||
if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device):
|
if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device):
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
|
|
||||||
@ -658,42 +668,6 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
)
|
)
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
|
|
||||||
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
|
||||||
if memory_to_free <= 0:
|
|
||||||
return 0
|
|
||||||
if not comfy.disk_weights.disk_weights_enabled():
|
|
||||||
return 0
|
|
||||||
|
|
||||||
free_before = get_free_memory(torch.device("cpu"))
|
|
||||||
freed = 0
|
|
||||||
can_unload = []
|
|
||||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
|
||||||
shift_model = current_loaded_models[i]
|
|
||||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
|
||||||
loaded_memory = shift_model.model_loaded_memory()
|
|
||||||
if loaded_memory > 0:
|
|
||||||
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
|
||||||
i = x[-1]
|
|
||||||
memory_needed = memory_to_free - freed
|
|
||||||
if memory_needed <= 0:
|
|
||||||
break
|
|
||||||
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
|
|
||||||
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
|
|
||||||
|
|
||||||
if freed > 0:
|
|
||||||
free_after = get_free_memory(torch.device("cpu"))
|
|
||||||
freed_total = max(0, free_after - free_before)
|
|
||||||
logging.debug(
|
|
||||||
"RAM evicted to disk: required=%d free=%d freed=%d",
|
|
||||||
memory_to_free,
|
|
||||||
free_before,
|
|
||||||
freed_total if freed_total > 0 else freed,
|
|
||||||
)
|
|
||||||
return freed
|
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user