From 5c609544483f044efc8a4ea59fe5257ffdbafce8 Mon Sep 17 00:00:00 2001 From: ifilipis <40601736+ifilipis@users.noreply.github.com> Date: Thu, 8 Jan 2026 20:43:08 +0200 Subject: [PATCH] Recast disk weights with dtype overrides --- comfy/disk_weights.py | 44 +++++++++++++++++---- tests-unit/utils/safetensors_stream_test.py | 28 +++++++++++++ 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index 6ec0d9a01..2b6ac57ef 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -665,8 +665,17 @@ def ensure_module_materialized( if current is None: continue if current.device.type != "meta" and current.device == target_device: - if current.device.type == "cpu": - CACHE.touch(module, name) + if dtype_override is not None and current.dtype != dtype_override: + tensor = current.to(device=target_device, dtype=dtype_override) + if is_buffer: + module._buffers[name] = tensor + else: + module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad) + if tensor.device.type == "cpu": + CACHE.record(module, name, tensor, is_buffer=is_buffer) + else: + if current.device.type == "cpu": + CACHE.touch(module, name) continue meta_nbytes = _meta_nbytes(disk_ref.meta) if meta_nbytes is None: @@ -704,7 +713,7 @@ def ensure_module_materialized( _log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized") -def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs): +def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}): if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE: return input_dtype = _find_tensor_dtype(args, kwargs) @@ -740,11 +749,15 @@ def evict_ram_cache(bytes_to_free: int): return CACHE.evict_bytes(bytes_to_free) -def materialize_module_tree(module: torch.nn.Module, target_device: torch.device): +def materialize_module_tree( + module: torch.nn.Module, + target_device: torch.device, + dtype_override: Optional[torch.dtype] = None, +): if not disk_weights_enabled(): return for submodule in module.modules(): - ensure_module_materialized(submodule, target_device) + ensure_module_materialized(submodule, target_device, dtype_override=dtype_override) def _extract_to_device(args, kwargs) -> Optional[torch.device]: @@ -758,6 +771,15 @@ def _extract_to_device(args, kwargs) -> Optional[torch.device]: return None +def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]: + if "dtype" in kwargs and kwargs["dtype"] is not None: + return kwargs["dtype"] + for arg in args: + if isinstance(arg, torch.dtype): + return arg + return None + + def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]: for param in module.parameters(recurse=True): if param is not None and param.device.type != "meta": @@ -771,9 +793,12 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]: def module_to(module: torch.nn.Module, *args, **kwargs): if disk_weights_enabled(): target_device = _extract_to_device(args, kwargs) + dtype_override = _extract_to_dtype(args, kwargs) if target_device is None: target_device = _find_existing_device(module) or torch.device("cpu") - materialize_module_tree(module, target_device) + if dtype_override is None: + dtype_override = getattr(module, "manual_cast_dtype", None) + materialize_module_tree(module, target_device, dtype_override=dtype_override) return module.to(*args, **kwargs) @@ -976,18 +1001,21 @@ def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = Fals if error_msgs: raise RuntimeError("Error(s) in loading state_dict:\n\t{}".format("\n\t".join(error_msgs))) + dtype_override = getattr(model, "manual_cast_dtype", None) for name, param in model.named_parameters(recurse=True): if name not in state_keys: continue meta = state_dict.meta(name) - meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta") + meta_dtype = dtype_override or meta.dtype + meta_tensor = torch.empty(meta.shape, dtype=meta_dtype, device="meta") _replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad) for name, buf in model.named_buffers(recurse=True): if buf is None or name not in state_keys: continue meta = state_dict.meta(name) - meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta") + meta_dtype = dtype_override or meta.dtype + meta_tensor = torch.empty(meta.shape, dtype=meta_dtype, device="meta") _replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False) register_module_weights(model, state_dict) diff --git a/tests-unit/utils/safetensors_stream_test.py b/tests-unit/utils/safetensors_stream_test.py index 60d36142d..d7911f517 100644 --- a/tests-unit/utils/safetensors_stream_test.py +++ b/tests-unit/utils/safetensors_stream_test.py @@ -180,3 +180,31 @@ def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch): assert len(calls) == 2 finally: comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled) + + +def test_lazy_disk_weights_respects_dtype_override(tmp_path): + if importlib.util.find_spec("fastsafetensors") is None: + pytest.skip("fastsafetensors not installed") + import comfy.utils + import comfy.disk_weights + + prev_cache = comfy.disk_weights.CACHE.max_bytes + prev_gds = comfy.disk_weights.ALLOW_GDS + prev_pin = comfy.disk_weights.PIN_IF_CPU + prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED + comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True) + + try: + path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.bfloat16), "bias": torch.zeros((4,), dtype=torch.bfloat16)}) + sd = comfy.utils.load_torch_file(path, safe_load=True) + model = torch.nn.Linear(4, 4, bias=True) + comfy.utils.load_state_dict(model, sd, strict=True) + assert model.weight.device.type == "meta" + + comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu")) + assert model.weight.dtype == torch.bfloat16 + + comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"), dtype_override=torch.float16) + assert model.weight.dtype == torch.float16 + finally: + comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)