Compare commits

...

6 Commits

Author SHA1 Message Date
ifilipis
c825bc526e Add functions for tensor input handling and casting 2026-01-21 20:00:35 +02:00
ifilipis
fcbd22b514 Fix weight casting double allocation 2026-01-21 20:00:35 +02:00
ifilipis
91809e83ff Fix disk weight device handling and cache accounting 2026-01-21 20:00:35 +02:00
ifilipis
82e70aa3c2 Fix disk weight movement and pinned inflight tracking 2026-01-21 20:00:35 +02:00
ifilipis
c3eaea0429 Fix disk-weight tiering and meta handling 2026-01-21 20:00:35 +02:00
ifilipis
95ca11fe25 Refine disk weight offload integration 2026-01-21 20:00:35 +02:00
20 changed files with 818 additions and 385 deletions

View File

@ -349,12 +349,12 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
| `--low-ram` | Enable disk weight offloading. Sets RAM headroom to 1024MB and uses `--reserve-vram` as the VRAM headroom. |
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
### Disk tier for model weights
When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
When `--low-ram` is enabled, ComfyUI streams safetensors weights from disk and offloads weights to disk when RAM headroom is reached.
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.

View File

@ -29,7 +29,7 @@ class AudioEncoderModel():
self.model_sample_rate = 16000
def load_sd(self, sd):
return comfy.utils.load_state_dict(self.model, sd, strict=False)
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()

View File

@ -114,7 +114,7 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
parser.add_argument("--low-ram", action="store_true", help="Enable disk weight offloading. Sets RAM headroom to 1024MB and uses --reserve-vram as the VRAM headroom.")
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
attn_group = parser.add_mutually_exclusive_group()

View File

@ -6,7 +6,6 @@ import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
@ -48,7 +47,7 @@ class ClipVisionModel():
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return comfy.utils.load_state_dict(self.model, sd, strict=False)
return self.model.load_state_dict(sd, strict=False)
def get_sd(self):
return self.model.state_dict()

View File

@ -25,7 +25,6 @@ import logging
import comfy.utils
import comfy.model_management
import comfy.model_detection
import comfy.disk_weights
import comfy.model_patcher
import comfy.ops
import comfy.latent_formats
@ -386,7 +385,7 @@ class ControlLora(ControlNet):
controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
@ -440,7 +439,7 @@ def controlnet_config(sd, model_options={}):
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
missing, unexpected = control_model.load_state_dict(sd, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@ -474,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}):
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
@ -749,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
pass
w = WeightsLoader()
w.control_model = control_model
missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else:
missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))
@ -817,8 +816,8 @@ class T2IAdapter(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None:
comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
comfy.disk_weights.module_to(self.t2i_model, self.device)
self.t2i_model.to(dtype=x_noisy.dtype)
self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu()
@ -875,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
else:
return None
missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
missing, unexpected = model_ad.load_state_dict(t2i_data, strict=True)
if len(missing) > 0:
logging.warning("t2i missing {}".format(missing))

View File

@ -33,7 +33,11 @@ from . import safetensors_stream
ALLOW_GDS = False
PIN_IF_CPU = False
DISK_WEIGHTS_ENABLED = False
RAM_HEADROOM_BYTES = 0
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
BASE_MODULE_TO = torch.nn.Module.to
BASE_LOAD_STATE_DICT = torch.nn.Module.load_state_dict
_MONKEYPATCHED = False
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary()
_MISSING = object()
@ -92,6 +96,7 @@ class CacheEntry:
name: str
size_bytes: int
is_buffer: bool
device_type: str
class DiskWeightCache:
@ -108,16 +113,25 @@ class DiskWeightCache:
return (id(module), name)
def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool):
if tensor.device.type != "cpu":
if tensor.device.type == "meta":
return
size_bytes = tensor.numel() * tensor.element_size()
key = self._entry_key(module, name)
if key in self._entries:
entry = self._entries.pop(key)
self.current_bytes -= entry.size_bytes
if entry.device_type == "cpu":
self.current_bytes -= entry.size_bytes
module_ref = weakref.ref(module, self._drop_module_entries)
self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer)
self.current_bytes += size_bytes
device_type = tensor.device.type
self._entries[key] = CacheEntry(
module_ref=module_ref,
name=name,
size_bytes=size_bytes,
is_buffer=is_buffer,
device_type=device_type,
)
if device_type == "cpu":
self.current_bytes += size_bytes
self._evict_if_needed()
def touch(self, module: torch.nn.Module, name: str):
@ -129,9 +143,10 @@ class DiskWeightCache:
def evict_bytes(self, bytes_to_free: int):
freed = 0
while self._entries and freed < bytes_to_free:
_, entry = self._entries.popitem(last=False)
entry = self.pop_lru(torch.device("cpu"))
if entry is None:
break
freed += entry.size_bytes
self.current_bytes -= entry.size_bytes
module = entry.module_ref()
if module is not None:
_evict_module_weight(module, entry.name, entry.is_buffer)
@ -144,8 +159,26 @@ class DiskWeightCache:
to_remove.append(key)
for key in to_remove:
entry = self._entries.pop(key)
if entry.device_type == "cpu":
self.current_bytes -= entry.size_bytes
def remove_entry(self, module: torch.nn.Module, name: str):
key = self._entry_key(module, name)
entry = self._entries.pop(key, None)
if entry is None:
return
if entry.device_type == "cpu":
self.current_bytes -= entry.size_bytes
def pop_lru(self, device: torch.device) -> Optional[CacheEntry]:
for key, entry in self._entries.items():
if entry.device_type == device.type:
self._entries.pop(key)
if entry.device_type == "cpu":
self.current_bytes -= entry.size_bytes
return entry
return None
def _drop_module_entries(self, module_ref: weakref.ReferenceType):
to_remove = []
for key, entry in self._entries.items():
@ -153,12 +186,14 @@ class DiskWeightCache:
to_remove.append(key)
for key in to_remove:
entry = self._entries.pop(key)
self.current_bytes -= entry.size_bytes
if entry.device_type == "cpu":
self.current_bytes -= entry.size_bytes
def _evict_if_needed(self):
while self._entries and self.current_bytes > self.max_bytes:
_, entry = self._entries.popitem(last=False)
self.current_bytes -= entry.size_bytes
entry = self.pop_lru(torch.device("cpu"))
if entry is None:
break
module = entry.module_ref()
if module is not None:
_evict_module_weight(module, entry.name, entry.is_buffer)
@ -169,13 +204,34 @@ CACHE = DiskWeightCache(0)
LOGGER = logging.getLogger(__name__)
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, enabled: bool = True):
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED, RAM_HEADROOM_BYTES
ALLOW_GDS = allow_gds
PIN_IF_CPU = pin_if_cpu
DISK_WEIGHTS_ENABLED = enabled
CACHE.set_limit(cache_bytes if enabled else 0)
if not enabled:
RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes))
if enabled:
from . import model_management
cpu_capacity_bytes = max(0, model_management.get_total_memory(torch.device("cpu")) - RAM_HEADROOM_BYTES)
CACHE.set_limit(cpu_capacity_bytes)
LOGGER.debug(
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
enabled,
RAM_HEADROOM_BYTES,
cpu_capacity_bytes,
)
else:
CACHE.set_limit(0)
LOGGER.debug(
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
enabled,
RAM_HEADROOM_BYTES,
0,
)
if enabled:
install_monkeypatches()
else:
uninstall_monkeypatches()
CACHE._entries.clear()
CACHE.current_bytes = 0
@ -184,6 +240,66 @@ def disk_weights_enabled() -> bool:
return DISK_WEIGHTS_ENABLED
def ram_headroom_bytes() -> int:
return RAM_HEADROOM_BYTES
def _is_stream_state_dict(state_dict) -> bool:
return (
getattr(state_dict, "is_stream_state_dict", False)
and hasattr(state_dict, "get_tensor")
and hasattr(state_dict, "meta")
)
def patched_to(self: torch.nn.Module, *args, **kwargs):
if not disk_weights_enabled():
return BASE_MODULE_TO(self, *args, **kwargs)
device, dtype, non_blocking, memory_format = torch._C._nn._parse_to(*args, **kwargs)
module_to(
self,
device=device,
dtype=dtype,
non_blocking=non_blocking,
memory_format=memory_format,
)
return self
def patched_load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
if not disk_weights_enabled():
if _is_stream_state_dict(state_dict):
return safetensors_stream.stream_load_state_dict(
self,
state_dict,
strict=strict,
assign=assign,
)
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
if _is_stream_state_dict(state_dict):
missing_keys, unexpected_keys = lazy_load_state_dict(self, state_dict, strict=strict)
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
def install_monkeypatches():
global _MONKEYPATCHED
if _MONKEYPATCHED:
return
torch.nn.Module.to = patched_to
torch.nn.Module.load_state_dict = patched_load_state_dict
_MONKEYPATCHED = True
def uninstall_monkeypatches():
global _MONKEYPATCHED
if not _MONKEYPATCHED:
return
torch.nn.Module.to = BASE_MODULE_TO
torch.nn.Module.load_state_dict = BASE_LOAD_STATE_DICT
_MONKEYPATCHED = False
def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
if not disk_weights_enabled():
return
@ -197,7 +313,7 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
meta = state_dict.meta(key)
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
REGISTRY.register(submodule, name, ref)
if param.device.type == "cpu":
if param.device.type != "meta":
CACHE.record(submodule, name, param, is_buffer=False)
for name, buf in submodule.named_buffers(recurse=False):
key = f"{module_prefix}{name}" if module_prefix else name
@ -205,7 +321,7 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
meta = state_dict.meta(key)
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
REGISTRY.register(submodule, name, ref)
if buf.device.type == "cpu":
if buf.device.type != "meta":
CACHE.record(submodule, name, buf, is_buffer=True)
@ -269,6 +385,23 @@ def _meta_tensor(meta, dtype_override: Optional[torch.dtype] = None) -> torch.Te
return torch.empty(shape, dtype=dtype, device="meta")
def _attach_disk_identity(tensor: torch.Tensor, module: torch.nn.Module, name: str, is_buffer: bool):
tensor._disk_weights_module_ref = weakref.ref(module)
tensor._disk_weights_name = name
tensor._disk_weights_is_buffer = is_buffer
def materialize_meta_tensor(tensor: torch.Tensor, target_device: torch.device, dtype_override: Optional[torch.dtype]):
module_ref = getattr(tensor, "_disk_weights_module_ref", None)
name = getattr(tensor, "_disk_weights_name", None)
if module_ref is None or name is None:
raise RuntimeError("Meta tensor missing disk weight identity")
module = module_ref()
if module is None:
raise RuntimeError("Disk weight module reference expired")
return load_module_tensor(module, name, target_device, dtype_override=dtype_override, temporary=False)
def _state_dict_meta(state_dict: MutableMapping, key: str):
if hasattr(state_dict, "meta"):
return state_dict.meta(key)
@ -369,33 +502,30 @@ def _device_free_memory(device: torch.device) -> int:
return int(model_management.get_free_memory(device))
def _evict_ram_for_budget(required_bytes: int) -> int:
if required_bytes <= 0:
return 0
freed = evict_ram_cache(required_bytes)
if freed < required_bytes:
def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_bytes: int) -> int:
free_before = _device_free_memory(device)
if free_before < required_bytes + headroom_bytes:
LOGGER.debug(
"Disk weight memory pressure: required=%d free=%d headroom=%d device=%s",
required_bytes,
free_before,
headroom_bytes,
device,
)
safetensors_stream._reap_pinned_inflight()
from . import model_management
freed += model_management.evict_ram_to_disk(required_bytes - freed)
return freed
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
free_mem = _device_free_memory(device)
if device.type == "cpu" and free_mem < required_bytes:
_evict_ram_for_budget(required_bytes - free_mem)
free_mem = _device_free_memory(device)
return free_mem
def _choose_alternate_device(device: torch.device) -> Optional[torch.device]:
from . import model_management
if device.type == "cpu":
alt = model_management.get_torch_device()
if alt.type != "cpu":
return alt
else:
return torch.device("cpu")
return None
model_management.free_memory(required_bytes + headroom_bytes, device)
free_after = _device_free_memory(device)
freed = max(0, free_after - free_before)
LOGGER.debug(
"Disk weight memory freed: freed=%d free_before=%d free_after=%d device=%s",
freed,
free_before,
free_after,
device,
)
return free_after
return free_before
class _BudgetedStateDict(MutableMapping):
@ -527,8 +657,10 @@ class _BudgetedStateDict(MutableMapping):
if default is _MISSING:
raise KeyError(key)
return default
value = self.get_tensor(key)
self._deleted.add(key)
return self.get_tensor(key)
self._overrides.pop(key, None)
return value
def meta(self, key: str):
return self._get_meta(key)
@ -564,6 +696,8 @@ def register_lazy_modules(model: torch.nn.Module, state_dict):
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
safetensors_stream._reap_pinned_inflight()
from . import model_management
lazy_state = LAZY_MODULE_STATE.get(module)
if lazy_state is not None:
CACHE.remove_module(module)
@ -571,15 +705,31 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
if refs:
state = _get_materialization_state(module)
for ref_name, disk_ref in refs.items():
if ref_name in module._parameters:
current = module._parameters[ref_name]
elif ref_name in module._buffers:
current = module._buffers[ref_name]
else:
current = None
if (
current is not None
and current.device.type == "cpu"
and current.data_ptr() in model_management.PINNED_MEMORY
):
model_management.wait_for_pinned_tensor(current)
model_management.unpin_memory(current)
shape = getattr(disk_ref.meta, "shape", None)
dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None)
dtype = getattr(disk_ref.meta, "dtype", None)
if shape is None or dtype is None:
continue
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
if disk_ref.is_buffer:
module._buffers[ref_name] = meta_tensor
_attach_disk_identity(meta_tensor, module, ref_name, True)
else:
module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
param = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
module._parameters[ref_name] = param
_attach_disk_identity(param, module, ref_name, False)
nbytes = _meta_nbytes(disk_ref.meta)
if nbytes is not None:
state.loaded_keys.discard(ref_name)
@ -593,16 +743,31 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
ref = REGISTRY.get(module)
if not ref or name not in ref:
return
CACHE.remove_entry(module, name)
disk_ref = ref[name]
if is_buffer:
current = module._buffers.get(name)
else:
current = module._parameters.get(name)
if (
current is not None
and current.device.type == "cpu"
and current.data_ptr() in model_management.PINNED_MEMORY
):
model_management.wait_for_pinned_tensor(current)
model_management.unpin_memory(current)
shape = getattr(disk_ref.meta, "shape", None)
dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None)
dtype = getattr(disk_ref.meta, "dtype", None)
if shape is None or dtype is None:
return
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
if is_buffer:
module._buffers[name] = meta_tensor
_attach_disk_identity(meta_tensor, module, name, True)
else:
module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
param = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
module._parameters[name] = param
_attach_disk_identity(param, module, name, False)
state = _get_materialization_state(module)
nbytes = _meta_nbytes(disk_ref.meta)
if nbytes is not None:
@ -671,7 +836,6 @@ def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype:
def ensure_module_materialized(
module: torch.nn.Module,
target_device: torch.device,
fallback_device: Optional[torch.device] = None,
dtype_override: Optional[torch.dtype] = None,
):
lazy_state = LAZY_MODULE_STATE.get(module)
@ -687,12 +851,12 @@ def ensure_module_materialized(
if not refs:
return
state = _get_materialization_state(module)
if dtype_override is not None:
for name in refs.keys():
_set_future_dtype(module, name, dtype_override)
# Do not persist dtype overrides into storage.
_rebuild_materialization_state(module, refs, state)
free_mem_start = _device_free_memory(target_device)
remaining_budget = free_mem_start
from . import model_management
non_blocking = model_management.device_supports_non_blocking(target_device)
offload_stream = model_management.get_offload_stream(target_device) if non_blocking else None
for name in sorted(refs.keys()):
disk_ref = refs[name]
if name in module._parameters:
@ -705,31 +869,20 @@ def ensure_module_materialized(
continue
if current is None:
continue
target_dtype = dtype_override or _get_future_dtype(module, name)
if current.device.type != "meta" and current.device == target_device and (
target_dtype is None or current.dtype == target_dtype
):
if current.device.type == "cpu":
CACHE.touch(module, name)
# Persistent tensors must remain in stored dtype.
target_dtype = None
if current.device.type != "meta" and current.device == target_device:
CACHE.touch(module, name)
continue
meta_nbytes = _meta_nbytes(disk_ref.meta)
if meta_nbytes is None:
continue
required_bytes = meta_nbytes
if target_device.type == "cpu":
free_mem = _maybe_free_ram_budget(target_device, required_bytes)
remaining_budget = min(remaining_budget, free_mem)
if required_bytes > remaining_budget:
if fallback_device is not None and fallback_device != target_device:
fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes)
if fallback_free >= required_bytes:
target_for_load = fallback_device
else:
continue
else:
continue
_ensure_free_memory(target_device, required_bytes, RAM_HEADROOM_BYTES)
else:
target_for_load = target_device
_ensure_free_memory(target_device, required_bytes, model_management.extra_reserved_memory())
target_for_load = target_device
if current.device.type == "meta":
tensor = disk_ref.load(
target_for_load,
@ -737,18 +890,38 @@ def ensure_module_materialized(
PIN_IF_CPU,
dtype_override=target_dtype,
)
if tensor.device != target_for_load or (target_dtype is not None and tensor.dtype != target_dtype):
tensor = model_management.cast_to(
tensor,
device=target_for_load,
dtype=target_dtype,
non_blocking=non_blocking,
stream=offload_stream,
)
if non_blocking and offload_stream is not None:
model_management.sync_stream(target_for_load, offload_stream)
else:
if target_dtype is not None and current.dtype != target_dtype:
tensor = current.to(device=target_for_load, dtype=target_dtype)
else:
tensor = current.to(device=target_for_load)
if (
current.device.type == "cpu"
and current.data_ptr() in model_management.PINNED_MEMORY
):
model_management.wait_for_pinned_tensor(current)
model_management.unpin_memory(current)
tensor = model_management.cast_to(
current,
device=target_for_load,
dtype=target_dtype if target_dtype is not None else current.dtype,
non_blocking=non_blocking,
stream=offload_stream,
)
if non_blocking and offload_stream is not None:
model_management.sync_stream(target_for_load, offload_stream)
if is_buffer:
module._buffers[name] = tensor
else:
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
if tensor.device.type == "cpu":
if tensor.device.type != "meta":
CACHE.record(module, name, tensor, is_buffer=is_buffer)
remaining_budget = max(0, remaining_budget - required_bytes)
_rebuild_materialization_state(module, refs, state)
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
@ -758,17 +931,17 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
return
input_dtype = _find_tensor_dtype(args, kwargs)
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
if getattr(module, "comfy_cast_weights", False):
target_device = torch.device("cpu")
fallback_device = _find_tensor_device(args, kwargs)
dtype_override = None # persistent tensors stay in stored dtype; per-op casting only
input_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
if getattr(module, "comfy_patched_weights", False):
target_device = input_device
elif getattr(module, "comfy_cast_weights", False):
target_device = input_device
else:
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
fallback_device = None
target_device = input_device
ensure_module_materialized(
module,
target_device,
fallback_device=fallback_device,
dtype_override=dtype_override,
)
@ -786,9 +959,74 @@ def attach_disk_weight_hooks(model: torch.nn.Module):
def evict_ram_cache(bytes_to_free: int):
if bytes_to_free <= 0:
return 0
safetensors_stream._reap_pinned_inflight()
return CACHE.evict_bytes(bytes_to_free)
def _move_cache_entry_to_cpu(entry: CacheEntry):
module = entry.module_ref()
if module is None:
return
if entry.is_buffer:
current = module._buffers.get(entry.name)
else:
current = module._parameters.get(entry.name)
if current is None or current.device.type == "meta":
return
from . import model_management
non_blocking = model_management.device_supports_non_blocking(torch.device("cpu"))
offload_stream = model_management.get_offload_stream(torch.device("cpu")) if non_blocking else None
tensor = model_management.cast_to(
current,
device=torch.device("cpu"),
dtype=current.dtype,
non_blocking=non_blocking,
stream=offload_stream,
)
if non_blocking and offload_stream is not None:
model_management.sync_stream(current.device, offload_stream)
if entry.is_buffer:
module._buffers[entry.name] = tensor
else:
module._parameters[entry.name] = torch.nn.Parameter(tensor, requires_grad=current.requires_grad)
CACHE.record(module, entry.name, tensor, is_buffer=entry.is_buffer)
def _evict_cpu_entry_to_meta(entry: CacheEntry):
module = entry.module_ref()
if module is None:
return
_evict_module_weight(module, entry.name, entry.is_buffer)
CACHE.remove_entry(module, entry.name)
def evict_for_budget(target_device: torch.device, required_bytes: int):
if not disk_weights_enabled() or required_bytes <= 0:
return
from . import model_management
free = model_management.get_free_memory(target_device)
if free >= required_bytes:
return
cpu_device = torch.device("cpu")
if target_device.type != "cpu":
while free < required_bytes:
entry = CACHE.pop_lru(target_device)
if entry is None:
break
free_cpu = model_management.get_free_memory(cpu_device)
if free_cpu < RAM_HEADROOM_BYTES:
CACHE.evict_bytes(RAM_HEADROOM_BYTES - free_cpu)
_move_cache_entry_to_cpu(entry)
free = model_management.get_free_memory(target_device)
else:
while free < required_bytes:
entry = CACHE.pop_lru(cpu_device)
if entry is None:
break
_evict_cpu_entry_to_meta(entry)
free = model_management.get_free_memory(target_device)
def materialize_module_tree(module: torch.nn.Module, target_device: torch.device):
if not disk_weights_enabled():
return
@ -826,17 +1064,55 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
return None
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
def _move(tensor):
if tensor is None:
return None
if tensor.device.type == "meta":
return tensor
if dtype_override is not None and tensor.dtype != dtype_override:
return tensor.to(device=device_to, dtype=dtype_override)
return tensor.to(device=device_to)
def move_module_tensors(
module: torch.nn.Module,
device_to: torch.device,
dtype_override: Optional[torch.dtype] = None,
non_blocking: bool = False,
):
from . import model_management
offload_stream = None
if non_blocking and model_management.device_supports_non_blocking(device_to):
offload_stream = model_management.get_offload_stream(device_to)
module._apply(_move)
def apply_fn(tensor):
if tensor is None or tensor.device.type == "meta":
return tensor
target_dtype = dtype_override or tensor.dtype
if (
tensor.device.type == "cpu"
and tensor.data_ptr() in model_management.PINNED_MEMORY
and (device_to.type != "cpu" or target_dtype != tensor.dtype)
):
model_management.wait_for_pinned_tensor(tensor)
model_management.unpin_memory(tensor)
if tensor.device == device_to and tensor.dtype == target_dtype:
return tensor
return model_management.cast_to(
tensor,
device=device_to,
dtype=target_dtype,
non_blocking=non_blocking,
stream=offload_stream,
)
module._apply(apply_fn)
if disk_weights_enabled():
for submodule in module.modules():
refs = REGISTRY.get(submodule)
if not refs:
continue
for name, disk_ref in refs.items():
if disk_ref.is_buffer:
tensor = submodule._buffers.get(name)
else:
tensor = submodule._parameters.get(name)
if tensor is None or tensor.device.type == "meta":
CACHE.remove_entry(submodule, name)
continue
CACHE.record(submodule, name, tensor, is_buffer=disk_ref.is_buffer)
if non_blocking and offload_stream is not None:
model_management.sync_stream(device_to, offload_stream)
return module
@ -864,21 +1140,60 @@ def offload_module_weights(module: torch.nn.Module) -> int:
return offloaded_bytes
def module_to(module: torch.nn.Module, *args, **kwargs):
def module_to(
module: torch.nn.Module,
*args,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
memory_format=None,
**kwargs,
):
allow_materialize = kwargs.pop("allow_materialize", True)
arg_device = _extract_to_device(args, kwargs)
arg_dtype = _extract_to_dtype(args, kwargs)
if disk_weights_enabled():
target_device = _extract_to_device(args, kwargs)
target_device = device or arg_device
if target_device is None:
target_device = _find_existing_device(module) or torch.device("cpu")
dtype_override = dtype or arg_dtype
if target_device.type == "meta":
offload_module_weights(module)
for submodule in module.modules():
offload_module_weights(submodule)
move_module_tensors(
submodule,
target_device,
dtype_override=dtype_override,
non_blocking=non_blocking,
)
return module
if allow_materialize:
materialize_module_tree(module, target_device)
return module.to(*args, **kwargs)
dtype_override = _extract_to_dtype(args, kwargs)
return move_module_tensors(module, target_device, dtype_override=dtype_override)
return module.to(*args, **kwargs)
if not allow_materialize:
move_module_tensors(
module,
target_device,
dtype_override=dtype_override,
non_blocking=non_blocking,
)
return module
for submodule in module.modules():
ensure_module_materialized(submodule, target_device, dtype_override=dtype_override)
move_module_tensors(
module,
target_device,
dtype_override=dtype_override,
non_blocking=non_blocking,
)
return module
base_kwargs = dict(kwargs)
if device is not None and arg_device is None:
base_kwargs["device"] = device
if dtype is not None and arg_dtype is None:
base_kwargs["dtype"] = dtype
if non_blocking:
base_kwargs["non_blocking"] = non_blocking
if memory_format is not None:
base_kwargs["memory_format"] = memory_format
return BASE_MODULE_TO(module, *args, **base_kwargs)
def load_module_tensor(
@ -886,7 +1201,6 @@ def load_module_tensor(
name: str,
device: torch.device,
*,
allow_alternate: bool = True,
record_cache: bool = True,
temporary: bool = False,
dtype_override: Optional[torch.dtype] = None,
@ -904,16 +1218,33 @@ def load_module_tensor(
return None
if current is None:
return None
target_dtype = dtype_override or _get_future_dtype(module, name)
if dtype_override is not None:
_set_future_dtype(module, name, dtype_override)
# Persistent loads must not mix storage with dtype casting.
if not temporary:
dtype_override = None
target_dtype = dtype_override
if current.device.type != "meta":
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
if target_dtype is not None and current.dtype != target_dtype:
tensor = current.to(device=device, dtype=target_dtype)
else:
tensor = current.to(device=device)
from . import model_management
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
_ensure_free_memory(device, _tensor_nbytes(current), headroom)
non_blocking = model_management.device_supports_non_blocking(device)
offload_stream = model_management.get_offload_stream(device) if non_blocking else None
tensor = model_management.cast_to(
current,
device=device,
dtype=target_dtype if target_dtype is not None else current.dtype,
non_blocking=non_blocking,
stream=offload_stream,
)
if non_blocking and offload_stream is not None:
model_management.sync_stream(device, offload_stream)
if not temporary:
if (
current.device.type == "cpu"
and current.data_ptr() in model_management.PINNED_MEMORY
):
model_management.wait_for_pinned_tensor(current)
model_management.unpin_memory(current)
if is_buffer:
module._buffers[name] = tensor
else:
@ -926,52 +1257,33 @@ def load_module_tensor(
required_bytes = _meta_nbytes(disk_ref.meta)
if required_bytes is None:
return current
free_mem_start = _device_free_memory(device)
free_mem = _maybe_free_ram_budget(device, required_bytes)
load_device = device
if free_mem < required_bytes and allow_alternate:
alt = _choose_alternate_device(device)
if alt is not None:
alt_free = _maybe_free_ram_budget(alt, required_bytes)
if alt_free >= required_bytes:
load_device = alt
else:
state = _get_materialization_state(module)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
return current
else:
state = _get_materialization_state(module)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
return current
elif free_mem < required_bytes:
state = _get_materialization_state(module)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
return current
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
from . import model_management
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
_ensure_free_memory(device, required_bytes, headroom)
non_blocking = model_management.device_supports_non_blocking(device)
offload_stream = model_management.get_offload_stream(device) if non_blocking else None
tensor = disk_ref.load(device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
if tensor.device != device or (target_dtype is not None and tensor.dtype != target_dtype):
tensor = model_management.cast_to(
tensor,
device=device,
dtype=target_dtype if target_dtype is not None else tensor.dtype,
non_blocking=non_blocking,
stream=offload_stream,
)
if non_blocking and offload_stream is not None:
model_management.sync_stream(device, offload_stream)
if temporary:
return tensor
if is_buffer:
module._buffers[name] = tensor
else:
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
if tensor.device.type == "cpu" and record_cache:
if tensor.device.type != "meta" and record_cache:
CACHE.record(module, name, tensor, is_buffer=is_buffer)
state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state)
_log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
_log_materialization(module, device, _device_free_memory(device), refs, state, "Disk weight loaded")
return tensor
@ -983,8 +1295,11 @@ def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_
attr = parts[-1]
if is_buffer:
module._buffers[attr] = tensor
return tensor
else:
module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad)
param = torch.nn.Parameter(tensor, requires_grad=requires_grad)
module._parameters[attr] = param
return param
def _materialize_module_from_state_dict(
@ -999,9 +1314,7 @@ def _materialize_module_from_state_dict(
metadata = getattr(lazy_state.state_dict, "_metadata", None)
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
refs = REGISTRY.get(module) or {}
if dtype_override is not None:
for name in refs.keys():
_set_future_dtype(module, name, dtype_override)
# Do not persist dtype overrides into storage.
state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state)
keys = sorted(lazy_state.state_dict.keys())
@ -1015,22 +1328,15 @@ def _materialize_module_from_state_dict(
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
existing[key] = buf
free_mem_start = _device_free_memory(target_device)
remaining_budget = free_mem_start
allowed = set(existing.keys())
allowed = set(keys)
from . import model_management
headroom = RAM_HEADROOM_BYTES if target_device.type == "cpu" else model_management.extra_reserved_memory()
for key in keys:
if key in allowed:
continue
meta = _state_dict_meta(lazy_state.state_dict, key)
required = _meta_nbytes(meta)
if required is None:
continue
if target_device.type == "cpu":
free_mem = _maybe_free_ram_budget(target_device, required)
remaining_budget = min(remaining_budget, free_mem)
if required <= remaining_budget:
allowed.add(key)
remaining_budget = max(0, remaining_budget - required)
deferred_state_dict_keys = {key for key in keys if key not in allowed}
_ensure_free_memory(target_device, required, headroom)
state_dict = _BudgetedStateDict(
lazy_state.state_dict,
allowed_keys=allowed,
@ -1064,14 +1370,25 @@ def _materialize_module_from_state_dict(
module.factory_kwargs["device"] = factory_device
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
for name, disk_ref in refs.items():
if name in module._parameters:
tensor = module._parameters[name]
is_buffer = False
elif name in module._buffers:
tensor = module._buffers[name]
is_buffer = True
else:
continue
if tensor is not None and tensor.device.type == "meta":
_attach_disk_identity(tensor, module, name, is_buffer)
_rebuild_materialization_state(module, refs, state)
lazy_state.loaded = len(deferred_state_dict_keys) == 0
lazy_state.loaded = True
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
for name, param in module.named_parameters(recurse=False):
if param.device.type == "cpu":
if param.device.type != "meta":
CACHE.record(module, name, param, is_buffer=False)
for name, buf in module.named_buffers(recurse=False):
if buf is not None and buf.device.type == "cpu":
if buf is not None and buf.device.type != "meta":
CACHE.record(module, name, buf, is_buffer=True)
@ -1100,14 +1417,16 @@ def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = Fals
continue
meta = state_dict.meta(name)
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
_replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
stored = _replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
_attach_disk_identity(stored, model, name, False)
for name, buf in model.named_buffers(recurse=True):
if buf is None or name not in state_keys:
continue
meta = state_dict.meta(name)
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
_replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
stored = _replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
_attach_disk_identity(stored, model, name, True)
register_module_weights(model, state_dict)
register_lazy_modules(model, state_dict)

View File

@ -283,7 +283,7 @@ def load_gligen(sd):
gated = GatedSelfAttentionDense(
query_dim, key_dim, n_heads, d_head)
comfy.utils.load_state_dict(gated, n_sd, strict=False)
gated.load_state_dict(n_sd, strict=False)
output_list.append(gated)
if "position_net.null_positive_feature" in sd_k:
@ -294,7 +294,7 @@ def load_gligen(sd):
pass
w = WeightsLoader()
w.position_net = PositionNet(in_dim, out_dim)
comfy.utils.load_state_dict(w, sd, strict=False)
w.load_state_dict(sd, strict=False)
gligen = Gligen(output_list, w.position_net, key_dim)
return gligen

View File

@ -1,5 +1,4 @@
import torch
import comfy.utils
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
@ -113,7 +112,7 @@ class HunyuanVideo15SRModel():
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return comfy.utils.load_state_dict(self.model, sd, strict=True)
return self.model.load_state_dict(sd, strict=True)
def get_sd(self):
return self.model.state_dict()

View File

@ -2,7 +2,6 @@ import json
from dataclasses import dataclass
import math
import torch
import comfy.utils
import torchaudio
import comfy.model_management
@ -154,8 +153,8 @@ class AudioVAE(torch.nn.Module):
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(

View File

@ -2,7 +2,6 @@ import logging
from typing import Optional
import torch
import comfy.utils
import torch.nn as nn
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
@ -153,7 +152,7 @@ class VAE(nn.Module):
return dec, posterior
def load_weights(self, src_dict) -> None:
comfy.utils.load_state_dict(self, src_dict, strict=True)
self.load_state_dict(src_dict, strict=True)
@property
def device(self) -> torch.device:

View File

@ -309,7 +309,7 @@ class BaseModel(torch.nn.Module):
else:
to_load = sd
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
@ -753,8 +753,8 @@ class StableAudio1(BaseModel):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights, strict=True)
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights, strict=True)
def extra_conds(self, **kwargs):
out = {}

View File

@ -18,6 +18,7 @@
import psutil
import logging
import collections
from enum import Enum
from comfy.cli_args import args, PerformanceFeature
import torch
@ -524,18 +525,14 @@ class LoadedModel:
return True
return False
def model_unload(self, memory_to_free=None, unpatch_weights=True):
def model_unload(self, memory_to_free=None, unpatch_weights=True, offload_device=None):
target_offload_device = self.model.offload_device if offload_device is None else offload_device
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
freed = self.model.partially_unload(target_offload_device, memory_to_free)
if freed >= memory_to_free:
return False
offload_device = None
if comfy.disk_weights.disk_weights_enabled():
offload_device = torch.device("meta")
self.model.detach(unpatch_weights, offload_device=offload_device)
if offload_device is not None and offload_device.type == "meta":
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
self.model.detach(unpatch_weights, offload_device=target_offload_device)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
@ -589,11 +586,33 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024))
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
if freed_cache < memory_required:
evict_ram_to_disk(memory_required - freed_cache)
gpu_deficit = 0
if comfy.disk_weights.disk_weights_enabled():
free_before = get_free_memory(device)
if is_device_cpu(device):
headroom = comfy.disk_weights.ram_headroom_bytes()
if free_before < memory_required:
logging.debug(
"Disk weights RAM pressure: required=%d free=%d headroom=%d device=%s",
memory_required,
free_before,
headroom,
device,
)
comfy.disk_weights.evict_ram_cache(memory_required)
elif is_device_cuda(device) or is_device_xpu(device):
if free_before < memory_required:
logging.debug(
"Disk weights VRAM pressure: required=%d free=%d device=%s",
memory_required,
free_before,
device,
)
comfy.disk_weights.evict_for_budget(device, memory_required)
free_after_vram = get_free_memory(device)
if free_after_vram < memory_required:
gpu_deficit = memory_required - free_after_vram
comfy.disk_weights.evict_ram_cache(gpu_deficit)
unloaded_model = []
can_unload = []
unloaded_models = []
@ -614,7 +633,17 @@ def free_memory(memory_required, device, keep_loaded=[]):
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
offload_device = None
if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device):
offload_device = torch.device("meta")
elif comfy.disk_weights.disk_weights_enabled() and gpu_deficit > 0 and (is_device_cuda(device) or is_device_xpu(device)):
cpu = torch.device("cpu")
headroom = comfy.disk_weights.ram_headroom_bytes()
required_cpu = current_loaded_models[i].model_loaded_memory() + headroom
free_cpu = get_free_memory(cpu)
if free_cpu < required_cpu:
offload_device = torch.device("meta")
if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device):
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
@ -627,36 +656,18 @@ def free_memory(memory_required, device, keep_loaded=[]):
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
if comfy.disk_weights.disk_weights_enabled():
free_after = get_free_memory(device)
freed_total = max(0, free_after - free_before)
logging.debug(
"Disk weights free_memory: device=%s free_before=%d free_after=%d freed=%d",
device,
free_before,
free_after,
freed_total,
)
return unloaded_models
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
if memory_to_free <= 0:
return 0
if not comfy.disk_weights.disk_weights_enabled():
return 0
freed = 0
can_unload = []
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model not in keep_loaded and not shift_model.is_dead():
loaded_memory = shift_model.model_loaded_memory()
if loaded_memory > 0:
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
for x in sorted(can_unload):
i = x[-1]
memory_needed = memory_to_free - freed
if memory_needed <= 0:
break
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
if freed > 0:
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
return freed
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@ -1122,6 +1133,10 @@ def sync_stream(device, stream):
current_stream(device).wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if comfy.disk_weights.disk_weights_enabled() and weight.device.type == "meta":
target_device = device if device is not None else torch.device("cpu")
target_dtype = dtype if dtype is not None else weight.dtype
weight = comfy.disk_weights.materialize_meta_tensor(weight, target_device, target_dtype)
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
@ -1134,7 +1149,6 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)
if stream is not None:
wf_context = stream
if hasattr(wf_context, "as_context"):
@ -1145,6 +1159,13 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
else:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
if non_blocking and (is_device_cuda(device) or is_device_cuda(weight.device)):
record_stream = stream if stream is not None else current_stream(device if is_device_cuda(device) else weight.device)
if record_stream is not None:
if is_device_cpu(weight.device) and weight.is_pinned():
_record_pinned_event(weight.data_ptr(), record_stream)
if is_device_cpu(r.device) and r.is_pinned():
_record_pinned_event(r.data_ptr(), record_stream)
return r
def cast_to_device(tensor, device, dtype, copy=False):
@ -1155,6 +1176,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
PINNED_IN_FLIGHT = {}
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
@ -1163,14 +1185,13 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
WEIGHTS_RAM_CACHE_BYTES = 0
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
if args.weights_ram_cache_gb is not None:
WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
if args.low_ram:
comfy.disk_weights.configure(
WEIGHTS_RAM_CACHE_BYTES,
allow_gds=WEIGHTS_GDS_ENABLED,
pin_if_cpu=not args.disable_pinned_memory,
ram_headroom_bytes=1024 * 1024 * 1024,
enabled=True,
)
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
@ -1223,7 +1244,19 @@ def pin_memory(tensor):
return False
def unpin_memory(tensor):
def _record_pinned_event(ptr, stream):
events = PINNED_IN_FLIGHT.setdefault(ptr, [])
event = torch.cuda.Event()
event.record(stream)
events.append(event)
def wait_for_pinned_tensor(tensor):
ptr = tensor.data_ptr()
events = PINNED_IN_FLIGHT.pop(ptr, [])
for event in events:
event.synchronize()
def _unpin_memory_now(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
@ -1254,6 +1287,17 @@ def unpin_memory(tensor):
return False
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_device_cpu(tensor.device):
return False
wait_for_pinned_tensor(tensor)
return _unpin_memory_now(tensor)
def sage_attention_enabled():
return args.use_sage_attention

View File

@ -621,6 +621,16 @@ class ModelPatcher:
weight, set_func, convert_func = get_key_weight(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update
if comfy.disk_weights.disk_weights_enabled() and weight is not None and weight.device.type == "meta":
parts = key.split(".")
param_name = parts[-1]
module = self.model
for part in parts[:-1]:
module = getattr(module, part)
target_device = device_to or self.offload_device or torch.device("cpu")
comfy.disk_weights.load_module_tensor(module, param_name, device=target_device)
weight, set_func, convert_func = get_key_weight(self.model, key)
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
@ -650,8 +660,8 @@ class ModelPatcher:
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
if comfy.model_management.unpin_memory(weight):
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
@ -885,9 +895,16 @@ class ModelPatcher:
NS = comfy.model_management.NUM_STREAMS
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
remaining_ram = None
cpu_device = torch.device("cpu")
if device_to is not None and comfy.model_management.is_device_cpu(device_to):
remaining_ram = comfy.model_management.get_free_memory(device_to)
def offload_module_tree(module):
freed = 0
for submodule in module.modules():
freed += comfy.disk_weights.offload_module_weights(submodule)
return freed
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
@ -922,15 +939,28 @@ class ModelPatcher:
cast_weight = self.force_cast_weights
freed_bytes = module_mem
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
freed_bytes = comfy.disk_weights.offload_module_weights(m)
freed_bytes = offload_module_tree(m)
if freed_bytes == 0:
freed_bytes = module_mem
else:
if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled():
logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024))
freed_bytes = comfy.disk_weights.offload_module_weights(m)
if freed_bytes == 0:
freed_bytes = module_mem
if remaining_ram is not None and comfy.disk_weights.disk_weights_enabled():
required_bytes = module_mem
headroom = comfy.disk_weights.ram_headroom_bytes()
comfy.model_management.free_memory(required_bytes + headroom, cpu_device, keep_loaded=[self])
remaining_ram = comfy.model_management.get_free_memory(cpu_device)
if remaining_ram < required_bytes:
logging.info(
"Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.",
n,
required_bytes / (1024 * 1024),
remaining_ram / (1024 * 1024),
)
freed_bytes = offload_module_tree(m)
if freed_bytes == 0:
freed_bytes = module_mem
else:
comfy.disk_weights.move_module_tensors(m, device_to)
remaining_ram = max(0, remaining_ram - required_bytes)
else:
if comfy.disk_weights.disk_weights_enabled():
comfy.disk_weights.move_module_tensors(m, device_to)
@ -980,16 +1010,22 @@ class ModelPatcher:
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
with self.use_ejected(skip_and_inject_on_exit_only=True):
offload_device = self.offload_device
if comfy.disk_weights.disk_weights_enabled() and device_to is not None:
if comfy.model_management.is_device_cpu(device_to):
offload_device = torch.device("meta")
else:
offload_device = torch.device("cpu")
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model
used = self.model.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
self.unpatch_model(offload_device, unpatch_weights=unpatch_weights)
if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
self.partially_unload(offload_device, -extra_memory, force_patch_weights=force_patch_weights)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:

View File

@ -19,12 +19,67 @@
import torch
import logging
import comfy.model_management
import comfy.disk_weights
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import json
def _dw_find_input_tensor(args, kwargs):
"""Return first tensor-like input (supports QuantizedTensor)."""
def check(obj):
if torch.is_tensor(obj):
return obj
if isinstance(obj, QuantizedTensor):
return obj
if isinstance(obj, (list, tuple)):
for it in obj:
r = check(it)
if r is not None:
return r
if isinstance(obj, dict):
for it in obj.values():
r = check(it)
if r is not None:
return r
return None
for a in args:
r = check(a)
if r is not None:
return r
return check(kwargs)
def _dw_disk_weights_enabled() -> bool:
# Delayed import avoids eager circular imports.
from comfy import disk_weights as _dw
return _dw.disk_weights_enabled()
def _dw_requires_temporary_cast(module, args, kwargs) -> bool:
"""
When disk_weights is enabled, route ops through the comfy_cast path when
weights/bias are not directly usable (dtype/device mismatch or meta).
"""
if not _dw_disk_weights_enabled():
return False
inp = _dw_find_input_tensor(args, kwargs)
if inp is None:
return False
w = getattr(module, "weight", None)
if w is None:
return False
if isinstance(inp, QuantizedTensor):
req_dtype = inp.params.orig_dtype
req_dev = inp.device
else:
req_dtype = inp.dtype
req_dev = inp.device
if w.device.type == "meta" or w.device != req_dev or w.dtype != req_dtype:
return True
b = getattr(module, "bias", None)
if b is not None and (b.device.type == "meta" or b.device != req_dev or b.dtype != req_dtype):
return True
return False
def run_every_op():
if torch.compiler.is_compiling():
return
@ -101,27 +156,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
weight_source = s.weight
bias_source = s.bias
if comfy.disk_weights.disk_weights_enabled():
if weight_source.device.type == "meta":
loaded = comfy.disk_weights.load_module_tensor(
s,
"weight",
device,
temporary=True,
dtype_override=dtype,
)
if loaded is not None:
weight_source = loaded
if bias_source is not None and bias_source.device.type == "meta":
loaded_bias = comfy.disk_weights.load_module_tensor(
s,
"bias",
device,
temporary=True,
dtype_override=bias_dtype,
)
if loaded_bias is not None:
bias_source = loaded_bias
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
@ -185,7 +219,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -202,7 +236,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -219,7 +253,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -245,7 +279,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -262,7 +296,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -284,7 +318,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -308,7 +342,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -332,7 +366,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -356,7 +390,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@ -378,7 +412,7 @@ class disable_weight_init:
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs:
@ -557,10 +591,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
if value.device.type != "meta":
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
@ -577,16 +610,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None and layer_conf.device.type != "meta":
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
elif layer_conf is not None:
layer_conf = None
if layer_conf is None:
if weight.device.type == "meta":
self.weight = torch.nn.Parameter(weight, requires_grad=False)
else:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
@ -632,13 +660,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
if weight.device.type == "meta":
self.weight = torch.nn.Parameter(weight, requires_grad=False)
else:
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
@ -648,10 +673,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_v = state_dict.pop(param_key, None)
if _v is None:
continue
if _v.device.type == "meta":
self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
else:
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

View File

@ -37,6 +37,19 @@ _FST_LOADED = False
_GDS_INITIALIZED = False
_MISSING = object()
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
_PINNED_INFLIGHT = collections.deque()
def _reap_pinned_inflight():
if not _PINNED_INFLIGHT:
return
pending = collections.deque()
while _PINNED_INFLIGHT:
event, tensor = _PINNED_INFLIGHT.popleft()
if event.query():
continue
pending.append((event, tensor))
_PINNED_INFLIGHT.extend(pending)
def _require_fastsafetensors():
@ -204,14 +217,22 @@ class _SafeTensorFile:
)
return tensor
target_dtype = dtype
if device_is_cuda and pin_if_cpu:
_reap_pinned_inflight()
target_dtype = None
cpu_tensor = self._read_tensor_nogds(
fst, framework, meta, torch.device("cpu"), dtype
fst, framework, meta, torch.device("cpu"), target_dtype, pin_memory=bool(device_is_cuda and pin_if_cpu)
)
if device_is_cuda:
if pin_if_cpu:
cpu_tensor = cpu_tensor.pin_memory()
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
if pin_if_cpu:
event = torch.cuda.Event()
event.record(torch.cuda.current_stream(device))
_PINNED_INFLIGHT.append((event, cpu_tensor))
if dtype is not None and dtype != gpu_tensor.dtype:
gpu_tensor = gpu_tensor.to(dtype=dtype)
return gpu_tensor
return cpu_tensor
@ -233,6 +254,8 @@ class _SafeTensorFile:
meta: TensorMeta,
device: torch.device,
dtype: Optional[torch.dtype],
*,
pin_memory: bool = False,
) -> torch.Tensor:
fd = self._ensure_fd()
reader = self._ensure_nogds_reader(use_cuda=False)
@ -241,7 +264,7 @@ class _SafeTensorFile:
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
chunk_bytes = max(1, chunk_bytes)
ptr_align = framework.get_device_ptr_align()
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu")
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu", pin_memory=pin_memory)
buffer_length = 0
buf_ptr = None
gbuf = None
@ -250,6 +273,8 @@ class _SafeTensorFile:
while chunk_offset < length:
chunk_len = min(length - chunk_offset, chunk_bytes)
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
if aligned_offset + aligned_length > self.index.size_bytes:
aligned_length = max(0, self.index.size_bytes - aligned_offset)
needed = aligned_length + ptr_align
if buf_ptr is None or needed > buffer_length:
if buf_ptr is not None:
@ -272,7 +297,6 @@ class _SafeTensorFile:
if buf_ptr is not None:
fst.cpp.cpu_free(buf_ptr)
if dtype is not None and dtype != dest_tensor.dtype:
_validate_dtype_conversion(dest_tensor.dtype, dtype)
dest_tensor = dest_tensor.to(dtype=dtype)
return dest_tensor
@ -289,6 +313,8 @@ class _SafeTensorFile:
abs_start = self.index.header_length + meta.data_offsets[0]
length = meta.data_offsets[1] - meta.data_offsets[0]
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
if aligned_offset + aligned_length > self.index.size_bytes:
aligned_length = max(0, self.index.size_bytes - aligned_offset)
ptr_align = framework.get_device_ptr_align()
buffer_length = aligned_length + ptr_align
fst_device = _fst_device_from_torch(fst, device)
@ -306,7 +332,6 @@ class _SafeTensorFile:
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
)
if dtype is not None and dtype != tensor.dtype:
_validate_dtype_conversion(tensor.dtype, dtype)
tensor = tensor.to(dtype=dtype)
return tensor
@ -348,8 +373,7 @@ def _dlpack_tensor_from_buffer(
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
return
def _get_gds_o_direct(framework) -> bool:
@ -523,8 +547,10 @@ class StreamStateDict(collections.abc.MutableMapping):
raise KeyError(key)
return default
if self._index.has(key):
value = self.get_tensor(key)
self._deleted.add(key)
return self.get_tensor(key)
self._overrides.pop(key, None)
return value
if default is _MISSING:
raise KeyError(key)
return default
@ -636,8 +662,10 @@ class _BaseViewStateDict(MutableMapping):
if default is _MISSING:
raise
return default
value = self.get_tensor(key)
self._deleted.add(key)
return self.get_tensor(key)
self._overrides.pop(key, None)
return value
def meta(self, key: str):
if key in self._overrides:
@ -768,8 +796,10 @@ class DeviceViewStateDict(_BaseViewStateDict):
if default is _MISSING:
raise
return default
value = self.get_tensor(key)
self._deleted.add(key)
return self.get_tensor(key)
self._overrides.pop(key, None)
return value
class FilterViewStateDict(_BaseViewStateDict):
@ -932,3 +962,48 @@ class MappedStateDict(_BaseViewStateDict):
def _iter_base_keys(self) -> Iterable[str]:
return self._view_to_base.keys()
def stream_load_state_dict(model, state_dict, strict: bool = False, assign: bool = False):
if getattr(state_dict, "is_stream_state_dict", False) and hasattr(state_dict, "copy"):
state_dict = state_dict.copy()
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(state_dict, "_metadata", None)
def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata["assign_to_params_buffers"] = assign
module._load_from_state_dict(
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
child_prefix = f"{prefix}{name}."
child_state_dict = FilterViewStateDict(
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
)
load(child, child_state_dict, child_prefix)
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible)
if out is not None:
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
load(model, state_dict)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)

View File

@ -26,7 +26,6 @@ import os
import comfy.utils
import comfy.safetensors_stream
import comfy.disk_weights
from . import clip_vision
from . import gligen
@ -126,7 +125,7 @@ class CLIP:
if not model_management.supports_cast(load_device, dt):
load_device = offload_device
if params['device'] != offload_device:
comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
@ -290,7 +289,7 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
@ -658,7 +657,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
@ -672,7 +671,7 @@ class VAE:
if dtype is None:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
self.first_stage_model.to(dtype=self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
@ -979,7 +978,7 @@ def load_style_model(ckpt_path):
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
comfy.utils.load_state_dict(model, model_data, strict=True)
model.load_state_dict(model_data, strict=True)
return StyleModel(model)
def sd_shape(state_dict, key):
@ -1547,7 +1546,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
model = comfy.disk_weights.module_to(model, offload_device)
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:

View File

@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self(tokens)
def load_sd(self, sd):
return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string):
result = []

View File

@ -56,9 +56,9 @@ class TAESD(nn.Module):
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
if decoder_path is not None:
comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
@staticmethod
def scale_latents(x):

View File

@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module):
if len(sdo) == 0:
sdo = sd
return comfy.utils.load_state_dict(self, sdo, strict=False)
return self.load_state_dict(sdo, strict=False)
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):

View File

@ -32,7 +32,6 @@ from einops import rearrange
from comfy.cli_args import args
import json
from . import safetensors_stream
import comfy.disk_weights
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -166,62 +165,6 @@ def state_dict_meta(state_dict, key):
)
def load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict):
if comfy.disk_weights.disk_weights_enabled():
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
comfy.disk_weights.register_module_weights(model, state_dict)
comfy.disk_weights.attach_disk_weight_hooks(model)
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
return missing, unexpected
return model.load_state_dict(state_dict, strict=strict)
def stream_load_state_dict(model, state_dict, strict=False, assign=False):
if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
state_dict = state_dict.copy()
missing_keys = []
unexpected_keys = []
error_msgs = []
metadata = getattr(state_dict, "_metadata", None)
def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata["assign_to_params_buffers"] = assign
module._load_from_state_dict(
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
child_prefix = f"{prefix}{name}."
child_state_dict = safetensors_stream.FilterViewStateDict(
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
)
load(child, child_state_dict, child_prefix)
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
for hook in module._load_state_dict_post_hooks.values():
out = hook(module, incompatible)
if out is not None:
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
load(model, state_dict)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
return missing_keys, unexpected_keys
def transformers_convert(sd, prefix_from, prefix_to, number):
keys_to_replace = {
"{}positional_embedding": "{}embeddings.position_embedding.weight",