From 45a77073ac3b15b821c703970c50019685eb1e17 Mon Sep 17 00:00:00 2001 From: ifilipis <40601736+ifilipis@users.noreply.github.com> Date: Thu, 8 Jan 2026 19:06:49 +0200 Subject: [PATCH] Implement partial disk weight materialization --- comfy/disk_weights.py | 466 +++++++++++++++++++++++++++++++++++++++--- comfy/ops.py | 51 +++-- 2 files changed, 473 insertions(+), 44 deletions(-) diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index 4342e9303..588e55518 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -20,8 +20,9 @@ from __future__ import annotations import collections import weakref -from dataclasses import dataclass -from typing import Dict, MutableMapping, Optional +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Dict, MutableMapping, Optional, Set import torch @@ -33,6 +34,8 @@ PIN_IF_CPU = False DISK_WEIGHTS_ENABLED = False BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict LAZY_MODULE_STATE = weakref.WeakKeyDictionary() +DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary() +_MISSING = object() @dataclass @@ -178,22 +181,24 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = " return if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"): return - for name, param in module.named_parameters(recurse=True): - key = f"{prefix}{name}" if prefix else name - if key in state_dict: - meta = state_dict.meta(key) - ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False) - REGISTRY.register(module, name, ref) - if param.device.type == "cpu": - CACHE.record(module, name, param, is_buffer=False) - for name, buf in module.named_buffers(recurse=True): - key = f"{prefix}{name}" if prefix else name - if key in state_dict and buf is not None: - meta = state_dict.meta(key) - ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True) - REGISTRY.register(module, name, ref) - if buf.device.type == "cpu": - CACHE.record(module, name, buf, is_buffer=True) + for module_name, submodule in module.named_modules(): + module_prefix = f"{prefix}{module_name}." if module_name else prefix + for name, param in submodule.named_parameters(recurse=False): + key = f"{module_prefix}{name}" if module_prefix else name + if key in state_dict: + meta = state_dict.meta(key) + ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False) + REGISTRY.register(submodule, name, ref) + if param.device.type == "cpu": + CACHE.record(submodule, name, param, is_buffer=False) + for name, buf in submodule.named_buffers(recurse=False): + key = f"{module_prefix}{name}" if module_prefix else name + if key in state_dict and buf is not None: + meta = state_dict.meta(key) + ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True) + REGISTRY.register(submodule, name, ref) + if buf.device.type == "cpu": + CACHE.record(submodule, name, buf, is_buffer=True) @dataclass @@ -203,6 +208,245 @@ class LazyModuleState: loaded: bool = False +@dataclass +class DiskMaterializationState: + loaded_keys: Set[str] = field(default_factory=set) + deferred_keys: Set[str] = field(default_factory=set) + loaded_bytes: int = 0 + deferred_bytes: int = 0 + + +def _get_materialization_state(module: torch.nn.Module) -> DiskMaterializationState: + state = DISK_MATERIALIZATION_STATE.get(module) + if state is None: + state = DiskMaterializationState() + DISK_MATERIALIZATION_STATE[module] = state + return state + + +def _update_disk_state_attrs(module: torch.nn.Module, state: DiskMaterializationState): + module.disk_loaded_weight_memory = state.loaded_bytes + module.disk_offload_buffer_memory = state.deferred_bytes + + +def _tensor_nbytes(tensor: torch.Tensor) -> int: + return tensor.numel() * tensor.element_size() + + +def _meta_nbytes(meta) -> Optional[int]: + return getattr(meta, "nbytes", None) + + +def _meta_tensor(meta, dtype_override: Optional[torch.dtype] = None) -> torch.Tensor: + dtype = dtype_override or getattr(meta, "dtype", None) + shape = getattr(meta, "shape", None) + if dtype is None or shape is None: + raise KeyError("Missing metadata for meta tensor") + return torch.empty(shape, dtype=dtype, device="meta") + + +def _state_dict_meta(state_dict: MutableMapping, key: str): + if hasattr(state_dict, "meta"): + return state_dict.meta(key) + if hasattr(state_dict, "get_tensor"): + t = state_dict.get_tensor(key, device=torch.device("meta")) + else: + t = state_dict[key] + numel = t.numel() + return SimpleNamespace( + dtype=t.dtype, + shape=tuple(t.shape), + numel=numel, + nbytes=numel * t.element_size(), + ) + + +def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, DiskTensorRef], state: DiskMaterializationState): + state.loaded_keys.clear() + state.deferred_keys.clear() + state.loaded_bytes = 0 + state.deferred_bytes = 0 + for name, ref in refs.items(): + if name in module._parameters: + tensor = module._parameters[name] + elif name in module._buffers: + tensor = module._buffers[name] + else: + continue + if tensor is None: + continue + nbytes = _meta_nbytes(ref.meta) or _tensor_nbytes(tensor) + if tensor.device.type == "meta": + state.deferred_keys.add(name) + state.deferred_bytes += nbytes + else: + state.loaded_keys.add(name) + state.loaded_bytes += nbytes + _update_disk_state_attrs(module, state) + + +def _device_free_memory(device: torch.device) -> int: + from . import model_management + return int(model_management.get_free_memory(device)) + + +def _evict_ram_for_budget(required_bytes: int) -> int: + if required_bytes <= 0: + return 0 + return evict_ram_cache(required_bytes) + + +def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int: + free_mem = _device_free_memory(device) + if device.type == "cpu" and free_mem < required_bytes: + _evict_ram_for_budget(required_bytes - free_mem) + free_mem = _device_free_memory(device) + return free_mem + + +def _choose_alternate_device(device: torch.device) -> Optional[torch.device]: + from . import model_management + if device.type == "cpu": + alt = model_management.get_torch_device() + if alt.type != "cpu": + return alt + else: + return torch.device("cpu") + return None + + +class _BudgetedStateDict(MutableMapping): + is_stream_state_dict = True + + def __init__( + self, + base: MutableMapping, + allowed_keys: Set[str], + device: torch.device, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + overrides: Optional[Dict[str, torch.Tensor]] = None, + ): + self._base = base + self._allowed_keys = allowed_keys + self._device = device + self._allow_gds = allow_gds + self._pin_if_cpu = pin_if_cpu + self._overrides = overrides or {} + self._deleted: Set[str] = set() + + def _get_meta(self, key: str): + if key in self._overrides: + t = self._overrides[key] + return safetensors_stream.TensorMeta( + dtype=t.dtype, + shape=tuple(t.shape), + numel=t.numel(), + nbytes=_tensor_nbytes(t), + data_offsets=(0, _tensor_nbytes(t)), + filename="", + fst_dtype=None, + strides=tuple(t.stride()), + ) + if hasattr(self._base, "meta"): + return self._base.meta(key) + if hasattr(self._base, "get_tensor"): + t = self._base.get_tensor(key, device=torch.device("meta")) + else: + t = self._base[key] + return safetensors_stream.TensorMeta( + dtype=t.dtype, + shape=tuple(t.shape), + numel=t.numel(), + nbytes=_tensor_nbytes(t), + data_offsets=(0, _tensor_nbytes(t)), + filename="", + fst_dtype=None, + strides=tuple(t.stride()), + ) + + def get_tensor( + self, + key: str, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + allow_gds: Optional[bool] = None, + pin_if_cpu: bool = False, + ) -> torch.Tensor: + if key in self._overrides: + t = self._overrides[key] + if device is not None and t.device != device: + t = t.to(device=device) + if dtype is not None and t.dtype != dtype: + t = t.to(dtype=dtype) + return t + if key in self._deleted: + raise KeyError(key) + if key not in self._allowed_keys: + meta = self._get_meta(key) + target_dtype = dtype or meta.dtype + return _meta_tensor(meta, dtype_override=target_dtype) + if hasattr(self._base, "get_tensor"): + return self._base.get_tensor( + key, + device=self._device if device is None else device, + dtype=dtype, + allow_gds=self._allow_gds if allow_gds is None else allow_gds, + pin_if_cpu=self._pin_if_cpu if not pin_if_cpu else pin_if_cpu, + ) + t = self._base[key] + if device is not None and t.device != device: + t = t.to(device=device) + if dtype is not None and t.dtype != dtype: + t = t.to(dtype=dtype) + return t + + def __getitem__(self, key: str) -> torch.Tensor: + return self.get_tensor(key) + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + self._overrides[key] = value + self._deleted.discard(key) + + def __delitem__(self, key: str) -> None: + if key in self._overrides: + del self._overrides[key] + return + if key in self._deleted: + raise KeyError(key) + self._deleted.add(key) + + def __iter__(self): + for k in self._base.keys(): + if k in self._deleted: + continue + yield k + for k in self._overrides.keys(): + if k not in self._deleted: + yield k + + def __len__(self) -> int: + base_keys = list(self._base.keys()) + return len(base_keys) - len(self._deleted) + len(self._overrides) + + def pop(self, key: str, default: object = _MISSING) -> torch.Tensor: + if key in self._overrides: + return self._overrides.pop(key) + if key in self._deleted: + if default is _MISSING: + raise KeyError(key) + return default + if key not in self._base: + if default is _MISSING: + raise KeyError(key) + return default + self._deleted.add(key) + return self.get_tensor(key) + + def meta(self, key: str): + return self._get_meta(key) + def _has_custom_load(module: torch.nn.Module) -> bool: return module.__class__._load_from_state_dict is not BASE_LOAD_FROM_STATE_DICT @@ -239,6 +483,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): CACHE.remove_module(module) refs = REGISTRY.get(module) if refs: + state = _get_materialization_state(module) for ref_name, disk_ref in refs.items(): shape = getattr(disk_ref.meta, "shape", None) dtype = getattr(disk_ref.meta, "dtype", None) @@ -249,6 +494,14 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): module._buffers[ref_name] = meta_tensor else: module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad) + nbytes = _meta_nbytes(disk_ref.meta) + if nbytes is not None: + state.loaded_keys.discard(ref_name) + if ref_name not in state.deferred_keys: + state.deferred_keys.add(ref_name) + state.deferred_bytes += nbytes + state.loaded_bytes = max(0, state.loaded_bytes - nbytes) + _update_disk_state_attrs(module, state) lazy_state.loaded = False return ref = REGISTRY.get(module) @@ -264,6 +517,15 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): module._buffers[name] = meta_tensor else: module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad) + state = _get_materialization_state(module) + nbytes = _meta_nbytes(disk_ref.meta) + if nbytes is not None: + state.loaded_keys.discard(name) + if name not in state.deferred_keys: + state.deferred_keys.add(name) + state.deferred_bytes += nbytes + state.loaded_bytes = max(0, state.loaded_bytes - nbytes) + _update_disk_state_attrs(module, state) def _find_tensor_device(args, kwargs) -> Optional[torch.device]: @@ -288,15 +550,23 @@ def _find_tensor_device(args, kwargs) -> Optional[torch.device]: return check(kwargs) -def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device): +def ensure_module_materialized( + module: torch.nn.Module, + target_device: torch.device, + fallback_device: Optional[torch.device] = None, +): lazy_state = LAZY_MODULE_STATE.get(module) - if lazy_state is not None and not lazy_state.loaded: + if lazy_state is not None: _materialize_module_from_state_dict(module, lazy_state, target_device) return refs = REGISTRY.get(module) if not refs: return - for name, disk_ref in refs.items(): + state = _get_materialization_state(module) + _rebuild_materialization_state(module, refs, state) + remaining_budget = _device_free_memory(target_device) + for name in sorted(refs.keys()): + disk_ref = refs[name] if name in module._parameters: current = module._parameters[name] is_buffer = False @@ -307,30 +577,52 @@ def ensure_module_materialized(module: torch.nn.Module, target_device: torch.dev continue if current is None: continue - if current.device.type == "meta": - tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU) - elif current.device != target_device: - tensor = current.to(device=target_device) - else: + if current.device.type != "meta" and current.device == target_device: if current.device.type == "cpu": CACHE.touch(module, name) continue + meta_nbytes = _meta_nbytes(disk_ref.meta) + if meta_nbytes is None: + continue + required_bytes = meta_nbytes + if target_device.type == "cpu": + free_mem = _maybe_free_ram_budget(target_device, required_bytes) + remaining_budget = min(remaining_budget, free_mem) + if required_bytes > remaining_budget: + if fallback_device is not None and fallback_device != target_device: + fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes) + if fallback_free >= required_bytes: + target_for_load = fallback_device + else: + continue + else: + continue + else: + target_for_load = target_device + if current.device.type == "meta": + tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU) + else: + tensor = current.to(device=target_for_load) if is_buffer: module._buffers[name] = tensor else: module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) if tensor.device.type == "cpu": CACHE.record(module, name, tensor, is_buffer=is_buffer) + remaining_budget = max(0, remaining_budget - required_bytes) + _rebuild_materialization_state(module, refs, state) def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs): - if not REGISTRY.has(module): + if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE: return if getattr(module, "comfy_cast_weights", False): target_device = torch.device("cpu") + fallback_device = _find_tensor_device(args, kwargs) else: target_device = _find_tensor_device(args, kwargs) or torch.device("cpu") - ensure_module_materialized(module, target_device) + fallback_device = None + ensure_module_materialized(module, target_device, fallback_device=fallback_device) def attach_disk_weight_hooks(model: torch.nn.Module): @@ -386,6 +678,87 @@ def module_to(module: torch.nn.Module, *args, **kwargs): return module.to(*args, **kwargs) +def load_module_tensor( + module: torch.nn.Module, + name: str, + device: torch.device, + *, + allow_alternate: bool = True, + record_cache: bool = True, + temporary: bool = False, +) -> Optional[torch.Tensor]: + refs = REGISTRY.get(module) + if not refs or name not in refs: + return None + if name in module._parameters: + current = module._parameters[name] + is_buffer = False + elif name in module._buffers: + current = module._buffers[name] + is_buffer = True + else: + return None + if current is None: + return None + if current.device.type != "meta": + if current.device != device: + tensor = current.to(device=device) + if not temporary: + if is_buffer: + module._buffers[name] = tensor + else: + module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=refs[name].requires_grad) + _rebuild_materialization_state(module, refs, _get_materialization_state(module)) + return tensor + return current + + disk_ref = refs[name] + required_bytes = _meta_nbytes(disk_ref.meta) + if required_bytes is None: + return current + free_mem = _maybe_free_ram_budget(device, required_bytes) + load_device = device + if free_mem < required_bytes and allow_alternate: + alt = _choose_alternate_device(device) + if alt is not None: + alt_free = _maybe_free_ram_budget(alt, required_bytes) + if alt_free >= required_bytes: + load_device = alt + else: + state = _get_materialization_state(module) + if name not in state.deferred_keys: + state.deferred_keys.add(name) + state.deferred_bytes += required_bytes + _update_disk_state_attrs(module, state) + return current + else: + state = _get_materialization_state(module) + if name not in state.deferred_keys: + state.deferred_keys.add(name) + state.deferred_bytes += required_bytes + _update_disk_state_attrs(module, state) + return current + elif free_mem < required_bytes: + state = _get_materialization_state(module) + if name not in state.deferred_keys: + state.deferred_keys.add(name) + state.deferred_bytes += required_bytes + _update_disk_state_attrs(module, state) + return current + + tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU) + if temporary: + return tensor + if is_buffer: + module._buffers[name] = tensor + else: + module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) + if tensor.device.type == "cpu" and record_cache: + CACHE.record(module, name, tensor, is_buffer=is_buffer) + _rebuild_materialization_state(module, refs, _get_materialization_state(module)) + return tensor + + def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool, requires_grad: bool): parts = name.split(".") module = model @@ -404,12 +777,42 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz error_msgs = [] metadata = getattr(lazy_state.state_dict, "_metadata", None) local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {}) - state_dict = safetensors_stream.DeviceViewStateDict( + refs = REGISTRY.get(module) or {} + state = _get_materialization_state(module) + _rebuild_materialization_state(module, refs, state) + keys = sorted(lazy_state.state_dict.keys()) + existing = {} + for name, param in module.named_parameters(recurse=False): + key = f"{lazy_state.prefix}{name}" + if key in lazy_state.state_dict and param is not None and param.device.type != "meta": + existing[key] = param + for name, buf in module.named_buffers(recurse=False): + key = f"{lazy_state.prefix}{name}" + if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta": + existing[key] = buf + remaining_budget = _device_free_memory(target_device) + allowed = set(existing.keys()) + for key in keys: + if key in allowed: + continue + meta = _state_dict_meta(lazy_state.state_dict, key) + required = _meta_nbytes(meta) + if required is None: + continue + if target_device.type == "cpu": + free_mem = _maybe_free_ram_budget(target_device, required) + remaining_budget = min(remaining_budget, free_mem) + if required <= remaining_budget: + allowed.add(key) + remaining_budget = max(0, remaining_budget - required) + deferred_state_dict_keys = {key for key in keys if key not in allowed} + state_dict = _BudgetedStateDict( lazy_state.state_dict, + allowed_keys=allowed, device=target_device, allow_gds=ALLOW_GDS, pin_if_cpu=PIN_IF_CPU, - mutate_base=False, + overrides=existing, ) factory_device = None if hasattr(module, "factory_kwargs") and "device" in module.factory_kwargs: @@ -435,7 +838,8 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz module.factory_kwargs["device"] = factory_device if len(error_msgs) > 0: raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs))) - lazy_state.loaded = True + _rebuild_materialization_state(module, refs, state) + lazy_state.loaded = len(deferred_state_dict_keys) == 0 for name, param in module.named_parameters(recurse=False): if param.device.type == "cpu": CACHE.record(module, name, param, is_buffer=False) diff --git a/comfy/ops.py b/comfy/ops.py index cd536e22d..06f28317a 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,6 +19,7 @@ import torch import logging import comfy.model_management +import comfy.disk_weights from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm @@ -98,11 +99,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight_source = s.weight + bias_source = s.bias + if comfy.disk_weights.disk_weights_enabled(): + if weight_source.device.type == "meta": + loaded = comfy.disk_weights.load_module_tensor(s, "weight", device, temporary=True) + if loaded is not None: + weight_source = loaded + if bias_source is not None and bias_source.device.type == "meta": + loaded_bias = comfy.disk_weights.load_module_tensor(s, "bias", device, temporary=True) + if loaded_bias is not None: + bias_source = loaded_bias + + weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) bias = None - if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + if bias_source is not None: + bias = comfy.model_management.cast_to(bias_source, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) comfy.model_management.sync_stream(device, offload_stream) @@ -532,9 +545,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec key = f"{prefix}{param_name}" value = state_dict.pop(key, None) if value is not None: - value = value.to(device=device) - if dtype is not None: - value = value.view(dtype=dtype) + if value.device.type != "meta": + value = value.to(device=device) + if dtype is not None: + value = value.view(dtype=dtype) manually_loaded_keys.append(key) return value @@ -551,11 +565,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec manually_loaded_keys = [weight_key] layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) - if layer_conf is not None: + if layer_conf is not None and layer_conf.device.type != "meta": layer_conf = json.loads(layer_conf.numpy().tobytes()) + elif layer_conf is not None: + layer_conf = None if layer_conf is None: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + if weight.device.type == "meta": + self.weight = torch.nn.Parameter(weight, requires_grad=False) + else: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) @@ -601,10 +620,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec else: raise ValueError(f"Unsupported quantization format: {self.quant_format}") - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), - requires_grad=False - ) + if weight.device.type == "meta": + self.weight = torch.nn.Parameter(weight, requires_grad=False) + else: + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), + requires_grad=False + ) for param_name in qconfig["parameters"]: if param_name in {"weight_scale", "weight_scale_2"}: @@ -614,7 +636,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec _v = state_dict.pop(param_key, None) if _v is None: continue - self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + if _v.device.type == "meta": + self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False)) + else: + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)