diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index e1507c266..f3b3ec31f 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -179,7 +179,24 @@ def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, ena PIN_IF_CPU = pin_if_cpu DISK_WEIGHTS_ENABLED = enabled RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes)) - CACHE.set_limit(0 if enabled else 0) + if enabled: + from . import model_management + cpu_capacity_bytes = max(0, model_management.get_total_memory(torch.device("cpu")) - RAM_HEADROOM_BYTES) + CACHE.set_limit(cpu_capacity_bytes) + LOGGER.debug( + "Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d", + enabled, + RAM_HEADROOM_BYTES, + cpu_capacity_bytes, + ) + else: + CACHE.set_limit(0) + LOGGER.debug( + "Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d", + enabled, + RAM_HEADROOM_BYTES, + 0, + ) if enabled: install_monkeypatches() else: @@ -449,12 +466,14 @@ def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_byte ) safetensors_stream._reap_pinned_inflight() from . import model_management + model_management._reap_pinned_inflight() model_management.free_memory(required_bytes + headroom_bytes, device) free_after = _device_free_memory(device) freed = max(0, free_after - free_before) LOGGER.debug( - "Disk weight memory freed: freed=%d free=%d device=%s", + "Disk weight memory freed: freed=%d free_before=%d free_after=%d device=%s", freed, + free_before, free_after, device, ) @@ -840,6 +859,8 @@ def evict_ram_cache(bytes_to_free: int): if bytes_to_free <= 0: return 0 safetensors_stream._reap_pinned_inflight() + from . import model_management + model_management._reap_pinned_inflight() return CACHE.evict_bytes(bytes_to_free) @@ -926,22 +947,41 @@ def module_to( if target_device is None: target_device = _find_existing_device(module) or torch.device("cpu") if target_device.type == "meta": - offload_module_weights(module) + for submodule in module.modules(): + offload_module_weights(submodule) return module - if allow_materialize: - materialize_module_tree(module, target_device) - base_kwargs = dict(kwargs) - if device is not None and arg_device is None: - base_kwargs["device"] = device - if dtype is not None and arg_dtype is None: - base_kwargs["dtype"] = dtype - if non_blocking: - base_kwargs["non_blocking"] = non_blocking - if memory_format is not None: - base_kwargs["memory_format"] = memory_format - return BASE_MODULE_TO(module, *args, **base_kwargs) dtype_override = dtype or arg_dtype - return move_module_tensors(module, target_device, dtype_override=dtype_override) + to_kwargs = {} + if non_blocking: + to_kwargs["non_blocking"] = non_blocking + if memory_format is not None: + to_kwargs["memory_format"] = memory_format + for submodule in module.modules(): + ensure_module_materialized(submodule, target_device, dtype_override=dtype_override) + refs = REGISTRY.get(submodule) or {} + for name, param in submodule.named_parameters(recurse=False): + if name in refs: + continue + if param is None or param.device.type == "meta": + continue + if param.device != target_device or (dtype_override is not None and param.dtype != dtype_override): + if dtype_override is not None: + tensor = param.to(device=target_device, dtype=dtype_override, **to_kwargs) + else: + tensor = param.to(device=target_device, **to_kwargs) + submodule._parameters[name] = torch.nn.Parameter(tensor, requires_grad=param.requires_grad) + for name, buf in submodule.named_buffers(recurse=False): + if name in refs: + continue + if buf is None or buf.device.type == "meta": + continue + if buf.device != target_device or (dtype_override is not None and buf.dtype != dtype_override): + if dtype_override is not None: + tensor = buf.to(device=target_device, dtype=dtype_override, **to_kwargs) + else: + tensor = buf.to(device=target_device, **to_kwargs) + submodule._buffers[name] = tensor + return module base_kwargs = dict(kwargs) if device is not None and arg_device is None: base_kwargs["device"] = device diff --git a/comfy/model_management.py b/comfy/model_management.py index e9c42294d..a2f5a62ef 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -18,6 +18,7 @@ import psutil import logging +import collections from enum import Enum from comfy.cli_args import args, PerformanceFeature import torch @@ -524,18 +525,14 @@ class LoadedModel: return True return False - def model_unload(self, memory_to_free=None, unpatch_weights=True): + def model_unload(self, memory_to_free=None, unpatch_weights=True, offload_device=None): + target_offload_device = self.model.offload_device if offload_device is None else offload_device if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + freed = self.model.partially_unload(target_offload_device, memory_to_free) if freed >= memory_to_free: return False - offload_device = None - if comfy.disk_weights.disk_weights_enabled(): - offload_device = torch.device("meta") - self.model.detach(unpatch_weights, offload_device=offload_device) - if offload_device is not None and offload_device.type == "meta": - logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk") + self.model.detach(unpatch_weights, offload_device=target_offload_device) self.model_finalizer.detach() self.model_finalizer = None self.real_model = None @@ -589,27 +586,30 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() - if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled(): + if comfy.disk_weights.disk_weights_enabled(): free_before = get_free_memory(device) - headroom = comfy.disk_weights.ram_headroom_bytes() - if free_before < memory_required: - logging.debug( - "RAM pressure: required=%d free=%d headroom=%d", - memory_required, - free_before, - headroom, - ) - freed_cache = comfy.disk_weights.evict_ram_cache(memory_required) - freed_disk = 0 - if freed_cache < memory_required: - freed_disk = evict_ram_to_disk(memory_required - freed_cache) - free_after = get_free_memory(device) - freed_total = max(0, free_after - free_before) - logging.debug( - "RAM freed: freed=%d free=%d", - freed_total if freed_total > 0 else freed_cache + freed_disk, - free_after, - ) + if is_device_cpu(device): + headroom = comfy.disk_weights.ram_headroom_bytes() + if free_before < memory_required: + logging.debug( + "Disk weights RAM pressure: required=%d free=%d headroom=%d device=%s", + memory_required, + free_before, + headroom, + device, + ) + freed_cache = comfy.disk_weights.evict_ram_cache(memory_required) + if freed_cache < memory_required: + evict_ram_to_disk(memory_required - freed_cache, keep_loaded=keep_loaded) + elif is_device_cuda(device) or is_device_xpu(device): + if free_before < memory_required: + logging.debug( + "Disk weights VRAM pressure: required=%d free=%d device=%s", + memory_required, + free_before, + device, + ) + _reap_pinned_inflight() unloaded_model = [] can_unload = [] unloaded_models = [] @@ -630,7 +630,10 @@ def free_memory(memory_required, device, keep_loaded=[]): break memory_to_free = memory_required - free_mem logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + offload_device = None + if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device): + offload_device = torch.device("meta") + if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device): unloaded_model.append(i) for i in sorted(unloaded_model, reverse=True): @@ -643,6 +646,16 @@ def free_memory(memory_required, device, keep_loaded=[]): mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) if mem_free_torch > mem_free_total * 0.25: soft_empty_cache() + if comfy.disk_weights.disk_weights_enabled(): + free_after = get_free_memory(device) + freed_total = max(0, free_after - free_before) + logging.debug( + "Disk weights free_memory: device=%s free_before=%d free_after=%d freed=%d", + device, + free_before, + free_after, + freed_total, + ) return unloaded_models @@ -826,8 +839,6 @@ def dtype_size(dtype): return dtype_size def unet_offload_device(): - if comfy.disk_weights.disk_weights_enabled(): - return torch.device("meta") if vram_state == VRAMState.HIGH_VRAM: return get_torch_device() else: @@ -932,8 +943,6 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo return torch.float32 def text_encoder_offload_device(): - if comfy.disk_weights.disk_weights_enabled(): - return torch.device("meta") if args.gpu_only: return get_torch_device() else: @@ -994,8 +1003,6 @@ def vae_device(): return get_torch_device() def vae_offload_device(): - if comfy.disk_weights.disk_weights_enabled(): - return torch.device("meta") if args.gpu_only: return get_torch_device() else: @@ -1175,6 +1182,15 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str else: r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) + if ( + non_blocking + and is_device_cuda(device) + and is_device_cpu(weight.device) + and weight.is_pinned() + ): + record_stream = stream if stream is not None else current_stream(device) + if record_stream is not None: + _track_pinned_inflight(record_stream, weight) return r def cast_to_device(tensor, device, dtype, copy=False): @@ -1185,6 +1201,8 @@ def cast_to_device(tensor, device, dtype, copy=False): PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 +PINNED_INFLIGHT = collections.deque() +DEFERRED_UNPIN = collections.deque() if not args.disable_pinned_memory: if is_nvidia() or is_amd(): if WINDOWS: @@ -1252,7 +1270,19 @@ def pin_memory(tensor): return False -def unpin_memory(tensor): +def _track_pinned_inflight(stream, tensor): + event = torch.cuda.Event() + event.record(stream) + PINNED_INFLIGHT.append((event, tensor)) + +def _tensor_inflight(tensor): + ptr = tensor.data_ptr() + for _, inflight_tensor in PINNED_INFLIGHT: + if inflight_tensor.data_ptr() == ptr: + return True + return False + +def _unpin_memory_now(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: return False @@ -1283,6 +1313,46 @@ def unpin_memory(tensor): return False +def _retry_deferred_unpins(): + if not DEFERRED_UNPIN: + return + remaining = collections.deque() + while DEFERRED_UNPIN: + tensor = DEFERRED_UNPIN.popleft() + if _tensor_inflight(tensor): + remaining.append(tensor) + continue + if not _unpin_memory_now(tensor): + remaining.append(tensor) + DEFERRED_UNPIN.extend(remaining) + +def _reap_pinned_inflight(): + if not PINNED_INFLIGHT: + _retry_deferred_unpins() + return + remaining = collections.deque() + while PINNED_INFLIGHT: + event, tensor = PINNED_INFLIGHT.popleft() + if event.query(): + continue + remaining.append((event, tensor)) + PINNED_INFLIGHT.extend(remaining) + _retry_deferred_unpins() + +def unpin_memory(tensor): + global TOTAL_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: + return False + + if not is_device_cpu(tensor.device): + return False + + _reap_pinned_inflight() + if _tensor_inflight(tensor): + DEFERRED_UNPIN.append(tensor) + return False + return _unpin_memory_now(tensor) + def sage_attention_enabled(): return args.use_sage_attention diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ffc22389d..497060b92 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -621,6 +621,16 @@ class ModelPatcher: weight, set_func, convert_func = get_key_weight(self.model, key) inplace_update = self.weight_inplace_update or inplace_update + if comfy.disk_weights.disk_weights_enabled() and weight is not None and weight.device.type == "meta": + parts = key.split(".") + param_name = parts[-1] + module = self.model + for part in parts[:-1]: + module = getattr(module, part) + target_device = device_to or self.offload_device or torch.device("cpu") + comfy.disk_weights.load_module_tensor(module, param_name, device=target_device) + weight, set_func, convert_func = get_key_weight(self.model, key) + if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) @@ -650,8 +660,8 @@ class ModelPatcher: def unpin_weight(self, key): if key in self.pinned: weight, set_func, convert_func = get_key_weight(self.model, key) - comfy.model_management.unpin_memory(weight) - self.pinned.remove(key) + if comfy.model_management.unpin_memory(weight): + self.pinned.remove(key) def unpin_all_weights(self): for key in list(self.pinned): @@ -885,9 +895,16 @@ class ModelPatcher: NS = comfy.model_management.NUM_STREAMS offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS remaining_ram = None + cpu_device = torch.device("cpu") if device_to is not None and comfy.model_management.is_device_cpu(device_to): remaining_ram = comfy.model_management.get_free_memory(device_to) + def offload_module_tree(module): + freed = 0 + for submodule in module.modules(): + freed += comfy.disk_weights.offload_module_weights(submodule) + return freed + for unload in unload_list: if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break @@ -922,15 +939,28 @@ class ModelPatcher: cast_weight = self.force_cast_weights freed_bytes = module_mem if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled(): - freed_bytes = comfy.disk_weights.offload_module_weights(m) + freed_bytes = offload_module_tree(m) if freed_bytes == 0: freed_bytes = module_mem else: - if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled(): - logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024)) - freed_bytes = comfy.disk_weights.offload_module_weights(m) - if freed_bytes == 0: - freed_bytes = module_mem + if remaining_ram is not None and comfy.disk_weights.disk_weights_enabled(): + required_bytes = module_mem + headroom = comfy.disk_weights.ram_headroom_bytes() + comfy.model_management.free_memory(required_bytes + headroom, cpu_device, keep_loaded=[self]) + remaining_ram = comfy.model_management.get_free_memory(cpu_device) + if remaining_ram < required_bytes: + logging.info( + "Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", + n, + required_bytes / (1024 * 1024), + remaining_ram / (1024 * 1024), + ) + freed_bytes = offload_module_tree(m) + if freed_bytes == 0: + freed_bytes = module_mem + else: + comfy.disk_weights.move_module_tensors(m, device_to) + remaining_ram = max(0, remaining_ram - required_bytes) else: if comfy.disk_weights.disk_weights_enabled(): comfy.disk_weights.move_module_tensors(m, device_to) @@ -980,16 +1010,22 @@ class ModelPatcher: def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): with self.use_ejected(skip_and_inject_on_exit_only=True): + offload_device = self.offload_device + if comfy.disk_weights.disk_weights_enabled() and device_to is not None: + if comfy.model_management.is_device_cpu(device_to): + offload_device = torch.device("meta") + else: + offload_device = torch.device("cpu") unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) # TODO: force_patch_weights should not unload + reload full model used = self.model.model_loaded_weight_memory - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + self.unpatch_model(offload_device, unpatch_weights=unpatch_weights) if unpatch_weights: extra_memory += (used - self.model.model_loaded_weight_memory) self.patch_model(load_weights=False) if extra_memory < 0 and not unpatch_weights: - self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + self.partially_unload(offload_device, -extra_memory, force_patch_weights=force_patch_weights) return 0 full_load = False if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: diff --git a/comfy/ops.py b/comfy/ops.py index 303f5ccd0..c260b6645 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -535,10 +535,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec key = f"{prefix}{param_name}" value = state_dict.pop(key, None) if value is not None: - if value.device.type != "meta": - value = value.to(device=device) - if dtype is not None: - value = value.view(dtype=dtype) + value = value.to(device=device) + if dtype is not None: + value = value.view(dtype=dtype) manually_loaded_keys.append(key) return value @@ -555,16 +554,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec manually_loaded_keys = [weight_key] layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) - if layer_conf is not None and layer_conf.device.type != "meta": + if layer_conf is not None: layer_conf = json.loads(layer_conf.numpy().tobytes()) - elif layer_conf is not None: - layer_conf = None if layer_conf is None: - if weight.device.type == "meta": - self.weight = torch.nn.Parameter(weight, requires_grad=False) - else: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) @@ -610,13 +604,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec else: raise ValueError(f"Unsupported quantization format: {self.quant_format}") - if weight.device.type == "meta": - self.weight = torch.nn.Parameter(weight, requires_grad=False) - else: - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), - requires_grad=False - ) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), + requires_grad=False + ) for param_name in qconfig["parameters"]: if param_name in {"weight_scale", "weight_scale_2"}: @@ -626,10 +617,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec _v = state_dict.pop(param_key, None) if _v is None: continue - if _v.device.type == "meta": - self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False)) - else: - self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)