Fix weight casting double allocation

This commit is contained in:
ifilipis 2026-01-21 17:55:10 +00:00
parent 91809e83ff
commit fcbd22b514

View File

@ -719,7 +719,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
model_management.wait_for_pinned_tensor(current)
model_management.unpin_memory(current)
shape = getattr(disk_ref.meta, "shape", None)
dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None)
dtype = getattr(disk_ref.meta, "dtype", None)
if shape is None or dtype is None:
continue
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
@ -757,7 +757,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
model_management.wait_for_pinned_tensor(current)
model_management.unpin_memory(current)
shape = getattr(disk_ref.meta, "shape", None)
dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None)
dtype = getattr(disk_ref.meta, "dtype", None)
if shape is None or dtype is None:
return
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
@ -851,9 +851,7 @@ def ensure_module_materialized(
if not refs:
return
state = _get_materialization_state(module)
if dtype_override is not None:
for name in refs.keys():
_set_future_dtype(module, name, dtype_override)
# Do not persist dtype overrides into storage.
_rebuild_materialization_state(module, refs, state)
free_mem_start = _device_free_memory(target_device)
from . import model_management
@ -871,10 +869,9 @@ def ensure_module_materialized(
continue
if current is None:
continue
target_dtype = dtype_override or _get_future_dtype(module, name)
if current.device.type != "meta" and current.device == target_device and (
target_dtype is None or current.dtype == target_dtype
):
# Persistent tensors must remain in stored dtype.
target_dtype = None
if current.device.type != "meta" and current.device == target_device:
CACHE.touch(module, name)
continue
meta_nbytes = _meta_nbytes(disk_ref.meta)
@ -934,7 +931,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
return
input_dtype = _find_tensor_dtype(args, kwargs)
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
dtype_override = None # persistent tensors stay in stored dtype; per-op casting only
input_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
if getattr(module, "comfy_patched_weights", False):
target_device = input_device
@ -1221,9 +1218,10 @@ def load_module_tensor(
return None
if current is None:
return None
target_dtype = dtype_override or _get_future_dtype(module, name)
if dtype_override is not None:
_set_future_dtype(module, name, dtype_override)
# Persistent loads must not mix storage with dtype casting.
if not temporary:
dtype_override = None
target_dtype = dtype_override
if current.device.type != "meta":
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
from . import model_management
@ -1316,9 +1314,7 @@ def _materialize_module_from_state_dict(
metadata = getattr(lazy_state.state_dict, "_metadata", None)
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
refs = REGISTRY.get(module) or {}
if dtype_override is not None:
for name in refs.keys():
_set_future_dtype(module, name, dtype_override)
# Do not persist dtype overrides into storage.
state = _get_materialization_state(module)
_rebuild_materialization_state(module, refs, state)
keys = sorted(lazy_state.state_dict.keys())