mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Compare commits
6 Commits
b93f165c2e
...
c825bc526e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c825bc526e | ||
|
|
fcbd22b514 | ||
|
|
91809e83ff | ||
|
|
82e70aa3c2 | ||
|
|
c3eaea0429 | ||
|
|
95ca11fe25 |
@ -349,12 +349,12 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
|
||||
| `--low-ram` | Enable disk weight offloading. Sets RAM headroom to 1024MB and uses `--reserve-vram` as the VRAM headroom. |
|
||||
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
|
||||
|
||||
### Disk tier for model weights
|
||||
|
||||
When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
|
||||
When `--low-ram` is enabled, ComfyUI streams safetensors weights from disk and offloads weights to disk when RAM headroom is reached.
|
||||
|
||||
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ class AudioEncoderModel():
|
||||
self.model_sample_rate = 16000
|
||||
|
||||
def load_sd(self, sd):
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -114,7 +114,7 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
|
||||
parser.add_argument("--low-ram", action="store_true", help="Enable disk weight offloading. Sets RAM headroom to 1024MB and uses --reserve-vram as the VRAM headroom.")
|
||||
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
|
||||
@ -6,7 +6,6 @@ import logging
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
|
||||
@ -48,7 +47,7 @@ class ClipVisionModel():
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -25,7 +25,6 @@ import logging
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_detection
|
||||
import comfy.disk_weights
|
||||
import comfy.model_patcher
|
||||
import comfy.ops
|
||||
import comfy.latent_formats
|
||||
@ -386,7 +385,7 @@ class ControlLora(ControlNet):
|
||||
controlnet_config["operations"] = control_lora_ops
|
||||
controlnet_config["dtype"] = dtype
|
||||
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||
comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
|
||||
self.control_model.to(comfy.model_management.get_torch_device())
|
||||
diffusion_model = model.diffusion_model
|
||||
sd = diffusion_model.state_dict()
|
||||
|
||||
@ -440,7 +439,7 @@ def controlnet_config(sd, model_options={}):
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
@ -474,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}):
|
||||
class ControlNetSD35(ControlNet):
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
if self.control_model.double_y_emb:
|
||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||
else:
|
||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def copy(self):
|
||||
@ -749,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
|
||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||
else:
|
||||
missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
|
||||
if len(missing) > 0:
|
||||
logging.warning("missing controlnet keys: {}".format(missing))
|
||||
@ -817,8 +816,8 @@ class T2IAdapter(ControlBase):
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
if self.control_input is None:
|
||||
comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
|
||||
comfy.disk_weights.module_to(self.t2i_model, self.device)
|
||||
self.t2i_model.to(dtype=x_noisy.dtype)
|
||||
self.t2i_model.to(self.device)
|
||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||
self.t2i_model.cpu()
|
||||
|
||||
@ -875,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
||||
else:
|
||||
return None
|
||||
|
||||
missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
|
||||
missing, unexpected = model_ad.load_state_dict(t2i_data, strict=True)
|
||||
if len(missing) > 0:
|
||||
logging.warning("t2i missing {}".format(missing))
|
||||
|
||||
|
||||
@ -33,7 +33,11 @@ from . import safetensors_stream
|
||||
ALLOW_GDS = False
|
||||
PIN_IF_CPU = False
|
||||
DISK_WEIGHTS_ENABLED = False
|
||||
RAM_HEADROOM_BYTES = 0
|
||||
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
|
||||
BASE_MODULE_TO = torch.nn.Module.to
|
||||
BASE_LOAD_STATE_DICT = torch.nn.Module.load_state_dict
|
||||
_MONKEYPATCHED = False
|
||||
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
|
||||
DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary()
|
||||
_MISSING = object()
|
||||
@ -92,6 +96,7 @@ class CacheEntry:
|
||||
name: str
|
||||
size_bytes: int
|
||||
is_buffer: bool
|
||||
device_type: str
|
||||
|
||||
|
||||
class DiskWeightCache:
|
||||
@ -108,16 +113,25 @@ class DiskWeightCache:
|
||||
return (id(module), name)
|
||||
|
||||
def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool):
|
||||
if tensor.device.type != "cpu":
|
||||
if tensor.device.type == "meta":
|
||||
return
|
||||
size_bytes = tensor.numel() * tensor.element_size()
|
||||
key = self._entry_key(module, name)
|
||||
if key in self._entries:
|
||||
entry = self._entries.pop(key)
|
||||
self.current_bytes -= entry.size_bytes
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
module_ref = weakref.ref(module, self._drop_module_entries)
|
||||
self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer)
|
||||
self.current_bytes += size_bytes
|
||||
device_type = tensor.device.type
|
||||
self._entries[key] = CacheEntry(
|
||||
module_ref=module_ref,
|
||||
name=name,
|
||||
size_bytes=size_bytes,
|
||||
is_buffer=is_buffer,
|
||||
device_type=device_type,
|
||||
)
|
||||
if device_type == "cpu":
|
||||
self.current_bytes += size_bytes
|
||||
self._evict_if_needed()
|
||||
|
||||
def touch(self, module: torch.nn.Module, name: str):
|
||||
@ -129,9 +143,10 @@ class DiskWeightCache:
|
||||
def evict_bytes(self, bytes_to_free: int):
|
||||
freed = 0
|
||||
while self._entries and freed < bytes_to_free:
|
||||
_, entry = self._entries.popitem(last=False)
|
||||
entry = self.pop_lru(torch.device("cpu"))
|
||||
if entry is None:
|
||||
break
|
||||
freed += entry.size_bytes
|
||||
self.current_bytes -= entry.size_bytes
|
||||
module = entry.module_ref()
|
||||
if module is not None:
|
||||
_evict_module_weight(module, entry.name, entry.is_buffer)
|
||||
@ -144,8 +159,26 @@ class DiskWeightCache:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
entry = self._entries.pop(key)
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
|
||||
def remove_entry(self, module: torch.nn.Module, name: str):
|
||||
key = self._entry_key(module, name)
|
||||
entry = self._entries.pop(key, None)
|
||||
if entry is None:
|
||||
return
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
|
||||
def pop_lru(self, device: torch.device) -> Optional[CacheEntry]:
|
||||
for key, entry in self._entries.items():
|
||||
if entry.device_type == device.type:
|
||||
self._entries.pop(key)
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
return entry
|
||||
return None
|
||||
|
||||
def _drop_module_entries(self, module_ref: weakref.ReferenceType):
|
||||
to_remove = []
|
||||
for key, entry in self._entries.items():
|
||||
@ -153,12 +186,14 @@ class DiskWeightCache:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
entry = self._entries.pop(key)
|
||||
self.current_bytes -= entry.size_bytes
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
|
||||
def _evict_if_needed(self):
|
||||
while self._entries and self.current_bytes > self.max_bytes:
|
||||
_, entry = self._entries.popitem(last=False)
|
||||
self.current_bytes -= entry.size_bytes
|
||||
entry = self.pop_lru(torch.device("cpu"))
|
||||
if entry is None:
|
||||
break
|
||||
module = entry.module_ref()
|
||||
if module is not None:
|
||||
_evict_module_weight(module, entry.name, entry.is_buffer)
|
||||
@ -169,13 +204,34 @@ CACHE = DiskWeightCache(0)
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
|
||||
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
|
||||
def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, enabled: bool = True):
|
||||
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED, RAM_HEADROOM_BYTES
|
||||
ALLOW_GDS = allow_gds
|
||||
PIN_IF_CPU = pin_if_cpu
|
||||
DISK_WEIGHTS_ENABLED = enabled
|
||||
CACHE.set_limit(cache_bytes if enabled else 0)
|
||||
if not enabled:
|
||||
RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes))
|
||||
if enabled:
|
||||
from . import model_management
|
||||
cpu_capacity_bytes = max(0, model_management.get_total_memory(torch.device("cpu")) - RAM_HEADROOM_BYTES)
|
||||
CACHE.set_limit(cpu_capacity_bytes)
|
||||
LOGGER.debug(
|
||||
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
|
||||
enabled,
|
||||
RAM_HEADROOM_BYTES,
|
||||
cpu_capacity_bytes,
|
||||
)
|
||||
else:
|
||||
CACHE.set_limit(0)
|
||||
LOGGER.debug(
|
||||
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
|
||||
enabled,
|
||||
RAM_HEADROOM_BYTES,
|
||||
0,
|
||||
)
|
||||
if enabled:
|
||||
install_monkeypatches()
|
||||
else:
|
||||
uninstall_monkeypatches()
|
||||
CACHE._entries.clear()
|
||||
CACHE.current_bytes = 0
|
||||
|
||||
@ -184,6 +240,66 @@ def disk_weights_enabled() -> bool:
|
||||
return DISK_WEIGHTS_ENABLED
|
||||
|
||||
|
||||
def ram_headroom_bytes() -> int:
|
||||
return RAM_HEADROOM_BYTES
|
||||
|
||||
|
||||
def _is_stream_state_dict(state_dict) -> bool:
|
||||
return (
|
||||
getattr(state_dict, "is_stream_state_dict", False)
|
||||
and hasattr(state_dict, "get_tensor")
|
||||
and hasattr(state_dict, "meta")
|
||||
)
|
||||
|
||||
|
||||
def patched_to(self: torch.nn.Module, *args, **kwargs):
|
||||
if not disk_weights_enabled():
|
||||
return BASE_MODULE_TO(self, *args, **kwargs)
|
||||
device, dtype, non_blocking, memory_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
module_to(
|
||||
self,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
non_blocking=non_blocking,
|
||||
memory_format=memory_format,
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
def patched_load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
|
||||
if not disk_weights_enabled():
|
||||
if _is_stream_state_dict(state_dict):
|
||||
return safetensors_stream.stream_load_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
strict=strict,
|
||||
assign=assign,
|
||||
)
|
||||
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
|
||||
if _is_stream_state_dict(state_dict):
|
||||
missing_keys, unexpected_keys = lazy_load_state_dict(self, state_dict, strict=strict)
|
||||
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
|
||||
|
||||
|
||||
def install_monkeypatches():
|
||||
global _MONKEYPATCHED
|
||||
if _MONKEYPATCHED:
|
||||
return
|
||||
torch.nn.Module.to = patched_to
|
||||
torch.nn.Module.load_state_dict = patched_load_state_dict
|
||||
_MONKEYPATCHED = True
|
||||
|
||||
|
||||
def uninstall_monkeypatches():
|
||||
global _MONKEYPATCHED
|
||||
if not _MONKEYPATCHED:
|
||||
return
|
||||
torch.nn.Module.to = BASE_MODULE_TO
|
||||
torch.nn.Module.load_state_dict = BASE_LOAD_STATE_DICT
|
||||
_MONKEYPATCHED = False
|
||||
|
||||
|
||||
def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
|
||||
if not disk_weights_enabled():
|
||||
return
|
||||
@ -197,7 +313,7 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
|
||||
REGISTRY.register(submodule, name, ref)
|
||||
if param.device.type == "cpu":
|
||||
if param.device.type != "meta":
|
||||
CACHE.record(submodule, name, param, is_buffer=False)
|
||||
for name, buf in submodule.named_buffers(recurse=False):
|
||||
key = f"{module_prefix}{name}" if module_prefix else name
|
||||
@ -205,7 +321,7 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
|
||||
REGISTRY.register(submodule, name, ref)
|
||||
if buf.device.type == "cpu":
|
||||
if buf.device.type != "meta":
|
||||
CACHE.record(submodule, name, buf, is_buffer=True)
|
||||
|
||||
|
||||
@ -269,6 +385,23 @@ def _meta_tensor(meta, dtype_override: Optional[torch.dtype] = None) -> torch.Te
|
||||
return torch.empty(shape, dtype=dtype, device="meta")
|
||||
|
||||
|
||||
def _attach_disk_identity(tensor: torch.Tensor, module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
tensor._disk_weights_module_ref = weakref.ref(module)
|
||||
tensor._disk_weights_name = name
|
||||
tensor._disk_weights_is_buffer = is_buffer
|
||||
|
||||
|
||||
def materialize_meta_tensor(tensor: torch.Tensor, target_device: torch.device, dtype_override: Optional[torch.dtype]):
|
||||
module_ref = getattr(tensor, "_disk_weights_module_ref", None)
|
||||
name = getattr(tensor, "_disk_weights_name", None)
|
||||
if module_ref is None or name is None:
|
||||
raise RuntimeError("Meta tensor missing disk weight identity")
|
||||
module = module_ref()
|
||||
if module is None:
|
||||
raise RuntimeError("Disk weight module reference expired")
|
||||
return load_module_tensor(module, name, target_device, dtype_override=dtype_override, temporary=False)
|
||||
|
||||
|
||||
def _state_dict_meta(state_dict: MutableMapping, key: str):
|
||||
if hasattr(state_dict, "meta"):
|
||||
return state_dict.meta(key)
|
||||
@ -369,33 +502,30 @@ def _device_free_memory(device: torch.device) -> int:
|
||||
return int(model_management.get_free_memory(device))
|
||||
|
||||
|
||||
def _evict_ram_for_budget(required_bytes: int) -> int:
|
||||
if required_bytes <= 0:
|
||||
return 0
|
||||
freed = evict_ram_cache(required_bytes)
|
||||
if freed < required_bytes:
|
||||
def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_bytes: int) -> int:
|
||||
free_before = _device_free_memory(device)
|
||||
if free_before < required_bytes + headroom_bytes:
|
||||
LOGGER.debug(
|
||||
"Disk weight memory pressure: required=%d free=%d headroom=%d device=%s",
|
||||
required_bytes,
|
||||
free_before,
|
||||
headroom_bytes,
|
||||
device,
|
||||
)
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
from . import model_management
|
||||
freed += model_management.evict_ram_to_disk(required_bytes - freed)
|
||||
return freed
|
||||
|
||||
|
||||
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
|
||||
free_mem = _device_free_memory(device)
|
||||
if device.type == "cpu" and free_mem < required_bytes:
|
||||
_evict_ram_for_budget(required_bytes - free_mem)
|
||||
free_mem = _device_free_memory(device)
|
||||
return free_mem
|
||||
|
||||
|
||||
def _choose_alternate_device(device: torch.device) -> Optional[torch.device]:
|
||||
from . import model_management
|
||||
if device.type == "cpu":
|
||||
alt = model_management.get_torch_device()
|
||||
if alt.type != "cpu":
|
||||
return alt
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
return None
|
||||
model_management.free_memory(required_bytes + headroom_bytes, device)
|
||||
free_after = _device_free_memory(device)
|
||||
freed = max(0, free_after - free_before)
|
||||
LOGGER.debug(
|
||||
"Disk weight memory freed: freed=%d free_before=%d free_after=%d device=%s",
|
||||
freed,
|
||||
free_before,
|
||||
free_after,
|
||||
device,
|
||||
)
|
||||
return free_after
|
||||
return free_before
|
||||
|
||||
|
||||
class _BudgetedStateDict(MutableMapping):
|
||||
@ -527,8 +657,10 @@ class _BudgetedStateDict(MutableMapping):
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
value = self.get_tensor(key)
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
self._overrides.pop(key, None)
|
||||
return value
|
||||
|
||||
def meta(self, key: str):
|
||||
return self._get_meta(key)
|
||||
@ -564,6 +696,8 @@ def register_lazy_modules(model: torch.nn.Module, state_dict):
|
||||
|
||||
|
||||
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
from . import model_management
|
||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||
if lazy_state is not None:
|
||||
CACHE.remove_module(module)
|
||||
@ -571,15 +705,31 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
if refs:
|
||||
state = _get_materialization_state(module)
|
||||
for ref_name, disk_ref in refs.items():
|
||||
if ref_name in module._parameters:
|
||||
current = module._parameters[ref_name]
|
||||
elif ref_name in module._buffers:
|
||||
current = module._buffers[ref_name]
|
||||
else:
|
||||
current = None
|
||||
if (
|
||||
current is not None
|
||||
and current.device.type == "cpu"
|
||||
and current.data_ptr() in model_management.PINNED_MEMORY
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None)
|
||||
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||
if shape is None or dtype is None:
|
||||
continue
|
||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||
if disk_ref.is_buffer:
|
||||
module._buffers[ref_name] = meta_tensor
|
||||
_attach_disk_identity(meta_tensor, module, ref_name, True)
|
||||
else:
|
||||
module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
param = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
module._parameters[ref_name] = param
|
||||
_attach_disk_identity(param, module, ref_name, False)
|
||||
nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if nbytes is not None:
|
||||
state.loaded_keys.discard(ref_name)
|
||||
@ -593,16 +743,31 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
ref = REGISTRY.get(module)
|
||||
if not ref or name not in ref:
|
||||
return
|
||||
CACHE.remove_entry(module, name)
|
||||
disk_ref = ref[name]
|
||||
if is_buffer:
|
||||
current = module._buffers.get(name)
|
||||
else:
|
||||
current = module._parameters.get(name)
|
||||
if (
|
||||
current is not None
|
||||
and current.device.type == "cpu"
|
||||
and current.data_ptr() in model_management.PINNED_MEMORY
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None)
|
||||
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||
if shape is None or dtype is None:
|
||||
return
|
||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||
if is_buffer:
|
||||
module._buffers[name] = meta_tensor
|
||||
_attach_disk_identity(meta_tensor, module, name, True)
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
param = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
module._parameters[name] = param
|
||||
_attach_disk_identity(param, module, name, False)
|
||||
state = _get_materialization_state(module)
|
||||
nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if nbytes is not None:
|
||||
@ -671,7 +836,6 @@ def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype:
|
||||
def ensure_module_materialized(
|
||||
module: torch.nn.Module,
|
||||
target_device: torch.device,
|
||||
fallback_device: Optional[torch.device] = None,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
):
|
||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||
@ -687,12 +851,12 @@ def ensure_module_materialized(
|
||||
if not refs:
|
||||
return
|
||||
state = _get_materialization_state(module)
|
||||
if dtype_override is not None:
|
||||
for name in refs.keys():
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
# Do not persist dtype overrides into storage.
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
free_mem_start = _device_free_memory(target_device)
|
||||
remaining_budget = free_mem_start
|
||||
from . import model_management
|
||||
non_blocking = model_management.device_supports_non_blocking(target_device)
|
||||
offload_stream = model_management.get_offload_stream(target_device) if non_blocking else None
|
||||
for name in sorted(refs.keys()):
|
||||
disk_ref = refs[name]
|
||||
if name in module._parameters:
|
||||
@ -705,31 +869,20 @@ def ensure_module_materialized(
|
||||
continue
|
||||
if current is None:
|
||||
continue
|
||||
target_dtype = dtype_override or _get_future_dtype(module, name)
|
||||
if current.device.type != "meta" and current.device == target_device and (
|
||||
target_dtype is None or current.dtype == target_dtype
|
||||
):
|
||||
if current.device.type == "cpu":
|
||||
CACHE.touch(module, name)
|
||||
# Persistent tensors must remain in stored dtype.
|
||||
target_dtype = None
|
||||
if current.device.type != "meta" and current.device == target_device:
|
||||
CACHE.touch(module, name)
|
||||
continue
|
||||
meta_nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if meta_nbytes is None:
|
||||
continue
|
||||
required_bytes = meta_nbytes
|
||||
if target_device.type == "cpu":
|
||||
free_mem = _maybe_free_ram_budget(target_device, required_bytes)
|
||||
remaining_budget = min(remaining_budget, free_mem)
|
||||
if required_bytes > remaining_budget:
|
||||
if fallback_device is not None and fallback_device != target_device:
|
||||
fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes)
|
||||
if fallback_free >= required_bytes:
|
||||
target_for_load = fallback_device
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
_ensure_free_memory(target_device, required_bytes, RAM_HEADROOM_BYTES)
|
||||
else:
|
||||
target_for_load = target_device
|
||||
_ensure_free_memory(target_device, required_bytes, model_management.extra_reserved_memory())
|
||||
target_for_load = target_device
|
||||
if current.device.type == "meta":
|
||||
tensor = disk_ref.load(
|
||||
target_for_load,
|
||||
@ -737,18 +890,38 @@ def ensure_module_materialized(
|
||||
PIN_IF_CPU,
|
||||
dtype_override=target_dtype,
|
||||
)
|
||||
if tensor.device != target_for_load or (target_dtype is not None and tensor.dtype != target_dtype):
|
||||
tensor = model_management.cast_to(
|
||||
tensor,
|
||||
device=target_for_load,
|
||||
dtype=target_dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(target_for_load, offload_stream)
|
||||
else:
|
||||
if target_dtype is not None and current.dtype != target_dtype:
|
||||
tensor = current.to(device=target_for_load, dtype=target_dtype)
|
||||
else:
|
||||
tensor = current.to(device=target_for_load)
|
||||
if (
|
||||
current.device.type == "cpu"
|
||||
and current.data_ptr() in model_management.PINNED_MEMORY
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
tensor = model_management.cast_to(
|
||||
current,
|
||||
device=target_for_load,
|
||||
dtype=target_dtype if target_dtype is not None else current.dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(target_for_load, offload_stream)
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu":
|
||||
if tensor.device.type != "meta":
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
remaining_budget = max(0, remaining_budget - required_bytes)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
|
||||
|
||||
@ -758,17 +931,17 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
||||
return
|
||||
input_dtype = _find_tensor_dtype(args, kwargs)
|
||||
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
|
||||
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
|
||||
if getattr(module, "comfy_cast_weights", False):
|
||||
target_device = torch.device("cpu")
|
||||
fallback_device = _find_tensor_device(args, kwargs)
|
||||
dtype_override = None # persistent tensors stay in stored dtype; per-op casting only
|
||||
input_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||
if getattr(module, "comfy_patched_weights", False):
|
||||
target_device = input_device
|
||||
elif getattr(module, "comfy_cast_weights", False):
|
||||
target_device = input_device
|
||||
else:
|
||||
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||
fallback_device = None
|
||||
target_device = input_device
|
||||
ensure_module_materialized(
|
||||
module,
|
||||
target_device,
|
||||
fallback_device=fallback_device,
|
||||
dtype_override=dtype_override,
|
||||
)
|
||||
|
||||
@ -786,9 +959,74 @@ def attach_disk_weight_hooks(model: torch.nn.Module):
|
||||
def evict_ram_cache(bytes_to_free: int):
|
||||
if bytes_to_free <= 0:
|
||||
return 0
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
return CACHE.evict_bytes(bytes_to_free)
|
||||
|
||||
|
||||
def _move_cache_entry_to_cpu(entry: CacheEntry):
|
||||
module = entry.module_ref()
|
||||
if module is None:
|
||||
return
|
||||
if entry.is_buffer:
|
||||
current = module._buffers.get(entry.name)
|
||||
else:
|
||||
current = module._parameters.get(entry.name)
|
||||
if current is None or current.device.type == "meta":
|
||||
return
|
||||
from . import model_management
|
||||
non_blocking = model_management.device_supports_non_blocking(torch.device("cpu"))
|
||||
offload_stream = model_management.get_offload_stream(torch.device("cpu")) if non_blocking else None
|
||||
tensor = model_management.cast_to(
|
||||
current,
|
||||
device=torch.device("cpu"),
|
||||
dtype=current.dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(current.device, offload_stream)
|
||||
if entry.is_buffer:
|
||||
module._buffers[entry.name] = tensor
|
||||
else:
|
||||
module._parameters[entry.name] = torch.nn.Parameter(tensor, requires_grad=current.requires_grad)
|
||||
CACHE.record(module, entry.name, tensor, is_buffer=entry.is_buffer)
|
||||
|
||||
|
||||
def _evict_cpu_entry_to_meta(entry: CacheEntry):
|
||||
module = entry.module_ref()
|
||||
if module is None:
|
||||
return
|
||||
_evict_module_weight(module, entry.name, entry.is_buffer)
|
||||
CACHE.remove_entry(module, entry.name)
|
||||
|
||||
|
||||
def evict_for_budget(target_device: torch.device, required_bytes: int):
|
||||
if not disk_weights_enabled() or required_bytes <= 0:
|
||||
return
|
||||
from . import model_management
|
||||
free = model_management.get_free_memory(target_device)
|
||||
if free >= required_bytes:
|
||||
return
|
||||
cpu_device = torch.device("cpu")
|
||||
if target_device.type != "cpu":
|
||||
while free < required_bytes:
|
||||
entry = CACHE.pop_lru(target_device)
|
||||
if entry is None:
|
||||
break
|
||||
free_cpu = model_management.get_free_memory(cpu_device)
|
||||
if free_cpu < RAM_HEADROOM_BYTES:
|
||||
CACHE.evict_bytes(RAM_HEADROOM_BYTES - free_cpu)
|
||||
_move_cache_entry_to_cpu(entry)
|
||||
free = model_management.get_free_memory(target_device)
|
||||
else:
|
||||
while free < required_bytes:
|
||||
entry = CACHE.pop_lru(cpu_device)
|
||||
if entry is None:
|
||||
break
|
||||
_evict_cpu_entry_to_meta(entry)
|
||||
free = model_management.get_free_memory(target_device)
|
||||
|
||||
|
||||
def materialize_module_tree(module: torch.nn.Module, target_device: torch.device):
|
||||
if not disk_weights_enabled():
|
||||
return
|
||||
@ -826,17 +1064,55 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
return None
|
||||
|
||||
|
||||
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
|
||||
def _move(tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
if tensor.device.type == "meta":
|
||||
return tensor
|
||||
if dtype_override is not None and tensor.dtype != dtype_override:
|
||||
return tensor.to(device=device_to, dtype=dtype_override)
|
||||
return tensor.to(device=device_to)
|
||||
def move_module_tensors(
|
||||
module: torch.nn.Module,
|
||||
device_to: torch.device,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
):
|
||||
from . import model_management
|
||||
offload_stream = None
|
||||
if non_blocking and model_management.device_supports_non_blocking(device_to):
|
||||
offload_stream = model_management.get_offload_stream(device_to)
|
||||
|
||||
module._apply(_move)
|
||||
def apply_fn(tensor):
|
||||
if tensor is None or tensor.device.type == "meta":
|
||||
return tensor
|
||||
target_dtype = dtype_override or tensor.dtype
|
||||
if (
|
||||
tensor.device.type == "cpu"
|
||||
and tensor.data_ptr() in model_management.PINNED_MEMORY
|
||||
and (device_to.type != "cpu" or target_dtype != tensor.dtype)
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(tensor)
|
||||
model_management.unpin_memory(tensor)
|
||||
if tensor.device == device_to and tensor.dtype == target_dtype:
|
||||
return tensor
|
||||
return model_management.cast_to(
|
||||
tensor,
|
||||
device=device_to,
|
||||
dtype=target_dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
|
||||
module._apply(apply_fn)
|
||||
if disk_weights_enabled():
|
||||
for submodule in module.modules():
|
||||
refs = REGISTRY.get(submodule)
|
||||
if not refs:
|
||||
continue
|
||||
for name, disk_ref in refs.items():
|
||||
if disk_ref.is_buffer:
|
||||
tensor = submodule._buffers.get(name)
|
||||
else:
|
||||
tensor = submodule._parameters.get(name)
|
||||
if tensor is None or tensor.device.type == "meta":
|
||||
CACHE.remove_entry(submodule, name)
|
||||
continue
|
||||
CACHE.record(submodule, name, tensor, is_buffer=disk_ref.is_buffer)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(device_to, offload_stream)
|
||||
return module
|
||||
|
||||
|
||||
@ -864,21 +1140,60 @@ def offload_module_weights(module: torch.nn.Module) -> int:
|
||||
return offloaded_bytes
|
||||
|
||||
|
||||
def module_to(module: torch.nn.Module, *args, **kwargs):
|
||||
def module_to(
|
||||
module: torch.nn.Module,
|
||||
*args,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
memory_format=None,
|
||||
**kwargs,
|
||||
):
|
||||
allow_materialize = kwargs.pop("allow_materialize", True)
|
||||
arg_device = _extract_to_device(args, kwargs)
|
||||
arg_dtype = _extract_to_dtype(args, kwargs)
|
||||
if disk_weights_enabled():
|
||||
target_device = _extract_to_device(args, kwargs)
|
||||
target_device = device or arg_device
|
||||
if target_device is None:
|
||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||
dtype_override = dtype or arg_dtype
|
||||
if target_device.type == "meta":
|
||||
offload_module_weights(module)
|
||||
for submodule in module.modules():
|
||||
offload_module_weights(submodule)
|
||||
move_module_tensors(
|
||||
submodule,
|
||||
target_device,
|
||||
dtype_override=dtype_override,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
return module
|
||||
if allow_materialize:
|
||||
materialize_module_tree(module, target_device)
|
||||
return module.to(*args, **kwargs)
|
||||
dtype_override = _extract_to_dtype(args, kwargs)
|
||||
return move_module_tensors(module, target_device, dtype_override=dtype_override)
|
||||
return module.to(*args, **kwargs)
|
||||
if not allow_materialize:
|
||||
move_module_tensors(
|
||||
module,
|
||||
target_device,
|
||||
dtype_override=dtype_override,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
return module
|
||||
for submodule in module.modules():
|
||||
ensure_module_materialized(submodule, target_device, dtype_override=dtype_override)
|
||||
move_module_tensors(
|
||||
module,
|
||||
target_device,
|
||||
dtype_override=dtype_override,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
return module
|
||||
base_kwargs = dict(kwargs)
|
||||
if device is not None and arg_device is None:
|
||||
base_kwargs["device"] = device
|
||||
if dtype is not None and arg_dtype is None:
|
||||
base_kwargs["dtype"] = dtype
|
||||
if non_blocking:
|
||||
base_kwargs["non_blocking"] = non_blocking
|
||||
if memory_format is not None:
|
||||
base_kwargs["memory_format"] = memory_format
|
||||
return BASE_MODULE_TO(module, *args, **base_kwargs)
|
||||
|
||||
|
||||
def load_module_tensor(
|
||||
@ -886,7 +1201,6 @@ def load_module_tensor(
|
||||
name: str,
|
||||
device: torch.device,
|
||||
*,
|
||||
allow_alternate: bool = True,
|
||||
record_cache: bool = True,
|
||||
temporary: bool = False,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
@ -904,16 +1218,33 @@ def load_module_tensor(
|
||||
return None
|
||||
if current is None:
|
||||
return None
|
||||
target_dtype = dtype_override or _get_future_dtype(module, name)
|
||||
if dtype_override is not None:
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
# Persistent loads must not mix storage with dtype casting.
|
||||
if not temporary:
|
||||
dtype_override = None
|
||||
target_dtype = dtype_override
|
||||
if current.device.type != "meta":
|
||||
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
|
||||
if target_dtype is not None and current.dtype != target_dtype:
|
||||
tensor = current.to(device=device, dtype=target_dtype)
|
||||
else:
|
||||
tensor = current.to(device=device)
|
||||
from . import model_management
|
||||
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
|
||||
_ensure_free_memory(device, _tensor_nbytes(current), headroom)
|
||||
non_blocking = model_management.device_supports_non_blocking(device)
|
||||
offload_stream = model_management.get_offload_stream(device) if non_blocking else None
|
||||
tensor = model_management.cast_to(
|
||||
current,
|
||||
device=device,
|
||||
dtype=target_dtype if target_dtype is not None else current.dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(device, offload_stream)
|
||||
if not temporary:
|
||||
if (
|
||||
current.device.type == "cpu"
|
||||
and current.data_ptr() in model_management.PINNED_MEMORY
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
@ -926,52 +1257,33 @@ def load_module_tensor(
|
||||
required_bytes = _meta_nbytes(disk_ref.meta)
|
||||
if required_bytes is None:
|
||||
return current
|
||||
free_mem_start = _device_free_memory(device)
|
||||
free_mem = _maybe_free_ram_budget(device, required_bytes)
|
||||
load_device = device
|
||||
if free_mem < required_bytes and allow_alternate:
|
||||
alt = _choose_alternate_device(device)
|
||||
if alt is not None:
|
||||
alt_free = _maybe_free_ram_budget(alt, required_bytes)
|
||||
if alt_free >= required_bytes:
|
||||
load_device = alt
|
||||
else:
|
||||
state = _get_materialization_state(module)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
_log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
|
||||
return current
|
||||
else:
|
||||
state = _get_materialization_state(module)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||
return current
|
||||
elif free_mem < required_bytes:
|
||||
state = _get_materialization_state(module)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||
return current
|
||||
|
||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
|
||||
from . import model_management
|
||||
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
|
||||
_ensure_free_memory(device, required_bytes, headroom)
|
||||
non_blocking = model_management.device_supports_non_blocking(device)
|
||||
offload_stream = model_management.get_offload_stream(device) if non_blocking else None
|
||||
tensor = disk_ref.load(device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
|
||||
if tensor.device != device or (target_dtype is not None and tensor.dtype != target_dtype):
|
||||
tensor = model_management.cast_to(
|
||||
tensor,
|
||||
device=device,
|
||||
dtype=target_dtype if target_dtype is not None else tensor.dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(device, offload_stream)
|
||||
if temporary:
|
||||
return tensor
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu" and record_cache:
|
||||
if tensor.device.type != "meta" and record_cache:
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
_log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
|
||||
_log_materialization(module, device, _device_free_memory(device), refs, state, "Disk weight loaded")
|
||||
return tensor
|
||||
|
||||
|
||||
@ -983,8 +1295,11 @@ def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_
|
||||
attr = parts[-1]
|
||||
if is_buffer:
|
||||
module._buffers[attr] = tensor
|
||||
return tensor
|
||||
else:
|
||||
module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad)
|
||||
param = torch.nn.Parameter(tensor, requires_grad=requires_grad)
|
||||
module._parameters[attr] = param
|
||||
return param
|
||||
|
||||
|
||||
def _materialize_module_from_state_dict(
|
||||
@ -999,9 +1314,7 @@ def _materialize_module_from_state_dict(
|
||||
metadata = getattr(lazy_state.state_dict, "_metadata", None)
|
||||
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
|
||||
refs = REGISTRY.get(module) or {}
|
||||
if dtype_override is not None:
|
||||
for name in refs.keys():
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
# Do not persist dtype overrides into storage.
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
keys = sorted(lazy_state.state_dict.keys())
|
||||
@ -1015,22 +1328,15 @@ def _materialize_module_from_state_dict(
|
||||
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
|
||||
existing[key] = buf
|
||||
free_mem_start = _device_free_memory(target_device)
|
||||
remaining_budget = free_mem_start
|
||||
allowed = set(existing.keys())
|
||||
allowed = set(keys)
|
||||
from . import model_management
|
||||
headroom = RAM_HEADROOM_BYTES if target_device.type == "cpu" else model_management.extra_reserved_memory()
|
||||
for key in keys:
|
||||
if key in allowed:
|
||||
continue
|
||||
meta = _state_dict_meta(lazy_state.state_dict, key)
|
||||
required = _meta_nbytes(meta)
|
||||
if required is None:
|
||||
continue
|
||||
if target_device.type == "cpu":
|
||||
free_mem = _maybe_free_ram_budget(target_device, required)
|
||||
remaining_budget = min(remaining_budget, free_mem)
|
||||
if required <= remaining_budget:
|
||||
allowed.add(key)
|
||||
remaining_budget = max(0, remaining_budget - required)
|
||||
deferred_state_dict_keys = {key for key in keys if key not in allowed}
|
||||
_ensure_free_memory(target_device, required, headroom)
|
||||
state_dict = _BudgetedStateDict(
|
||||
lazy_state.state_dict,
|
||||
allowed_keys=allowed,
|
||||
@ -1064,14 +1370,25 @@ def _materialize_module_from_state_dict(
|
||||
module.factory_kwargs["device"] = factory_device
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
for name, disk_ref in refs.items():
|
||||
if name in module._parameters:
|
||||
tensor = module._parameters[name]
|
||||
is_buffer = False
|
||||
elif name in module._buffers:
|
||||
tensor = module._buffers[name]
|
||||
is_buffer = True
|
||||
else:
|
||||
continue
|
||||
if tensor is not None and tensor.device.type == "meta":
|
||||
_attach_disk_identity(tensor, module, name, is_buffer)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
lazy_state.loaded = len(deferred_state_dict_keys) == 0
|
||||
lazy_state.loaded = True
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.device.type == "cpu":
|
||||
if param.device.type != "meta":
|
||||
CACHE.record(module, name, param, is_buffer=False)
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if buf is not None and buf.device.type == "cpu":
|
||||
if buf is not None and buf.device.type != "meta":
|
||||
CACHE.record(module, name, buf, is_buffer=True)
|
||||
|
||||
|
||||
@ -1100,14 +1417,16 @@ def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = Fals
|
||||
continue
|
||||
meta = state_dict.meta(name)
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||
_replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
|
||||
stored = _replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
|
||||
_attach_disk_identity(stored, model, name, False)
|
||||
|
||||
for name, buf in model.named_buffers(recurse=True):
|
||||
if buf is None or name not in state_keys:
|
||||
continue
|
||||
meta = state_dict.meta(name)
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||
_replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
|
||||
stored = _replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
|
||||
_attach_disk_identity(stored, model, name, True)
|
||||
|
||||
register_module_weights(model, state_dict)
|
||||
register_lazy_modules(model, state_dict)
|
||||
|
||||
@ -283,7 +283,7 @@ def load_gligen(sd):
|
||||
|
||||
gated = GatedSelfAttentionDense(
|
||||
query_dim, key_dim, n_heads, d_head)
|
||||
comfy.utils.load_state_dict(gated, n_sd, strict=False)
|
||||
gated.load_state_dict(n_sd, strict=False)
|
||||
output_list.append(gated)
|
||||
|
||||
if "position_net.null_positive_feature" in sd_k:
|
||||
@ -294,7 +294,7 @@ def load_gligen(sd):
|
||||
pass
|
||||
w = WeightsLoader()
|
||||
w.position_net = PositionNet(in_dim, out_dim)
|
||||
comfy.utils.load_state_dict(w, sd, strict=False)
|
||||
w.load_state_dict(sd, strict=False)
|
||||
|
||||
gligen = Gligen(output_list, w.position_net, key_dim)
|
||||
return gligen
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
@ -113,7 +112,7 @@ class HunyuanVideo15SRModel():
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return comfy.utils.load_state_dict(self.model, sd, strict=True)
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
@ -2,7 +2,6 @@ import json
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torchaudio
|
||||
|
||||
import comfy.model_management
|
||||
@ -154,8 +153,8 @@ class AudioVAE(torch.nn.Module):
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
|
||||
comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
autoencoder_config = self.autoencoder.get_config()
|
||||
self.normalizer = AudioLatentNormalizer(
|
||||
|
||||
@ -2,7 +2,6 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.utils
|
||||
import torch.nn as nn
|
||||
|
||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||
@ -153,7 +152,7 @@ class VAE(nn.Module):
|
||||
return dec, posterior
|
||||
|
||||
def load_weights(self, src_dict) -> None:
|
||||
comfy.utils.load_state_dict(self, src_dict, strict=True)
|
||||
self.load_state_dict(src_dict, strict=True)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
|
||||
@ -309,7 +309,7 @@ class BaseModel(torch.nn.Module):
|
||||
else:
|
||||
to_load = sd
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
|
||||
@ -753,8 +753,8 @@ class StableAudio1(BaseModel):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
||||
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||
utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
|
||||
utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
|
||||
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights, strict=True)
|
||||
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights, strict=True)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
import collections
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
@ -524,18 +525,14 @@ class LoadedModel:
|
||||
return True
|
||||
return False
|
||||
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True, offload_device=None):
|
||||
target_offload_device = self.model.offload_device if offload_device is None else offload_device
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
freed = self.model.partially_unload(target_offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
offload_device = None
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
offload_device = torch.device("meta")
|
||||
self.model.detach(unpatch_weights, offload_device=offload_device)
|
||||
if offload_device is not None and offload_device.type == "meta":
|
||||
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
|
||||
self.model.detach(unpatch_weights, offload_device=target_offload_device)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
@ -589,11 +586,33 @@ def minimum_inference_memory():
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
||||
logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024))
|
||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
if freed_cache < memory_required:
|
||||
evict_ram_to_disk(memory_required - freed_cache)
|
||||
gpu_deficit = 0
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
free_before = get_free_memory(device)
|
||||
if is_device_cpu(device):
|
||||
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||
if free_before < memory_required:
|
||||
logging.debug(
|
||||
"Disk weights RAM pressure: required=%d free=%d headroom=%d device=%s",
|
||||
memory_required,
|
||||
free_before,
|
||||
headroom,
|
||||
device,
|
||||
)
|
||||
comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
elif is_device_cuda(device) or is_device_xpu(device):
|
||||
if free_before < memory_required:
|
||||
logging.debug(
|
||||
"Disk weights VRAM pressure: required=%d free=%d device=%s",
|
||||
memory_required,
|
||||
free_before,
|
||||
device,
|
||||
)
|
||||
comfy.disk_weights.evict_for_budget(device, memory_required)
|
||||
free_after_vram = get_free_memory(device)
|
||||
if free_after_vram < memory_required:
|
||||
gpu_deficit = memory_required - free_after_vram
|
||||
comfy.disk_weights.evict_ram_cache(gpu_deficit)
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -614,7 +633,17 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
offload_device = None
|
||||
if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device):
|
||||
offload_device = torch.device("meta")
|
||||
elif comfy.disk_weights.disk_weights_enabled() and gpu_deficit > 0 and (is_device_cuda(device) or is_device_xpu(device)):
|
||||
cpu = torch.device("cpu")
|
||||
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||
required_cpu = current_loaded_models[i].model_loaded_memory() + headroom
|
||||
free_cpu = get_free_memory(cpu)
|
||||
if free_cpu < required_cpu:
|
||||
offload_device = torch.device("meta")
|
||||
if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device):
|
||||
unloaded_model.append(i)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
@ -627,36 +656,18 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||
if mem_free_torch > mem_free_total * 0.25:
|
||||
soft_empty_cache()
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
free_after = get_free_memory(device)
|
||||
freed_total = max(0, free_after - free_before)
|
||||
logging.debug(
|
||||
"Disk weights free_memory: device=%s free_before=%d free_after=%d freed=%d",
|
||||
device,
|
||||
free_before,
|
||||
free_after,
|
||||
freed_total,
|
||||
)
|
||||
return unloaded_models
|
||||
|
||||
|
||||
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
||||
if memory_to_free <= 0:
|
||||
return 0
|
||||
if not comfy.disk_weights.disk_weights_enabled():
|
||||
return 0
|
||||
|
||||
freed = 0
|
||||
can_unload = []
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
loaded_memory = shift_model.model_loaded_memory()
|
||||
if loaded_memory > 0:
|
||||
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_needed = memory_to_free - freed
|
||||
if memory_needed <= 0:
|
||||
break
|
||||
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
|
||||
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
|
||||
|
||||
if freed > 0:
|
||||
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
|
||||
return freed
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
@ -1122,6 +1133,10 @@ def sync_stream(device, stream):
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if comfy.disk_weights.disk_weights_enabled() and weight.device.type == "meta":
|
||||
target_device = device if device is not None else torch.device("cpu")
|
||||
target_dtype = dtype if dtype is not None else weight.dtype
|
||||
weight = comfy.disk_weights.materialize_meta_tensor(weight, target_device, target_dtype)
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
@ -1134,7 +1149,6 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
@ -1145,6 +1159,13 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
if non_blocking and (is_device_cuda(device) or is_device_cuda(weight.device)):
|
||||
record_stream = stream if stream is not None else current_stream(device if is_device_cuda(device) else weight.device)
|
||||
if record_stream is not None:
|
||||
if is_device_cpu(weight.device) and weight.is_pinned():
|
||||
_record_pinned_event(weight.data_ptr(), record_stream)
|
||||
if is_device_cpu(r.device) and r.is_pinned():
|
||||
_record_pinned_event(r.data_ptr(), record_stream)
|
||||
return r
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
@ -1155,6 +1176,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
PINNED_MEMORY = {}
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
PINNED_IN_FLIGHT = {}
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
if WINDOWS:
|
||||
@ -1163,14 +1185,13 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
WEIGHTS_RAM_CACHE_BYTES = 0
|
||||
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
|
||||
if args.weights_ram_cache_gb is not None:
|
||||
WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
|
||||
if args.low_ram:
|
||||
comfy.disk_weights.configure(
|
||||
WEIGHTS_RAM_CACHE_BYTES,
|
||||
allow_gds=WEIGHTS_GDS_ENABLED,
|
||||
pin_if_cpu=not args.disable_pinned_memory,
|
||||
ram_headroom_bytes=1024 * 1024 * 1024,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
@ -1223,7 +1244,19 @@ def pin_memory(tensor):
|
||||
|
||||
return False
|
||||
|
||||
def unpin_memory(tensor):
|
||||
def _record_pinned_event(ptr, stream):
|
||||
events = PINNED_IN_FLIGHT.setdefault(ptr, [])
|
||||
event = torch.cuda.Event()
|
||||
event.record(stream)
|
||||
events.append(event)
|
||||
|
||||
def wait_for_pinned_tensor(tensor):
|
||||
ptr = tensor.data_ptr()
|
||||
events = PINNED_IN_FLIGHT.pop(ptr, [])
|
||||
for event in events:
|
||||
event.synchronize()
|
||||
|
||||
def _unpin_memory_now(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
@ -1254,6 +1287,17 @@ def unpin_memory(tensor):
|
||||
|
||||
return False
|
||||
|
||||
def unpin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
wait_for_pinned_tensor(tensor)
|
||||
return _unpin_memory_now(tensor)
|
||||
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
|
||||
@ -621,6 +621,16 @@ class ModelPatcher:
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
if comfy.disk_weights.disk_weights_enabled() and weight is not None and weight.device.type == "meta":
|
||||
parts = key.split(".")
|
||||
param_name = parts[-1]
|
||||
module = self.model
|
||||
for part in parts[:-1]:
|
||||
module = getattr(module, part)
|
||||
target_device = device_to or self.offload_device or torch.device("cpu")
|
||||
comfy.disk_weights.load_module_tensor(module, param_name, device=target_device)
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
@ -650,8 +660,8 @@ class ModelPatcher:
|
||||
def unpin_weight(self, key):
|
||||
if key in self.pinned:
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
comfy.model_management.unpin_memory(weight)
|
||||
self.pinned.remove(key)
|
||||
if comfy.model_management.unpin_memory(weight):
|
||||
self.pinned.remove(key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
for key in list(self.pinned):
|
||||
@ -885,9 +895,16 @@ class ModelPatcher:
|
||||
NS = comfy.model_management.NUM_STREAMS
|
||||
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
||||
remaining_ram = None
|
||||
cpu_device = torch.device("cpu")
|
||||
if device_to is not None and comfy.model_management.is_device_cpu(device_to):
|
||||
remaining_ram = comfy.model_management.get_free_memory(device_to)
|
||||
|
||||
def offload_module_tree(module):
|
||||
freed = 0
|
||||
for submodule in module.modules():
|
||||
freed += comfy.disk_weights.offload_module_weights(submodule)
|
||||
return freed
|
||||
|
||||
for unload in unload_list:
|
||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||
break
|
||||
@ -922,15 +939,28 @@ class ModelPatcher:
|
||||
cast_weight = self.force_cast_weights
|
||||
freed_bytes = module_mem
|
||||
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
|
||||
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
||||
freed_bytes = offload_module_tree(m)
|
||||
if freed_bytes == 0:
|
||||
freed_bytes = module_mem
|
||||
else:
|
||||
if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled():
|
||||
logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024))
|
||||
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
||||
if freed_bytes == 0:
|
||||
freed_bytes = module_mem
|
||||
if remaining_ram is not None and comfy.disk_weights.disk_weights_enabled():
|
||||
required_bytes = module_mem
|
||||
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||
comfy.model_management.free_memory(required_bytes + headroom, cpu_device, keep_loaded=[self])
|
||||
remaining_ram = comfy.model_management.get_free_memory(cpu_device)
|
||||
if remaining_ram < required_bytes:
|
||||
logging.info(
|
||||
"Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.",
|
||||
n,
|
||||
required_bytes / (1024 * 1024),
|
||||
remaining_ram / (1024 * 1024),
|
||||
)
|
||||
freed_bytes = offload_module_tree(m)
|
||||
if freed_bytes == 0:
|
||||
freed_bytes = module_mem
|
||||
else:
|
||||
comfy.disk_weights.move_module_tensors(m, device_to)
|
||||
remaining_ram = max(0, remaining_ram - required_bytes)
|
||||
else:
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
comfy.disk_weights.move_module_tensors(m, device_to)
|
||||
@ -980,16 +1010,22 @@ class ModelPatcher:
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
offload_device = self.offload_device
|
||||
if comfy.disk_weights.disk_weights_enabled() and device_to is not None:
|
||||
if comfy.model_management.is_device_cpu(device_to):
|
||||
offload_device = torch.device("meta")
|
||||
else:
|
||||
offload_device = torch.device("cpu")
|
||||
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
||||
# TODO: force_patch_weights should not unload + reload full model
|
||||
used = self.model.model_loaded_weight_memory
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.unpatch_model(offload_device, unpatch_weights=unpatch_weights)
|
||||
if unpatch_weights:
|
||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||
|
||||
self.patch_model(load_weights=False)
|
||||
if extra_memory < 0 and not unpatch_weights:
|
||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
self.partially_unload(offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
return 0
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
|
||||
130
comfy/ops.py
130
comfy/ops.py
@ -19,12 +19,67 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.disk_weights
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
|
||||
def _dw_find_input_tensor(args, kwargs):
|
||||
"""Return first tensor-like input (supports QuantizedTensor)."""
|
||||
def check(obj):
|
||||
if torch.is_tensor(obj):
|
||||
return obj
|
||||
if isinstance(obj, QuantizedTensor):
|
||||
return obj
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for it in obj:
|
||||
r = check(it)
|
||||
if r is not None:
|
||||
return r
|
||||
if isinstance(obj, dict):
|
||||
for it in obj.values():
|
||||
r = check(it)
|
||||
if r is not None:
|
||||
return r
|
||||
return None
|
||||
for a in args:
|
||||
r = check(a)
|
||||
if r is not None:
|
||||
return r
|
||||
return check(kwargs)
|
||||
|
||||
def _dw_disk_weights_enabled() -> bool:
|
||||
# Delayed import avoids eager circular imports.
|
||||
from comfy import disk_weights as _dw
|
||||
return _dw.disk_weights_enabled()
|
||||
|
||||
def _dw_requires_temporary_cast(module, args, kwargs) -> bool:
|
||||
"""
|
||||
When disk_weights is enabled, route ops through the comfy_cast path when
|
||||
weights/bias are not directly usable (dtype/device mismatch or meta).
|
||||
"""
|
||||
if not _dw_disk_weights_enabled():
|
||||
return False
|
||||
inp = _dw_find_input_tensor(args, kwargs)
|
||||
if inp is None:
|
||||
return False
|
||||
w = getattr(module, "weight", None)
|
||||
if w is None:
|
||||
return False
|
||||
if isinstance(inp, QuantizedTensor):
|
||||
req_dtype = inp.params.orig_dtype
|
||||
req_dev = inp.device
|
||||
else:
|
||||
req_dtype = inp.dtype
|
||||
req_dev = inp.device
|
||||
if w.device.type == "meta" or w.device != req_dev or w.dtype != req_dtype:
|
||||
return True
|
||||
b = getattr(module, "bias", None)
|
||||
if b is not None and (b.device.type == "meta" or b.device != req_dev or b.dtype != req_dtype):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
return
|
||||
@ -101,27 +156,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
|
||||
weight_source = s.weight
|
||||
bias_source = s.bias
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
if weight_source.device.type == "meta":
|
||||
loaded = comfy.disk_weights.load_module_tensor(
|
||||
s,
|
||||
"weight",
|
||||
device,
|
||||
temporary=True,
|
||||
dtype_override=dtype,
|
||||
)
|
||||
if loaded is not None:
|
||||
weight_source = loaded
|
||||
if bias_source is not None and bias_source.device.type == "meta":
|
||||
loaded_bias = comfy.disk_weights.load_module_tensor(
|
||||
s,
|
||||
"bias",
|
||||
device,
|
||||
temporary=True,
|
||||
dtype_override=bias_dtype,
|
||||
)
|
||||
if loaded_bias is not None:
|
||||
bias_source = loaded_bias
|
||||
|
||||
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
|
||||
@ -185,7 +219,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -202,7 +236,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -219,7 +253,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -245,7 +279,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -262,7 +296,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -284,7 +318,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -308,7 +342,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -332,7 +366,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -356,7 +390,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
@ -378,7 +412,7 @@ class disable_weight_init:
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or _dw_requires_temporary_cast(self, args, kwargs):
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
if "out_dtype" in kwargs:
|
||||
@ -557,10 +591,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
if value.device.type != "meta":
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
@ -577,16 +610,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None and layer_conf.device.type != "meta":
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
elif layer_conf is not None:
|
||||
layer_conf = None
|
||||
|
||||
if layer_conf is None:
|
||||
if weight.device.type == "meta":
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
@ -632,13 +660,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
if weight.device.type == "meta":
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
@ -648,10 +673,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
if _v.device.type == "meta":
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
|
||||
else:
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
@ -37,6 +37,19 @@ _FST_LOADED = False
|
||||
_GDS_INITIALIZED = False
|
||||
_MISSING = object()
|
||||
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
|
||||
_PINNED_INFLIGHT = collections.deque()
|
||||
|
||||
|
||||
def _reap_pinned_inflight():
|
||||
if not _PINNED_INFLIGHT:
|
||||
return
|
||||
pending = collections.deque()
|
||||
while _PINNED_INFLIGHT:
|
||||
event, tensor = _PINNED_INFLIGHT.popleft()
|
||||
if event.query():
|
||||
continue
|
||||
pending.append((event, tensor))
|
||||
_PINNED_INFLIGHT.extend(pending)
|
||||
|
||||
|
||||
def _require_fastsafetensors():
|
||||
@ -204,14 +217,22 @@ class _SafeTensorFile:
|
||||
)
|
||||
return tensor
|
||||
|
||||
target_dtype = dtype
|
||||
if device_is_cuda and pin_if_cpu:
|
||||
_reap_pinned_inflight()
|
||||
target_dtype = None
|
||||
cpu_tensor = self._read_tensor_nogds(
|
||||
fst, framework, meta, torch.device("cpu"), dtype
|
||||
fst, framework, meta, torch.device("cpu"), target_dtype, pin_memory=bool(device_is_cuda and pin_if_cpu)
|
||||
)
|
||||
if device_is_cuda:
|
||||
if pin_if_cpu:
|
||||
cpu_tensor = cpu_tensor.pin_memory()
|
||||
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
|
||||
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
|
||||
if pin_if_cpu:
|
||||
event = torch.cuda.Event()
|
||||
event.record(torch.cuda.current_stream(device))
|
||||
_PINNED_INFLIGHT.append((event, cpu_tensor))
|
||||
if dtype is not None and dtype != gpu_tensor.dtype:
|
||||
gpu_tensor = gpu_tensor.to(dtype=dtype)
|
||||
return gpu_tensor
|
||||
return cpu_tensor
|
||||
|
||||
@ -233,6 +254,8 @@ class _SafeTensorFile:
|
||||
meta: TensorMeta,
|
||||
device: torch.device,
|
||||
dtype: Optional[torch.dtype],
|
||||
*,
|
||||
pin_memory: bool = False,
|
||||
) -> torch.Tensor:
|
||||
fd = self._ensure_fd()
|
||||
reader = self._ensure_nogds_reader(use_cuda=False)
|
||||
@ -241,7 +264,7 @@ class _SafeTensorFile:
|
||||
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
|
||||
chunk_bytes = max(1, chunk_bytes)
|
||||
ptr_align = framework.get_device_ptr_align()
|
||||
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu")
|
||||
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu", pin_memory=pin_memory)
|
||||
buffer_length = 0
|
||||
buf_ptr = None
|
||||
gbuf = None
|
||||
@ -250,6 +273,8 @@ class _SafeTensorFile:
|
||||
while chunk_offset < length:
|
||||
chunk_len = min(length - chunk_offset, chunk_bytes)
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
|
||||
if aligned_offset + aligned_length > self.index.size_bytes:
|
||||
aligned_length = max(0, self.index.size_bytes - aligned_offset)
|
||||
needed = aligned_length + ptr_align
|
||||
if buf_ptr is None or needed > buffer_length:
|
||||
if buf_ptr is not None:
|
||||
@ -272,7 +297,6 @@ class _SafeTensorFile:
|
||||
if buf_ptr is not None:
|
||||
fst.cpp.cpu_free(buf_ptr)
|
||||
if dtype is not None and dtype != dest_tensor.dtype:
|
||||
_validate_dtype_conversion(dest_tensor.dtype, dtype)
|
||||
dest_tensor = dest_tensor.to(dtype=dtype)
|
||||
return dest_tensor
|
||||
|
||||
@ -289,6 +313,8 @@ class _SafeTensorFile:
|
||||
abs_start = self.index.header_length + meta.data_offsets[0]
|
||||
length = meta.data_offsets[1] - meta.data_offsets[0]
|
||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
||||
if aligned_offset + aligned_length > self.index.size_bytes:
|
||||
aligned_length = max(0, self.index.size_bytes - aligned_offset)
|
||||
ptr_align = framework.get_device_ptr_align()
|
||||
buffer_length = aligned_length + ptr_align
|
||||
fst_device = _fst_device_from_torch(fst, device)
|
||||
@ -306,7 +332,6 @@ class _SafeTensorFile:
|
||||
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
||||
)
|
||||
if dtype is not None and dtype != tensor.dtype:
|
||||
_validate_dtype_conversion(tensor.dtype, dtype)
|
||||
tensor = tensor.to(dtype=dtype)
|
||||
return tensor
|
||||
|
||||
@ -348,8 +373,7 @@ def _dlpack_tensor_from_buffer(
|
||||
|
||||
|
||||
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
|
||||
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
|
||||
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
|
||||
return
|
||||
|
||||
|
||||
def _get_gds_o_direct(framework) -> bool:
|
||||
@ -523,8 +547,10 @@ class StreamStateDict(collections.abc.MutableMapping):
|
||||
raise KeyError(key)
|
||||
return default
|
||||
if self._index.has(key):
|
||||
value = self.get_tensor(key)
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
self._overrides.pop(key, None)
|
||||
return value
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
@ -636,8 +662,10 @@ class _BaseViewStateDict(MutableMapping):
|
||||
if default is _MISSING:
|
||||
raise
|
||||
return default
|
||||
value = self.get_tensor(key)
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
self._overrides.pop(key, None)
|
||||
return value
|
||||
|
||||
def meta(self, key: str):
|
||||
if key in self._overrides:
|
||||
@ -768,8 +796,10 @@ class DeviceViewStateDict(_BaseViewStateDict):
|
||||
if default is _MISSING:
|
||||
raise
|
||||
return default
|
||||
value = self.get_tensor(key)
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
self._overrides.pop(key, None)
|
||||
return value
|
||||
|
||||
|
||||
class FilterViewStateDict(_BaseViewStateDict):
|
||||
@ -932,3 +962,48 @@ class MappedStateDict(_BaseViewStateDict):
|
||||
|
||||
def _iter_base_keys(self) -> Iterable[str]:
|
||||
return self._view_to_base.keys()
|
||||
|
||||
|
||||
def stream_load_state_dict(model, state_dict, strict: bool = False, assign: bool = False):
|
||||
if getattr(state_dict, "is_stream_state_dict", False) and hasattr(state_dict, "copy"):
|
||||
state_dict = state_dict.copy()
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
|
||||
def load(module, local_state_dict, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
if assign:
|
||||
local_metadata["assign_to_params_buffers"] = assign
|
||||
module._load_from_state_dict(
|
||||
local_state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
True,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
child_prefix = f"{prefix}{name}."
|
||||
child_state_dict = FilterViewStateDict(
|
||||
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
|
||||
)
|
||||
load(child, child_state_dict, child_prefix)
|
||||
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
for hook in module._load_state_dict_post_hooks.values():
|
||||
out = hook(module, incompatible)
|
||||
if out is not None:
|
||||
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
|
||||
|
||||
load(model, state_dict)
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
13
comfy/sd.py
13
comfy/sd.py
@ -26,7 +26,6 @@ import os
|
||||
|
||||
import comfy.utils
|
||||
import comfy.safetensors_stream
|
||||
import comfy.disk_weights
|
||||
|
||||
from . import clip_vision
|
||||
from . import gligen
|
||||
@ -126,7 +125,7 @@ class CLIP:
|
||||
if not model_management.supports_cast(load_device, dt):
|
||||
load_device = offload_device
|
||||
if params['device'] != offload_device:
|
||||
comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
|
||||
self.cond_stage_model.to(offload_device)
|
||||
logging.warning("Had to shift TE back.")
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
@ -290,7 +289,7 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
|
||||
@ -658,7 +657,7 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
|
||||
@ -672,7 +671,7 @@ class VAE:
|
||||
if dtype is None:
|
||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||
self.vae_dtype = dtype
|
||||
comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
|
||||
self.first_stage_model.to(dtype=self.vae_dtype)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
@ -979,7 +978,7 @@ def load_style_model(ckpt_path):
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
comfy.utils.load_state_dict(model, model_data, strict=True)
|
||||
model.load_state_dict(model_data, strict=True)
|
||||
return StyleModel(model)
|
||||
|
||||
def sd_shape(state_dict, key):
|
||||
@ -1547,7 +1546,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = comfy.disk_weights.module_to(model, offload_device)
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
|
||||
@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
return self(tokens)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
|
||||
return self.transformer.load_state_dict(sd, strict=False)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
|
||||
@ -56,9 +56,9 @@ class TAESD(nn.Module):
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||
if encoder_path is not None:
|
||||
comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
|
||||
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
|
||||
if decoder_path is not None:
|
||||
comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
|
||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
|
||||
@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
return comfy.utils.load_state_dict(self, sdo, strict=False)
|
||||
return self.load_state_dict(sdo, strict=False)
|
||||
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
|
||||
@ -32,7 +32,6 @@ from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
from . import safetensors_stream
|
||||
import comfy.disk_weights
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -166,62 +165,6 @@ def state_dict_meta(state_dict, key):
|
||||
)
|
||||
|
||||
|
||||
def load_state_dict(model, state_dict, strict=False, assign=False):
|
||||
if is_stream_state_dict(state_dict):
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
|
||||
comfy.disk_weights.register_module_weights(model, state_dict)
|
||||
comfy.disk_weights.attach_disk_weight_hooks(model)
|
||||
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
|
||||
return missing, unexpected
|
||||
return model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def stream_load_state_dict(model, state_dict, strict=False, assign=False):
|
||||
if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
|
||||
state_dict = state_dict.copy()
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
|
||||
def load(module, local_state_dict, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
if assign:
|
||||
local_metadata["assign_to_params_buffers"] = assign
|
||||
module._load_from_state_dict(
|
||||
local_state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
True,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
child_prefix = f"{prefix}{name}."
|
||||
child_state_dict = safetensors_stream.FilterViewStateDict(
|
||||
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
|
||||
)
|
||||
load(child, child_state_dict, child_prefix)
|
||||
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
for hook in module._load_state_dict_post_hooks.values():
|
||||
out = hook(module, incompatible)
|
||||
if out is not None:
|
||||
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
|
||||
|
||||
load(model, state_dict)
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
return missing_keys, unexpected_keys
|
||||
|
||||
|
||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user