Refine disk weight offload integration

This commit is contained in:
ifilipis 2026-01-18 01:08:07 +02:00
parent b93f165c2e
commit 95ca11fe25
19 changed files with 307 additions and 250 deletions

View File

@ -349,12 +349,12 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
| `--low-ram` | Enable disk weight offloading. Sets RAM headroom to 1024MB and uses `--reserve-vram` as the VRAM headroom. |
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
### Disk tier for model weights
When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
When `--low-ram` is enabled, ComfyUI streams safetensors weights from disk and offloads weights to disk when RAM headroom is reached.
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.

View File

@ -29,7 +29,7 @@ class AudioEncoderModel():
self.model_sample_rate = 16000
def load_sd(self, sd):
return comfy.utils.load_state_dict(self.model, sd, strict=False)
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()

View File

@ -114,7 +114,7 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
parser.add_argument("--low-ram", action="store_true", help="Enable disk weight offloading. Sets RAM headroom to 1024MB and uses --reserve-vram as the VRAM headroom.")
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
attn_group = parser.add_mutually_exclusive_group()

View File

@ -6,7 +6,6 @@ import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
@ -48,7 +47,7 @@ class ClipVisionModel():
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return comfy.utils.load_state_dict(self.model, sd, strict=False)
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()

View File

@ -25,7 +25,6 @@ import logging
import comfy.utils
import comfy.model_management
import comfy.model_detection
import comfy.disk_weights
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
@ -386,7 +385,7 @@ class ControlLora(ControlNet):
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
@ -440,7 +439,7 @@ def controlnet_config(sd, model_options={}):
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
missing, unexpected = control_model.load_state_dict(sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@ -474,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}):
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
@ -749,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
pass
w = WeightsLoader()
w.control_model = control_model
missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else:
missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@ -817,8 +816,8 @@ class T2IAdapter(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
comfy.disk_weights.module_to(self.t2i_model, self.device)
self.t2i_model.to(dtype=x_noisy.dtype)
self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu()
@ -875,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
else:
return None
missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
missing, unexpected = model_ad.load_state_dict(t2i_data, strict=True)
if len(missing) > 0:
logging.warning("t2i missing {}".format(missing))

View File

@ -33,7 +33,11 @@ from . import safetensors_stream
ALLOW_GDS = False
PIN_IF_CPU = False
DISK_WEIGHTS_ENABLED = False
RAM_HEADROOM_BYTES = 0
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
BASE_MODULE_TO = torch.nn.Module.to
BASE_LOAD_STATE_DICT = torch.nn.Module.load_state_dict
_MONKEYPATCHED = False
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary()
_MISSING = object()
@ -169,13 +173,17 @@ CACHE = DiskWeightCache(0)
LOGGER = logging.getLogger(__name__)
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, enabled: bool = True):
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED, RAM_HEADROOM_BYTES
ALLOW_GDS = allow_gds
PIN_IF_CPU = pin_if_cpu
DISK_WEIGHTS_ENABLED = enabled
CACHE.set_limit(cache_bytes if enabled else 0)
if not enabled:
RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes))
CACHE.set_limit(0 if enabled else 0)
if enabled:
install_monkeypatches()
else:
uninstall_monkeypatches()
CACHE._entries.clear()
CACHE.current_bytes = 0
@ -184,6 +192,66 @@ def disk_weights_enabled() -> bool:
return DISK_WEIGHTS_ENABLED
def ram_headroom_bytes() -> int:
return RAM_HEADROOM_BYTES
def _is_stream_state_dict(state_dict) -> bool:
return (
getattr(state_dict, "is_stream_state_dict", False)
and hasattr(state_dict, "get_tensor")
and hasattr(state_dict, "meta")
)
def patched_to(self: torch.nn.Module, *args, **kwargs):
if not disk_weights_enabled():
return BASE_MODULE_TO(self, *args, **kwargs)
device, dtype, non_blocking, memory_format = torch._C._nn._parse_to(*args, **kwargs)
module_to(
self,
device=device,
dtype=dtype,
non_blocking=non_blocking,
memory_format=memory_format,
)
return self
def patched_load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
if not disk_weights_enabled():
if _is_stream_state_dict(state_dict):
return safetensors_stream.stream_load_state_dict(
self,
state_dict,
strict=strict,
assign=assign,
)
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
if _is_stream_state_dict(state_dict):
missing_keys, unexpected_keys = lazy_load_state_dict(self, state_dict, strict=strict)
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
def install_monkeypatches():
global _MONKEYPATCHED
if _MONKEYPATCHED:
return
torch.nn.Module.to = patched_to
torch.nn.Module.load_state_dict = patched_load_state_dict
_MONKEYPATCHED = True
def uninstall_monkeypatches():
global _MONKEYPATCHED
if not _MONKEYPATCHED:
return
torch.nn.Module.to = BASE_MODULE_TO
torch.nn.Module.load_state_dict = BASE_LOAD_STATE_DICT
_MONKEYPATCHED = False
def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
if not disk_weights_enabled():
return
@ -369,33 +437,29 @@ def _device_free_memory(device: torch.device) -> int:
return int(model_management.get_free_memory(device))
def _evict_ram_for_budget(required_bytes: int) -> int:
if required_bytes <= 0:
return 0
freed = evict_ram_cache(required_bytes)
if freed < required_bytes:
def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_bytes: int) -> int:
free_before = _device_free_memory(device)
if free_before < required_bytes + headroom_bytes:
LOGGER.debug(
"Disk weight memory pressure: required=%d free=%d headroom=%d device=%s",
required_bytes,
free_before,
headroom_bytes,
device,
)
safetensors_stream._reap_pinned_inflight()
from . import model_management
freed += model_management.evict_ram_to_disk(required_bytes - freed)
return freed
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
free_mem = _device_free_memory(device)
if device.type == "cpu" and free_mem < required_bytes:
_evict_ram_for_budget(required_bytes - free_mem)
free_mem = _device_free_memory(device)
return free_mem
def _choose_alternate_device(device: torch.device) -> Optional[torch.device]:
from . import model_management
if device.type == "cpu":
alt = model_management.get_torch_device()
if alt.type != "cpu":
return alt
else:
return torch.device("cpu")
return None
model_management.free_memory(required_bytes + headroom_bytes, device)
free_after = _device_free_memory(device)
freed = max(0, free_after - free_before)
LOGGER.debug(
"Disk weight memory freed: freed=%d free=%d device=%s",
freed,
free_after,
device,
)
return free_after
return free_before
class _BudgetedStateDict(MutableMapping):
@ -527,8 +591,10 @@ class _BudgetedStateDict(MutableMapping):
if default is _MISSING:
raise KeyError(key)
return default
value = self.get_tensor(key)
self._deleted.add(key)
return self.get_tensor(key)
self._overrides.pop(key, None)
return value
def meta(self, key: str):
return self._get_meta(key)
@ -564,6 +630,7 @@ def register_lazy_modules(model: torch.nn.Module, state_dict):
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
safetensors_stream._reap_pinned_inflight()
lazy_state = LAZY_MODULE_STATE.get(module)
if lazy_state is not None:
CACHE.remove_module(module)
@ -671,7 +738,6 @@ def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype:
def ensure_module_materialized(
module: torch.nn.Module,
target_device: torch.device,
fallback_device: Optional[torch.device] = None,
dtype_override: Optional[torch.dtype] = None,
):
lazy_state = LAZY_MODULE_STATE.get(module)
@ -692,7 +758,6 @@ def ensure_module_materialized(
_set_future_dtype(module, name, dtype_override)
_rebuild_materialization_state(module, refs, state)
free_mem_start = _device_free_memory(target_device)
remaining_budget = free_mem_start
for name in sorted(refs.keys()):
disk_ref = refs[name]
if name in module._parameters:
@ -717,19 +782,11 @@ def ensure_module_materialized(
continue
required_bytes = meta_nbytes
if target_device.type == "cpu":
free_mem = _maybe_free_ram_budget(target_device, required_bytes)
remaining_budget = min(remaining_budget, free_mem)
if required_bytes > remaining_budget:
if fallback_device is not None and fallback_device != target_device:
fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes)
if fallback_free >= required_bytes:
target_for_load = fallback_device
else:
continue
else:
continue
_ensure_free_memory(target_device, required_bytes, RAM_HEADROOM_BYTES)
else:
target_for_load = target_device
from . import model_management
_ensure_free_memory(target_device, required_bytes, model_management.extra_reserved_memory())
target_for_load = target_device
if current.device.type == "meta":
tensor = disk_ref.load(
target_for_load,
@ -748,7 +805,6 @@ def ensure_module_materialized(
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
if tensor.device.type == "cpu":
CACHE.record(module, name, tensor, is_buffer=is_buffer)
remaining_budget = max(0, remaining_budget - required_bytes)
_rebuild_materialization_state(module, refs, state)
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
@ -761,14 +817,11 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
if getattr(module, "comfy_cast_weights", False):
target_device = torch.device("cpu")
fallback_device = _find_tensor_device(args, kwargs)
else:
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
fallback_device = None
ensure_module_materialized(
module,
target_device,
fallback_device=fallback_device,
dtype_override=dtype_override,
)
@ -786,6 +839,7 @@ def attach_disk_weight_hooks(model: torch.nn.Module):
def evict_ram_cache(bytes_to_free: int):
if bytes_to_free <= 0:
return 0
safetensors_stream._reap_pinned_inflight()
return CACHE.evict_bytes(bytes_to_free)
@ -827,16 +881,7 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
def _move(tensor):
if tensor is None:
return None
if tensor.device.type == "meta":
return tensor
if dtype_override is not None and tensor.dtype != dtype_override:
return tensor.to(device=device_to, dtype=dtype_override)
return tensor.to(device=device_to)
module._apply(_move)
ensure_module_materialized(module, device_to, dtype_override=dtype_override)
return module
@ -864,10 +909,20 @@ def offload_module_weights(module: torch.nn.Module) -> int:
return offloaded_bytes
def module_to(module: torch.nn.Module, *args, **kwargs):
def module_to(
module: torch.nn.Module,
*args,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
memory_format=None,
**kwargs,
):
allow_materialize = kwargs.pop("allow_materialize", True)
arg_device = _extract_to_device(args, kwargs)
arg_dtype = _extract_to_dtype(args, kwargs)
if disk_weights_enabled():
target_device = _extract_to_device(args, kwargs)
target_device = device or arg_device
if target_device is None:
target_device = _find_existing_device(module) or torch.device("cpu")
if target_device.type == "meta":
@ -875,10 +930,28 @@ def module_to(module: torch.nn.Module, *args, **kwargs):
return module
if allow_materialize:
materialize_module_tree(module, target_device)
return module.to(*args, **kwargs)
dtype_override = _extract_to_dtype(args, kwargs)
base_kwargs = dict(kwargs)
if device is not None and arg_device is None:
base_kwargs["device"] = device
if dtype is not None and arg_dtype is None:
base_kwargs["dtype"] = dtype
if non_blocking:
base_kwargs["non_blocking"] = non_blocking
if memory_format is not None:
base_kwargs["memory_format"] = memory_format
return BASE_MODULE_TO(module, *args, **base_kwargs)
dtype_override = dtype or arg_dtype
return move_module_tensors(module, target_device, dtype_override=dtype_override)
return module.to(*args, **kwargs)
base_kwargs = dict(kwargs)
if device is not None and arg_device is None:
base_kwargs["device"] = device
if dtype is not None and arg_dtype is None:
base_kwargs["dtype"] = dtype
if non_blocking:
base_kwargs["non_blocking"] = non_blocking
if memory_format is not None:
base_kwargs["memory_format"] = memory_format
return BASE_MODULE_TO(module, *args, **base_kwargs)
def load_module_tensor(
@ -886,7 +959,6 @@ def load_module_tensor(
name: str,
device: torch.device,
*,
allow_alternate: bool = True,
record_cache: bool = True,
temporary: bool = False,
dtype_override: Optional[torch.dtype] = None,
@ -909,6 +981,9 @@ def load_module_tensor(
_set_future_dtype(module, name, dtype_override)
if current.device.type != "meta":
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
from . import model_management
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
_ensure_free_memory(device, _tensor_nbytes(current), headroom)
if target_dtype is not None and current.dtype != target_dtype:
tensor = current.to(device=device, dtype=target_dtype)
else:
@ -926,41 +1001,11 @@ def load_module_tensor(
required_bytes = _meta_nbytes(disk_ref.meta)
if required_bytes is None:
return current
free_mem_start = _device_free_memory(device)
free_mem = _maybe_free_ram_budget(device, required_bytes)
load_device = device
if free_mem < required_bytes and allow_alternate:
alt = _choose_alternate_device(device)
if alt is not None:
alt_free = _maybe_free_ram_budget(alt, required_bytes)
if alt_free >= required_bytes:
load_device = alt
else:
state = _get_materialization_state(module)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
return current
else:
state = _get_materialization_state(module)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
return current
elif free_mem < required_bytes:
state = _get_materialization_state(module)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
return current
from . import model_management
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
_ensure_free_memory(device, required_bytes, headroom)
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
tensor = disk_ref.load(device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
if temporary:
return tensor
if is_buffer:
@ -971,7 +1016,7 @@ def load_module_tensor(
CACHE.record(module, name, tensor, is_buffer=is_buffer)
state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state)
_log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
_log_materialization(module, device, _device_free_memory(device), refs, state, "Disk weight loaded")
return tensor
@ -1015,22 +1060,15 @@ def _materialize_module_from_state_dict(
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
existing[key] = buf
free_mem_start = _device_free_memory(target_device)
remaining_budget = free_mem_start
allowed = set(existing.keys())
allowed = set(keys)
from . import model_management
headroom = RAM_HEADROOM_BYTES if target_device.type == "cpu" else model_management.extra_reserved_memory()
for key in keys:
if key in allowed:
continue
meta = _state_dict_meta(lazy_state.state_dict, key)
required = _meta_nbytes(meta)
if required is None:
continue
if target_device.type == "cpu":
free_mem = _maybe_free_ram_budget(target_device, required)
remaining_budget = min(remaining_budget, free_mem)
if required <= remaining_budget:
allowed.add(key)
remaining_budget = max(0, remaining_budget - required)
deferred_state_dict_keys = {key for key in keys if key not in allowed}
_ensure_free_memory(target_device, required, headroom)
state_dict = _BudgetedStateDict(
lazy_state.state_dict,
allowed_keys=allowed,
@ -1065,7 +1103,7 @@ def _materialize_module_from_state_dict(
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
_rebuild_materialization_state(module, refs, state)
lazy_state.loaded = len(deferred_state_dict_keys) == 0
lazy_state.loaded = True
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
for name, param in module.named_parameters(recurse=False):
if param.device.type == "cpu":

View File

@ -283,7 +283,7 @@ def load_gligen(sd):
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
comfy.utils.load_state_dict(gated, n_sd, strict=False)
gated.load_state_dict(n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
@ -294,7 +294,7 @@ def load_gligen(sd):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
comfy.utils.load_state_dict(w, sd, strict=False)
w.load_state_dict(sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen

View File

@ -1,5 +1,4 @@
import torch
import comfy.utils
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
@ -113,7 +112,7 @@ class HunyuanVideo15SRModel():
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return comfy.utils.load_state_dict(self.model, sd, strict=True)
return self.model.load_state_dict(sd, strict=True)
def get_sd(self):
return self.model.state_dict()

View File

@ -2,7 +2,6 @@ import json
from dataclasses import dataclass
import math
import torch
import comfy.utils
import torchaudio
import comfy.model_management
@ -154,8 +153,8 @@ class AudioVAE(torch.nn.Module):
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(

View File

@ -2,7 +2,6 @@ import logging
from typing import Optional
import torch
import comfy.utils
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
@ -153,7 +152,7 @@ class VAE(nn.Module):
return dec, posterior
def load_weights(self, src_dict) -> None:
comfy.utils.load_state_dict(self, src_dict, strict=True)
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:

View File

@ -309,7 +309,7 @@ class BaseModel(torch.nn.Module):
else:
to_load = sd
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
@ -753,8 +753,8 @@ class StableAudio1(BaseModel):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights, strict=True)
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights, strict=True)
def extra_conds(self, **kwargs):
out = {}

View File

@ -590,10 +590,26 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024))
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
if freed_cache < memory_required:
evict_ram_to_disk(memory_required - freed_cache)
free_before = get_free_memory(device)
headroom = comfy.disk_weights.ram_headroom_bytes()
if free_before < memory_required:
logging.debug(
"RAM pressure: required=%d free=%d headroom=%d",
memory_required,
free_before,
headroom,
)
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
freed_disk = 0
if freed_cache < memory_required:
freed_disk = evict_ram_to_disk(memory_required - freed_cache)
free_after = get_free_memory(device)
freed_total = max(0, free_after - free_before)
logging.debug(
"RAM freed: freed=%d free=%d",
freed_total if freed_total > 0 else freed_cache + freed_disk,
free_after,
)
unloaded_model = []
can_unload = []
unloaded_models = []
@ -636,6 +652,7 @@ def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
if not comfy.disk_weights.disk_weights_enabled():
return 0
free_before = get_free_memory(torch.device("cpu"))
freed = 0
can_unload = []
for i in range(len(current_loaded_models) - 1, -1, -1):
@ -654,7 +671,14 @@ def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
if freed > 0:
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
free_after = get_free_memory(torch.device("cpu"))
freed_total = max(0, free_after - free_before)
logging.debug(
"RAM evicted to disk: required=%d free=%d freed=%d",
memory_to_free,
free_before,
freed_total if freed_total > 0 else freed,
)
return freed
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
@ -802,6 +826,8 @@ def dtype_size(dtype):
return dtype_size
def unet_offload_device():
if comfy.disk_weights.disk_weights_enabled():
return torch.device("meta")
if vram_state == VRAMState.HIGH_VRAM:
return get_torch_device()
else:
@ -906,6 +932,8 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
return torch.float32
def text_encoder_offload_device():
if comfy.disk_weights.disk_weights_enabled():
return torch.device("meta")
if args.gpu_only:
return get_torch_device()
else:
@ -966,6 +994,8 @@ def vae_device():
return get_torch_device()
def vae_offload_device():
if comfy.disk_weights.disk_weights_enabled():
return torch.device("meta")
if args.gpu_only:
return get_torch_device()
else:
@ -1163,14 +1193,13 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
WEIGHTS_RAM_CACHE_BYTES = 0
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
if args.weights_ram_cache_gb is not None:
WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
if args.low_ram:
comfy.disk_weights.configure(
WEIGHTS_RAM_CACHE_BYTES,
allow_gds=WEIGHTS_GDS_ENABLED,
pin_if_cpu=not args.disable_pinned_memory,
ram_headroom_bytes=1024 * 1024 * 1024,
enabled=True,
)
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])

View File

@ -19,7 +19,6 @@
import torch
import logging
import comfy.model_management
import comfy.disk_weights
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
@ -101,27 +100,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
weight_source = s.weight
bias_source = s.bias
if comfy.disk_weights.disk_weights_enabled():
if weight_source.device.type == "meta":
loaded = comfy.disk_weights.load_module_tensor(
s,
"weight",
device,
temporary=True,
dtype_override=dtype,
)
if loaded is not None:
weight_source = loaded
if bias_source is not None and bias_source.device.type == "meta":
loaded_bias = comfy.disk_weights.load_module_tensor(
s,
"bias",
device,
temporary=True,
dtype_override=bias_dtype,
)
if loaded_bias is not None:
bias_source = loaded_bias
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)

View File

@ -37,6 +37,19 @@ _FST_LOADED = False
_GDS_INITIALIZED = False
_MISSING = object()
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
_PINNED_INFLIGHT = collections.deque()
def _reap_pinned_inflight():
if not _PINNED_INFLIGHT:
return
pending = collections.deque()
while _PINNED_INFLIGHT:
event, tensor = _PINNED_INFLIGHT.popleft()
if event.query():
continue
pending.append((event, tensor))
_PINNED_INFLIGHT.extend(pending)
def _require_fastsafetensors():
@ -204,14 +217,22 @@ class _SafeTensorFile:
)
return tensor
target_dtype = dtype
if device_is_cuda and pin_if_cpu:
_reap_pinned_inflight()
target_dtype = None
cpu_tensor = self._read_tensor_nogds(
fst, framework, meta, torch.device("cpu"), dtype
fst, framework, meta, torch.device("cpu"), target_dtype, pin_memory=bool(device_is_cuda and pin_if_cpu)
)
if device_is_cuda:
if pin_if_cpu:
cpu_tensor = cpu_tensor.pin_memory()
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
if pin_if_cpu:
event = torch.cuda.Event()
event.record(torch.cuda.current_stream(device))
_PINNED_INFLIGHT.append((event, cpu_tensor))
if dtype is not None and dtype != gpu_tensor.dtype:
gpu_tensor = gpu_tensor.to(dtype=dtype)
return gpu_tensor
return cpu_tensor
@ -233,6 +254,8 @@ class _SafeTensorFile:
meta: TensorMeta,
device: torch.device,
dtype: Optional[torch.dtype],
*,
pin_memory: bool = False,
) -> torch.Tensor:
fd = self._ensure_fd()
reader = self._ensure_nogds_reader(use_cuda=False)
@ -241,7 +264,7 @@ class _SafeTensorFile:
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
chunk_bytes = max(1, chunk_bytes)
ptr_align = framework.get_device_ptr_align()
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu")
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu", pin_memory=pin_memory)
buffer_length = 0
buf_ptr = None
gbuf = None
@ -250,6 +273,8 @@ class _SafeTensorFile:
while chunk_offset < length:
chunk_len = min(length - chunk_offset, chunk_bytes)
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
if aligned_offset + aligned_length > self.index.size_bytes:
aligned_length = max(0, self.index.size_bytes - aligned_offset)
needed = aligned_length + ptr_align
if buf_ptr is None or needed > buffer_length:
if buf_ptr is not None:
@ -272,7 +297,6 @@ class _SafeTensorFile:
if buf_ptr is not None:
fst.cpp.cpu_free(buf_ptr)
if dtype is not None and dtype != dest_tensor.dtype:
_validate_dtype_conversion(dest_tensor.dtype, dtype)
dest_tensor = dest_tensor.to(dtype=dtype)
return dest_tensor
@ -289,6 +313,8 @@ class _SafeTensorFile:
abs_start = self.index.header_length + meta.data_offsets[0]
length = meta.data_offsets[1] - meta.data_offsets[0]
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
if aligned_offset + aligned_length > self.index.size_bytes:
aligned_length = max(0, self.index.size_bytes - aligned_offset)
ptr_align = framework.get_device_ptr_align()
buffer_length = aligned_length + ptr_align
fst_device = _fst_device_from_torch(fst, device)
@ -306,7 +332,6 @@ class _SafeTensorFile:
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
)
if dtype is not None and dtype != tensor.dtype:
_validate_dtype_conversion(tensor.dtype, dtype)
tensor = tensor.to(dtype=dtype)
return tensor
@ -348,8 +373,7 @@ def _dlpack_tensor_from_buffer(
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
return
def _get_gds_o_direct(framework) -> bool:
@ -523,8 +547,10 @@ class StreamStateDict(collections.abc.MutableMapping):
raise KeyError(key)
return default
if self._index.has(key):
value = self.get_tensor(key)
self._deleted.add(key)
return self.get_tensor(key)
self._overrides.pop(key, None)
return value
if default is _MISSING:
raise KeyError(key)
return default
@ -636,8 +662,10 @@ class _BaseViewStateDict(MutableMapping):
if default is _MISSING:
raise
return default
value = self.get_tensor(key)
self._deleted.add(key)
return self.get_tensor(key)
self._overrides.pop(key, None)
return value
def meta(self, key: str):
if key in self._overrides:
@ -768,8 +796,10 @@ class DeviceViewStateDict(_BaseViewStateDict):
if default is _MISSING:
raise
return default
value = self.get_tensor(key)
self._deleted.add(key)
return self.get_tensor(key)
self._overrides.pop(key, None)
return value
class FilterViewStateDict(_BaseViewStateDict):
@ -932,3 +962,48 @@ class MappedStateDict(_BaseViewStateDict):
def _iter_base_keys(self) -> Iterable[str]:
return self._view_to_base.keys()
def stream_load_state_dict(model, state_dict, strict: bool = False, assign: bool = False):
if getattr(state_dict, "is_stream_state_dict", False) and hasattr(state_dict, "copy"):
state_dict = state_dict.copy()
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(state_dict, "_metadata", None)
def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata["assign_to_params_buffers"] = assign
module._load_from_state_dict(
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
child_prefix = f"{prefix}{name}."
child_state_dict = FilterViewStateDict(
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
)
load(child, child_state_dict, child_prefix)
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible)
if out is not None:
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
load(model, state_dict)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)

View File

@ -26,7 +26,6 @@ import os
import comfy.utils
import comfy.safetensors_stream
import comfy.disk_weights
from . import clip_vision
from . import gligen
@ -126,7 +125,7 @@ class CLIP:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
if params['device'] != offload_device:
comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
@ -290,7 +289,7 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
@ -658,7 +657,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
@ -672,7 +671,7 @@ class VAE:
if dtype is None:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
self.first_stage_model.to(dtype=self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
@ -979,7 +978,7 @@ def load_style_model(ckpt_path):
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
comfy.utils.load_state_dict(model, model_data, strict=True)
model.load_state_dict(model_data, strict=True)
return StyleModel(model)
def sd_shape(state_dict, key):
@ -1547,7 +1546,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model = comfy.disk_weights.module_to(model, offload_device)
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:

View File

@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self(tokens)
def load_sd(self, sd):
return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string):
result = []

View File

@ -56,9 +56,9 @@ class TAESD(nn.Module):
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
if decoder_path is not None:
comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
@staticmethod
def scale_latents(x):

View File

@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module):
if len(sdo) == 0:
sdo = sd
return comfy.utils.load_state_dict(self, sdo, strict=False)
return self.load_state_dict(sdo, strict=False)
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):

View File

@ -32,7 +32,6 @@ from einops import rearrange
from comfy.cli_args import args
import json
from . import safetensors_stream
import comfy.disk_weights
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -166,62 +165,6 @@ def state_dict_meta(state_dict, key):
)
def load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict):
if comfy.disk_weights.disk_weights_enabled():
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
comfy.disk_weights.register_module_weights(model, state_dict)
comfy.disk_weights.attach_disk_weight_hooks(model)
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
return missing, unexpected
return model.load_state_dict(state_dict, strict=strict)
def stream_load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
state_dict = state_dict.copy()
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(state_dict, "_metadata", None)
def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata["assign_to_params_buffers"] = assign
module._load_from_state_dict(
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
child_prefix = f"{prefix}{name}."
child_state_dict = safetensors_stream.FilterViewStateDict(
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
)
load(child, child_state_dict, child_prefix)
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible)
if out is not None:
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
load(model, state_dict)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
return missing_keys, unexpected_keys
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight",