mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-03 18:20:26 +08:00
Refine disk weight offload integration
This commit is contained in:
parent
b93f165c2e
commit
95ca11fe25
@ -349,12 +349,12 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
|||||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||||
| `--weights-ram-cache-gb` | Enable a disk tier for model weights and keep up to N GB in RAM. Set to `0` to disable RAM caching while still allowing disk streaming. |
|
| `--low-ram` | Enable disk weight offloading. Sets RAM headroom to 1024MB and uses `--reserve-vram` as the VRAM headroom. |
|
||||||
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
|
| `--weights-gds` | Enable GPUDirect Storage (GDS) for disk→GPU weight loads. Requires libcufile and GDS support. |
|
||||||
|
|
||||||
### Disk tier for model weights
|
### Disk tier for model weights
|
||||||
|
|
||||||
When `--weights-ram-cache-gb` is set, ComfyUI streams safetensors weights from disk and keeps a bounded RAM cache. If the cache limit is exceeded, weights are evicted back to disk and reloaded on demand.
|
When `--low-ram` is enabled, ComfyUI streams safetensors weights from disk and offloads weights to disk when RAM headroom is reached.
|
||||||
|
|
||||||
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
|
If `--weights-gds` is enabled, ComfyUI attempts disk→GPU reads via GPUDirect Storage. If GDS is not available (missing libcufile or unsupported platform), the load will fail with a clear error. Disable GDS by omitting `--weights-gds` to use disk→RAM→GPU staging instead.
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class AudioEncoderModel():
|
|||||||
self.model_sample_rate = 16000
|
self.model_sample_rate = 16000
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|||||||
@ -114,7 +114,7 @@ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU cachi
|
|||||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||||
|
|
||||||
parser.add_argument("--weights-ram-cache-gb", type=float, default=None, help="Enable a disk tier for model weights by keeping up to N GB in RAM. Set to 0 to disable RAM caching while keeping disk tier enabled.")
|
parser.add_argument("--low-ram", action="store_true", help="Enable disk weight offloading. Sets RAM headroom to 1024MB and uses --reserve-vram as the VRAM headroom.")
|
||||||
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
|
parser.add_argument("--weights-gds", action="store_true", help="Enable GPUDirect Storage (GDS) for disk->GPU weight loads. Requires libcufile and GDS support.")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
|
|||||||
@ -6,7 +6,6 @@ import logging
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
|
||||||
import comfy.clip_model
|
import comfy.clip_model
|
||||||
import comfy.image_encoders.dino2
|
import comfy.image_encoders.dino2
|
||||||
|
|
||||||
@ -48,7 +47,7 @@ class ClipVisionModel():
|
|||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return comfy.utils.load_state_dict(self.model, sd, strict=False)
|
return self.model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|||||||
@ -25,7 +25,6 @@ import logging
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
import comfy.disk_weights
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
@ -386,7 +385,7 @@ class ControlLora(ControlNet):
|
|||||||
controlnet_config["operations"] = control_lora_ops
|
controlnet_config["operations"] = control_lora_ops
|
||||||
controlnet_config["dtype"] = dtype
|
controlnet_config["dtype"] = dtype
|
||||||
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
comfy.disk_weights.module_to(self.control_model, comfy.model_management.get_torch_device())
|
self.control_model.to(comfy.model_management.get_torch_device())
|
||||||
diffusion_model = model.diffusion_model
|
diffusion_model = model.diffusion_model
|
||||||
sd = diffusion_model.state_dict()
|
sd = diffusion_model.state_dict()
|
||||||
|
|
||||||
@ -440,7 +439,7 @@ def controlnet_config(sd, model_options={}):
|
|||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
|
|
||||||
def controlnet_load_state_dict(control_model, sd):
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
missing, unexpected = comfy.utils.load_state_dict(control_model, sd, strict=False)
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
logging.warning("missing controlnet keys: {}".format(missing))
|
logging.warning("missing controlnet keys: {}".format(missing))
|
||||||
@ -474,9 +473,9 @@ def load_controlnet_mmdit(sd, model_options={}):
|
|||||||
class ControlNetSD35(ControlNet):
|
class ControlNetSD35(ControlNet):
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
if self.control_model.double_y_emb:
|
if self.control_model.double_y_emb:
|
||||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.orig_y_embedder, model.diffusion_model.y_embedder.state_dict(), strict=False)
|
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||||
else:
|
else:
|
||||||
missing, unexpected = comfy.utils.load_state_dict(self.control_model.x_embedder, model.diffusion_model.x_embedder.state_dict(), strict=False)
|
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||||
super().pre_run(model, percent_to_timestep_function)
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@ -749,9 +748,9 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
pass
|
pass
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
w.control_model = control_model
|
w.control_model = control_model
|
||||||
missing, unexpected = comfy.utils.load_state_dict(w, controlnet_data, strict=False)
|
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||||
else:
|
else:
|
||||||
missing, unexpected = comfy.utils.load_state_dict(control_model, controlnet_data, strict=False)
|
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||||
|
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
logging.warning("missing controlnet keys: {}".format(missing))
|
logging.warning("missing controlnet keys: {}".format(missing))
|
||||||
@ -817,8 +816,8 @@ class T2IAdapter(ControlBase):
|
|||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
if self.control_input is None:
|
if self.control_input is None:
|
||||||
comfy.disk_weights.module_to(self.t2i_model, dtype=x_noisy.dtype)
|
self.t2i_model.to(dtype=x_noisy.dtype)
|
||||||
comfy.disk_weights.module_to(self.t2i_model, self.device)
|
self.t2i_model.to(self.device)
|
||||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||||
self.t2i_model.cpu()
|
self.t2i_model.cpu()
|
||||||
|
|
||||||
@ -875,7 +874,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
missing, unexpected = comfy.utils.load_state_dict(model_ad, t2i_data, strict=True)
|
missing, unexpected = model_ad.load_state_dict(t2i_data, strict=True)
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
logging.warning("t2i missing {}".format(missing))
|
logging.warning("t2i missing {}".format(missing))
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,11 @@ from . import safetensors_stream
|
|||||||
ALLOW_GDS = False
|
ALLOW_GDS = False
|
||||||
PIN_IF_CPU = False
|
PIN_IF_CPU = False
|
||||||
DISK_WEIGHTS_ENABLED = False
|
DISK_WEIGHTS_ENABLED = False
|
||||||
|
RAM_HEADROOM_BYTES = 0
|
||||||
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
|
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
|
||||||
|
BASE_MODULE_TO = torch.nn.Module.to
|
||||||
|
BASE_LOAD_STATE_DICT = torch.nn.Module.load_state_dict
|
||||||
|
_MONKEYPATCHED = False
|
||||||
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
|
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
|
||||||
DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary()
|
DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary()
|
||||||
_MISSING = object()
|
_MISSING = object()
|
||||||
@ -169,13 +173,17 @@ CACHE = DiskWeightCache(0)
|
|||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
|
def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, enabled: bool = True):
|
||||||
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
|
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED, RAM_HEADROOM_BYTES
|
||||||
ALLOW_GDS = allow_gds
|
ALLOW_GDS = allow_gds
|
||||||
PIN_IF_CPU = pin_if_cpu
|
PIN_IF_CPU = pin_if_cpu
|
||||||
DISK_WEIGHTS_ENABLED = enabled
|
DISK_WEIGHTS_ENABLED = enabled
|
||||||
CACHE.set_limit(cache_bytes if enabled else 0)
|
RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes))
|
||||||
if not enabled:
|
CACHE.set_limit(0 if enabled else 0)
|
||||||
|
if enabled:
|
||||||
|
install_monkeypatches()
|
||||||
|
else:
|
||||||
|
uninstall_monkeypatches()
|
||||||
CACHE._entries.clear()
|
CACHE._entries.clear()
|
||||||
CACHE.current_bytes = 0
|
CACHE.current_bytes = 0
|
||||||
|
|
||||||
@ -184,6 +192,66 @@ def disk_weights_enabled() -> bool:
|
|||||||
return DISK_WEIGHTS_ENABLED
|
return DISK_WEIGHTS_ENABLED
|
||||||
|
|
||||||
|
|
||||||
|
def ram_headroom_bytes() -> int:
|
||||||
|
return RAM_HEADROOM_BYTES
|
||||||
|
|
||||||
|
|
||||||
|
def _is_stream_state_dict(state_dict) -> bool:
|
||||||
|
return (
|
||||||
|
getattr(state_dict, "is_stream_state_dict", False)
|
||||||
|
and hasattr(state_dict, "get_tensor")
|
||||||
|
and hasattr(state_dict, "meta")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patched_to(self: torch.nn.Module, *args, **kwargs):
|
||||||
|
if not disk_weights_enabled():
|
||||||
|
return BASE_MODULE_TO(self, *args, **kwargs)
|
||||||
|
device, dtype, non_blocking, memory_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||||
|
module_to(
|
||||||
|
self,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
non_blocking=non_blocking,
|
||||||
|
memory_format=memory_format,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def patched_load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
|
||||||
|
if not disk_weights_enabled():
|
||||||
|
if _is_stream_state_dict(state_dict):
|
||||||
|
return safetensors_stream.stream_load_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict,
|
||||||
|
strict=strict,
|
||||||
|
assign=assign,
|
||||||
|
)
|
||||||
|
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
|
||||||
|
if _is_stream_state_dict(state_dict):
|
||||||
|
missing_keys, unexpected_keys = lazy_load_state_dict(self, state_dict, strict=strict)
|
||||||
|
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||||
|
return BASE_LOAD_STATE_DICT(self, state_dict, strict=strict, assign=assign)
|
||||||
|
|
||||||
|
|
||||||
|
def install_monkeypatches():
|
||||||
|
global _MONKEYPATCHED
|
||||||
|
if _MONKEYPATCHED:
|
||||||
|
return
|
||||||
|
torch.nn.Module.to = patched_to
|
||||||
|
torch.nn.Module.load_state_dict = patched_load_state_dict
|
||||||
|
_MONKEYPATCHED = True
|
||||||
|
|
||||||
|
|
||||||
|
def uninstall_monkeypatches():
|
||||||
|
global _MONKEYPATCHED
|
||||||
|
if not _MONKEYPATCHED:
|
||||||
|
return
|
||||||
|
torch.nn.Module.to = BASE_MODULE_TO
|
||||||
|
torch.nn.Module.load_state_dict = BASE_LOAD_STATE_DICT
|
||||||
|
_MONKEYPATCHED = False
|
||||||
|
|
||||||
|
|
||||||
def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
|
def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
|
||||||
if not disk_weights_enabled():
|
if not disk_weights_enabled():
|
||||||
return
|
return
|
||||||
@ -369,33 +437,29 @@ def _device_free_memory(device: torch.device) -> int:
|
|||||||
return int(model_management.get_free_memory(device))
|
return int(model_management.get_free_memory(device))
|
||||||
|
|
||||||
|
|
||||||
def _evict_ram_for_budget(required_bytes: int) -> int:
|
def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_bytes: int) -> int:
|
||||||
if required_bytes <= 0:
|
free_before = _device_free_memory(device)
|
||||||
return 0
|
if free_before < required_bytes + headroom_bytes:
|
||||||
freed = evict_ram_cache(required_bytes)
|
LOGGER.debug(
|
||||||
if freed < required_bytes:
|
"Disk weight memory pressure: required=%d free=%d headroom=%d device=%s",
|
||||||
|
required_bytes,
|
||||||
|
free_before,
|
||||||
|
headroom_bytes,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
safetensors_stream._reap_pinned_inflight()
|
||||||
from . import model_management
|
from . import model_management
|
||||||
freed += model_management.evict_ram_to_disk(required_bytes - freed)
|
model_management.free_memory(required_bytes + headroom_bytes, device)
|
||||||
return freed
|
free_after = _device_free_memory(device)
|
||||||
|
freed = max(0, free_after - free_before)
|
||||||
|
LOGGER.debug(
|
||||||
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
|
"Disk weight memory freed: freed=%d free=%d device=%s",
|
||||||
free_mem = _device_free_memory(device)
|
freed,
|
||||||
if device.type == "cpu" and free_mem < required_bytes:
|
free_after,
|
||||||
_evict_ram_for_budget(required_bytes - free_mem)
|
device,
|
||||||
free_mem = _device_free_memory(device)
|
)
|
||||||
return free_mem
|
return free_after
|
||||||
|
return free_before
|
||||||
|
|
||||||
def _choose_alternate_device(device: torch.device) -> Optional[torch.device]:
|
|
||||||
from . import model_management
|
|
||||||
if device.type == "cpu":
|
|
||||||
alt = model_management.get_torch_device()
|
|
||||||
if alt.type != "cpu":
|
|
||||||
return alt
|
|
||||||
else:
|
|
||||||
return torch.device("cpu")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class _BudgetedStateDict(MutableMapping):
|
class _BudgetedStateDict(MutableMapping):
|
||||||
@ -527,8 +591,10 @@ class _BudgetedStateDict(MutableMapping):
|
|||||||
if default is _MISSING:
|
if default is _MISSING:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
return default
|
return default
|
||||||
|
value = self.get_tensor(key)
|
||||||
self._deleted.add(key)
|
self._deleted.add(key)
|
||||||
return self.get_tensor(key)
|
self._overrides.pop(key, None)
|
||||||
|
return value
|
||||||
|
|
||||||
def meta(self, key: str):
|
def meta(self, key: str):
|
||||||
return self._get_meta(key)
|
return self._get_meta(key)
|
||||||
@ -564,6 +630,7 @@ def register_lazy_modules(model: torch.nn.Module, state_dict):
|
|||||||
|
|
||||||
|
|
||||||
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||||
|
safetensors_stream._reap_pinned_inflight()
|
||||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||||
if lazy_state is not None:
|
if lazy_state is not None:
|
||||||
CACHE.remove_module(module)
|
CACHE.remove_module(module)
|
||||||
@ -671,7 +738,6 @@ def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype:
|
|||||||
def ensure_module_materialized(
|
def ensure_module_materialized(
|
||||||
module: torch.nn.Module,
|
module: torch.nn.Module,
|
||||||
target_device: torch.device,
|
target_device: torch.device,
|
||||||
fallback_device: Optional[torch.device] = None,
|
|
||||||
dtype_override: Optional[torch.dtype] = None,
|
dtype_override: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||||
@ -692,7 +758,6 @@ def ensure_module_materialized(
|
|||||||
_set_future_dtype(module, name, dtype_override)
|
_set_future_dtype(module, name, dtype_override)
|
||||||
_rebuild_materialization_state(module, refs, state)
|
_rebuild_materialization_state(module, refs, state)
|
||||||
free_mem_start = _device_free_memory(target_device)
|
free_mem_start = _device_free_memory(target_device)
|
||||||
remaining_budget = free_mem_start
|
|
||||||
for name in sorted(refs.keys()):
|
for name in sorted(refs.keys()):
|
||||||
disk_ref = refs[name]
|
disk_ref = refs[name]
|
||||||
if name in module._parameters:
|
if name in module._parameters:
|
||||||
@ -717,19 +782,11 @@ def ensure_module_materialized(
|
|||||||
continue
|
continue
|
||||||
required_bytes = meta_nbytes
|
required_bytes = meta_nbytes
|
||||||
if target_device.type == "cpu":
|
if target_device.type == "cpu":
|
||||||
free_mem = _maybe_free_ram_budget(target_device, required_bytes)
|
_ensure_free_memory(target_device, required_bytes, RAM_HEADROOM_BYTES)
|
||||||
remaining_budget = min(remaining_budget, free_mem)
|
|
||||||
if required_bytes > remaining_budget:
|
|
||||||
if fallback_device is not None and fallback_device != target_device:
|
|
||||||
fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes)
|
|
||||||
if fallback_free >= required_bytes:
|
|
||||||
target_for_load = fallback_device
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
target_for_load = target_device
|
from . import model_management
|
||||||
|
_ensure_free_memory(target_device, required_bytes, model_management.extra_reserved_memory())
|
||||||
|
target_for_load = target_device
|
||||||
if current.device.type == "meta":
|
if current.device.type == "meta":
|
||||||
tensor = disk_ref.load(
|
tensor = disk_ref.load(
|
||||||
target_for_load,
|
target_for_load,
|
||||||
@ -748,7 +805,6 @@ def ensure_module_materialized(
|
|||||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||||
if tensor.device.type == "cpu":
|
if tensor.device.type == "cpu":
|
||||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||||
remaining_budget = max(0, remaining_budget - required_bytes)
|
|
||||||
_rebuild_materialization_state(module, refs, state)
|
_rebuild_materialization_state(module, refs, state)
|
||||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
|
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
|
||||||
|
|
||||||
@ -761,14 +817,11 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
|||||||
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
|
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
|
||||||
if getattr(module, "comfy_cast_weights", False):
|
if getattr(module, "comfy_cast_weights", False):
|
||||||
target_device = torch.device("cpu")
|
target_device = torch.device("cpu")
|
||||||
fallback_device = _find_tensor_device(args, kwargs)
|
|
||||||
else:
|
else:
|
||||||
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||||
fallback_device = None
|
|
||||||
ensure_module_materialized(
|
ensure_module_materialized(
|
||||||
module,
|
module,
|
||||||
target_device,
|
target_device,
|
||||||
fallback_device=fallback_device,
|
|
||||||
dtype_override=dtype_override,
|
dtype_override=dtype_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -786,6 +839,7 @@ def attach_disk_weight_hooks(model: torch.nn.Module):
|
|||||||
def evict_ram_cache(bytes_to_free: int):
|
def evict_ram_cache(bytes_to_free: int):
|
||||||
if bytes_to_free <= 0:
|
if bytes_to_free <= 0:
|
||||||
return 0
|
return 0
|
||||||
|
safetensors_stream._reap_pinned_inflight()
|
||||||
return CACHE.evict_bytes(bytes_to_free)
|
return CACHE.evict_bytes(bytes_to_free)
|
||||||
|
|
||||||
|
|
||||||
@ -827,16 +881,7 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
|||||||
|
|
||||||
|
|
||||||
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
|
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
|
||||||
def _move(tensor):
|
ensure_module_materialized(module, device_to, dtype_override=dtype_override)
|
||||||
if tensor is None:
|
|
||||||
return None
|
|
||||||
if tensor.device.type == "meta":
|
|
||||||
return tensor
|
|
||||||
if dtype_override is not None and tensor.dtype != dtype_override:
|
|
||||||
return tensor.to(device=device_to, dtype=dtype_override)
|
|
||||||
return tensor.to(device=device_to)
|
|
||||||
|
|
||||||
module._apply(_move)
|
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
@ -864,10 +909,20 @@ def offload_module_weights(module: torch.nn.Module) -> int:
|
|||||||
return offloaded_bytes
|
return offloaded_bytes
|
||||||
|
|
||||||
|
|
||||||
def module_to(module: torch.nn.Module, *args, **kwargs):
|
def module_to(
|
||||||
|
module: torch.nn.Module,
|
||||||
|
*args,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
memory_format=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
allow_materialize = kwargs.pop("allow_materialize", True)
|
allow_materialize = kwargs.pop("allow_materialize", True)
|
||||||
|
arg_device = _extract_to_device(args, kwargs)
|
||||||
|
arg_dtype = _extract_to_dtype(args, kwargs)
|
||||||
if disk_weights_enabled():
|
if disk_weights_enabled():
|
||||||
target_device = _extract_to_device(args, kwargs)
|
target_device = device or arg_device
|
||||||
if target_device is None:
|
if target_device is None:
|
||||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||||
if target_device.type == "meta":
|
if target_device.type == "meta":
|
||||||
@ -875,10 +930,28 @@ def module_to(module: torch.nn.Module, *args, **kwargs):
|
|||||||
return module
|
return module
|
||||||
if allow_materialize:
|
if allow_materialize:
|
||||||
materialize_module_tree(module, target_device)
|
materialize_module_tree(module, target_device)
|
||||||
return module.to(*args, **kwargs)
|
base_kwargs = dict(kwargs)
|
||||||
dtype_override = _extract_to_dtype(args, kwargs)
|
if device is not None and arg_device is None:
|
||||||
|
base_kwargs["device"] = device
|
||||||
|
if dtype is not None and arg_dtype is None:
|
||||||
|
base_kwargs["dtype"] = dtype
|
||||||
|
if non_blocking:
|
||||||
|
base_kwargs["non_blocking"] = non_blocking
|
||||||
|
if memory_format is not None:
|
||||||
|
base_kwargs["memory_format"] = memory_format
|
||||||
|
return BASE_MODULE_TO(module, *args, **base_kwargs)
|
||||||
|
dtype_override = dtype or arg_dtype
|
||||||
return move_module_tensors(module, target_device, dtype_override=dtype_override)
|
return move_module_tensors(module, target_device, dtype_override=dtype_override)
|
||||||
return module.to(*args, **kwargs)
|
base_kwargs = dict(kwargs)
|
||||||
|
if device is not None and arg_device is None:
|
||||||
|
base_kwargs["device"] = device
|
||||||
|
if dtype is not None and arg_dtype is None:
|
||||||
|
base_kwargs["dtype"] = dtype
|
||||||
|
if non_blocking:
|
||||||
|
base_kwargs["non_blocking"] = non_blocking
|
||||||
|
if memory_format is not None:
|
||||||
|
base_kwargs["memory_format"] = memory_format
|
||||||
|
return BASE_MODULE_TO(module, *args, **base_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_module_tensor(
|
def load_module_tensor(
|
||||||
@ -886,7 +959,6 @@ def load_module_tensor(
|
|||||||
name: str,
|
name: str,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
*,
|
*,
|
||||||
allow_alternate: bool = True,
|
|
||||||
record_cache: bool = True,
|
record_cache: bool = True,
|
||||||
temporary: bool = False,
|
temporary: bool = False,
|
||||||
dtype_override: Optional[torch.dtype] = None,
|
dtype_override: Optional[torch.dtype] = None,
|
||||||
@ -909,6 +981,9 @@ def load_module_tensor(
|
|||||||
_set_future_dtype(module, name, dtype_override)
|
_set_future_dtype(module, name, dtype_override)
|
||||||
if current.device.type != "meta":
|
if current.device.type != "meta":
|
||||||
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
|
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
|
||||||
|
from . import model_management
|
||||||
|
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
|
||||||
|
_ensure_free_memory(device, _tensor_nbytes(current), headroom)
|
||||||
if target_dtype is not None and current.dtype != target_dtype:
|
if target_dtype is not None and current.dtype != target_dtype:
|
||||||
tensor = current.to(device=device, dtype=target_dtype)
|
tensor = current.to(device=device, dtype=target_dtype)
|
||||||
else:
|
else:
|
||||||
@ -926,41 +1001,11 @@ def load_module_tensor(
|
|||||||
required_bytes = _meta_nbytes(disk_ref.meta)
|
required_bytes = _meta_nbytes(disk_ref.meta)
|
||||||
if required_bytes is None:
|
if required_bytes is None:
|
||||||
return current
|
return current
|
||||||
free_mem_start = _device_free_memory(device)
|
from . import model_management
|
||||||
free_mem = _maybe_free_ram_budget(device, required_bytes)
|
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
|
||||||
load_device = device
|
_ensure_free_memory(device, required_bytes, headroom)
|
||||||
if free_mem < required_bytes and allow_alternate:
|
|
||||||
alt = _choose_alternate_device(device)
|
|
||||||
if alt is not None:
|
|
||||||
alt_free = _maybe_free_ram_budget(alt, required_bytes)
|
|
||||||
if alt_free >= required_bytes:
|
|
||||||
load_device = alt
|
|
||||||
else:
|
|
||||||
state = _get_materialization_state(module)
|
|
||||||
if name not in state.deferred_keys:
|
|
||||||
state.deferred_keys.add(name)
|
|
||||||
state.deferred_bytes += required_bytes
|
|
||||||
_update_disk_state_attrs(module, state)
|
|
||||||
_log_materialization(module, device, free_mem_start, refs, _get_materialization_state(module), "Disk weight deferred")
|
|
||||||
return current
|
|
||||||
else:
|
|
||||||
state = _get_materialization_state(module)
|
|
||||||
if name not in state.deferred_keys:
|
|
||||||
state.deferred_keys.add(name)
|
|
||||||
state.deferred_bytes += required_bytes
|
|
||||||
_update_disk_state_attrs(module, state)
|
|
||||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
|
||||||
return current
|
|
||||||
elif free_mem < required_bytes:
|
|
||||||
state = _get_materialization_state(module)
|
|
||||||
if name not in state.deferred_keys:
|
|
||||||
state.deferred_keys.add(name)
|
|
||||||
state.deferred_bytes += required_bytes
|
|
||||||
_update_disk_state_attrs(module, state)
|
|
||||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
|
||||||
return current
|
|
||||||
|
|
||||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
|
tensor = disk_ref.load(device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
|
||||||
if temporary:
|
if temporary:
|
||||||
return tensor
|
return tensor
|
||||||
if is_buffer:
|
if is_buffer:
|
||||||
@ -971,7 +1016,7 @@ def load_module_tensor(
|
|||||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||||
state = _get_materialization_state(module)
|
state = _get_materialization_state(module)
|
||||||
_rebuild_materialization_state(module, refs, state)
|
_rebuild_materialization_state(module, refs, state)
|
||||||
_log_materialization(module, load_device, free_mem_start, refs, state, "Disk weight loaded")
|
_log_materialization(module, device, _device_free_memory(device), refs, state, "Disk weight loaded")
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
@ -1015,22 +1060,15 @@ def _materialize_module_from_state_dict(
|
|||||||
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
|
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
|
||||||
existing[key] = buf
|
existing[key] = buf
|
||||||
free_mem_start = _device_free_memory(target_device)
|
free_mem_start = _device_free_memory(target_device)
|
||||||
remaining_budget = free_mem_start
|
allowed = set(keys)
|
||||||
allowed = set(existing.keys())
|
from . import model_management
|
||||||
|
headroom = RAM_HEADROOM_BYTES if target_device.type == "cpu" else model_management.extra_reserved_memory()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key in allowed:
|
|
||||||
continue
|
|
||||||
meta = _state_dict_meta(lazy_state.state_dict, key)
|
meta = _state_dict_meta(lazy_state.state_dict, key)
|
||||||
required = _meta_nbytes(meta)
|
required = _meta_nbytes(meta)
|
||||||
if required is None:
|
if required is None:
|
||||||
continue
|
continue
|
||||||
if target_device.type == "cpu":
|
_ensure_free_memory(target_device, required, headroom)
|
||||||
free_mem = _maybe_free_ram_budget(target_device, required)
|
|
||||||
remaining_budget = min(remaining_budget, free_mem)
|
|
||||||
if required <= remaining_budget:
|
|
||||||
allowed.add(key)
|
|
||||||
remaining_budget = max(0, remaining_budget - required)
|
|
||||||
deferred_state_dict_keys = {key for key in keys if key not in allowed}
|
|
||||||
state_dict = _BudgetedStateDict(
|
state_dict = _BudgetedStateDict(
|
||||||
lazy_state.state_dict,
|
lazy_state.state_dict,
|
||||||
allowed_keys=allowed,
|
allowed_keys=allowed,
|
||||||
@ -1065,7 +1103,7 @@ def _materialize_module_from_state_dict(
|
|||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
_rebuild_materialization_state(module, refs, state)
|
_rebuild_materialization_state(module, refs, state)
|
||||||
lazy_state.loaded = len(deferred_state_dict_keys) == 0
|
lazy_state.loaded = True
|
||||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
|
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
|
||||||
for name, param in module.named_parameters(recurse=False):
|
for name, param in module.named_parameters(recurse=False):
|
||||||
if param.device.type == "cpu":
|
if param.device.type == "cpu":
|
||||||
|
|||||||
@ -283,7 +283,7 @@ def load_gligen(sd):
|
|||||||
|
|
||||||
gated = GatedSelfAttentionDense(
|
gated = GatedSelfAttentionDense(
|
||||||
query_dim, key_dim, n_heads, d_head)
|
query_dim, key_dim, n_heads, d_head)
|
||||||
comfy.utils.load_state_dict(gated, n_sd, strict=False)
|
gated.load_state_dict(n_sd, strict=False)
|
||||||
output_list.append(gated)
|
output_list.append(gated)
|
||||||
|
|
||||||
if "position_net.null_positive_feature" in sd_k:
|
if "position_net.null_positive_feature" in sd_k:
|
||||||
@ -294,7 +294,7 @@ def load_gligen(sd):
|
|||||||
pass
|
pass
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
w.position_net = PositionNet(in_dim, out_dim)
|
w.position_net = PositionNet(in_dim, out_dim)
|
||||||
comfy.utils.load_state_dict(w, sd, strict=False)
|
w.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
gligen = Gligen(output_list, w.position_net, key_dim)
|
gligen = Gligen(output_list, w.position_net, key_dim)
|
||||||
return gligen
|
return gligen
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||||
@ -113,7 +112,7 @@ class HunyuanVideo15SRModel():
|
|||||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return comfy.utils.load_state_dict(self.model, sd, strict=True)
|
return self.model.load_state_dict(sd, strict=True)
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import json
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -154,8 +153,8 @@ class AudioVAE(torch.nn.Module):
|
|||||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||||
|
|
||||||
comfy.utils.load_state_dict(self.autoencoder, vae_sd, strict=False)
|
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||||
comfy.utils.load_state_dict(self.vocoder, vocoder_sd, strict=False)
|
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||||
|
|
||||||
autoencoder_config = self.autoencoder.get_config()
|
autoencoder_config = self.autoencoder.get_config()
|
||||||
self.normalizer = AudioLatentNormalizer(
|
self.normalizer = AudioLatentNormalizer(
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||||
@ -153,7 +152,7 @@ class VAE(nn.Module):
|
|||||||
return dec, posterior
|
return dec, posterior
|
||||||
|
|
||||||
def load_weights(self, src_dict) -> None:
|
def load_weights(self, src_dict) -> None:
|
||||||
comfy.utils.load_state_dict(self, src_dict, strict=True)
|
self.load_state_dict(src_dict, strict=True)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
|
|||||||
@ -309,7 +309,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
to_load = sd
|
to_load = sd
|
||||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||||
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
logging.warning("unet missing: {}".format(m))
|
logging.warning("unet missing: {}".format(m))
|
||||||
|
|
||||||
@ -753,8 +753,8 @@ class StableAudio1(BaseModel):
|
|||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
||||||
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
self.seconds_start_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||||
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=512)
|
||||||
utils.load_state_dict(self.seconds_start_embedder, seconds_start_embedder_weights, strict=True)
|
self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights, strict=True)
|
||||||
utils.load_state_dict(self.seconds_total_embedder, seconds_total_embedder_weights, strict=True)
|
self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights, strict=True)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -590,10 +590,26 @@ def minimum_inference_memory():
|
|||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
||||||
logging.info("RAM pressure: requested %.2f MB, free %.2f MB", memory_required / (1024 * 1024), get_free_memory(device) / (1024 * 1024))
|
free_before = get_free_memory(device)
|
||||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||||
if freed_cache < memory_required:
|
if free_before < memory_required:
|
||||||
evict_ram_to_disk(memory_required - freed_cache)
|
logging.debug(
|
||||||
|
"RAM pressure: required=%d free=%d headroom=%d",
|
||||||
|
memory_required,
|
||||||
|
free_before,
|
||||||
|
headroom,
|
||||||
|
)
|
||||||
|
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||||
|
freed_disk = 0
|
||||||
|
if freed_cache < memory_required:
|
||||||
|
freed_disk = evict_ram_to_disk(memory_required - freed_cache)
|
||||||
|
free_after = get_free_memory(device)
|
||||||
|
freed_total = max(0, free_after - free_before)
|
||||||
|
logging.debug(
|
||||||
|
"RAM freed: freed=%d free=%d",
|
||||||
|
freed_total if freed_total > 0 else freed_cache + freed_disk,
|
||||||
|
free_after,
|
||||||
|
)
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
unloaded_models = []
|
unloaded_models = []
|
||||||
@ -636,6 +652,7 @@ def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
|||||||
if not comfy.disk_weights.disk_weights_enabled():
|
if not comfy.disk_weights.disk_weights_enabled():
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
free_before = get_free_memory(torch.device("cpu"))
|
||||||
freed = 0
|
freed = 0
|
||||||
can_unload = []
|
can_unload = []
|
||||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||||
@ -654,7 +671,14 @@ def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
|||||||
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
|
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
|
||||||
|
|
||||||
if freed > 0:
|
if freed > 0:
|
||||||
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
|
free_after = get_free_memory(torch.device("cpu"))
|
||||||
|
freed_total = max(0, free_after - free_before)
|
||||||
|
logging.debug(
|
||||||
|
"RAM evicted to disk: required=%d free=%d freed=%d",
|
||||||
|
memory_to_free,
|
||||||
|
free_before,
|
||||||
|
freed_total if freed_total > 0 else freed,
|
||||||
|
)
|
||||||
return freed
|
return freed
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
@ -802,6 +826,8 @@ def dtype_size(dtype):
|
|||||||
return dtype_size
|
return dtype_size
|
||||||
|
|
||||||
def unet_offload_device():
|
def unet_offload_device():
|
||||||
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
|
return torch.device("meta")
|
||||||
if vram_state == VRAMState.HIGH_VRAM:
|
if vram_state == VRAMState.HIGH_VRAM:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
@ -906,6 +932,8 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
def text_encoder_offload_device():
|
def text_encoder_offload_device():
|
||||||
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
|
return torch.device("meta")
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
@ -966,6 +994,8 @@ def vae_device():
|
|||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
|
|
||||||
def vae_offload_device():
|
def vae_offload_device():
|
||||||
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
|
return torch.device("meta")
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
@ -1163,14 +1193,13 @@ if not args.disable_pinned_memory:
|
|||||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||||
|
|
||||||
WEIGHTS_RAM_CACHE_BYTES = 0
|
|
||||||
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
|
WEIGHTS_GDS_ENABLED = bool(args.weights_gds)
|
||||||
if args.weights_ram_cache_gb is not None:
|
if args.low_ram:
|
||||||
WEIGHTS_RAM_CACHE_BYTES = int(max(0.0, args.weights_ram_cache_gb) * (1024 ** 3))
|
|
||||||
comfy.disk_weights.configure(
|
comfy.disk_weights.configure(
|
||||||
WEIGHTS_RAM_CACHE_BYTES,
|
|
||||||
allow_gds=WEIGHTS_GDS_ENABLED,
|
allow_gds=WEIGHTS_GDS_ENABLED,
|
||||||
pin_if_cpu=not args.disable_pinned_memory,
|
pin_if_cpu=not args.disable_pinned_memory,
|
||||||
|
ram_headroom_bytes=1024 * 1024 * 1024,
|
||||||
|
enabled=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||||
|
|||||||
22
comfy/ops.py
22
comfy/ops.py
@ -19,7 +19,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.disk_weights
|
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
@ -101,27 +100,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
|
|
||||||
weight_source = s.weight
|
weight_source = s.weight
|
||||||
bias_source = s.bias
|
bias_source = s.bias
|
||||||
if comfy.disk_weights.disk_weights_enabled():
|
|
||||||
if weight_source.device.type == "meta":
|
|
||||||
loaded = comfy.disk_weights.load_module_tensor(
|
|
||||||
s,
|
|
||||||
"weight",
|
|
||||||
device,
|
|
||||||
temporary=True,
|
|
||||||
dtype_override=dtype,
|
|
||||||
)
|
|
||||||
if loaded is not None:
|
|
||||||
weight_source = loaded
|
|
||||||
if bias_source is not None and bias_source.device.type == "meta":
|
|
||||||
loaded_bias = comfy.disk_weights.load_module_tensor(
|
|
||||||
s,
|
|
||||||
"bias",
|
|
||||||
device,
|
|
||||||
temporary=True,
|
|
||||||
dtype_override=bias_dtype,
|
|
||||||
)
|
|
||||||
if loaded_bias is not None:
|
|
||||||
bias_source = loaded_bias
|
|
||||||
|
|
||||||
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
|||||||
@ -37,6 +37,19 @@ _FST_LOADED = False
|
|||||||
_GDS_INITIALIZED = False
|
_GDS_INITIALIZED = False
|
||||||
_MISSING = object()
|
_MISSING = object()
|
||||||
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
|
_NOGDS_CHUNK_BYTES_DEFAULT = 64 * 1024 * 1024
|
||||||
|
_PINNED_INFLIGHT = collections.deque()
|
||||||
|
|
||||||
|
|
||||||
|
def _reap_pinned_inflight():
|
||||||
|
if not _PINNED_INFLIGHT:
|
||||||
|
return
|
||||||
|
pending = collections.deque()
|
||||||
|
while _PINNED_INFLIGHT:
|
||||||
|
event, tensor = _PINNED_INFLIGHT.popleft()
|
||||||
|
if event.query():
|
||||||
|
continue
|
||||||
|
pending.append((event, tensor))
|
||||||
|
_PINNED_INFLIGHT.extend(pending)
|
||||||
|
|
||||||
|
|
||||||
def _require_fastsafetensors():
|
def _require_fastsafetensors():
|
||||||
@ -204,14 +217,22 @@ class _SafeTensorFile:
|
|||||||
)
|
)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
target_dtype = dtype
|
||||||
|
if device_is_cuda and pin_if_cpu:
|
||||||
|
_reap_pinned_inflight()
|
||||||
|
target_dtype = None
|
||||||
cpu_tensor = self._read_tensor_nogds(
|
cpu_tensor = self._read_tensor_nogds(
|
||||||
fst, framework, meta, torch.device("cpu"), dtype
|
fst, framework, meta, torch.device("cpu"), target_dtype, pin_memory=bool(device_is_cuda and pin_if_cpu)
|
||||||
)
|
)
|
||||||
if device_is_cuda:
|
if device_is_cuda:
|
||||||
if pin_if_cpu:
|
|
||||||
cpu_tensor = cpu_tensor.pin_memory()
|
|
||||||
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
|
gpu_tensor = torch.empty_like(cpu_tensor, device=device)
|
||||||
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
|
gpu_tensor.copy_(cpu_tensor, non_blocking=pin_if_cpu)
|
||||||
|
if pin_if_cpu:
|
||||||
|
event = torch.cuda.Event()
|
||||||
|
event.record(torch.cuda.current_stream(device))
|
||||||
|
_PINNED_INFLIGHT.append((event, cpu_tensor))
|
||||||
|
if dtype is not None and dtype != gpu_tensor.dtype:
|
||||||
|
gpu_tensor = gpu_tensor.to(dtype=dtype)
|
||||||
return gpu_tensor
|
return gpu_tensor
|
||||||
return cpu_tensor
|
return cpu_tensor
|
||||||
|
|
||||||
@ -233,6 +254,8 @@ class _SafeTensorFile:
|
|||||||
meta: TensorMeta,
|
meta: TensorMeta,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: Optional[torch.dtype],
|
dtype: Optional[torch.dtype],
|
||||||
|
*,
|
||||||
|
pin_memory: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
fd = self._ensure_fd()
|
fd = self._ensure_fd()
|
||||||
reader = self._ensure_nogds_reader(use_cuda=False)
|
reader = self._ensure_nogds_reader(use_cuda=False)
|
||||||
@ -241,7 +264,7 @@ class _SafeTensorFile:
|
|||||||
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
|
chunk_bytes = int(os.getenv("COMFY_SAFETENSORS_NOGDS_CHUNK_BYTES", _NOGDS_CHUNK_BYTES_DEFAULT))
|
||||||
chunk_bytes = max(1, chunk_bytes)
|
chunk_bytes = max(1, chunk_bytes)
|
||||||
ptr_align = framework.get_device_ptr_align()
|
ptr_align = framework.get_device_ptr_align()
|
||||||
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu")
|
dest_tensor = torch.empty_strided(meta.shape, meta.strides, dtype=meta.dtype, device="cpu", pin_memory=pin_memory)
|
||||||
buffer_length = 0
|
buffer_length = 0
|
||||||
buf_ptr = None
|
buf_ptr = None
|
||||||
gbuf = None
|
gbuf = None
|
||||||
@ -250,6 +273,8 @@ class _SafeTensorFile:
|
|||||||
while chunk_offset < length:
|
while chunk_offset < length:
|
||||||
chunk_len = min(length - chunk_offset, chunk_bytes)
|
chunk_len = min(length - chunk_offset, chunk_bytes)
|
||||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
|
aligned_offset, aligned_length, head = self._aligned_range(abs_start + chunk_offset, chunk_len)
|
||||||
|
if aligned_offset + aligned_length > self.index.size_bytes:
|
||||||
|
aligned_length = max(0, self.index.size_bytes - aligned_offset)
|
||||||
needed = aligned_length + ptr_align
|
needed = aligned_length + ptr_align
|
||||||
if buf_ptr is None or needed > buffer_length:
|
if buf_ptr is None or needed > buffer_length:
|
||||||
if buf_ptr is not None:
|
if buf_ptr is not None:
|
||||||
@ -272,7 +297,6 @@ class _SafeTensorFile:
|
|||||||
if buf_ptr is not None:
|
if buf_ptr is not None:
|
||||||
fst.cpp.cpu_free(buf_ptr)
|
fst.cpp.cpu_free(buf_ptr)
|
||||||
if dtype is not None and dtype != dest_tensor.dtype:
|
if dtype is not None and dtype != dest_tensor.dtype:
|
||||||
_validate_dtype_conversion(dest_tensor.dtype, dtype)
|
|
||||||
dest_tensor = dest_tensor.to(dtype=dtype)
|
dest_tensor = dest_tensor.to(dtype=dtype)
|
||||||
return dest_tensor
|
return dest_tensor
|
||||||
|
|
||||||
@ -289,6 +313,8 @@ class _SafeTensorFile:
|
|||||||
abs_start = self.index.header_length + meta.data_offsets[0]
|
abs_start = self.index.header_length + meta.data_offsets[0]
|
||||||
length = meta.data_offsets[1] - meta.data_offsets[0]
|
length = meta.data_offsets[1] - meta.data_offsets[0]
|
||||||
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
aligned_offset, aligned_length, head = self._aligned_range(abs_start, length)
|
||||||
|
if aligned_offset + aligned_length > self.index.size_bytes:
|
||||||
|
aligned_length = max(0, self.index.size_bytes - aligned_offset)
|
||||||
ptr_align = framework.get_device_ptr_align()
|
ptr_align = framework.get_device_ptr_align()
|
||||||
buffer_length = aligned_length + ptr_align
|
buffer_length = aligned_length + ptr_align
|
||||||
fst_device = _fst_device_from_torch(fst, device)
|
fst_device = _fst_device_from_torch(fst, device)
|
||||||
@ -306,7 +332,6 @@ class _SafeTensorFile:
|
|||||||
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
fst, framework, gbuf.get_base_address() + ptr_off + head, meta, device, owner
|
||||||
)
|
)
|
||||||
if dtype is not None and dtype != tensor.dtype:
|
if dtype is not None and dtype != tensor.dtype:
|
||||||
_validate_dtype_conversion(tensor.dtype, dtype)
|
|
||||||
tensor = tensor.to(dtype=dtype)
|
tensor = tensor.to(dtype=dtype)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@ -348,8 +373,7 @@ def _dlpack_tensor_from_buffer(
|
|||||||
|
|
||||||
|
|
||||||
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
|
def _validate_dtype_conversion(src: torch.dtype, dst: torch.dtype):
|
||||||
if torch.tensor([], dtype=dst).element_size() > torch.tensor([], dtype=src).element_size():
|
return
|
||||||
raise ValueError(f"Online type conversion to larger sizes is not supported ({src} -> {dst})")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_gds_o_direct(framework) -> bool:
|
def _get_gds_o_direct(framework) -> bool:
|
||||||
@ -523,8 +547,10 @@ class StreamStateDict(collections.abc.MutableMapping):
|
|||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
return default
|
return default
|
||||||
if self._index.has(key):
|
if self._index.has(key):
|
||||||
|
value = self.get_tensor(key)
|
||||||
self._deleted.add(key)
|
self._deleted.add(key)
|
||||||
return self.get_tensor(key)
|
self._overrides.pop(key, None)
|
||||||
|
return value
|
||||||
if default is _MISSING:
|
if default is _MISSING:
|
||||||
raise KeyError(key)
|
raise KeyError(key)
|
||||||
return default
|
return default
|
||||||
@ -636,8 +662,10 @@ class _BaseViewStateDict(MutableMapping):
|
|||||||
if default is _MISSING:
|
if default is _MISSING:
|
||||||
raise
|
raise
|
||||||
return default
|
return default
|
||||||
|
value = self.get_tensor(key)
|
||||||
self._deleted.add(key)
|
self._deleted.add(key)
|
||||||
return self.get_tensor(key)
|
self._overrides.pop(key, None)
|
||||||
|
return value
|
||||||
|
|
||||||
def meta(self, key: str):
|
def meta(self, key: str):
|
||||||
if key in self._overrides:
|
if key in self._overrides:
|
||||||
@ -768,8 +796,10 @@ class DeviceViewStateDict(_BaseViewStateDict):
|
|||||||
if default is _MISSING:
|
if default is _MISSING:
|
||||||
raise
|
raise
|
||||||
return default
|
return default
|
||||||
|
value = self.get_tensor(key)
|
||||||
self._deleted.add(key)
|
self._deleted.add(key)
|
||||||
return self.get_tensor(key)
|
self._overrides.pop(key, None)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class FilterViewStateDict(_BaseViewStateDict):
|
class FilterViewStateDict(_BaseViewStateDict):
|
||||||
@ -932,3 +962,48 @@ class MappedStateDict(_BaseViewStateDict):
|
|||||||
|
|
||||||
def _iter_base_keys(self) -> Iterable[str]:
|
def _iter_base_keys(self) -> Iterable[str]:
|
||||||
return self._view_to_base.keys()
|
return self._view_to_base.keys()
|
||||||
|
|
||||||
|
|
||||||
|
def stream_load_state_dict(model, state_dict, strict: bool = False, assign: bool = False):
|
||||||
|
if getattr(state_dict, "is_stream_state_dict", False) and hasattr(state_dict, "copy"):
|
||||||
|
state_dict = state_dict.copy()
|
||||||
|
missing_keys = []
|
||||||
|
unexpected_keys = []
|
||||||
|
error_msgs = []
|
||||||
|
metadata = getattr(state_dict, "_metadata", None)
|
||||||
|
|
||||||
|
def load(module, local_state_dict, prefix=""):
|
||||||
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||||
|
if assign:
|
||||||
|
local_metadata["assign_to_params_buffers"] = assign
|
||||||
|
module._load_from_state_dict(
|
||||||
|
local_state_dict,
|
||||||
|
prefix,
|
||||||
|
local_metadata,
|
||||||
|
True,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
)
|
||||||
|
for name, child in module._modules.items():
|
||||||
|
if child is not None:
|
||||||
|
child_prefix = f"{prefix}{name}."
|
||||||
|
child_state_dict = FilterViewStateDict(
|
||||||
|
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
|
||||||
|
)
|
||||||
|
load(child, child_state_dict, child_prefix)
|
||||||
|
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||||
|
for hook in module._load_state_dict_post_hooks.values():
|
||||||
|
out = hook(module, incompatible)
|
||||||
|
if out is not None:
|
||||||
|
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
|
||||||
|
|
||||||
|
load(model, state_dict)
|
||||||
|
if strict:
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
|
||||||
|
if len(missing_keys) > 0:
|
||||||
|
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
|
||||||
|
if len(error_msgs) > 0:
|
||||||
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||||
|
|||||||
13
comfy/sd.py
13
comfy/sd.py
@ -26,7 +26,6 @@ import os
|
|||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.safetensors_stream
|
import comfy.safetensors_stream
|
||||||
import comfy.disk_weights
|
|
||||||
|
|
||||||
from . import clip_vision
|
from . import clip_vision
|
||||||
from . import gligen
|
from . import gligen
|
||||||
@ -126,7 +125,7 @@ class CLIP:
|
|||||||
if not model_management.supports_cast(load_device, dt):
|
if not model_management.supports_cast(load_device, dt):
|
||||||
load_device = offload_device
|
load_device = offload_device
|
||||||
if params['device'] != offload_device:
|
if params['device'] != offload_device:
|
||||||
comfy.disk_weights.module_to(self.cond_stage_model, offload_device)
|
self.cond_stage_model.to(offload_device)
|
||||||
logging.warning("Had to shift TE back.")
|
logging.warning("Had to shift TE back.")
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
@ -290,7 +289,7 @@ class CLIP:
|
|||||||
|
|
||||||
def load_sd(self, sd, full_model=False):
|
def load_sd(self, sd, full_model=False):
|
||||||
if full_model:
|
if full_model:
|
||||||
return comfy.utils.load_state_dict(self.cond_stage_model, sd, strict=False)
|
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||||
else:
|
else:
|
||||||
return self.cond_stage_model.load_sd(sd)
|
return self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
@ -658,7 +657,7 @@ class VAE:
|
|||||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||||
self.first_stage_model = self.first_stage_model.eval()
|
self.first_stage_model = self.first_stage_model.eval()
|
||||||
|
|
||||||
m, u = comfy.utils.load_state_dict(self.first_stage_model, sd, strict=False)
|
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
logging.warning("Missing VAE keys {}".format(m))
|
logging.warning("Missing VAE keys {}".format(m))
|
||||||
|
|
||||||
@ -672,7 +671,7 @@ class VAE:
|
|||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||||
self.vae_dtype = dtype
|
self.vae_dtype = dtype
|
||||||
comfy.disk_weights.module_to(self.first_stage_model, dtype=self.vae_dtype)
|
self.first_stage_model.to(dtype=self.vae_dtype)
|
||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
@ -979,7 +978,7 @@ def load_style_model(ckpt_path):
|
|||||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||||
else:
|
else:
|
||||||
raise Exception("invalid style model {}".format(ckpt_path))
|
raise Exception("invalid style model {}".format(ckpt_path))
|
||||||
comfy.utils.load_state_dict(model, model_data, strict=True)
|
model.load_state_dict(model_data, strict=True)
|
||||||
return StyleModel(model)
|
return StyleModel(model)
|
||||||
|
|
||||||
def sd_shape(state_dict, key):
|
def sd_shape(state_dict, key):
|
||||||
@ -1547,7 +1546,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = comfy.disk_weights.module_to(model, offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
left_over = sd.keys()
|
left_over = sd.keys()
|
||||||
if len(left_over) > 0:
|
if len(left_over) > 0:
|
||||||
|
|||||||
@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return comfy.utils.load_state_dict(self.transformer, sd, strict=False)
|
return self.transformer.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
def parse_parentheses(string):
|
def parse_parentheses(string):
|
||||||
result = []
|
result = []
|
||||||
|
|||||||
@ -56,9 +56,9 @@ class TAESD(nn.Module):
|
|||||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||||
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||||
if encoder_path is not None:
|
if encoder_path is not None:
|
||||||
comfy.utils.load_state_dict(self.taesd_encoder, comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
|
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True), strict=True)
|
||||||
if decoder_path is not None:
|
if decoder_path is not None:
|
||||||
comfy.utils.load_state_dict(self.taesd_decoder, comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
|
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True), strict=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def scale_latents(x):
|
def scale_latents(x):
|
||||||
|
|||||||
@ -116,7 +116,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
if len(sdo) == 0:
|
if len(sdo) == 0:
|
||||||
sdo = sd
|
sdo = sd
|
||||||
|
|
||||||
return comfy.utils.load_state_dict(self, sdo, strict=False)
|
return self.load_state_dict(sdo, strict=False)
|
||||||
|
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
|||||||
@ -32,7 +32,6 @@ from einops import rearrange
|
|||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import json
|
import json
|
||||||
from . import safetensors_stream
|
from . import safetensors_stream
|
||||||
import comfy.disk_weights
|
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
DISABLE_MMAP = args.disable_mmap
|
DISABLE_MMAP = args.disable_mmap
|
||||||
@ -166,62 +165,6 @@ def state_dict_meta(state_dict, key):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(model, state_dict, strict=False, assign=False):
|
|
||||||
if is_stream_state_dict(state_dict):
|
|
||||||
if comfy.disk_weights.disk_weights_enabled():
|
|
||||||
return comfy.disk_weights.lazy_load_state_dict(model, state_dict, strict=strict)
|
|
||||||
comfy.disk_weights.register_module_weights(model, state_dict)
|
|
||||||
comfy.disk_weights.attach_disk_weight_hooks(model)
|
|
||||||
missing, unexpected = stream_load_state_dict(model, state_dict, strict=strict, assign=assign)
|
|
||||||
return missing, unexpected
|
|
||||||
return model.load_state_dict(state_dict, strict=strict)
|
|
||||||
|
|
||||||
|
|
||||||
def stream_load_state_dict(model, state_dict, strict=False, assign=False):
|
|
||||||
if is_stream_state_dict(state_dict) and hasattr(state_dict, "copy"):
|
|
||||||
state_dict = state_dict.copy()
|
|
||||||
missing_keys = []
|
|
||||||
unexpected_keys = []
|
|
||||||
error_msgs = []
|
|
||||||
metadata = getattr(state_dict, "_metadata", None)
|
|
||||||
|
|
||||||
def load(module, local_state_dict, prefix=""):
|
|
||||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
||||||
if assign:
|
|
||||||
local_metadata["assign_to_params_buffers"] = assign
|
|
||||||
module._load_from_state_dict(
|
|
||||||
local_state_dict,
|
|
||||||
prefix,
|
|
||||||
local_metadata,
|
|
||||||
True,
|
|
||||||
missing_keys,
|
|
||||||
unexpected_keys,
|
|
||||||
error_msgs,
|
|
||||||
)
|
|
||||||
for name, child in module._modules.items():
|
|
||||||
if child is not None:
|
|
||||||
child_prefix = f"{prefix}{name}."
|
|
||||||
child_state_dict = safetensors_stream.FilterViewStateDict(
|
|
||||||
local_state_dict, lambda k, p=child_prefix: k.startswith(p), mutate_base=False
|
|
||||||
)
|
|
||||||
load(child, child_state_dict, child_prefix)
|
|
||||||
incompatible = torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
|
||||||
for hook in module._load_state_dict_post_hooks.values():
|
|
||||||
out = hook(module, incompatible)
|
|
||||||
if out is not None:
|
|
||||||
raise RuntimeError("load_state_dict post hook returned a value, which is unsupported.")
|
|
||||||
|
|
||||||
load(model, state_dict)
|
|
||||||
if strict:
|
|
||||||
if len(unexpected_keys) > 0:
|
|
||||||
error_msgs.insert(0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in unexpected_keys)))
|
|
||||||
if len(missing_keys) > 0:
|
|
||||||
error_msgs.insert(0, 'Missing key(s) in state_dict: {}. '.format(', '.join(f'"{k}"' for k in missing_keys)))
|
|
||||||
if len(error_msgs) > 0:
|
|
||||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(model.__class__.__name__, "\n\t".join(error_msgs)))
|
|
||||||
return missing_keys, unexpected_keys
|
|
||||||
|
|
||||||
|
|
||||||
def transformers_convert(sd, prefix_from, prefix_to, number):
|
def transformers_convert(sd, prefix_from, prefix_to, number):
|
||||||
keys_to_replace = {
|
keys_to_replace = {
|
||||||
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
"{}positional_embedding": "{}embeddings.position_embedding.weight",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user