Fix disk-weight tiering and meta handling

This commit is contained in:
ifilipis 2026-01-18 04:53:06 +02:00
parent 95ca11fe25
commit c3eaea0429
4 changed files with 218 additions and 84 deletions

View File

@ -179,7 +179,24 @@ def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, ena
PIN_IF_CPU = pin_if_cpu
DISK_WEIGHTS_ENABLED = enabled
RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes))
CACHE.set_limit(0 if enabled else 0)
if enabled:
from . import model_management
cpu_capacity_bytes = max(0, model_management.get_total_memory(torch.device("cpu")) - RAM_HEADROOM_BYTES)
CACHE.set_limit(cpu_capacity_bytes)
LOGGER.debug(
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
enabled,
RAM_HEADROOM_BYTES,
cpu_capacity_bytes,
)
else:
CACHE.set_limit(0)
LOGGER.debug(
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
enabled,
RAM_HEADROOM_BYTES,
0,
)
if enabled:
install_monkeypatches()
else:
@ -449,12 +466,14 @@ def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_byte
)
safetensors_stream._reap_pinned_inflight()
from . import model_management
model_management._reap_pinned_inflight()
model_management.free_memory(required_bytes + headroom_bytes, device)
free_after = _device_free_memory(device)
freed = max(0, free_after - free_before)
LOGGER.debug(
"Disk weight memory freed: freed=%d free=%d device=%s",
"Disk weight memory freed: freed=%d free_before=%d free_after=%d device=%s",
freed,
free_before,
free_after,
device,
)
@ -840,6 +859,8 @@ def evict_ram_cache(bytes_to_free: int):
if bytes_to_free <= 0:
return 0
safetensors_stream._reap_pinned_inflight()
from . import model_management
model_management._reap_pinned_inflight()
return CACHE.evict_bytes(bytes_to_free)
@ -926,22 +947,41 @@ def module_to(
if target_device is None:
target_device = _find_existing_device(module) or torch.device("cpu")
if target_device.type == "meta":
offload_module_weights(module)
for submodule in module.modules():
offload_module_weights(submodule)
return module
if allow_materialize:
materialize_module_tree(module, target_device)
base_kwargs = dict(kwargs)
if device is not None and arg_device is None:
base_kwargs["device"] = device
if dtype is not None and arg_dtype is None:
base_kwargs["dtype"] = dtype
if non_blocking:
base_kwargs["non_blocking"] = non_blocking
if memory_format is not None:
base_kwargs["memory_format"] = memory_format
return BASE_MODULE_TO(module, *args, **base_kwargs)
dtype_override = dtype or arg_dtype
return move_module_tensors(module, target_device, dtype_override=dtype_override)
to_kwargs = {}
if non_blocking:
to_kwargs["non_blocking"] = non_blocking
if memory_format is not None:
to_kwargs["memory_format"] = memory_format
for submodule in module.modules():
ensure_module_materialized(submodule, target_device, dtype_override=dtype_override)
refs = REGISTRY.get(submodule) or {}
for name, param in submodule.named_parameters(recurse=False):
if name in refs:
continue
if param is None or param.device.type == "meta":
continue
if param.device != target_device or (dtype_override is not None and param.dtype != dtype_override):
if dtype_override is not None:
tensor = param.to(device=target_device, dtype=dtype_override, **to_kwargs)
else:
tensor = param.to(device=target_device, **to_kwargs)
submodule._parameters[name] = torch.nn.Parameter(tensor, requires_grad=param.requires_grad)
for name, buf in submodule.named_buffers(recurse=False):
if name in refs:
continue
if buf is None or buf.device.type == "meta":
continue
if buf.device != target_device or (dtype_override is not None and buf.dtype != dtype_override):
if dtype_override is not None:
tensor = buf.to(device=target_device, dtype=dtype_override, **to_kwargs)
else:
tensor = buf.to(device=target_device, **to_kwargs)
submodule._buffers[name] = tensor
return module
base_kwargs = dict(kwargs)
if device is not None and arg_device is None:
base_kwargs["device"] = device

View File

@ -18,6 +18,7 @@
import psutil
import logging
import collections
from enum import Enum
from comfy.cli_args import args, PerformanceFeature
import torch
@ -524,18 +525,14 @@ class LoadedModel:
return True
return False
def model_unload(self, memory_to_free=None, unpatch_weights=True):
def model_unload(self, memory_to_free=None, unpatch_weights=True, offload_device=None):
target_offload_device = self.model.offload_device if offload_device is None else offload_device
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
freed = self.model.partially_unload(target_offload_device, memory_to_free)
if freed >= memory_to_free:
return False
offload_device = None
if comfy.disk_weights.disk_weights_enabled():
offload_device = torch.device("meta")
self.model.detach(unpatch_weights, offload_device=offload_device)
if offload_device is not None and offload_device.type == "meta":
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
self.model.detach(unpatch_weights, offload_device=target_offload_device)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
@ -589,27 +586,30 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
if comfy.disk_weights.disk_weights_enabled():
free_before = get_free_memory(device)
headroom = comfy.disk_weights.ram_headroom_bytes()
if free_before < memory_required:
logging.debug(
"RAM pressure: required=%d free=%d headroom=%d",
memory_required,
free_before,
headroom,
)
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
freed_disk = 0
if freed_cache < memory_required:
freed_disk = evict_ram_to_disk(memory_required - freed_cache)
free_after = get_free_memory(device)
freed_total = max(0, free_after - free_before)
logging.debug(
"RAM freed: freed=%d free=%d",
freed_total if freed_total > 0 else freed_cache + freed_disk,
free_after,
)
if is_device_cpu(device):
headroom = comfy.disk_weights.ram_headroom_bytes()
if free_before < memory_required:
logging.debug(
"Disk weights RAM pressure: required=%d free=%d headroom=%d device=%s",
memory_required,
free_before,
headroom,
device,
)
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
if freed_cache < memory_required:
evict_ram_to_disk(memory_required - freed_cache, keep_loaded=keep_loaded)
elif is_device_cuda(device) or is_device_xpu(device):
if free_before < memory_required:
logging.debug(
"Disk weights VRAM pressure: required=%d free=%d device=%s",
memory_required,
free_before,
device,
)
_reap_pinned_inflight()
unloaded_model = []
can_unload = []
unloaded_models = []
@ -630,7 +630,10 @@ def free_memory(memory_required, device, keep_loaded=[]):
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
offload_device = None
if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device):
offload_device = torch.device("meta")
if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device):
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
@ -643,6 +646,16 @@ def free_memory(memory_required, device, keep_loaded=[]):
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
if comfy.disk_weights.disk_weights_enabled():
free_after = get_free_memory(device)
freed_total = max(0, free_after - free_before)
logging.debug(
"Disk weights free_memory: device=%s free_before=%d free_after=%d freed=%d",
device,
free_before,
free_after,
freed_total,
)
return unloaded_models
@ -826,8 +839,6 @@ def dtype_size(dtype):
return dtype_size
def unet_offload_device():
if comfy.disk_weights.disk_weights_enabled():
return torch.device("meta")
if vram_state == VRAMState.HIGH_VRAM:
return get_torch_device()
else:
@ -932,8 +943,6 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
return torch.float32
def text_encoder_offload_device():
if comfy.disk_weights.disk_weights_enabled():
return torch.device("meta")
if args.gpu_only:
return get_torch_device()
else:
@ -994,8 +1003,6 @@ def vae_device():
return get_torch_device()
def vae_offload_device():
if comfy.disk_weights.disk_weights_enabled():
return torch.device("meta")
if args.gpu_only:
return get_torch_device()
else:
@ -1175,6 +1182,15 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
else:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
if (
non_blocking
and is_device_cuda(device)
and is_device_cpu(weight.device)
and weight.is_pinned()
):
record_stream = stream if stream is not None else current_stream(device)
if record_stream is not None:
_track_pinned_inflight(record_stream, weight)
return r
def cast_to_device(tensor, device, dtype, copy=False):
@ -1185,6 +1201,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
PINNED_INFLIGHT = collections.deque()
DEFERRED_UNPIN = collections.deque()
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
@ -1252,7 +1270,19 @@ def pin_memory(tensor):
return False
def unpin_memory(tensor):
def _track_pinned_inflight(stream, tensor):
event = torch.cuda.Event()
event.record(stream)
PINNED_INFLIGHT.append((event, tensor))
def _tensor_inflight(tensor):
ptr = tensor.data_ptr()
for _, inflight_tensor in PINNED_INFLIGHT:
if inflight_tensor.data_ptr() == ptr:
return True
return False
def _unpin_memory_now(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
@ -1283,6 +1313,46 @@ def unpin_memory(tensor):
return False
def _retry_deferred_unpins():
if not DEFERRED_UNPIN:
return
remaining = collections.deque()
while DEFERRED_UNPIN:
tensor = DEFERRED_UNPIN.popleft()
if _tensor_inflight(tensor):
remaining.append(tensor)
continue
if not _unpin_memory_now(tensor):
remaining.append(tensor)
DEFERRED_UNPIN.extend(remaining)
def _reap_pinned_inflight():
if not PINNED_INFLIGHT:
_retry_deferred_unpins()
return
remaining = collections.deque()
while PINNED_INFLIGHT:
event, tensor = PINNED_INFLIGHT.popleft()
if event.query():
continue
remaining.append((event, tensor))
PINNED_INFLIGHT.extend(remaining)
_retry_deferred_unpins()
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_device_cpu(tensor.device):
return False
_reap_pinned_inflight()
if _tensor_inflight(tensor):
DEFERRED_UNPIN.append(tensor)
return False
return _unpin_memory_now(tensor)
def sage_attention_enabled():
return args.use_sage_attention

View File

@ -621,6 +621,16 @@ class ModelPatcher:
weight, set_func, convert_func = get_key_weight(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update
if comfy.disk_weights.disk_weights_enabled() and weight is not None and weight.device.type == "meta":
parts = key.split(".")
param_name = parts[-1]
module = self.model
for part in parts[:-1]:
module = getattr(module, part)
target_device = device_to or self.offload_device or torch.device("cpu")
comfy.disk_weights.load_module_tensor(module, param_name, device=target_device)
weight, set_func, convert_func = get_key_weight(self.model, key)
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
@ -650,8 +660,8 @@ class ModelPatcher:
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
if comfy.model_management.unpin_memory(weight):
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
@ -885,9 +895,16 @@ class ModelPatcher:
NS = comfy.model_management.NUM_STREAMS
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
remaining_ram = None
cpu_device = torch.device("cpu")
if device_to is not None and comfy.model_management.is_device_cpu(device_to):
remaining_ram = comfy.model_management.get_free_memory(device_to)
def offload_module_tree(module):
freed = 0
for submodule in module.modules():
freed += comfy.disk_weights.offload_module_weights(submodule)
return freed
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
@ -922,15 +939,28 @@ class ModelPatcher:
cast_weight = self.force_cast_weights
freed_bytes = module_mem
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
freed_bytes = comfy.disk_weights.offload_module_weights(m)
freed_bytes = offload_module_tree(m)
if freed_bytes == 0:
freed_bytes = module_mem
else:
if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled():
logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024))
freed_bytes = comfy.disk_weights.offload_module_weights(m)
if freed_bytes == 0:
freed_bytes = module_mem
if remaining_ram is not None and comfy.disk_weights.disk_weights_enabled():
required_bytes = module_mem
headroom = comfy.disk_weights.ram_headroom_bytes()
comfy.model_management.free_memory(required_bytes + headroom, cpu_device, keep_loaded=[self])
remaining_ram = comfy.model_management.get_free_memory(cpu_device)
if remaining_ram < required_bytes:
logging.info(
"Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.",
n,
required_bytes / (1024 * 1024),
remaining_ram / (1024 * 1024),
)
freed_bytes = offload_module_tree(m)
if freed_bytes == 0:
freed_bytes = module_mem
else:
comfy.disk_weights.move_module_tensors(m, device_to)
remaining_ram = max(0, remaining_ram - required_bytes)
else:
if comfy.disk_weights.disk_weights_enabled():
comfy.disk_weights.move_module_tensors(m, device_to)
@ -980,16 +1010,22 @@ class ModelPatcher:
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
with self.use_ejected(skip_and_inject_on_exit_only=True):
offload_device = self.offload_device
if comfy.disk_weights.disk_weights_enabled() and device_to is not None:
if comfy.model_management.is_device_cpu(device_to):
offload_device = torch.device("meta")
else:
offload_device = torch.device("cpu")
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model
used = self.model.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
self.unpatch_model(offload_device, unpatch_weights=unpatch_weights)
if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
self.partially_unload(offload_device, -extra_memory, force_patch_weights=force_patch_weights)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:

View File

@ -535,10 +535,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
if value.device.type != "meta":
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
@ -555,16 +554,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None and layer_conf.device.type != "meta":
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
elif layer_conf is not None:
layer_conf = None
if layer_conf is None:
if weight.device.type == "meta":
self.weight = torch.nn.Parameter(weight, requires_grad=False)
else:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
@ -610,13 +604,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
if weight.device.type == "meta":
self.weight = torch.nn.Parameter(weight, requires_grad=False)
else:
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
@ -626,10 +617,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_v = state_dict.pop(param_key, None)
if _v is None:
continue
if _v.device.type == "meta":
self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
else:
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)