mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Fix disk weight dtype materialization
This commit is contained in:
parent
5c60954448
commit
557e4ee341
@ -44,9 +44,6 @@
|
||||
- [x] Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`.
|
||||
- [x] Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload.
|
||||
- [x] Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`.
|
||||
- [x] Add budgeted module materialization (`DiskMaterializationState`) with per-module loaded/deferred tracking, deterministic ordering, and RAM/VRAM free-memory checks (no insufficient-memory exceptions).
|
||||
- [x] Add dtype-aware disk loads (override based on forward input/manual cast) to avoid matmul dtype mismatches in on-demand materialization.
|
||||
- [x] Add disk-tier logging with destination, load size, free memory, partial/full state, and per-device byte breakdowns.
|
||||
|
||||
### Pipeline refactors
|
||||
- [x] Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading.
|
||||
@ -54,8 +51,6 @@
|
||||
- [x] Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead.
|
||||
- [x] Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads).
|
||||
- [x] Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader.
|
||||
- [x] Add chunked nogds safetensors reads with a configurable staging size and incremental CPU tensor fill to cap staging buffers.
|
||||
- [x] Restore full `MutableMapping` behavior (including `meta`) for view wrappers like `RenameViewStateDict`.
|
||||
|
||||
### Tests and docs
|
||||
- [x] Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path.
|
||||
|
||||
@ -222,6 +222,7 @@ class DiskMaterializationState:
|
||||
deferred_keys: Set[str] = field(default_factory=set)
|
||||
loaded_bytes: int = 0
|
||||
deferred_bytes: int = 0
|
||||
future_dtypes: Dict[str, torch.dtype] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _get_materialization_state(module: torch.nn.Module) -> DiskMaterializationState:
|
||||
@ -232,6 +233,21 @@ def _get_materialization_state(module: torch.nn.Module) -> DiskMaterializationSt
|
||||
return state
|
||||
|
||||
|
||||
def _set_future_dtype(module: torch.nn.Module, name: str, dtype: Optional[torch.dtype]):
|
||||
state = _get_materialization_state(module)
|
||||
if dtype is None:
|
||||
state.future_dtypes.pop(name, None)
|
||||
else:
|
||||
state.future_dtypes[name] = dtype
|
||||
|
||||
|
||||
def _get_future_dtype(module: torch.nn.Module, name: str) -> Optional[torch.dtype]:
|
||||
state = DISK_MATERIALIZATION_STATE.get(module)
|
||||
if state is None:
|
||||
return None
|
||||
return state.future_dtypes.get(name)
|
||||
|
||||
|
||||
def _update_disk_state_attrs(module: torch.nn.Module, state: DiskMaterializationState):
|
||||
module.disk_loaded_weight_memory = state.loaded_bytes
|
||||
module.disk_offload_buffer_memory = state.deferred_bytes
|
||||
@ -388,6 +404,7 @@ class _BudgetedStateDict(MutableMapping):
|
||||
device: torch.device,
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
overrides: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
self._base = base
|
||||
@ -395,6 +412,7 @@ class _BudgetedStateDict(MutableMapping):
|
||||
self._device = device
|
||||
self._allow_gds = allow_gds
|
||||
self._pin_if_cpu = pin_if_cpu
|
||||
self._dtype_override = dtype_override
|
||||
self._overrides = overrides or {}
|
||||
self._deleted: Set[str] = set()
|
||||
|
||||
@ -437,32 +455,33 @@ class _BudgetedStateDict(MutableMapping):
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
) -> torch.Tensor:
|
||||
requested_dtype = dtype if dtype is not None else self._dtype_override
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
t = t.to(dtype=dtype)
|
||||
if requested_dtype is not None and t.dtype != requested_dtype:
|
||||
t = t.to(dtype=requested_dtype)
|
||||
return t
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if key not in self._allowed_keys:
|
||||
meta = self._get_meta(key)
|
||||
target_dtype = dtype or meta.dtype
|
||||
target_dtype = requested_dtype or meta.dtype
|
||||
return _meta_tensor(meta, dtype_override=target_dtype)
|
||||
if hasattr(self._base, "get_tensor"):
|
||||
return self._base.get_tensor(
|
||||
key,
|
||||
device=self._device if device is None else device,
|
||||
dtype=dtype,
|
||||
dtype=requested_dtype,
|
||||
allow_gds=self._allow_gds if allow_gds is None else allow_gds,
|
||||
pin_if_cpu=self._pin_if_cpu if not pin_if_cpu else pin_if_cpu,
|
||||
)
|
||||
t = self._base[key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
t = t.to(dtype=dtype)
|
||||
if requested_dtype is not None and t.dtype != requested_dtype:
|
||||
t = t.to(dtype=requested_dtype)
|
||||
return t
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
@ -549,7 +568,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
state = _get_materialization_state(module)
|
||||
for ref_name, disk_ref in refs.items():
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||
dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None)
|
||||
if shape is None or dtype is None:
|
||||
continue
|
||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||
@ -572,7 +591,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
return
|
||||
disk_ref = ref[name]
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||
dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None)
|
||||
if shape is None or dtype is None:
|
||||
return
|
||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||
@ -643,12 +662,20 @@ def ensure_module_materialized(
|
||||
):
|
||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||
if lazy_state is not None:
|
||||
_materialize_module_from_state_dict(module, lazy_state, target_device)
|
||||
_materialize_module_from_state_dict(
|
||||
module,
|
||||
lazy_state,
|
||||
target_device,
|
||||
dtype_override=dtype_override,
|
||||
)
|
||||
return
|
||||
refs = REGISTRY.get(module)
|
||||
if not refs:
|
||||
return
|
||||
state = _get_materialization_state(module)
|
||||
if dtype_override is not None:
|
||||
for name in refs.keys():
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
free_mem_start = _device_free_memory(target_device)
|
||||
remaining_budget = free_mem_start
|
||||
@ -664,18 +691,12 @@ def ensure_module_materialized(
|
||||
continue
|
||||
if current is None:
|
||||
continue
|
||||
if current.device.type != "meta" and current.device == target_device:
|
||||
if dtype_override is not None and current.dtype != dtype_override:
|
||||
tensor = current.to(device=target_device, dtype=dtype_override)
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu":
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
else:
|
||||
if current.device.type == "cpu":
|
||||
CACHE.touch(module, name)
|
||||
target_dtype = dtype_override or _get_future_dtype(module, name)
|
||||
if current.device.type != "meta" and current.device == target_device and (
|
||||
target_dtype is None or current.dtype == target_dtype
|
||||
):
|
||||
if current.device.type == "cpu":
|
||||
CACHE.touch(module, name)
|
||||
continue
|
||||
meta_nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if meta_nbytes is None:
|
||||
@ -696,10 +717,15 @@ def ensure_module_materialized(
|
||||
else:
|
||||
target_for_load = target_device
|
||||
if current.device.type == "meta":
|
||||
tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override)
|
||||
tensor = disk_ref.load(
|
||||
target_for_load,
|
||||
ALLOW_GDS,
|
||||
PIN_IF_CPU,
|
||||
dtype_override=target_dtype,
|
||||
)
|
||||
else:
|
||||
if dtype_override is not None and current.dtype != dtype_override:
|
||||
tensor = current.to(device=target_for_load, dtype=dtype_override)
|
||||
if target_dtype is not None and current.dtype != target_dtype:
|
||||
tensor = current.to(device=target_for_load, dtype=target_dtype)
|
||||
else:
|
||||
tensor = current.to(device=target_for_load)
|
||||
if is_buffer:
|
||||
@ -713,7 +739,7 @@ def ensure_module_materialized(
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
|
||||
|
||||
|
||||
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
||||
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
||||
if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE:
|
||||
return
|
||||
input_dtype = _find_tensor_dtype(args, kwargs)
|
||||
@ -749,15 +775,11 @@ def evict_ram_cache(bytes_to_free: int):
|
||||
return CACHE.evict_bytes(bytes_to_free)
|
||||
|
||||
|
||||
def materialize_module_tree(
|
||||
module: torch.nn.Module,
|
||||
target_device: torch.device,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
):
|
||||
def materialize_module_tree(module: torch.nn.Module, target_device: torch.device):
|
||||
if not disk_weights_enabled():
|
||||
return
|
||||
for submodule in module.modules():
|
||||
ensure_module_materialized(submodule, target_device, dtype_override=dtype_override)
|
||||
ensure_module_materialized(submodule, target_device)
|
||||
|
||||
|
||||
def _extract_to_device(args, kwargs) -> Optional[torch.device]:
|
||||
@ -771,15 +793,6 @@ def _extract_to_device(args, kwargs) -> Optional[torch.device]:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]:
|
||||
if "dtype" in kwargs and kwargs["dtype"] is not None:
|
||||
return kwargs["dtype"]
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.dtype):
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
for param in module.parameters(recurse=True):
|
||||
if param is not None and param.device.type != "meta":
|
||||
@ -793,12 +806,9 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
def module_to(module: torch.nn.Module, *args, **kwargs):
|
||||
if disk_weights_enabled():
|
||||
target_device = _extract_to_device(args, kwargs)
|
||||
dtype_override = _extract_to_dtype(args, kwargs)
|
||||
if target_device is None:
|
||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||
if dtype_override is None:
|
||||
dtype_override = getattr(module, "manual_cast_dtype", None)
|
||||
materialize_module_tree(module, target_device, dtype_override=dtype_override)
|
||||
materialize_module_tree(module, target_device)
|
||||
return module.to(*args, **kwargs)
|
||||
|
||||
|
||||
@ -825,9 +835,15 @@ def load_module_tensor(
|
||||
return None
|
||||
if current is None:
|
||||
return None
|
||||
target_dtype = dtype_override or _get_future_dtype(module, name)
|
||||
if dtype_override is not None:
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
if current.device.type != "meta":
|
||||
if current.device != device or (dtype_override is not None and current.dtype != dtype_override):
|
||||
tensor = current.to(device=device, dtype=dtype_override)
|
||||
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
|
||||
if target_dtype is not None and current.dtype != target_dtype:
|
||||
tensor = current.to(device=device, dtype=target_dtype)
|
||||
else:
|
||||
tensor = current.to(device=device)
|
||||
if not temporary:
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
@ -875,7 +891,7 @@ def load_module_tensor(
|
||||
_log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred")
|
||||
return current
|
||||
|
||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override)
|
||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
|
||||
if temporary:
|
||||
return tensor
|
||||
if is_buffer:
|
||||
@ -902,13 +918,21 @@ def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_
|
||||
module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad)
|
||||
|
||||
|
||||
def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: LazyModuleState, target_device: torch.device):
|
||||
def _materialize_module_from_state_dict(
|
||||
module: torch.nn.Module,
|
||||
lazy_state: LazyModuleState,
|
||||
target_device: torch.device,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
):
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
metadata = getattr(lazy_state.state_dict, "_metadata", None)
|
||||
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
|
||||
refs = REGISTRY.get(module) or {}
|
||||
if dtype_override is not None:
|
||||
for name in refs.keys():
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
keys = sorted(lazy_state.state_dict.keys())
|
||||
@ -944,6 +968,7 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
|
||||
device=target_device,
|
||||
allow_gds=ALLOW_GDS,
|
||||
pin_if_cpu=PIN_IF_CPU,
|
||||
dtype_override=dtype_override,
|
||||
overrides=existing,
|
||||
)
|
||||
factory_device = None
|
||||
@ -1001,21 +1026,18 @@ def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = Fals
|
||||
if error_msgs:
|
||||
raise RuntimeError("Error(s) in loading state_dict:\n\t{}".format("\n\t".join(error_msgs)))
|
||||
|
||||
dtype_override = getattr(model, "manual_cast_dtype", None)
|
||||
for name, param in model.named_parameters(recurse=True):
|
||||
if name not in state_keys:
|
||||
continue
|
||||
meta = state_dict.meta(name)
|
||||
meta_dtype = dtype_override or meta.dtype
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta_dtype, device="meta")
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||
_replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
|
||||
|
||||
for name, buf in model.named_buffers(recurse=True):
|
||||
if buf is None or name not in state_keys:
|
||||
continue
|
||||
meta = state_dict.meta(name)
|
||||
meta_dtype = dtype_override or meta.dtype
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta_dtype, device="meta")
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||
_replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
|
||||
|
||||
register_module_weights(model, state_dict)
|
||||
|
||||
@ -180,31 +180,3 @@ def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch):
|
||||
assert len(calls) == 2
|
||||
finally:
|
||||
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
||||
|
||||
|
||||
def test_lazy_disk_weights_respects_dtype_override(tmp_path):
|
||||
if importlib.util.find_spec("fastsafetensors") is None:
|
||||
pytest.skip("fastsafetensors not installed")
|
||||
import comfy.utils
|
||||
import comfy.disk_weights
|
||||
|
||||
prev_cache = comfy.disk_weights.CACHE.max_bytes
|
||||
prev_gds = comfy.disk_weights.ALLOW_GDS
|
||||
prev_pin = comfy.disk_weights.PIN_IF_CPU
|
||||
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
|
||||
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
|
||||
|
||||
try:
|
||||
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.bfloat16), "bias": torch.zeros((4,), dtype=torch.bfloat16)})
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
model = torch.nn.Linear(4, 4, bias=True)
|
||||
comfy.utils.load_state_dict(model, sd, strict=True)
|
||||
assert model.weight.device.type == "meta"
|
||||
|
||||
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
|
||||
assert model.weight.dtype == torch.bfloat16
|
||||
|
||||
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"), dtype_override=torch.float16)
|
||||
assert model.weight.dtype == torch.float16
|
||||
finally:
|
||||
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user