mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 16:50:17 +08:00
Fix weight casting double allocation
This commit is contained in:
parent
91809e83ff
commit
fcbd22b514
@ -719,7 +719,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None)
|
||||
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||
if shape is None or dtype is None:
|
||||
continue
|
||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||
@ -757,7 +757,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None)
|
||||
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||
if shape is None or dtype is None:
|
||||
return
|
||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||
@ -851,9 +851,7 @@ def ensure_module_materialized(
|
||||
if not refs:
|
||||
return
|
||||
state = _get_materialization_state(module)
|
||||
if dtype_override is not None:
|
||||
for name in refs.keys():
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
# Do not persist dtype overrides into storage.
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
free_mem_start = _device_free_memory(target_device)
|
||||
from . import model_management
|
||||
@ -871,10 +869,9 @@ def ensure_module_materialized(
|
||||
continue
|
||||
if current is None:
|
||||
continue
|
||||
target_dtype = dtype_override or _get_future_dtype(module, name)
|
||||
if current.device.type != "meta" and current.device == target_device and (
|
||||
target_dtype is None or current.dtype == target_dtype
|
||||
):
|
||||
# Persistent tensors must remain in stored dtype.
|
||||
target_dtype = None
|
||||
if current.device.type != "meta" and current.device == target_device:
|
||||
CACHE.touch(module, name)
|
||||
continue
|
||||
meta_nbytes = _meta_nbytes(disk_ref.meta)
|
||||
@ -934,7 +931,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
||||
return
|
||||
input_dtype = _find_tensor_dtype(args, kwargs)
|
||||
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
|
||||
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
|
||||
dtype_override = None # persistent tensors stay in stored dtype; per-op casting only
|
||||
input_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||
if getattr(module, "comfy_patched_weights", False):
|
||||
target_device = input_device
|
||||
@ -1221,9 +1218,10 @@ def load_module_tensor(
|
||||
return None
|
||||
if current is None:
|
||||
return None
|
||||
target_dtype = dtype_override or _get_future_dtype(module, name)
|
||||
if dtype_override is not None:
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
# Persistent loads must not mix storage with dtype casting.
|
||||
if not temporary:
|
||||
dtype_override = None
|
||||
target_dtype = dtype_override
|
||||
if current.device.type != "meta":
|
||||
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
|
||||
from . import model_management
|
||||
@ -1316,9 +1314,7 @@ def _materialize_module_from_state_dict(
|
||||
metadata = getattr(lazy_state.state_dict, "_metadata", None)
|
||||
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
|
||||
refs = REGISTRY.get(module) or {}
|
||||
if dtype_override is not None:
|
||||
for name in refs.keys():
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
# Do not persist dtype overrides into storage.
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
keys = sorted(lazy_state.state_dict.keys())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user