Implement partial disk weight materialization

This commit is contained in:
ifilipis 2026-01-08 19:06:49 +02:00
parent 5f2188e31b
commit 45a77073ac
2 changed files with 473 additions and 44 deletions

View File

@ -20,8 +20,9 @@ from __future__ import annotations
import collections
import weakref
from dataclasses import dataclass
from typing import Dict, MutableMapping, Optional
from dataclasses import dataclass, field
from types import SimpleNamespace
from typing import Dict, MutableMapping, Optional, Set
import torch
@ -33,6 +34,8 @@ PIN_IF_CPU = False
DISK_WEIGHTS_ENABLED = False
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary()
_MISSING = object()
@dataclass
@ -178,22 +181,24 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
return
if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"):
return
for name, param in module.named_parameters(recurse=True):
key = f"{prefix}{name}" if prefix else name
if key in state_dict:
meta = state_dict.meta(key)
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
REGISTRY.register(module, name, ref)
if param.device.type == "cpu":
CACHE.record(module, name, param, is_buffer=False)
for name, buf in module.named_buffers(recurse=True):
key = f"{prefix}{name}" if prefix else name
if key in state_dict and buf is not None:
meta = state_dict.meta(key)
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
REGISTRY.register(module, name, ref)
if buf.device.type == "cpu":
CACHE.record(module, name, buf, is_buffer=True)
for module_name, submodule in module.named_modules():
module_prefix = f"{prefix}{module_name}." if module_name else prefix
for name, param in submodule.named_parameters(recurse=False):
key = f"{module_prefix}{name}" if module_prefix else name
if key in state_dict:
meta = state_dict.meta(key)
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
REGISTRY.register(submodule, name, ref)
if param.device.type == "cpu":
CACHE.record(submodule, name, param, is_buffer=False)
for name, buf in submodule.named_buffers(recurse=False):
key = f"{module_prefix}{name}" if module_prefix else name
if key in state_dict and buf is not None:
meta = state_dict.meta(key)
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
REGISTRY.register(submodule, name, ref)
if buf.device.type == "cpu":
CACHE.record(submodule, name, buf, is_buffer=True)
@dataclass
@ -203,6 +208,245 @@ class LazyModuleState:
loaded: bool = False
@dataclass
class DiskMaterializationState:
loaded_keys: Set[str] = field(default_factory=set)
deferred_keys: Set[str] = field(default_factory=set)
loaded_bytes: int = 0
deferred_bytes: int = 0
def _get_materialization_state(module: torch.nn.Module) -> DiskMaterializationState:
state = DISK_MATERIALIZATION_STATE.get(module)
if state is None:
state = DiskMaterializationState()
DISK_MATERIALIZATION_STATE[module] = state
return state
def _update_disk_state_attrs(module: torch.nn.Module, state: DiskMaterializationState):
module.disk_loaded_weight_memory = state.loaded_bytes
module.disk_offload_buffer_memory = state.deferred_bytes
def _tensor_nbytes(tensor: torch.Tensor) -> int:
return tensor.numel() * tensor.element_size()
def _meta_nbytes(meta) -> Optional[int]:
return getattr(meta, "nbytes", None)
def _meta_tensor(meta, dtype_override: Optional[torch.dtype] = None) -> torch.Tensor:
dtype = dtype_override or getattr(meta, "dtype", None)
shape = getattr(meta, "shape", None)
if dtype is None or shape is None:
raise KeyError("Missing metadata for meta tensor")
return torch.empty(shape, dtype=dtype, device="meta")
def _state_dict_meta(state_dict: MutableMapping, key: str):
if hasattr(state_dict, "meta"):
return state_dict.meta(key)
if hasattr(state_dict, "get_tensor"):
t = state_dict.get_tensor(key, device=torch.device("meta"))
else:
t = state_dict[key]
numel = t.numel()
return SimpleNamespace(
dtype=t.dtype,
shape=tuple(t.shape),
numel=numel,
nbytes=numel * t.element_size(),
)
def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, DiskTensorRef], state: DiskMaterializationState):
state.loaded_keys.clear()
state.deferred_keys.clear()
state.loaded_bytes = 0
state.deferred_bytes = 0
for name, ref in refs.items():
if name in module._parameters:
tensor = module._parameters[name]
elif name in module._buffers:
tensor = module._buffers[name]
else:
continue
if tensor is None:
continue
nbytes = _meta_nbytes(ref.meta) or _tensor_nbytes(tensor)
if tensor.device.type == "meta":
state.deferred_keys.add(name)
state.deferred_bytes += nbytes
else:
state.loaded_keys.add(name)
state.loaded_bytes += nbytes
_update_disk_state_attrs(module, state)
def _device_free_memory(device: torch.device) -> int:
from . import model_management
return int(model_management.get_free_memory(device))
def _evict_ram_for_budget(required_bytes: int) -> int:
if required_bytes <= 0:
return 0
return evict_ram_cache(required_bytes)
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
free_mem = _device_free_memory(device)
if device.type == "cpu" and free_mem < required_bytes:
_evict_ram_for_budget(required_bytes - free_mem)
free_mem = _device_free_memory(device)
return free_mem
def _choose_alternate_device(device: torch.device) -> Optional[torch.device]:
from . import model_management
if device.type == "cpu":
alt = model_management.get_torch_device()
if alt.type != "cpu":
return alt
else:
return torch.device("cpu")
return None
class _BudgetedStateDict(MutableMapping):
is_stream_state_dict = True
def __init__(
self,
base: MutableMapping,
allowed_keys: Set[str],
device: torch.device,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
overrides: Optional[Dict[str, torch.Tensor]] = None,
):
self._base = base
self._allowed_keys = allowed_keys
self._device = device
self._allow_gds = allow_gds
self._pin_if_cpu = pin_if_cpu
self._overrides = overrides or {}
self._deleted: Set[str] = set()
def _get_meta(self, key: str):
if key in self._overrides:
t = self._overrides[key]
return safetensors_stream.TensorMeta(
dtype=t.dtype,
shape=tuple(t.shape),
numel=t.numel(),
nbytes=_tensor_nbytes(t),
data_offsets=(0, _tensor_nbytes(t)),
filename="<override>",
fst_dtype=None,
strides=tuple(t.stride()),
)
if hasattr(self._base, "meta"):
return self._base.meta(key)
if hasattr(self._base, "get_tensor"):
t = self._base.get_tensor(key, device=torch.device("meta"))
else:
t = self._base[key]
return safetensors_stream.TensorMeta(
dtype=t.dtype,
shape=tuple(t.shape),
numel=t.numel(),
nbytes=_tensor_nbytes(t),
data_offsets=(0, _tensor_nbytes(t)),
filename="<tensor>",
fst_dtype=None,
strides=tuple(t.stride()),
)
def get_tensor(
self,
key: str,
*,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
allow_gds: Optional[bool] = None,
pin_if_cpu: bool = False,
) -> torch.Tensor:
if key in self._overrides:
t = self._overrides[key]
if device is not None and t.device != device:
t = t.to(device=device)
if dtype is not None and t.dtype != dtype:
t = t.to(dtype=dtype)
return t
if key in self._deleted:
raise KeyError(key)
if key not in self._allowed_keys:
meta = self._get_meta(key)
target_dtype = dtype or meta.dtype
return _meta_tensor(meta, dtype_override=target_dtype)
if hasattr(self._base, "get_tensor"):
return self._base.get_tensor(
key,
device=self._device if device is None else device,
dtype=dtype,
allow_gds=self._allow_gds if allow_gds is None else allow_gds,
pin_if_cpu=self._pin_if_cpu if not pin_if_cpu else pin_if_cpu,
)
t = self._base[key]
if device is not None and t.device != device:
t = t.to(device=device)
if dtype is not None and t.dtype != dtype:
t = t.to(dtype=dtype)
return t
def __getitem__(self, key: str) -> torch.Tensor:
return self.get_tensor(key)
def __setitem__(self, key: str, value: torch.Tensor) -> None:
self._overrides[key] = value
self._deleted.discard(key)
def __delitem__(self, key: str) -> None:
if key in self._overrides:
del self._overrides[key]
return
if key in self._deleted:
raise KeyError(key)
self._deleted.add(key)
def __iter__(self):
for k in self._base.keys():
if k in self._deleted:
continue
yield k
for k in self._overrides.keys():
if k not in self._deleted:
yield k
def __len__(self) -> int:
base_keys = list(self._base.keys())
return len(base_keys) - len(self._deleted) + len(self._overrides)
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
if key in self._overrides:
return self._overrides.pop(key)
if key in self._deleted:
if default is _MISSING:
raise KeyError(key)
return default
if key not in self._base:
if default is _MISSING:
raise KeyError(key)
return default
self._deleted.add(key)
return self.get_tensor(key)
def meta(self, key: str):
return self._get_meta(key)
def _has_custom_load(module: torch.nn.Module) -> bool:
return module.__class__._load_from_state_dict is not BASE_LOAD_FROM_STATE_DICT
@ -239,6 +483,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
CACHE.remove_module(module)
refs = REGISTRY.get(module)
if refs:
state = _get_materialization_state(module)
for ref_name, disk_ref in refs.items():
shape = getattr(disk_ref.meta, "shape", None)
dtype = getattr(disk_ref.meta, "dtype", None)
@ -249,6 +494,14 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
module._buffers[ref_name] = meta_tensor
else:
module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
nbytes = _meta_nbytes(disk_ref.meta)
if nbytes is not None:
state.loaded_keys.discard(ref_name)
if ref_name not in state.deferred_keys:
state.deferred_keys.add(ref_name)
state.deferred_bytes += nbytes
state.loaded_bytes = max(0, state.loaded_bytes - nbytes)
_update_disk_state_attrs(module, state)
lazy_state.loaded = False
return
ref = REGISTRY.get(module)
@ -264,6 +517,15 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
module._buffers[name] = meta_tensor
else:
module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
state = _get_materialization_state(module)
nbytes = _meta_nbytes(disk_ref.meta)
if nbytes is not None:
state.loaded_keys.discard(name)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += nbytes
state.loaded_bytes = max(0, state.loaded_bytes - nbytes)
_update_disk_state_attrs(module, state)
def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
@ -288,15 +550,23 @@ def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
return check(kwargs)
def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
def ensure_module_materialized(
module: torch.nn.Module,
target_device: torch.device,
fallback_device: Optional[torch.device] = None,
):
lazy_state = LAZY_MODULE_STATE.get(module)
if lazy_state is not None and not lazy_state.loaded:
if lazy_state is not None:
_materialize_module_from_state_dict(module, lazy_state, target_device)
return
refs = REGISTRY.get(module)
if not refs:
return
for name, disk_ref in refs.items():
state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state)
remaining_budget = _device_free_memory(target_device)
for name in sorted(refs.keys()):
disk_ref = refs[name]
if name in module._parameters:
current = module._parameters[name]
is_buffer = False
@ -307,30 +577,52 @@ def ensure_module_materialized(module: torch.nn.Module, target_device: torch.dev
continue
if current is None:
continue
if current.device.type == "meta":
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
elif current.device != target_device:
tensor = current.to(device=target_device)
else:
if current.device.type != "meta" and current.device == target_device:
if current.device.type == "cpu":
CACHE.touch(module, name)
continue
meta_nbytes = _meta_nbytes(disk_ref.meta)
if meta_nbytes is None:
continue
required_bytes = meta_nbytes
if target_device.type == "cpu":
free_mem = _maybe_free_ram_budget(target_device, required_bytes)
remaining_budget = min(remaining_budget, free_mem)
if required_bytes > remaining_budget:
if fallback_device is not None and fallback_device != target_device:
fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes)
if fallback_free >= required_bytes:
target_for_load = fallback_device
else:
continue
else:
continue
else:
target_for_load = target_device
if current.device.type == "meta":
tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU)
else:
tensor = current.to(device=target_for_load)
if is_buffer:
module._buffers[name] = tensor
else:
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
if tensor.device.type == "cpu":
CACHE.record(module, name, tensor, is_buffer=is_buffer)
remaining_budget = max(0, remaining_budget - required_bytes)
_rebuild_materialization_state(module, refs, state)
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
if not REGISTRY.has(module):
if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE:
return
if getattr(module, "comfy_cast_weights", False):
target_device = torch.device("cpu")
fallback_device = _find_tensor_device(args, kwargs)
else:
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
ensure_module_materialized(module, target_device)
fallback_device = None
ensure_module_materialized(module, target_device, fallback_device=fallback_device)
def attach_disk_weight_hooks(model: torch.nn.Module):
@ -386,6 +678,87 @@ def module_to(module: torch.nn.Module, *args, **kwargs):
return module.to(*args, **kwargs)
def load_module_tensor(
module: torch.nn.Module,
name: str,
device: torch.device,
*,
allow_alternate: bool = True,
record_cache: bool = True,
temporary: bool = False,
) -> Optional[torch.Tensor]:
refs = REGISTRY.get(module)
if not refs or name not in refs:
return None
if name in module._parameters:
current = module._parameters[name]
is_buffer = False
elif name in module._buffers:
current = module._buffers[name]
is_buffer = True
else:
return None
if current is None:
return None
if current.device.type != "meta":
if current.device != device:
tensor = current.to(device=device)
if not temporary:
if is_buffer:
module._buffers[name] = tensor
else:
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=refs[name].requires_grad)
_rebuild_materialization_state(module, refs, _get_materialization_state(module))
return tensor
return current
disk_ref = refs[name]
required_bytes = _meta_nbytes(disk_ref.meta)
if required_bytes is None:
return current
free_mem = _maybe_free_ram_budget(device, required_bytes)
load_device = device
if free_mem < required_bytes and allow_alternate:
alt = _choose_alternate_device(device)
if alt is not None:
alt_free = _maybe_free_ram_budget(alt, required_bytes)
if alt_free >= required_bytes:
load_device = alt
else:
state = _get_materialization_state(module)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
return current
else:
state = _get_materialization_state(module)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
return current
elif free_mem < required_bytes:
state = _get_materialization_state(module)
if name not in state.deferred_keys:
state.deferred_keys.add(name)
state.deferred_bytes += required_bytes
_update_disk_state_attrs(module, state)
return current
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU)
if temporary:
return tensor
if is_buffer:
module._buffers[name] = tensor
else:
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
if tensor.device.type == "cpu" and record_cache:
CACHE.record(module, name, tensor, is_buffer=is_buffer)
_rebuild_materialization_state(module, refs, _get_materialization_state(module))
return tensor
def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool, requires_grad: bool):
parts = name.split(".")
module = model
@ -404,12 +777,42 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
error_msgs = []
metadata = getattr(lazy_state.state_dict, "_metadata", None)
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
state_dict = safetensors_stream.DeviceViewStateDict(
refs = REGISTRY.get(module) or {}
state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state)
keys = sorted(lazy_state.state_dict.keys())
existing = {}
for name, param in module.named_parameters(recurse=False):
key = f"{lazy_state.prefix}{name}"
if key in lazy_state.state_dict and param is not None and param.device.type != "meta":
existing[key] = param
for name, buf in module.named_buffers(recurse=False):
key = f"{lazy_state.prefix}{name}"
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
existing[key] = buf
remaining_budget = _device_free_memory(target_device)
allowed = set(existing.keys())
for key in keys:
if key in allowed:
continue
meta = _state_dict_meta(lazy_state.state_dict, key)
required = _meta_nbytes(meta)
if required is None:
continue
if target_device.type == "cpu":
free_mem = _maybe_free_ram_budget(target_device, required)
remaining_budget = min(remaining_budget, free_mem)
if required <= remaining_budget:
allowed.add(key)
remaining_budget = max(0, remaining_budget - required)
deferred_state_dict_keys = {key for key in keys if key not in allowed}
state_dict = _BudgetedStateDict(
lazy_state.state_dict,
allowed_keys=allowed,
device=target_device,
allow_gds=ALLOW_GDS,
pin_if_cpu=PIN_IF_CPU,
mutate_base=False,
overrides=existing,
)
factory_device = None
if hasattr(module, "factory_kwargs") and "device" in module.factory_kwargs:
@ -435,7 +838,8 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
module.factory_kwargs["device"] = factory_device
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
lazy_state.loaded = True
_rebuild_materialization_state(module, refs, state)
lazy_state.loaded = len(deferred_state_dict_keys) == 0
for name, param in module.named_parameters(recurse=False):
if param.device.type == "cpu":
CACHE.record(module, name, param, is_buffer=False)

View File

@ -19,6 +19,7 @@
import torch
import logging
import comfy.model_management
import comfy.disk_weights
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
@ -98,11 +99,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
weight_source = s.weight
bias_source = s.bias
if comfy.disk_weights.disk_weights_enabled():
if weight_source.device.type == "meta":
loaded = comfy.disk_weights.load_module_tensor(s, "weight", device, temporary=True)
if loaded is not None:
weight_source = loaded
if bias_source is not None and bias_source.device.type == "meta":
loaded_bias = comfy.disk_weights.load_module_tensor(s, "bias", device, temporary=True)
if loaded_bias is not None:
bias_source = loaded_bias
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_source is not None:
bias = comfy.model_management.cast_to(bias_source, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
@ -532,9 +545,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
if value.device.type != "meta":
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
@ -551,11 +565,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
if layer_conf is not None and layer_conf.device.type != "meta":
layer_conf = json.loads(layer_conf.numpy().tobytes())
elif layer_conf is not None:
layer_conf = None
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
if weight.device.type == "meta":
self.weight = torch.nn.Parameter(weight, requires_grad=False)
else:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
@ -601,10 +620,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
if weight.device.type == "meta":
self.weight = torch.nn.Parameter(weight, requires_grad=False)
else:
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
@ -614,7 +636,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_v = state_dict.pop(param_key, None)
if _v is None:
continue
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
if _v.device.type == "meta":
self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
else:
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)