From 557e4ee34158208510da7f49f53a2a32f6ffbf35 Mon Sep 17 00:00:00 2001 From: ifilipis <40601736+ifilipis@users.noreply.github.com> Date: Thu, 8 Jan 2026 23:49:57 +0200 Subject: [PATCH] Fix disk weight dtype materialization --- DESIGN.md | 5 - comfy/disk_weights.py | 128 ++++++++++++-------- tests-unit/utils/safetensors_stream_test.py | 28 ----- 3 files changed, 75 insertions(+), 86 deletions(-) diff --git a/DESIGN.md b/DESIGN.md index 38e4e6eec..19a1bfc5d 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -44,9 +44,6 @@ - [x] Represent disk-resident weights as meta tensors (`device='meta'`) plus a `DiskRef` registry that stores `(module, param_name) -> TensorMeta + loader handle`. - [x] Add an LRU cache for RAM-resident weights loaded from disk with configurable max bytes. Eviction replaces RAM tensors with meta tensors and keeps `DiskRef` for reload. - [x] Add a general `forward_pre_hook` to materialize any meta+DiskRef weights before compute; this covers modules that bypass `comfy.ops`. -- [x] Add budgeted module materialization (`DiskMaterializationState`) with per-module loaded/deferred tracking, deterministic ordering, and RAM/VRAM free-memory checks (no insufficient-memory exceptions). -- [x] Add dtype-aware disk loads (override based on forward input/manual cast) to avoid matmul dtype mismatches in on-demand materialization. -- [x] Add disk-tier logging with destination, load size, free memory, partial/full state, and per-device byte breakdowns. ### Pipeline refactors - [x] Update `load_torch_file` to return `StreamStateDict` for `.safetensors`/`.sft` and return metadata without loading. @@ -54,8 +51,6 @@ - [x] Update `BaseModel.load_model_weights` and other load paths to avoid building large dicts; use streaming mappings + view wrappers instead. - [x] Update model detection (`comfy/model_detection.py`) to use metadata-based shape/dtype access (no tensor reads). - [x] Update direct safetensors loaders (e.g., `comfy/sd1_clip.py`) to go through `load_torch_file` so everything uses the same streaming loader. -- [x] Add chunked nogds safetensors reads with a configurable staging size and incremental CPU tensor fill to cap staging buffers. -- [x] Restore full `MutableMapping` behavior (including `meta`) for view wrappers like `RenameViewStateDict`. ### Tests and docs - [x] Add unit tests for metadata correctness, single-tensor loading, and lazy views (no full materialization), plus integration tests for load behavior and GDS failure path. diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index 2b6ac57ef..bceef4c30 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -222,6 +222,7 @@ class DiskMaterializationState: deferred_keys: Set[str] = field(default_factory=set) loaded_bytes: int = 0 deferred_bytes: int = 0 + future_dtypes: Dict[str, torch.dtype] = field(default_factory=dict) def _get_materialization_state(module: torch.nn.Module) -> DiskMaterializationState: @@ -232,6 +233,21 @@ def _get_materialization_state(module: torch.nn.Module) -> DiskMaterializationSt return state +def _set_future_dtype(module: torch.nn.Module, name: str, dtype: Optional[torch.dtype]): + state = _get_materialization_state(module) + if dtype is None: + state.future_dtypes.pop(name, None) + else: + state.future_dtypes[name] = dtype + + +def _get_future_dtype(module: torch.nn.Module, name: str) -> Optional[torch.dtype]: + state = DISK_MATERIALIZATION_STATE.get(module) + if state is None: + return None + return state.future_dtypes.get(name) + + def _update_disk_state_attrs(module: torch.nn.Module, state: DiskMaterializationState): module.disk_loaded_weight_memory = state.loaded_bytes module.disk_offload_buffer_memory = state.deferred_bytes @@ -388,6 +404,7 @@ class _BudgetedStateDict(MutableMapping): device: torch.device, allow_gds: Optional[bool] = None, pin_if_cpu: bool = False, + dtype_override: Optional[torch.dtype] = None, overrides: Optional[Dict[str, torch.Tensor]] = None, ): self._base = base @@ -395,6 +412,7 @@ class _BudgetedStateDict(MutableMapping): self._device = device self._allow_gds = allow_gds self._pin_if_cpu = pin_if_cpu + self._dtype_override = dtype_override self._overrides = overrides or {} self._deleted: Set[str] = set() @@ -437,32 +455,33 @@ class _BudgetedStateDict(MutableMapping): allow_gds: Optional[bool] = None, pin_if_cpu: bool = False, ) -> torch.Tensor: + requested_dtype = dtype if dtype is not None else self._dtype_override if key in self._overrides: t = self._overrides[key] if device is not None and t.device != device: t = t.to(device=device) - if dtype is not None and t.dtype != dtype: - t = t.to(dtype=dtype) + if requested_dtype is not None and t.dtype != requested_dtype: + t = t.to(dtype=requested_dtype) return t if key in self._deleted: raise KeyError(key) if key not in self._allowed_keys: meta = self._get_meta(key) - target_dtype = dtype or meta.dtype + target_dtype = requested_dtype or meta.dtype return _meta_tensor(meta, dtype_override=target_dtype) if hasattr(self._base, "get_tensor"): return self._base.get_tensor( key, device=self._device if device is None else device, - dtype=dtype, + dtype=requested_dtype, allow_gds=self._allow_gds if allow_gds is None else allow_gds, pin_if_cpu=self._pin_if_cpu if not pin_if_cpu else pin_if_cpu, ) t = self._base[key] if device is not None and t.device != device: t = t.to(device=device) - if dtype is not None and t.dtype != dtype: - t = t.to(dtype=dtype) + if requested_dtype is not None and t.dtype != requested_dtype: + t = t.to(dtype=requested_dtype) return t def __getitem__(self, key: str) -> torch.Tensor: @@ -549,7 +568,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): state = _get_materialization_state(module) for ref_name, disk_ref in refs.items(): shape = getattr(disk_ref.meta, "shape", None) - dtype = getattr(disk_ref.meta, "dtype", None) + dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None) if shape is None or dtype is None: continue meta_tensor = torch.empty(shape, dtype=dtype, device="meta") @@ -572,7 +591,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): return disk_ref = ref[name] shape = getattr(disk_ref.meta, "shape", None) - dtype = getattr(disk_ref.meta, "dtype", None) + dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None) if shape is None or dtype is None: return meta_tensor = torch.empty(shape, dtype=dtype, device="meta") @@ -643,12 +662,20 @@ def ensure_module_materialized( ): lazy_state = LAZY_MODULE_STATE.get(module) if lazy_state is not None: - _materialize_module_from_state_dict(module, lazy_state, target_device) + _materialize_module_from_state_dict( + module, + lazy_state, + target_device, + dtype_override=dtype_override, + ) return refs = REGISTRY.get(module) if not refs: return state = _get_materialization_state(module) + if dtype_override is not None: + for name in refs.keys(): + _set_future_dtype(module, name, dtype_override) _rebuild_materialization_state(module, refs, state) free_mem_start = _device_free_memory(target_device) remaining_budget = free_mem_start @@ -664,18 +691,12 @@ def ensure_module_materialized( continue if current is None: continue - if current.device.type != "meta" and current.device == target_device: - if dtype_override is not None and current.dtype != dtype_override: - tensor = current.to(device=target_device, dtype=dtype_override) - if is_buffer: - module._buffers[name] = tensor - else: - module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) - if tensor.device.type == "cpu": - CACHE.record(module, name, tensor, is_buffer=is_buffer) - else: - if current.device.type == "cpu": - CACHE.touch(module, name) + target_dtype = dtype_override or _get_future_dtype(module, name) + if current.device.type != "meta" and current.device == target_device and ( + target_dtype is None or current.dtype == target_dtype + ): + if current.device.type == "cpu": + CACHE.touch(module, name) continue meta_nbytes = _meta_nbytes(disk_ref.meta) if meta_nbytes is None: @@ -696,10 +717,15 @@ def ensure_module_materialized( else: target_for_load = target_device if current.device.type == "meta": - tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override) + tensor = disk_ref.load( + target_for_load, + ALLOW_GDS, + PIN_IF_CPU, + dtype_override=target_dtype, + ) else: - if dtype_override is not None and current.dtype != dtype_override: - tensor = current.to(device=target_for_load, dtype=dtype_override) + if target_dtype is not None and current.dtype != target_dtype: + tensor = current.to(device=target_for_load, dtype=target_dtype) else: tensor = current.to(device=target_for_load) if is_buffer: @@ -713,7 +739,7 @@ def ensure_module_materialized( _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized") -def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}): +def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs): if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE: return input_dtype = _find_tensor_dtype(args, kwargs) @@ -749,15 +775,11 @@ def evict_ram_cache(bytes_to_free: int): return CACHE.evict_bytes(bytes_to_free) -def materialize_module_tree( - module: torch.nn.Module, - target_device: torch.device, - dtype_override: Optional[torch.dtype] = None, -): +def materialize_module_tree(module: torch.nn.Module, target_device: torch.device): if not disk_weights_enabled(): return for submodule in module.modules(): - ensure_module_materialized(submodule, target_device, dtype_override=dtype_override) + ensure_module_materialized(submodule, target_device) def _extract_to_device(args, kwargs) -> Optional[torch.device]: @@ -771,15 +793,6 @@ def _extract_to_device(args, kwargs) -> Optional[torch.device]: return None -def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]: - if "dtype" in kwargs and kwargs["dtype"] is not None: - return kwargs["dtype"] - for arg in args: - if isinstance(arg, torch.dtype): - return arg - return None - - def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]: for param in module.parameters(recurse=True): if param is not None and param.device.type != "meta": @@ -793,12 +806,9 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]: def module_to(module: torch.nn.Module, *args, **kwargs): if disk_weights_enabled(): target_device = _extract_to_device(args, kwargs) - dtype_override = _extract_to_dtype(args, kwargs) if target_device is None: target_device = _find_existing_device(module) or torch.device("cpu") - if dtype_override is None: - dtype_override = getattr(module, "manual_cast_dtype", None) - materialize_module_tree(module, target_device, dtype_override=dtype_override) + materialize_module_tree(module, target_device) return module.to(*args, **kwargs) @@ -825,9 +835,15 @@ def load_module_tensor( return None if current is None: return None + target_dtype = dtype_override or _get_future_dtype(module, name) + if dtype_override is not None: + _set_future_dtype(module, name, dtype_override) if current.device.type != "meta": - if current.device != device or (dtype_override is not None and current.dtype != dtype_override): - tensor = current.to(device=device, dtype=dtype_override) + if current.device != device or (target_dtype is not None and current.dtype != target_dtype): + if target_dtype is not None and current.dtype != target_dtype: + tensor = current.to(device=device, dtype=target_dtype) + else: + tensor = current.to(device=device) if not temporary: if is_buffer: module._buffers[name] = tensor @@ -875,7 +891,7 @@ def load_module_tensor( _log_materialization(module, device, free_mem_start, refs, state, "Disk weight deferred") return current - tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=dtype_override) + tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype) if temporary: return tensor if is_buffer: @@ -902,13 +918,21 @@ def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_ module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad) -def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: LazyModuleState, target_device: torch.device): +def _materialize_module_from_state_dict( + module: torch.nn.Module, + lazy_state: LazyModuleState, + target_device: torch.device, + dtype_override: Optional[torch.dtype] = None, +): missing_keys = [] unexpected_keys = [] error_msgs = [] metadata = getattr(lazy_state.state_dict, "_metadata", None) local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {}) refs = REGISTRY.get(module) or {} + if dtype_override is not None: + for name in refs.keys(): + _set_future_dtype(module, name, dtype_override) state = _get_materialization_state(module) _rebuild_materialization_state(module, refs, state) keys = sorted(lazy_state.state_dict.keys()) @@ -944,6 +968,7 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz device=target_device, allow_gds=ALLOW_GDS, pin_if_cpu=PIN_IF_CPU, + dtype_override=dtype_override, overrides=existing, ) factory_device = None @@ -1001,21 +1026,18 @@ def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = Fals if error_msgs: raise RuntimeError("Error(s) in loading state_dict:\n\t{}".format("\n\t".join(error_msgs))) - dtype_override = getattr(model, "manual_cast_dtype", None) for name, param in model.named_parameters(recurse=True): if name not in state_keys: continue meta = state_dict.meta(name) - meta_dtype = dtype_override or meta.dtype - meta_tensor = torch.empty(meta.shape, dtype=meta_dtype, device="meta") + meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta") _replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad) for name, buf in model.named_buffers(recurse=True): if buf is None or name not in state_keys: continue meta = state_dict.meta(name) - meta_dtype = dtype_override or meta.dtype - meta_tensor = torch.empty(meta.shape, dtype=meta_dtype, device="meta") + meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta") _replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False) register_module_weights(model, state_dict) diff --git a/tests-unit/utils/safetensors_stream_test.py b/tests-unit/utils/safetensors_stream_test.py index d7911f517..60d36142d 100644 --- a/tests-unit/utils/safetensors_stream_test.py +++ b/tests-unit/utils/safetensors_stream_test.py @@ -180,31 +180,3 @@ def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch): assert len(calls) == 2 finally: comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled) - - -def test_lazy_disk_weights_respects_dtype_override(tmp_path): - if importlib.util.find_spec("fastsafetensors") is None: - pytest.skip("fastsafetensors not installed") - import comfy.utils - import comfy.disk_weights - - prev_cache = comfy.disk_weights.CACHE.max_bytes - prev_gds = comfy.disk_weights.ALLOW_GDS - prev_pin = comfy.disk_weights.PIN_IF_CPU - prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED - comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True) - - try: - path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.bfloat16), "bias": torch.zeros((4,), dtype=torch.bfloat16)}) - sd = comfy.utils.load_torch_file(path, safe_load=True) - model = torch.nn.Linear(4, 4, bias=True) - comfy.utils.load_state_dict(model, sd, strict=True) - assert model.weight.device.type == "meta" - - comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu")) - assert model.weight.dtype == torch.bfloat16 - - comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"), dtype_override=torch.float16) - assert model.weight.dtype == torch.float16 - finally: - comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)