mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-01 09:10:16 +08:00
Refine disk weight offload integration
This commit is contained in:
parent
b93f165c2e
commit
95ca11fe25
@ -349,12 +349,12 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
|
||||
| `--low-ram` | Enable disk weight offloading. Sets RAM headroom to 1024MB and uses `--reserve-vram` as the VRAM headroom. |
|
||||
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
|
||||
|
||||
### Disk tier for model weights
|
||||
|
||||
When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
|
||||
When `--low-ram` is enabled, ComfyUI streams safetensors weights from disk and offloads weights to disk when RAM headroom is reached.
|
||||
|
||||
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ class AudioEncoderModel():
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -114,7 +114,7 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
|
||||
parser.add_argument("--low-ram", action="store_true", help="Enable disk weight offloading. Sets RAM headroom to 1024MB and uses --reserve-vram as the VRAM headroom.")
|
||||
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
|
||||
@ -6,7 +6,6 @@ import logging
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
|
||||
@ -48,7 +47,7 @@ class ClipVisionModel():
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -25,7 +25,6 @@ import logging
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
import comfy.disk_weights
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
import comfy.latent_formats
|
||||
@ -386,7 +385,7 @@ class ControlLora(ControlNet):
|
||||
controlnet_config["operations"] = control_lora_ops
|
||||
controlnet_config["dtype"] = dtype
|
||||
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
|
||||
self.control_model.to(comfy.model_management.get_torch_device())
|
||||
diffusion_model = model.diffusion_model
|
||||
sd = diffusion_model.state_dict()
|
||||
|
||||
@ -440,7 +439,7 @@ def controlnet_config(sd, model_options={}):
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
@ -474,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}):
|
||||
class ControlNetSD35(ControlNet):
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
if self.control_model.double_y_emb:
|
||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
else:
|
||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def copy(self):
|
||||
@ -749,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
|
||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||
else:
|
||||
missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
@ -817,8 +816,8 @@ class T2IAdapter(ControlBase):
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
if self.control_input is None:
|
||||
comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
|
||||
comfy.disk_weights.module_to(self.t2i_model, self.device)
|
||||
self.t2i_model.to(dtype=x_noisy.dtype)
|
||||
self.t2i_model.to(self.device)
|
||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||
self.t2i_model.cpu()
|
||||
|
||||
@ -875,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
else:
|
||||
return None
|
||||
|
||||
missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
|
||||
missing, unexpected = model_ad.load_state_dict(t2i_data, strict=True)
|
||||
if len(missing) > 0:
|
||||
logging.warning("t2i missing {}".format(missing))
|
||||
|
||||
|
||||
@ -33,7 +33,11 @@ from . import safetensors_stream
|
||||
ALLOW_GDS = False
|
||||
PIN_IF_CPU = False
|
||||
DISK_WEIGHTS_ENABLED = False
|
||||
RAM_HEADROOM_BYTES = 0
|
||||
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
|
||||
BASE_MODULE_TO = torch.nn.Module.to
|
||||
BASE_LOAD_STATE_DICT = torch.nn.Module.load_state_dict
|
||||
_MONKEYPATCHED = False
|
||||
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
|
||||
DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary()
|
||||
_MISSING = object()
|
||||
@ -169,13 +173,17 @@ CACHE = DiskWeightCache(0)
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
|
||||
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
|
||||
def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, enabled: bool = True):
|
||||
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED, RAM_HEADROOM_BYTES
|
||||
ALLOW_GDS = allow_gds
|
||||
PIN_IF_CPU = pin_if_cpu
|
||||
DISK_WEIGHTS_ENABLED = enabled
|
||||
CACHE.set_limit(cache_bytes if enabled else 0)
|
||||
if not enabled:
|
||||
RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes))
|
||||
CACHE.set_limit(0 if enabled else 0)
|
||||
if enabled:
|
||||
install_monkeypatches()
|
||||
else:
|
||||
uninstall_monkeypatches()
|
||||
CACHE._entries.clear()
|
||||
CACHE.current_bytes = 0
|
||||
|
||||
@ -184,6 +192,66 @@ def disk_weights_enabled() -> bool:
|
||||
return DISK_WEIGHTS_ENABLED
|
||||
|
||||
|
||||
def ram_headroom_bytes() -> int:
|
||||
return RAM_HEADROOM_BYTES
|
||||
|
||||
|
||||
def _is_stream_state_dict(state_dict) -> bool:
|
||||
return (
|
||||
getattr(state_dict, "is_stream_state_dict", False)
|
||||
and hasattr(state_dict, "get_tensor")
|
||||
and hasattr(state_dict, "meta")
|
||||
)
|
||||
|
||||
|
||||
def patched_to(self: torch.nn.Module, *args, **kwargs):
|
||||
if not disk_weights_enabled():
|
||||
return BASE_MODULE_TO(self, *args, **kwargs)
|
||||
device, dtype, non_blocking, memory_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
module_to(
|
||||
self,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
non_blocking=non_blocking,
|
||||
memory_format=memory_format,
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
def patched_load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
|
||||
if not disk_weights_enabled():
|
||||
if _is_stream_state_dict(state_dict):
|
||||
return safetensors_stream.stream_load_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
strict=strict,
|
||||
assign=assign,
|
||||
)
|
||||
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
|
||||
if _is_stream_state_dict(state_dict):
|
||||
missing_keys, unexpected_keys = lazy_load_state_dict(self, state_dict, strict=strict)
|
||||
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
|
||||
|
||||
|
||||
def install_monkeypatches():
|
||||
global _MONKEYPATCHED
|
||||
if _MONKEYPATCHED:
|
||||
return
|
||||
torch.nn.Module.to = patched_to
|
||||
torch.nn.Module.load_state_dict = patched_load_state_dict
|
||||
_MONKEYPATCHED = True
|
||||
|
||||
|
||||
def uninstall_monkeypatches():
|
||||
global _MONKEYPATCHED
|
||||
if not _MONKEYPATCHED:
|
||||
return
|
||||
torch.nn.Module.to = BASE_MODULE_TO
|
||||
torch.nn.Module.load_state_dict = BASE_LOAD_STATE_DICT
|
||||
_MONKEYPATCHED = False
|
||||
|
||||
|
||||
def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
|
||||
if not disk_weights_enabled():
|
||||
return
|
||||
@ -369,33 +437,29 @@ def _device_free_memory(device: torch.device) -> int:
|
||||
return int(model_management.get_free_memory(device))
|
||||
|
||||
|
||||
def _evict_ram_for_budget(required_bytes: int) -> int:
|
||||
if required_bytes <= 0:
|
||||
return 0
|
||||
freed = evict_ram_cache(required_bytes)
|
||||
if freed < required_bytes:
|
||||
def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_bytes: int) -> int:
|
||||
free_before = _device_free_memory(device)
|
||||
if free_before < required_bytes + headroom_bytes:
|
||||
LOGGER.debug(
|
||||
"Disk weight memory pressure: required=%d free=%d headroom=%d device=%s",
|
||||
required_bytes,
|
||||
free_before,
|
||||
headroom_bytes,
|
||||
device,
|
||||
)
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
from . import model_management
|
||||
freed += model_management.evict_ram_to_disk(required_bytes - freed)
|
||||
return freed
|
||||
|
||||
|
||||
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
|
||||
free_mem = _device_free_memory(device)
|
||||
if device.type == "cpu" and free_mem < required_bytes:
|
||||
_evict_ram_for_budget(required_bytes - free_mem)
|
||||
free_mem = _device_free_memory(device)
|
||||
return free_mem
|
||||
|
||||
|
||||
def _choose_alternate_device(device: torch.device) -> Optional[torch.device]:
|
||||
from . import model_management
|
||||
if device.type == "cpu":
|
||||
alt = model_management.get_torch_device()
|
||||
if alt.type != "cpu":
|
||||
return alt
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
return None
|
||||
model_management.free_memory(required_bytes + headroom_bytes, device)
|
||||
free_after = _device_free_memory(device)
|
||||
freed = max(0, free_after - free_before)
|
||||
LOGGER.debug(
|
||||
"Disk weight memory freed: freed=%d free=%d device=%s",
|
||||
freed,
|
||||
free_after,
|
||||
device,
|
||||
)
|
||||
return free_after
|
||||
return free_before
|
||||
|
||||
|
||||
class _BudgetedStateDict(MutableMapping):
|
||||
@ -527,8 +591,10 @@ class _BudgetedStateDict(MutableMapping):
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
value = self.get_tensor(key)
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
self._overrides.pop(key, None)
|
||||
return value
|
||||
|
||||
def meta(self, key: str):
|
||||
return self._get_meta(key)
|
||||
@ -564,6 +630,7 @@ def register_lazy_modules(model: torch.nn.Module, state_dict):
|
||||
|
||||
|
||||
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||
if lazy_state is not None:
|
||||
CACHE.remove_module(module)
|
||||
@ -671,7 +738,6 @@ def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype:
|
||||
def ensure_module_materialized(
|
||||
module: torch.nn.Module,
|
||||
target_device: torch.device,
|
||||
fallback_device: Optional[torch.device] = None,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
):
|
||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||
@ -692,7 +758,6 @@ def ensure_module_materialized(
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
free_mem_start = _device_free_memory(target_device)
|
||||
remaining_budget = free_mem_start
|
||||
for name in sorted(refs.keys()):
|
||||
disk_ref = refs[name]
|
||||
if name in module._parameters:
|
||||
@ -717,19 +782,11 @@ def ensure_module_materialized(
|
||||
continue
|
||||
required_bytes = meta_nbytes
|
||||
if target_device.type == "cpu":
|
||||
free_mem = _maybe_free_ram_budget(target_device, required_bytes)
|
||||
remaining_budget = min(remaining_budget, free_mem)
|
||||
if required_bytes > remaining_budget:
|
||||
if fallback_device is not None and fallback_device != target_device:
|
||||
fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes)
|
||||
if fallback_free >= required_bytes:
|
||||
target_for_load = fallback_device
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
_ensure_free_memory(target_device, required_bytes, RAM_HEADROOM_BYTES)
|
||||
else:
|
||||
target_for_load = target_device
|
||||
from . import model_management
|
||||
_ensure_free_memory(target_device, required_bytes, model_management.extra_reserved_memory())
|
||||
target_for_load = target_device
|
||||
if current.device.type == "meta":
|
||||
tensor = disk_ref.load(
|
||||
target_for_load,
|
||||
@ -748,7 +805,6 @@ def ensure_module_materialized(
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu":
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
remaining_budget = max(0, remaining_budget - required_bytes)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
|
||||
|
||||
@ -761,14 +817,11 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
||||
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
|
||||
if getattr(module, "comfy_cast_weights", False):
|
||||
target_device = torch.device("cpu")
|
||||
fallback_device = _find_tensor_device(args, kwargs)
|
||||
else:
|
||||
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||
fallback_device = None
|
||||
ensure_module_materialized(
|
||||
module,
|
||||
target_device,
|
||||
fallback_device=fallback_device,
|
||||
dtype_override=dtype_override,
|
||||
)
|
||||
|
||||
@ -786,6 +839,7 @@ def attach_disk_weight_hooks(model: torch.nn.Module):
|
||||
def evict_ram_cache(bytes_to_free: int):
|
||||
if bytes_to_free <= 0:
|
||||
return 0
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
return CACHE.evict_bytes(bytes_to_free)
|
||||
|
||||
|
||||
@ -827,16 +881,7 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
|
||||
|
||||
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
|
||||
def _move(tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
if tensor.device.type == "meta":
|
||||
return tensor
|
||||
if dtype_override is not None and tensor.dtype != dtype_override:
|
||||
return tensor.to(device=device_to, dtype=dtype_override)
|
||||
return tensor.to(device=device_to)
|
||||
|
||||
module._apply(_move)
|
||||
ensure_module_materialized(module, device_to, dtype_override=dtype_override)
|
||||
return module
|
||||
|
||||
|
||||
@ -864,10 +909,20 @@ def offload_module_weights(module: torch.nn.Module) -> int:
|
||||
return offloaded_bytes
|
||||
|
||||
|
||||
def module_to(module: torch.nn.Module, *args, **kwargs):
|
||||
def module_to(
|
||||
module: torch.nn.Module,
|
||||
*args,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
memory_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
allow_materialize = kwargs.pop("allow_materialize", True)
|
||||
arg_device = _extract_to_device(args, kwargs)
|
||||
arg_dtype = _extract_to_dtype(args, kwargs)
|
||||
if disk_weights_enabled():
|
||||
target_device = _extract_to_device(args, kwargs)
|
||||
target_device = device or arg_device
|
||||
if target_device is None:
|
||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||
if target_device.type == "meta":
|
||||
@ -875,10 +930,28 @@ def module_to(module: torch.nn.Module, *args, **kwargs):
|
||||
return module
|
||||
if allow_materialize:
|
||||
materialize_module_tree(module, target_device)
|
||||
return module.to(*args, **kwargs)
|
||||
dtype_override = _extract_to_dtype(args, kwargs)
|
||||
base_kwargs = dict(kwargs)
|
||||
if device is not None and arg_device is None:
|
||||
base_kwargs["device"] = device
|
||||
if dtype is not None and arg_dtype is None:
|
||||
base_kwargs["dtype"] = dtype
|
||||
if non_blocking:
|
||||
base_kwargs["non_blocking"] = non_blocking
|
||||
if memory_format is not None:
|
||||
base_kwargs["memory_format"] = memory_format
|
||||
return BASE_MODULE_TO(module, *args, **base_kwargs)
|
||||
dtype_override = dtype or arg_dtype
|
||||
return move_module_tensors(module, target_device, dtype_override=dtype_override)
|
||||
return module.to(*args, **kwargs)
|
||||
base_kwargs = dict(kwargs)
|
||||
if device is not None and arg_device is None:
|
||||
base_kwargs["device"] = device
|
||||
if dtype is not None and arg_dtype is None:
|
||||
base_kwargs["dtype"] = dtype
|
||||
if non_blocking:
|
||||
base_kwargs["non_blocking"] = non_blocking
|
||||
if memory_format is not None:
|
||||
base_kwargs["memory_format"] = memory_format
|
||||
return BASE_MODULE_TO(module, *args, **base_kwargs)
|
||||
|
||||
|
||||
def load_module_tensor(
|
||||
@ -886,7 +959,6 @@ def load_module_tensor(
|
||||
name: str,
|
||||
device: torch.device,
|
||||
*,
|
||||
allow_alternate: bool = True,
|
||||
record_cache: bool = True,
|
||||
temporary: bool = False,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
@ -909,6 +981,9 @@ def load_module_tensor(
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
if current.device.type != "meta":
|
||||
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
|
||||
from . import model_management
|
||||
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
|
||||
_ensure_free_memory(device, _tensor_nbytes(current), headroom)
|
||||
if target_dtype is not None and current.dtype != target_dtype:
|
||||
tensor = current.to(device=device, dtype=target_dtype)
|
||||
else:
|
||||
@ -926,41 +1001,11 @@ def load_module_tensor(
|
||||
required_bytes = _meta_nbytes(disk_ref.meta)
|
||||
if required_bytes is None:
|
||||
return current
|
||||
free_mem_start = _device_free_memory(device)
|
||||
free_mem = _maybe_free_ram_budget(device, required_bytes)
|
||||
load_device = device
|
||||
if free_mem < required_bytes and allow_alternate:
|
||||
alt = _choose_alternate_device(device)
|
||||
if alt is not None:
|
||||
alt_free = _maybe_free_ram_budget(alt, required_bytes)
|
||||
if alt_free >= required_bytes:
|
||||
load_device = alt
|
||||
else:
|
||||
state = _get_materialization_state(module)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
_log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
|
||||
return current
|
||||
else:
|
||||
state = _get_materialization_state(module)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||
return current
|
||||
elif free_mem < required_bytes:
|
||||
state = _get_materialization_state(module)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||
return current
|
||||
from . import model_management
|
||||
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
|
||||
_ensure_free_memory(device, required_bytes, headroom)
|
||||
|
||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
|
||||
tensor = disk_ref.load(device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
|
||||
if temporary:
|
||||
return tensor
|
||||
if is_buffer:
|
||||
@ -971,7 +1016,7 @@ def load_module_tensor(
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
_log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
|
||||
_log_materialization(module, device, _device_free_memory(device), refs, state, "Disk weight loaded")
|
||||
return tensor
|
||||
|
||||
|
||||
@ -1015,22 +1060,15 @@ def _materialize_module_from_state_dict(
|
||||
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
|
||||
existing[key] = buf
|
||||
free_mem_start = _device_free_memory(target_device)
|
||||
remaining_budget = free_mem_start
|
||||
allowed = set(existing.keys())
|
||||
allowed = set(keys)
|
||||
from . import model_management
|
||||
headroom = RAM_HEADROOM_BYTES if target_device.type == "cpu" else model_management.extra_reserved_memory()
|
||||
for key in keys:
|
||||
if key in allowed:
|
||||
continue
|
||||
meta = _state_dict_meta(lazy_state.state_dict, key)
|
||||
required = _meta_nbytes(meta)
|
||||
if required is None:
|
||||
continue
|
||||
if target_device.type == "cpu":
|
||||
free_mem = _maybe_free_ram_budget(target_device, required)
|
||||
remaining_budget = min(remaining_budget, free_mem)
|
||||
if required <= remaining_budget:
|
||||
allowed.add(key)
|
||||
remaining_budget = max(0, remaining_budget - required)
|
||||
deferred_state_dict_keys = {key for key in keys if key not in allowed}
|
||||
_ensure_free_memory(target_device, required, headroom)
|
||||
state_dict = _BudgetedStateDict(
|
||||
lazy_state.state_dict,
|
||||
allowed_keys=allowed,
|
||||
@ -1065,7 +1103,7 @@ def _materialize_module_from_state_dict(
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
lazy_state.loaded = len(deferred_state_dict_keys) == 0
|
||||
lazy_state.loaded = True
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.device.type == "cpu":
|
||||
|
||||
@ -283,7 +283,7 @@ def load_gligen(sd):
|
||||
|
||||
gated = GatedSelfAttentionDense(
|
||||
query_dim, key_dim, n_heads, d_head)
|
||||
comfy.utils.load_state_dict(gated, n_sd, strict=False)
|
||||
gated.load_state_dict(n_sd, strict=False)
|
||||
output_list.append(gated)
|
||||
|
||||
if "position_net.null_positive_feature" in sd_k:
|
||||
@ -294,7 +294,7 @@ def load_gligen(sd):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.position_net = PositionNet(in_dim, out_dim)
|
||||
comfy.utils.load_state_dict(w, sd, strict=False)
|
||||
w.load_state_dict(sd, strict=False)
|
||||
|
||||
gligen = Gligen(output_list, w.position_net, key_dim)
|
||||
return gligen
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
@ -113,7 +112,7 @@ class HunyuanVideo15SRModel():
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=True)
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
@ -154,8 +153,8 @@ class AudioVAE(torch.nn.Module):
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
|
||||
comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torch.nn as nn
|
||||
|
||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||
@ -153,7 +152,7 @@ class VAE(nn.Module):
|
||||
return dec, posterior
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
comfy.utils.load_state_dict(self, src_dict, strict=True)
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
|
||||
@ -309,7 +309,7 @@ class BaseModel(torch.nn.Module):
|
||||
else:
|
||||
to_load = sd
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
@ -753,8 +753,8 @@ class StableAudio1(BaseModel):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
||||
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
|
||||
utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
|
||||
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights, strict=True)
|
||||
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights, strict=True)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
@ -590,10 +590,26 @@ def minimum_inference_memory():
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
||||
logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024))
|
||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
if freed_cache < memory_required:
|
||||
evict_ram_to_disk(memory_required - freed_cache)
|
||||
free_before = get_free_memory(device)
|
||||
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||
if free_before < memory_required:
|
||||
logging.debug(
|
||||
"RAM pressure: required=%d free=%d headroom=%d",
|
||||
memory_required,
|
||||
free_before,
|
||||
headroom,
|
||||
)
|
||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
freed_disk = 0
|
||||
if freed_cache < memory_required:
|
||||
freed_disk = evict_ram_to_disk(memory_required - freed_cache)
|
||||
free_after = get_free_memory(device)
|
||||
freed_total = max(0, free_after - free_before)
|
||||
logging.debug(
|
||||
"RAM freed: freed=%d free=%d",
|
||||
freed_total if freed_total > 0 else freed_cache + freed_disk,
|
||||
free_after,
|
||||
)
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -636,6 +652,7 @@ def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
||||
if not comfy.disk_weights.disk_weights_enabled():
|
||||
return 0
|
||||
|
||||
free_before = get_free_memory(torch.device("cpu"))
|
||||
freed = 0
|
||||
can_unload = []
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
@ -654,7 +671,14 @@ def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
||||
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
|
||||
|
||||
if freed > 0:
|
||||
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
|
||||
free_after = get_free_memory(torch.device("cpu"))
|
||||
freed_total = max(0, free_after - free_before)
|
||||
logging.debug(
|
||||
"RAM evicted to disk: required=%d free=%d freed=%d",
|
||||
memory_to_free,
|
||||
free_before,
|
||||
freed_total if freed_total > 0 else freed,
|
||||
)
|
||||
return freed
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
@ -802,6 +826,8 @@ def dtype_size(dtype):
|
||||
return dtype_size
|
||||
|
||||
def unet_offload_device():
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
return torch.device("meta")
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
return get_torch_device()
|
||||
else:
|
||||
@ -906,6 +932,8 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
||||
return torch.float32
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
return torch.device("meta")
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
@ -966,6 +994,8 @@ def vae_device():
|
||||
return get_torch_device()
|
||||
|
||||
def vae_offload_device():
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
return torch.device("meta")
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
else:
|
||||
@ -1163,14 +1193,13 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
WEIGHTS_RAM_CACHE_BYTES = 0
|
||||
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
|
||||
if args.weights_ram_cache_gb is not None:
|
||||
WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
|
||||
if args.low_ram:
|
||||
comfy.disk_weights.configure(
|
||||
WEIGHTS_RAM_CACHE_BYTES,
|
||||
allow_gds=WEIGHTS_GDS_ENABLED,
|
||||
pin_if_cpu=not args.disable_pinned_memory,
|
||||
ram_headroom_bytes=1024 * 1024 * 1024,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
22
comfy/ops.py
22
comfy/ops.py
@ -19,7 +19,6 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.disk_weights
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
@ -101,27 +100,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
|
||||
weight_source = s.weight
|
||||
bias_source = s.bias
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
if weight_source.device.type == "meta":
|
||||
loaded = comfy.disk_weights.load_module_tensor(
|
||||
s,
|
||||
"weight",
|
||||
device,
|
||||
temporary=True,
|
||||
dtype_override=dtype,
|
||||
)
|
||||
if loaded is not None:
|
||||
weight_source = loaded
|
||||
if bias_source is not None and bias_source.device.type == "meta":
|
||||
loaded_bias = comfy.disk_weights.load_module_tensor(
|
||||
s,
|
||||
"bias",
|
||||
device,
|
||||
temporary=True,
|
||||
dtype_override=bias_dtype,
|
||||
)
|
||||
if loaded_bias is not None:
|
||||
bias_source = loaded_bias
|
||||
|
||||
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
|
||||
|
||||
@ -37,6 +37,19 @@ _FST_LOADED = False
|
||||
_GDS_INITIALIZED = False
|
||||
_MISSING = object()
|
||||
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
|
||||
_PINNED_INFLIGHT = collections.deque()
|
||||
|
||||
|
||||
def _reap_pinned_inflight():
|
||||
if not _PINNED_INFLIGHT:
|
||||
return
|
||||
pending = collections.deque()
|
||||
while _PINNED_INFLIGHT:
|
||||
event, tensor = _PINNED_INFLIGHT.popleft()
|
||||
if event.query():
|
||||
continue
|
||||
pending.append((event, tensor))
|
||||
_PINNED_INFLIGHT.extend(pending)
|
||||
|
||||
|
||||
def _require_fastsafetensors():
|
||||
@ -204,14 +217,22 @@ class _SafeTensorFile:
|
||||
)
|
||||
return tensor
|
||||
|
||||
target_dtype = dtype
|
||||
if device_is_cuda and pin_if_cpu:
|
||||
_reap_pinned_inflight()
|
||||
target_dtype = None
|
||||
cpu_tensor = self._read_tensor_nogds(
|
||||
fst, framework, meta, torch.device("cpu"), dtype
|
||||
fst, framework, meta, torch.device("cpu"), target_dtype, pin_memory=bool(device_is_cuda and pin_if_cpu)
|
||||
)
|
||||
if device_is_cuda:
|
||||
if pin_if_cpu:
|
||||
cpu_tensor = cpu_tensor.pin_memory()
|
||||
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
|
||||
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
|
||||
if pin_if_cpu:
|
||||
event = torch.cuda.Event()
|
||||
event.record(torch.cuda.current_stream(device))
|
||||
_PINNED_INFLIGHT.append((event, cpu_tensor))
|
||||
if dtype is not None and dtype != gpu_tensor.dtype:
|
||||
gpu_tensor = gpu_tensor.to(dtype=dtype)
|
||||
return gpu_tensor
|
||||
return cpu_tensor
|
||||
|
||||
@ -233,6 +254,8 @@ class _SafeTensorFile:
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
dtype: Optional[torch.dtype],
|
||||
*,
|
||||
pin_memory: bool = False,
|
||||
) -> torch.Tensor:
|
||||
fd = self._ensure_fd()
|
||||
reader = self._ensure_nogds_reader(use_cuda=False)
|
||||
@ -241,7 +264,7 @@ class _SafeTensorFile:
|
||||
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
|
||||
chunk_bytes = max(1, chunk_bytes)
|
||||
ptr_align = framework.get_device_ptr_align()
|
||||
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu")
|
||||
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu", pin_memory=pin_memory)
|
||||
buffer_length = 0
|
||||
buf_ptr = None
|
||||
gbuf = None
|
||||
@ -250,6 +273,8 @@ class _SafeTensorFile:
|
||||
while chunk_offset < length:
|
||||
chunk_len = min(length - chunk_offset, chunk_bytes)
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
|
||||
if aligned_offset + aligned_length > self.index.size_bytes:
|
||||
aligned_length = max(0, self.index.size_bytes - aligned_offset)
|
||||
needed = aligned_length + ptr_align
|
||||
if buf_ptr is None or needed > buffer_length:
|
||||
if buf_ptr is not None:
|
||||
@ -272,7 +297,6 @@ class _SafeTensorFile:
|
||||
if buf_ptr is not None:
|
||||
fst.cpp.cpu_free(buf_ptr)
|
||||
if dtype is not None and dtype != dest_tensor.dtype:
|
||||
_validate_dtype_conversion(dest_tensor.dtype, dtype)
|
||||
dest_tensor = dest_tensor.to(dtype=dtype)
|
||||
return dest_tensor
|
||||
|
||||
@ -289,6 +313,8 @@ class _SafeTensorFile:
|
||||
abs_start = self.index.header_length + meta.data_offsets[0]
|
||||
length = meta.data_offsets[1] - meta.data_offsets[0]
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
||||
if aligned_offset + aligned_length > self.index.size_bytes:
|
||||
aligned_length = max(0, self.index.size_bytes - aligned_offset)
|
||||
ptr_align = framework.get_device_ptr_align()
|
||||
buffer_length = aligned_length + ptr_align
|
||||
fst_device = _fst_device_from_torch(fst, device)
|
||||
@ -306,7 +332,6 @@ class _SafeTensorFile:
|
||||
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
||||
)
|
||||
if dtype is not None and dtype != tensor.dtype:
|
||||
_validate_dtype_conversion(tensor.dtype, dtype)
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
return tensor
|
||||
|
||||
@ -348,8 +373,7 @@ def _dlpack_tensor_from_buffer(
|
||||
|
||||
|
||||
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
|
||||
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
|
||||
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
|
||||
return
|
||||
|
||||
|
||||
def _get_gds_o_direct(framework) -> bool:
|
||||
@ -523,8 +547,10 @@ class StreamStateDict(collections.abc.MutableMapping):
|
||||
raise KeyError(key)
|
||||
return default
|
||||
if self._index.has(key):
|
||||
value = self.get_tensor(key)
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
self._overrides.pop(key, None)
|
||||
return value
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
@ -636,8 +662,10 @@ class _BaseViewStateDict(MutableMapping):
|
||||
if default is _MISSING:
|
||||
raise
|
||||
return default
|
||||
value = self.get_tensor(key)
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
self._overrides.pop(key, None)
|
||||
return value
|
||||
|
||||
def meta(self, key: str):
|
||||
if key in self._overrides:
|
||||
@ -768,8 +796,10 @@ class DeviceViewStateDict(_BaseViewStateDict):
|
||||
if default is _MISSING:
|
||||
raise
|
||||
return default
|
||||
value = self.get_tensor(key)
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
self._overrides.pop(key, None)
|
||||
return value
|
||||
|
||||
|
||||
class FilterViewStateDict(_BaseViewStateDict):
|
||||
@ -932,3 +962,48 @@ class MappedStateDict(_BaseViewStateDict):
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._view_to_base.keys()
|
||||
|
||||
|
||||
def stream_load_state_dict(model, state_dict, strict: bool = False, assign: bool = False):
|
||||
if getattr(state_dict, "is_stream_state_dict", False) and hasattr(state_dict, "copy"):
|
||||
state_dict = state_dict.copy()
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
|
||||
def load(module, local_state_dict, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
if assign:
|
||||
local_metadata["assign_to_params_buffers"] = assign
|
||||
module._load_from_state_dict(
|
||||
local_state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
True,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
child_prefix = f"{prefix}{name}."
|
||||
child_state_dict = FilterViewStateDict(
|
||||
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
|
||||
)
|
||||
load(child, child_state_dict, child_prefix)
|
||||
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
for hook in module._load_state_dict_post_hooks.values():
|
||||
out = hook(module, incompatible)
|
||||
if out is not None:
|
||||
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
|
||||
|
||||
load(model, state_dict)
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
13
comfy/sd.py
13
comfy/sd.py
@ -26,7 +26,6 @@ import os
|
||||
|
||||
import comfy.utils
|
||||
import comfy.safetensors_stream
|
||||
import comfy.disk_weights
|
||||
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
@ -126,7 +125,7 @@ class CLIP:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
if params['device'] != offload_device:
|
||||
comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
|
||||
self.cond_stage_model.to(offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
@ -290,7 +289,7 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
@ -658,7 +657,7 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
@ -672,7 +671,7 @@ class VAE:
|
||||
if dtype is None:
|
||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||
self.vae_dtype = dtype
|
||||
comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
|
||||
self.first_stage_model.to(dtype=self.vae_dtype)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
@ -979,7 +978,7 @@ def load_style_model(ckpt_path):
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
comfy.utils.load_state_dict(model, model_data, strict=True)
|
||||
model.load_state_dict(model_data, strict=True)
|
||||
return StyleModel(model)
|
||||
|
||||
def sd_shape(state_dict, key):
|
||||
@ -1547,7 +1546,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = comfy.disk_weights.module_to(model, offload_device)
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
|
||||
@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
return self(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
|
||||
@ -56,9 +56,9 @@ class TAESD(nn.Module):
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||
if encoder_path is not None:
|
||||
comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
|
||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
|
||||
if decoder_path is not None:
|
||||
comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
|
||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
|
||||
@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
return comfy.utils.load_state_dict(self, sdo, strict=False)
|
||||
return self.load_state_dict(sdo, strict=False)
|
||||
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
|
||||
@ -32,7 +32,6 @@ from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
from . import safetensors_stream
|
||||
import comfy.disk_weights
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -166,62 +165,6 @@ def state_dict_meta(state_dict, key):
|
||||
)
|
||||
|
||||
|
||||
def load_state_dict(model, state_dict, strict=False, assign=False):
|
||||
if is_stream_state_dict(state_dict):
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
|
||||
comfy.disk_weights.register_module_weights(model, state_dict)
|
||||
comfy.disk_weights.attach_disk_weight_hooks(model)
|
||||
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
|
||||
return missing, unexpected
|
||||
return model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def stream_load_state_dict(model, state_dict, strict=False, assign=False):
|
||||
if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
|
||||
state_dict = state_dict.copy()
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
|
||||
def load(module, local_state_dict, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
if assign:
|
||||
local_metadata["assign_to_params_buffers"] = assign
|
||||
module._load_from_state_dict(
|
||||
local_state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
True,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
child_prefix = f"{prefix}{name}."
|
||||
child_state_dict = safetensors_stream.FilterViewStateDict(
|
||||
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
|
||||
)
|
||||
load(child, child_state_dict, child_prefix)
|
||||
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
for hook in module._load_state_dict_post_hooks.values():
|
||||
out = hook(module, incompatible)
|
||||
if out is not None:
|
||||
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
|
||||
|
||||
load(model, state_dict)
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user