diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index 5f75f6b09..950392a03 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -939,7 +939,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}): if getattr(module, "comfy_patched_weights", False): target_device = input_device elif getattr(module, "comfy_cast_weights", False): - target_device = torch.device("cpu") + target_device = input_device else: target_device = input_device ensure_module_materialized( @@ -1082,6 +1082,13 @@ def move_module_tensors( if tensor is None or tensor.device.type == "meta": return tensor target_dtype = dtype_override or tensor.dtype + if ( + tensor.device.type == "cpu" + and tensor.data_ptr() in model_management.PINNED_MEMORY + and (device_to.type != "cpu" or target_dtype != tensor.dtype) + ): + model_management.wait_for_pinned_tensor(tensor) + model_management.unpin_memory(tensor) if tensor.device == device_to and tensor.dtype == target_dtype: return tensor return model_management.cast_to( @@ -1093,6 +1100,20 @@ def move_module_tensors( ) module._apply(apply_fn) + if disk_weights_enabled(): + for submodule in module.modules(): + refs = REGISTRY.get(submodule) + if not refs: + continue + for name, disk_ref in refs.items(): + if disk_ref.is_buffer: + tensor = submodule._buffers.get(name) + else: + tensor = submodule._parameters.get(name) + if tensor is None or tensor.device.type == "meta": + CACHE.remove_entry(submodule, name) + continue + CACHE.record(submodule, name, tensor, is_buffer=disk_ref.is_buffer) if non_blocking and offload_stream is not None: model_management.sync_stream(device_to, offload_stream) return module @@ -1140,12 +1161,11 @@ def module_to( target_device = _find_existing_device(module) or torch.device("cpu") dtype_override = dtype or arg_dtype if target_device.type == "meta": - cpu_device = torch.device("cpu") for submodule in module.modules(): offload_module_weights(submodule) move_module_tensors( submodule, - cpu_device, + target_device, dtype_override=dtype_override, non_blocking=non_blocking, ) diff --git a/comfy/model_management.py b/comfy/model_management.py index 27c723f1b..77772f3f6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -586,6 +586,7 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() + gpu_deficit = 0 if comfy.disk_weights.disk_weights_enabled(): free_before = get_free_memory(device) if is_device_cpu(device): @@ -598,9 +599,7 @@ def free_memory(memory_required, device, keep_loaded=[]): headroom, device, ) - freed_cache = comfy.disk_weights.evict_ram_cache(memory_required) - if freed_cache < memory_required: - evict_ram_to_disk(memory_required - freed_cache, keep_loaded=keep_loaded) + comfy.disk_weights.evict_ram_cache(memory_required) elif is_device_cuda(device) or is_device_xpu(device): if free_before < memory_required: logging.debug( @@ -610,6 +609,10 @@ def free_memory(memory_required, device, keep_loaded=[]): device, ) comfy.disk_weights.evict_for_budget(device, memory_required) + free_after_vram = get_free_memory(device) + if free_after_vram < memory_required: + gpu_deficit = memory_required - free_after_vram + comfy.disk_weights.evict_ram_cache(gpu_deficit) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -633,6 +636,13 @@ def free_memory(memory_required, device, keep_loaded=[]): offload_device = None if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device): offload_device = torch.device("meta") + elif comfy.disk_weights.disk_weights_enabled() and gpu_deficit > 0 and (is_device_cuda(device) or is_device_xpu(device)): + cpu = torch.device("cpu") + headroom = comfy.disk_weights.ram_headroom_bytes() + required_cpu = current_loaded_models[i].model_loaded_memory() + headroom + free_cpu = get_free_memory(cpu) + if free_cpu < required_cpu: + offload_device = torch.device("meta") if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device): unloaded_model.append(i) @@ -658,42 +668,6 @@ def free_memory(memory_required, device, keep_loaded=[]): ) return unloaded_models - -def evict_ram_to_disk(memory_to_free, keep_loaded=[]): - if memory_to_free <= 0: - return 0 - if not comfy.disk_weights.disk_weights_enabled(): - return 0 - - free_before = get_free_memory(torch.device("cpu")) - freed = 0 - can_unload = [] - for i in range(len(current_loaded_models) - 1, -1, -1): - shift_model = current_loaded_models[i] - if shift_model not in keep_loaded and not shift_model.is_dead(): - loaded_memory = shift_model.model_loaded_memory() - if loaded_memory > 0: - can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) - - for x in sorted(can_unload): - i = x[-1] - memory_needed = memory_to_free - freed - if memory_needed <= 0: - break - logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk") - freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed) - - if freed > 0: - free_after = get_free_memory(torch.device("cpu")) - freed_total = max(0, free_after - free_before) - logging.debug( - "RAM evicted to disk: required=%d free=%d freed=%d", - memory_to_free, - free_before, - freed_total if freed_total > 0 else freed, - ) - return freed - def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): cleanup_models_gc() global vram_state