Fix weight casting double allocation

This commit is contained in:
ifilipis 2026-01-21 17:55:10 +00:00
parent 91809e83ff
commit fcbd22b514

View File

@ -719,7 +719,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
model_management.wait_for_pinned_tensor(current) model_management.wait_for_pinned_tensor(current)
model_management.unpin_memory(current) model_management.unpin_memory(current)
shape = getattr(disk_ref.meta, "shape", None) shape = getattr(disk_ref.meta, "shape", None)
dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None) dtype = getattr(disk_ref.meta, "dtype", None)
if shape is None or dtype is None: if shape is None or dtype is None:
continue continue
meta_tensor = torch.empty(shape, dtype=dtype, device="meta") meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
@ -757,7 +757,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
model_management.wait_for_pinned_tensor(current) model_management.wait_for_pinned_tensor(current)
model_management.unpin_memory(current) model_management.unpin_memory(current)
shape = getattr(disk_ref.meta, "shape", None) shape = getattr(disk_ref.meta, "shape", None)
dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None) dtype = getattr(disk_ref.meta, "dtype", None)
if shape is None or dtype is None: if shape is None or dtype is None:
return return
meta_tensor = torch.empty(shape, dtype=dtype, device="meta") meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
@ -851,9 +851,7 @@ def ensure_module_materialized(
if not refs: if not refs:
return return
state = _get_materialization_state(module) state = _get_materialization_state(module)
if dtype_override is not None: # Do not persist dtype overrides into storage.
for name in refs.keys():
_set_future_dtype(module, name, dtype_override)
_rebuild_materialization_state(module, refs, state) _rebuild_materialization_state(module, refs, state)
free_mem_start = _device_free_memory(target_device) free_mem_start = _device_free_memory(target_device)
from . import model_management from . import model_management
@ -871,10 +869,9 @@ def ensure_module_materialized(
continue continue
if current is None: if current is None:
continue continue
target_dtype = dtype_override or _get_future_dtype(module, name) # Persistent tensors must remain in stored dtype.
if current.device.type != "meta" and current.device == target_device and ( target_dtype = None
target_dtype is None or current.dtype == target_dtype if current.device.type != "meta" and current.device == target_device:
):
CACHE.touch(module, name) CACHE.touch(module, name)
continue continue
meta_nbytes = _meta_nbytes(disk_ref.meta) meta_nbytes = _meta_nbytes(disk_ref.meta)
@ -934,7 +931,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
return return
input_dtype = _find_tensor_dtype(args, kwargs) input_dtype = _find_tensor_dtype(args, kwargs)
manual_cast_dtype = getattr(module, "manual_cast_dtype", None) manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype) dtype_override = None # persistent tensors stay in stored dtype; per-op casting only
input_device = _find_tensor_device(args, kwargs) or torch.device("cpu") input_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
if getattr(module, "comfy_patched_weights", False): if getattr(module, "comfy_patched_weights", False):
target_device = input_device target_device = input_device
@ -1221,9 +1218,10 @@ def load_module_tensor(
return None return None
if current is None: if current is None:
return None return None
target_dtype = dtype_override or _get_future_dtype(module, name) # Persistent loads must not mix storage with dtype casting.
if dtype_override is not None: if not temporary:
_set_future_dtype(module, name, dtype_override) dtype_override = None
target_dtype = dtype_override
if current.device.type != "meta": if current.device.type != "meta":
if current.device != device or (target_dtype is not None and current.dtype != target_dtype): if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
from . import model_management from . import model_management
@ -1316,9 +1314,7 @@ def _materialize_module_from_state_dict(
metadata = getattr(lazy_state.state_dict, "_metadata", None) metadata = getattr(lazy_state.state_dict, "_metadata", None)
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
refs = REGISTRY.get(module) or {} refs = REGISTRY.get(module) or {}
if dtype_override is not None: # Do not persist dtype overrides into storage.
for name in refs.keys():
_set_future_dtype(module, name, dtype_override)
state = _get_materialization_state(module) state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state) _rebuild_materialization_state(module, refs, state)
keys = sorted(lazy_state.state_dict.keys()) keys = sorted(lazy_state.state_dict.keys())