diff --git a/README.md b/README.md index 550482e5c..b16681be3 100644 --- a/README.md +++ b/README.md @@ -349,12 +349,12 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step | `--enable-manager` | Enable ComfyUI-Manager | | `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | | `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | -| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. | +| `--low-ram` | Enable disk weight offloading. Sets RAM headroom to 1024MB and uses `--reserve-vram` as the VRAM headroom. | | `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. | ### Disk tier for model weights -When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand. +When `--low-ram` is enabled, ComfyUI streams safetensors weights from disk and offloads weights to disk when RAM headroom is reached. If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead. diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index bf3abf834..46ef21c95 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -29,7 +29,7 @@ class AudioEncoderModel(): self.model_sample_rate = 16000 def load_sd(self, sd): - return comfy.utils.load_state_dict(self.model, sd, strict=False) + return self.model.load_state_dict(sd, strict=False) def get_sd(self): return self.model.state_dict() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d75b9fe99..17ec5cf4d 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -114,7 +114,7 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") -parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.") +parser.add_argument("--low-ram", action="store_true", help="Enable disk weight offloading. Sets RAM headroom to 1024MB and uses --reserve-vram as the VRAM headroom.") parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.") attn_group = parser.add_mutually_exclusive_group() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 837eba013..341ae7103 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -6,7 +6,6 @@ import logging import comfy.ops import comfy.model_patcher import comfy.model_management -import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 @@ -48,7 +47,7 @@ class ClipVisionModel(): self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return comfy.utils.load_state_dict(self.model, sd, strict=False) + return self.model.load_state_dict(sd, strict=False) def get_sd(self): return self.model.state_dict() diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 551f0bb18..053a2bf31 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -25,7 +25,6 @@ import logging import comfy.utils import comfy.model_management import comfy.model_detection -import comfy.disk_weights import comfy.model_patcher import comfy.ops import comfy.latent_formats @@ -386,7 +385,7 @@ class ControlLora(ControlNet): controlnet_config["operations"] = control_lora_ops controlnet_config["dtype"] = dtype self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) - comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device()) + self.control_model.to(comfy.model_management.get_torch_device()) diffusion_model = model.diffusion_model sd = diffusion_model.state_dict() @@ -440,7 +439,7 @@ def controlnet_config(sd, model_options={}): return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device def controlnet_load_state_dict(control_model, sd): - missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False) + missing, unexpected = control_model.load_state_dict(sd, strict=False) if len(missing) > 0: logging.warning("missing controlnet keys: {}".format(missing)) @@ -474,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}): class ControlNetSD35(ControlNet): def pre_run(self, model, percent_to_timestep_function): if self.control_model.double_y_emb: - missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False) + missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False) else: - missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False) + missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False) super().pre_run(model, percent_to_timestep_function) def copy(self): @@ -749,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): pass w = WeightsLoader() w.control_model = control_model - missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False) + missing, unexpected = w.load_state_dict(controlnet_data, strict=False) else: - missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False) + missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) if len(missing) > 0: logging.warning("missing controlnet keys: {}".format(missing)) @@ -817,8 +816,8 @@ class T2IAdapter(ControlBase): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) if self.control_input is None: - comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype) - comfy.disk_weights.module_to(self.t2i_model, self.device) + self.t2i_model.to(dtype=x_noisy.dtype) + self.t2i_model.to(self.device) self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) self.t2i_model.cpu() @@ -875,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options else: return None - missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True) + missing, unexpected = model_ad.load_state_dict(t2i_data, strict=True) if len(missing) > 0: logging.warning("t2i missing {}".format(missing)) diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index 07e9b5f11..e1507c266 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -33,7 +33,11 @@ from . import safetensors_stream ALLOW_GDS = False PIN_IF_CPU = False DISK_WEIGHTS_ENABLED = False +RAM_HEADROOM_BYTES = 0 BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict +BASE_MODULE_TO = torch.nn.Module.to +BASE_LOAD_STATE_DICT = torch.nn.Module.load_state_dict +_MONKEYPATCHED = False LAZY_MODULE_STATE = weakref.WeakKeyDictionary() DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary() _MISSING = object() @@ -169,13 +173,17 @@ CACHE = DiskWeightCache(0) LOGGER = logging.getLogger(__name__) -def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True): - global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED +def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, enabled: bool = True): + global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED, RAM_HEADROOM_BYTES ALLOW_GDS = allow_gds PIN_IF_CPU = pin_if_cpu DISK_WEIGHTS_ENABLED = enabled - CACHE.set_limit(cache_bytes if enabled else 0) - if not enabled: + RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes)) + CACHE.set_limit(0 if enabled else 0) + if enabled: + install_monkeypatches() + else: + uninstall_monkeypatches() CACHE._entries.clear() CACHE.current_bytes = 0 @@ -184,6 +192,66 @@ def disk_weights_enabled() -> bool: return DISK_WEIGHTS_ENABLED +def ram_headroom_bytes() -> int: + return RAM_HEADROOM_BYTES + + +def _is_stream_state_dict(state_dict) -> bool: + return ( + getattr(state_dict, "is_stream_state_dict", False) + and hasattr(state_dict, "get_tensor") + and hasattr(state_dict, "meta") + ) + + +def patched_to(self: torch.nn.Module, *args, **kwargs): + if not disk_weights_enabled(): + return BASE_MODULE_TO(self, *args, **kwargs) + device, dtype, non_blocking, memory_format = torch._C._nn._parse_to(*args, **kwargs) + module_to( + self, + device=device, + dtype=dtype, + non_blocking=non_blocking, + memory_format=memory_format, + ) + return self + + +def patched_load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): + if not disk_weights_enabled(): + if _is_stream_state_dict(state_dict): + return safetensors_stream.stream_load_state_dict( + self, + state_dict, + strict=strict, + assign=assign, + ) + return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign) + if _is_stream_state_dict(state_dict): + missing_keys, unexpected_keys = lazy_load_state_dict(self, state_dict, strict=strict) + return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) + return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign) + + +def install_monkeypatches(): + global _MONKEYPATCHED + if _MONKEYPATCHED: + return + torch.nn.Module.to = patched_to + torch.nn.Module.load_state_dict = patched_load_state_dict + _MONKEYPATCHED = True + + +def uninstall_monkeypatches(): + global _MONKEYPATCHED + if not _MONKEYPATCHED: + return + torch.nn.Module.to = BASE_MODULE_TO + torch.nn.Module.load_state_dict = BASE_LOAD_STATE_DICT + _MONKEYPATCHED = False + + def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""): if not disk_weights_enabled(): return @@ -369,33 +437,29 @@ def _device_free_memory(device: torch.device) -> int: return int(model_management.get_free_memory(device)) -def _evict_ram_for_budget(required_bytes: int) -> int: - if required_bytes <= 0: - return 0 - freed = evict_ram_cache(required_bytes) - if freed < required_bytes: +def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_bytes: int) -> int: + free_before = _device_free_memory(device) + if free_before < required_bytes + headroom_bytes: + LOGGER.debug( + "Disk weight memory pressure: required=%d free=%d headroom=%d device=%s", + required_bytes, + free_before, + headroom_bytes, + device, + ) + safetensors_stream._reap_pinned_inflight() from . import model_management - freed += model_management.evict_ram_to_disk(required_bytes - freed) - return freed - - -def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int: - free_mem = _device_free_memory(device) - if device.type == "cpu" and free_mem < required_bytes: - _evict_ram_for_budget(required_bytes - free_mem) - free_mem = _device_free_memory(device) - return free_mem - - -def _choose_alternate_device(device: torch.device) -> Optional[torch.device]: - from . import model_management - if device.type == "cpu": - alt = model_management.get_torch_device() - if alt.type != "cpu": - return alt - else: - return torch.device("cpu") - return None + model_management.free_memory(required_bytes + headroom_bytes, device) + free_after = _device_free_memory(device) + freed = max(0, free_after - free_before) + LOGGER.debug( + "Disk weight memory freed: freed=%d free=%d device=%s", + freed, + free_after, + device, + ) + return free_after + return free_before class _BudgetedStateDict(MutableMapping): @@ -527,8 +591,10 @@ class _BudgetedStateDict(MutableMapping): if default is _MISSING: raise KeyError(key) return default + value = self.get_tensor(key) self._deleted.add(key) - return self.get_tensor(key) + self._overrides.pop(key, None) + return value def meta(self, key: str): return self._get_meta(key) @@ -564,6 +630,7 @@ def register_lazy_modules(model: torch.nn.Module, state_dict): def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): + safetensors_stream._reap_pinned_inflight() lazy_state = LAZY_MODULE_STATE.get(module) if lazy_state is not None: CACHE.remove_module(module) @@ -671,7 +738,6 @@ def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype: def ensure_module_materialized( module: torch.nn.Module, target_device: torch.device, - fallback_device: Optional[torch.device] = None, dtype_override: Optional[torch.dtype] = None, ): lazy_state = LAZY_MODULE_STATE.get(module) @@ -692,7 +758,6 @@ def ensure_module_materialized( _set_future_dtype(module, name, dtype_override) _rebuild_materialization_state(module, refs, state) free_mem_start = _device_free_memory(target_device) - remaining_budget = free_mem_start for name in sorted(refs.keys()): disk_ref = refs[name] if name in module._parameters: @@ -717,19 +782,11 @@ def ensure_module_materialized( continue required_bytes = meta_nbytes if target_device.type == "cpu": - free_mem = _maybe_free_ram_budget(target_device, required_bytes) - remaining_budget = min(remaining_budget, free_mem) - if required_bytes > remaining_budget: - if fallback_device is not None and fallback_device != target_device: - fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes) - if fallback_free >= required_bytes: - target_for_load = fallback_device - else: - continue - else: - continue + _ensure_free_memory(target_device, required_bytes, RAM_HEADROOM_BYTES) else: - target_for_load = target_device + from . import model_management + _ensure_free_memory(target_device, required_bytes, model_management.extra_reserved_memory()) + target_for_load = target_device if current.device.type == "meta": tensor = disk_ref.load( target_for_load, @@ -748,7 +805,6 @@ def ensure_module_materialized( module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) if tensor.device.type == "cpu": CACHE.record(module, name, tensor, is_buffer=is_buffer) - remaining_budget = max(0, remaining_budget - required_bytes) _rebuild_materialization_state(module, refs, state) _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized") @@ -761,14 +817,11 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}): dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype) if getattr(module, "comfy_cast_weights", False): target_device = torch.device("cpu") - fallback_device = _find_tensor_device(args, kwargs) else: target_device = _find_tensor_device(args, kwargs) or torch.device("cpu") - fallback_device = None ensure_module_materialized( module, target_device, - fallback_device=fallback_device, dtype_override=dtype_override, ) @@ -786,6 +839,7 @@ def attach_disk_weight_hooks(model: torch.nn.Module): def evict_ram_cache(bytes_to_free: int): if bytes_to_free <= 0: return 0 + safetensors_stream._reap_pinned_inflight() return CACHE.evict_bytes(bytes_to_free) @@ -827,16 +881,7 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]: def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None): - def _move(tensor): - if tensor is None: - return None - if tensor.device.type == "meta": - return tensor - if dtype_override is not None and tensor.dtype != dtype_override: - return tensor.to(device=device_to, dtype=dtype_override) - return tensor.to(device=device_to) - - module._apply(_move) + ensure_module_materialized(module, device_to, dtype_override=dtype_override) return module @@ -864,10 +909,20 @@ def offload_module_weights(module: torch.nn.Module) -> int: return offloaded_bytes -def module_to(module: torch.nn.Module, *args, **kwargs): +def module_to( + module: torch.nn.Module, + *args, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + memory_format=None, + **kwargs, +): allow_materialize = kwargs.pop("allow_materialize", True) + arg_device = _extract_to_device(args, kwargs) + arg_dtype = _extract_to_dtype(args, kwargs) if disk_weights_enabled(): - target_device = _extract_to_device(args, kwargs) + target_device = device or arg_device if target_device is None: target_device = _find_existing_device(module) or torch.device("cpu") if target_device.type == "meta": @@ -875,10 +930,28 @@ def module_to(module: torch.nn.Module, *args, **kwargs): return module if allow_materialize: materialize_module_tree(module, target_device) - return module.to(*args, **kwargs) - dtype_override = _extract_to_dtype(args, kwargs) + base_kwargs = dict(kwargs) + if device is not None and arg_device is None: + base_kwargs["device"] = device + if dtype is not None and arg_dtype is None: + base_kwargs["dtype"] = dtype + if non_blocking: + base_kwargs["non_blocking"] = non_blocking + if memory_format is not None: + base_kwargs["memory_format"] = memory_format + return BASE_MODULE_TO(module, *args, **base_kwargs) + dtype_override = dtype or arg_dtype return move_module_tensors(module, target_device, dtype_override=dtype_override) - return module.to(*args, **kwargs) + base_kwargs = dict(kwargs) + if device is not None and arg_device is None: + base_kwargs["device"] = device + if dtype is not None and arg_dtype is None: + base_kwargs["dtype"] = dtype + if non_blocking: + base_kwargs["non_blocking"] = non_blocking + if memory_format is not None: + base_kwargs["memory_format"] = memory_format + return BASE_MODULE_TO(module, *args, **base_kwargs) def load_module_tensor( @@ -886,7 +959,6 @@ def load_module_tensor( name: str, device: torch.device, *, - allow_alternate: bool = True, record_cache: bool = True, temporary: bool = False, dtype_override: Optional[torch.dtype] = None, @@ -909,6 +981,9 @@ def load_module_tensor( _set_future_dtype(module, name, dtype_override) if current.device.type != "meta": if current.device != device or (target_dtype is not None and current.dtype != target_dtype): + from . import model_management + headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory() + _ensure_free_memory(device, _tensor_nbytes(current), headroom) if target_dtype is not None and current.dtype != target_dtype: tensor = current.to(device=device, dtype=target_dtype) else: @@ -926,41 +1001,11 @@ def load_module_tensor( required_bytes = _meta_nbytes(disk_ref.meta) if required_bytes is None: return current - free_mem_start = _device_free_memory(device) - free_mem = _maybe_free_ram_budget(device, required_bytes) - load_device = device - if free_mem < required_bytes and allow_alternate: - alt = _choose_alternate_device(device) - if alt is not None: - alt_free = _maybe_free_ram_budget(alt, required_bytes) - if alt_free >= required_bytes: - load_device = alt - else: - state = _get_materialization_state(module) - if name not in state.deferred_keys: - state.deferred_keys.add(name) - state.deferred_bytes += required_bytes - _update_disk_state_attrs(module, state) - _log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred") - return current - else: - state = _get_materialization_state(module) - if name not in state.deferred_keys: - state.deferred_keys.add(name) - state.deferred_bytes += required_bytes - _update_disk_state_attrs(module, state) - _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred") - return current - elif free_mem < required_bytes: - state = _get_materialization_state(module) - if name not in state.deferred_keys: - state.deferred_keys.add(name) - state.deferred_bytes += required_bytes - _update_disk_state_attrs(module, state) - _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred") - return current + from . import model_management + headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory() + _ensure_free_memory(device, required_bytes, headroom) - tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype) + tensor = disk_ref.load(device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype) if temporary: return tensor if is_buffer: @@ -971,7 +1016,7 @@ def load_module_tensor( CACHE.record(module, name, tensor, is_buffer=is_buffer) state = _get_materialization_state(module) _rebuild_materialization_state(module, refs, state) - _log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded") + _log_materialization(module, device, _device_free_memory(device), refs, state, "Disk weight loaded") return tensor @@ -1015,22 +1060,15 @@ def _materialize_module_from_state_dict( if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta": existing[key] = buf free_mem_start = _device_free_memory(target_device) - remaining_budget = free_mem_start - allowed = set(existing.keys()) + allowed = set(keys) + from . import model_management + headroom = RAM_HEADROOM_BYTES if target_device.type == "cpu" else model_management.extra_reserved_memory() for key in keys: - if key in allowed: - continue meta = _state_dict_meta(lazy_state.state_dict, key) required = _meta_nbytes(meta) if required is None: continue - if target_device.type == "cpu": - free_mem = _maybe_free_ram_budget(target_device, required) - remaining_budget = min(remaining_budget, free_mem) - if required <= remaining_budget: - allowed.add(key) - remaining_budget = max(0, remaining_budget - required) - deferred_state_dict_keys = {key for key in keys if key not in allowed} + _ensure_free_memory(target_device, required, headroom) state_dict = _BudgetedStateDict( lazy_state.state_dict, allowed_keys=allowed, @@ -1065,7 +1103,7 @@ def _materialize_module_from_state_dict( if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs))) _rebuild_materialization_state(module, refs, state) - lazy_state.loaded = len(deferred_state_dict_keys) == 0 + lazy_state.loaded = True _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed") for name, param in module.named_parameters(recurse=False): if param.device.type == "cpu": diff --git a/comfy/gligen.py b/comfy/gligen.py index c2cf7c6db..f5483bee7 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -283,7 +283,7 @@ def load_gligen(sd): gated = GatedSelfAttentionDense( query_dim, key_dim, n_heads, d_head) - comfy.utils.load_state_dict(gated, n_sd, strict=False) + gated.load_state_dict(n_sd, strict=False) output_list.append(gated) if "position_net.null_positive_feature" in sd_k: @@ -294,7 +294,7 @@ def load_gligen(sd): pass w = WeightsLoader() w.position_net = PositionNet(in_dim, out_dim) - comfy.utils.load_state_dict(w, sd, strict=False) + w.load_state_dict(sd, strict=False) gligen = Gligen(output_list, w.position_net, key_dim) return gligen diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 0c8b9240f..d9e76922f 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -1,5 +1,4 @@ import torch -import comfy.utils import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d @@ -113,7 +112,7 @@ class HunyuanVideo15SRModel(): self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return comfy.utils.load_state_dict(self.model, sd, strict=True) + return self.model.load_state_dict(sd, strict=True) def get_sd(self): return self.model.state_dict() diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index c61883ba3..a9111d3bd 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -2,7 +2,6 @@ import json from dataclasses import dataclass import math import torch -import comfy.utils import torchaudio import comfy.model_management @@ -154,8 +153,8 @@ class AudioVAE(torch.nn.Module): self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) self.vocoder = Vocoder(config=component_config.vocoder) - comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False) - comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False) + self.autoencoder.load_state_dict(vae_sd, strict=False) + self.vocoder.load_state_dict(vocoder_sd, strict=False) autoencoder_config = self.autoencoder.get_config() self.normalizer = AudioLatentNormalizer( diff --git a/comfy/ldm/mmaudio/vae/vae.py b/comfy/ldm/mmaudio/vae/vae.py index 1009d2d76..831aa3973 100644 --- a/comfy/ldm/mmaudio/vae/vae.py +++ b/comfy/ldm/mmaudio/vae/vae.py @@ -2,7 +2,6 @@ import logging from typing import Optional import torch -import comfy.utils import torch.nn as nn from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D, @@ -153,7 +152,7 @@ class VAE(nn.Module): return dec, posterior def load_weights(self, src_dict) -> None: - comfy.utils.load_state_dict(self, src_dict, strict=True) + self.load_state_dict(src_dict, strict=True) @property def device(self) -> torch.device: diff --git a/comfy/model_base.py b/comfy/model_base.py index 84766591b..7b7a05454 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -309,7 +309,7 @@ class BaseModel(torch.nn.Module): else: to_load = sd to_load = self.model_config.process_unet_state_dict(to_load) - m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: logging.warning("unet missing: {}".format(m)) @@ -753,8 +753,8 @@ class StableAudio1(BaseModel): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512) self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512) - utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True) - utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True) + self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights, strict=True) + self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights, strict=True) def extra_conds(self, **kwargs): out = {} diff --git a/comfy/model_management.py b/comfy/model_management.py index b6ef8aaa1..e9c42294d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -590,10 +590,26 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled(): - logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024)) - freed_cache = comfy.disk_weights.evict_ram_cache(memory_required) - if freed_cache < memory_required: - evict_ram_to_disk(memory_required - freed_cache) + free_before = get_free_memory(device) + headroom = comfy.disk_weights.ram_headroom_bytes() + if free_before < memory_required: + logging.debug( + "RAM pressure: required=%d free=%d headroom=%d", + memory_required, + free_before, + headroom, + ) + freed_cache = comfy.disk_weights.evict_ram_cache(memory_required) + freed_disk = 0 + if freed_cache < memory_required: + freed_disk = evict_ram_to_disk(memory_required - freed_cache) + free_after = get_free_memory(device) + freed_total = max(0, free_after - free_before) + logging.debug( + "RAM freed: freed=%d free=%d", + freed_total if freed_total > 0 else freed_cache + freed_disk, + free_after, + ) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -636,6 +652,7 @@ def evict_ram_to_disk(memory_to_free, keep_loaded=[]): if not comfy.disk_weights.disk_weights_enabled(): return 0 + free_before = get_free_memory(torch.device("cpu")) freed = 0 can_unload = [] for i in range(len(current_loaded_models) - 1, -1, -1): @@ -654,7 +671,14 @@ def evict_ram_to_disk(memory_to_free, keep_loaded=[]): freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed) if freed > 0: - logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024))) + free_after = get_free_memory(torch.device("cpu")) + freed_total = max(0, free_after - free_before) + logging.debug( + "RAM evicted to disk: required=%d free=%d freed=%d", + memory_to_free, + free_before, + freed_total if freed_total > 0 else freed, + ) return freed def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): @@ -802,6 +826,8 @@ def dtype_size(dtype): return dtype_size def unet_offload_device(): + if comfy.disk_weights.disk_weights_enabled(): + return torch.device("meta") if vram_state == VRAMState.HIGH_VRAM: return get_torch_device() else: @@ -906,6 +932,8 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo return torch.float32 def text_encoder_offload_device(): + if comfy.disk_weights.disk_weights_enabled(): + return torch.device("meta") if args.gpu_only: return get_torch_device() else: @@ -966,6 +994,8 @@ def vae_device(): return get_torch_device() def vae_offload_device(): + if comfy.disk_weights.disk_weights_enabled(): + return torch.device("meta") if args.gpu_only: return get_torch_device() else: @@ -1163,14 +1193,13 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) -WEIGHTS_RAM_CACHE_BYTES = 0 WEIGHTS_GDS_ENABLED = bool(args.weights_gds) -if args.weights_ram_cache_gb is not None: - WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3)) +if args.low_ram: comfy.disk_weights.configure( - WEIGHTS_RAM_CACHE_BYTES, allow_gds=WEIGHTS_GDS_ENABLED, pin_if_cpu=not args.disable_pinned_memory, + ram_headroom_bytes=1024 * 1024 * 1024, + enabled=True, ) PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) diff --git a/comfy/ops.py b/comfy/ops.py index 67f151381..303f5ccd0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,7 +19,6 @@ import torch import logging import comfy.model_management -import comfy.disk_weights from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm @@ -101,27 +100,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_source = s.weight bias_source = s.bias - if comfy.disk_weights.disk_weights_enabled(): - if weight_source.device.type == "meta": - loaded = comfy.disk_weights.load_module_tensor( - s, - "weight", - device, - temporary=True, - dtype_override=dtype, - ) - if loaded is not None: - weight_source = loaded - if bias_source is not None and bias_source.device.type == "meta": - loaded_bias = comfy.disk_weights.load_module_tensor( - s, - "bias", - device, - temporary=True, - dtype_override=bias_dtype, - ) - if loaded_bias is not None: - bias_source = loaded_bias weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) diff --git a/comfy/safetensors_stream.py b/comfy/safetensors_stream.py index 61943e53f..1d4015f75 100644 --- a/comfy/safetensors_stream.py +++ b/comfy/safetensors_stream.py @@ -37,6 +37,19 @@ _FST_LOADED = False _GDS_INITIALIZED = False _MISSING = object() _NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024 +_PINNED_INFLIGHT = collections.deque() + + +def _reap_pinned_inflight(): + if not _PINNED_INFLIGHT: + return + pending = collections.deque() + while _PINNED_INFLIGHT: + event, tensor = _PINNED_INFLIGHT.popleft() + if event.query(): + continue + pending.append((event, tensor)) + _PINNED_INFLIGHT.extend(pending) def _require_fastsafetensors(): @@ -204,14 +217,22 @@ class _SafeTensorFile: ) return tensor + target_dtype = dtype + if device_is_cuda and pin_if_cpu: + _reap_pinned_inflight() + target_dtype = None cpu_tensor = self._read_tensor_nogds( - fst, framework, meta, torch.device("cpu"), dtype + fst, framework, meta, torch.device("cpu"), target_dtype, pin_memory=bool(device_is_cuda and pin_if_cpu) ) if device_is_cuda: - if pin_if_cpu: - cpu_tensor = cpu_tensor.pin_memory() gpu_tensor = torch.empty_like(cpu_tensor, device=device) gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu) + if pin_if_cpu: + event = torch.cuda.Event() + event.record(torch.cuda.current_stream(device)) + _PINNED_INFLIGHT.append((event, cpu_tensor)) + if dtype is not None and dtype != gpu_tensor.dtype: + gpu_tensor = gpu_tensor.to(dtype=dtype) return gpu_tensor return cpu_tensor @@ -233,6 +254,8 @@ class _SafeTensorFile: meta: TensorMeta, device: torch.device, dtype: Optional[torch.dtype], + *, + pin_memory: bool = False, ) -> torch.Tensor: fd = self._ensure_fd() reader = self._ensure_nogds_reader(use_cuda=False) @@ -241,7 +264,7 @@ class _SafeTensorFile: chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT)) chunk_bytes = max(1, chunk_bytes) ptr_align = framework.get_device_ptr_align() - dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu") + dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu", pin_memory=pin_memory) buffer_length = 0 buf_ptr = None gbuf = None @@ -250,6 +273,8 @@ class _SafeTensorFile: while chunk_offset < length: chunk_len = min(length - chunk_offset, chunk_bytes) aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len) + if aligned_offset + aligned_length > self.index.size_bytes: + aligned_length = max(0, self.index.size_bytes - aligned_offset) needed = aligned_length + ptr_align if buf_ptr is None or needed > buffer_length: if buf_ptr is not None: @@ -272,7 +297,6 @@ class _SafeTensorFile: if buf_ptr is not None: fst.cpp.cpu_free(buf_ptr) if dtype is not None and dtype != dest_tensor.dtype: - _validate_dtype_conversion(dest_tensor.dtype, dtype) dest_tensor = dest_tensor.to(dtype=dtype) return dest_tensor @@ -289,6 +313,8 @@ class _SafeTensorFile: abs_start = self.index.header_length + meta.data_offsets[0] length = meta.data_offsets[1] - meta.data_offsets[0] aligned_offset, aligned_length, head = self._aligned_range(abs_start, length) + if aligned_offset + aligned_length > self.index.size_bytes: + aligned_length = max(0, self.index.size_bytes - aligned_offset) ptr_align = framework.get_device_ptr_align() buffer_length = aligned_length + ptr_align fst_device = _fst_device_from_torch(fst, device) @@ -306,7 +332,6 @@ class _SafeTensorFile: fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner ) if dtype is not None and dtype != tensor.dtype: - _validate_dtype_conversion(tensor.dtype, dtype) tensor = tensor.to(dtype=dtype) return tensor @@ -348,8 +373,7 @@ def _dlpack_tensor_from_buffer( def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype): - if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size(): - raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})") + return def _get_gds_o_direct(framework) -> bool: @@ -523,8 +547,10 @@ class StreamStateDict(collections.abc.MutableMapping): raise KeyError(key) return default if self._index.has(key): + value = self.get_tensor(key) self._deleted.add(key) - return self.get_tensor(key) + self._overrides.pop(key, None) + return value if default is _MISSING: raise KeyError(key) return default @@ -636,8 +662,10 @@ class _BaseViewStateDict(MutableMapping): if default is _MISSING: raise return default + value = self.get_tensor(key) self._deleted.add(key) - return self.get_tensor(key) + self._overrides.pop(key, None) + return value def meta(self, key: str): if key in self._overrides: @@ -768,8 +796,10 @@ class DeviceViewStateDict(_BaseViewStateDict): if default is _MISSING: raise return default + value = self.get_tensor(key) self._deleted.add(key) - return self.get_tensor(key) + self._overrides.pop(key, None) + return value class FilterViewStateDict(_BaseViewStateDict): @@ -932,3 +962,48 @@ class MappedStateDict(_BaseViewStateDict): def _iter_base_keys(self) -> Iterable[str]: return self._view_to_base.keys() + + +def stream_load_state_dict(model, state_dict, strict: bool = False, assign: bool = False): + if getattr(state_dict, "is_stream_state_dict", False) and hasattr(state_dict, "copy"): + state_dict = state_dict.copy() + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + metadata = getattr(state_dict, "_metadata", None) + + def load(module, local_state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + if assign: + local_metadata["assign_to_params_buffers"] = assign + module._load_from_state_dict( + local_state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) + for name, child in module._modules.items(): + if child is not None: + child_prefix = f"{prefix}{name}." + child_state_dict = FilterViewStateDict( + local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False + ) + load(child, child_state_dict, child_prefix) + incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible) + if out is not None: + raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.") + + load(model, state_dict) + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys))) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs))) + return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) diff --git a/comfy/sd.py b/comfy/sd.py index f988b0005..96eb04e51 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -26,7 +26,6 @@ import os import comfy.utils import comfy.safetensors_stream -import comfy.disk_weights from . import clip_vision from . import gligen @@ -126,7 +125,7 @@ class CLIP: if not model_management.supports_cast(load_device, dt): load_device = offload_device if params['device'] != offload_device: - comfy.disk_weights.module_to(self.cond_stage_model, offload_device) + self.cond_stage_model.to(offload_device) logging.warning("Had to shift TE back.") self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) @@ -290,7 +289,7 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False) + return self.cond_stage_model.load_state_dict(sd, strict=False) else: return self.cond_stage_model.load_sd(sd) @@ -658,7 +657,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False) + m, u = self.first_stage_model.load_state_dict(sd, strict=False) if len(m) > 0: logging.warning("Missing VAE keys {}".format(m)) @@ -672,7 +671,7 @@ class VAE: if dtype is None: dtype = model_management.vae_dtype(self.device, self.working_dtypes) self.vae_dtype = dtype - comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype) + self.first_stage_model.to(dtype=self.vae_dtype) self.output_device = model_management.intermediate_device() self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) @@ -979,7 +978,7 @@ def load_style_model(ckpt_path): model = comfy.ldm.flux.redux.ReduxImageEncoder() else: raise Exception("invalid style model {}".format(ckpt_path)) - comfy.utils.load_state_dict(model, model_data, strict=True) + model.load_state_dict(model_data, strict=True) return StyleModel(model) def sd_shape(state_dict, key): @@ -1547,7 +1546,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = comfy.disk_weights.module_to(model, offload_device) + model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 9fb8a8f5c..305555dea 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return comfy.utils.load_state_dict(self.transformer, sd, strict=False) + return self.transformer.load_state_dict(sd, strict=False) def parse_parentheses(string): result = [] diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 7cf040588..cbebb3802 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -56,9 +56,9 @@ class TAESD(nn.Module): self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) self.vae_shift = torch.nn.Parameter(torch.tensor(0.0)) if encoder_path is not None: - comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True) + self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True) if decoder_path is not None: - comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True) + self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True) @staticmethod def scale_latents(x): diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index c998061b4..130ebaeae 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module): if len(sdo) == 0: sdo = sd - return comfy.utils.load_state_dict(self, sdo, strict=False) + return self.load_state_dict(sdo, strict=False) def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): diff --git a/comfy/utils.py b/comfy/utils.py index dda041625..af07275b1 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -32,7 +32,6 @@ from einops import rearrange from comfy.cli_args import args import json from . import safetensors_stream -import comfy.disk_weights MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -166,62 +165,6 @@ def state_dict_meta(state_dict, key): ) -def load_state_dict(model, state_dict, strict=False, assign=False): - if is_stream_state_dict(state_dict): - if comfy.disk_weights.disk_weights_enabled(): - return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict) - comfy.disk_weights.register_module_weights(model, state_dict) - comfy.disk_weights.attach_disk_weight_hooks(model) - missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign) - return missing, unexpected - return model.load_state_dict(state_dict, strict=strict) - - -def stream_load_state_dict(model, state_dict, strict=False, assign=False): - if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"): - state_dict = state_dict.copy() - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - metadata = getattr(state_dict, "_metadata", None) - - def load(module, local_state_dict, prefix=""): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - if assign: - local_metadata["assign_to_params_buffers"] = assign - module._load_from_state_dict( - local_state_dict, - prefix, - local_metadata, - True, - missing_keys, - unexpected_keys, - error_msgs, - ) - for name, child in module._modules.items(): - if child is not None: - child_prefix = f"{prefix}{name}." - child_state_dict = safetensors_stream.FilterViewStateDict( - local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False - ) - load(child, child_state_dict, child_prefix) - incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) - for hook in module._load_state_dict_post_hooks.values(): - out = hook(module, incompatible) - if out is not None: - raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.") - - load(model, state_dict) - if strict: - if len(unexpected_keys) > 0: - error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys))) - if len(missing_keys) > 0: - error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys))) - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs))) - return missing_keys, unexpected_keys - - def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight",