From fcbd22b51414e41fbbee373fac6eedbca8cb1d20 Mon Sep 17 00:00:00 2001 From: ifilipis <40601736+ifilipis@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:55:10 +0000 Subject: [PATCH] Fix weight casting double allocation --- comfy/disk_weights.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/comfy/disk_weights.py b/comfy/disk_weights.py index 950392a03..dadfca2f7 100644 --- a/comfy/disk_weights.py +++ b/comfy/disk_weights.py @@ -719,7 +719,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): model_management.wait_for_pinned_tensor(current) model_management.unpin_memory(current) shape = getattr(disk_ref.meta, "shape", None) - dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None) + dtype = getattr(disk_ref.meta, "dtype", None) if shape is None or dtype is None: continue meta_tensor = torch.empty(shape, dtype=dtype, device="meta") @@ -757,7 +757,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool): model_management.wait_for_pinned_tensor(current) model_management.unpin_memory(current) shape = getattr(disk_ref.meta, "shape", None) - dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None) + dtype = getattr(disk_ref.meta, "dtype", None) if shape is None or dtype is None: return meta_tensor = torch.empty(shape, dtype=dtype, device="meta") @@ -851,9 +851,7 @@ def ensure_module_materialized( if not refs: return state = _get_materialization_state(module) - if dtype_override is not None: - for name in refs.keys(): - _set_future_dtype(module, name, dtype_override) + # Do not persist dtype overrides into storage. _rebuild_materialization_state(module, refs, state) free_mem_start = _device_free_memory(target_device) from . import model_management @@ -871,10 +869,9 @@ def ensure_module_materialized( continue if current is None: continue - target_dtype = dtype_override or _get_future_dtype(module, name) - if current.device.type != "meta" and current.device == target_device and ( - target_dtype is None or current.dtype == target_dtype - ): + # Persistent tensors must remain in stored dtype. + target_dtype = None + if current.device.type != "meta" and current.device == target_device: CACHE.touch(module, name) continue meta_nbytes = _meta_nbytes(disk_ref.meta) @@ -934,7 +931,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}): return input_dtype = _find_tensor_dtype(args, kwargs) manual_cast_dtype = getattr(module, "manual_cast_dtype", None) - dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype) + dtype_override = None # persistent tensors stay in stored dtype; per-op casting only input_device = _find_tensor_device(args, kwargs) or torch.device("cpu") if getattr(module, "comfy_patched_weights", False): target_device = input_device @@ -1221,9 +1218,10 @@ def load_module_tensor( return None if current is None: return None - target_dtype = dtype_override or _get_future_dtype(module, name) - if dtype_override is not None: - _set_future_dtype(module, name, dtype_override) + # Persistent loads must not mix storage with dtype casting. + if not temporary: + dtype_override = None + target_dtype = dtype_override if current.device.type != "meta": if current.device != device or (target_dtype is not None and current.dtype != target_dtype): from . import model_management @@ -1316,9 +1314,7 @@ def _materialize_module_from_state_dict( metadata = getattr(lazy_state.state_dict, "_metadata", None) local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {}) refs = REGISTRY.get(module) or {} - if dtype_override is not None: - for name in refs.keys(): - _set_future_dtype(module, name, dtype_override) + # Do not persist dtype overrides into storage. state = _get_materialization_state(module) _rebuild_materialization_state(module, refs, state) keys = sorted(lazy_state.state_dict.keys())