mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-01 09:10:16 +08:00
Fix weight casting double allocation
This commit is contained in:
parent
91809e83ff
commit
fcbd22b514
@ -719,7 +719,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
|||||||
model_management.wait_for_pinned_tensor(current)
|
model_management.wait_for_pinned_tensor(current)
|
||||||
model_management.unpin_memory(current)
|
model_management.unpin_memory(current)
|
||||||
shape = getattr(disk_ref.meta, "shape", None)
|
shape = getattr(disk_ref.meta, "shape", None)
|
||||||
dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None)
|
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||||
if shape is None or dtype is None:
|
if shape is None or dtype is None:
|
||||||
continue
|
continue
|
||||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||||
@ -757,7 +757,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
|||||||
model_management.wait_for_pinned_tensor(current)
|
model_management.wait_for_pinned_tensor(current)
|
||||||
model_management.unpin_memory(current)
|
model_management.unpin_memory(current)
|
||||||
shape = getattr(disk_ref.meta, "shape", None)
|
shape = getattr(disk_ref.meta, "shape", None)
|
||||||
dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None)
|
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||||
if shape is None or dtype is None:
|
if shape is None or dtype is None:
|
||||||
return
|
return
|
||||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||||
@ -851,9 +851,7 @@ def ensure_module_materialized(
|
|||||||
if not refs:
|
if not refs:
|
||||||
return
|
return
|
||||||
state = _get_materialization_state(module)
|
state = _get_materialization_state(module)
|
||||||
if dtype_override is not None:
|
# Do not persist dtype overrides into storage.
|
||||||
for name in refs.keys():
|
|
||||||
_set_future_dtype(module, name, dtype_override)
|
|
||||||
_rebuild_materialization_state(module, refs, state)
|
_rebuild_materialization_state(module, refs, state)
|
||||||
free_mem_start = _device_free_memory(target_device)
|
free_mem_start = _device_free_memory(target_device)
|
||||||
from . import model_management
|
from . import model_management
|
||||||
@ -871,10 +869,9 @@ def ensure_module_materialized(
|
|||||||
continue
|
continue
|
||||||
if current is None:
|
if current is None:
|
||||||
continue
|
continue
|
||||||
target_dtype = dtype_override or _get_future_dtype(module, name)
|
# Persistent tensors must remain in stored dtype.
|
||||||
if current.device.type != "meta" and current.device == target_device and (
|
target_dtype = None
|
||||||
target_dtype is None or current.dtype == target_dtype
|
if current.device.type != "meta" and current.device == target_device:
|
||||||
):
|
|
||||||
CACHE.touch(module, name)
|
CACHE.touch(module, name)
|
||||||
continue
|
continue
|
||||||
meta_nbytes = _meta_nbytes(disk_ref.meta)
|
meta_nbytes = _meta_nbytes(disk_ref.meta)
|
||||||
@ -934,7 +931,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
|||||||
return
|
return
|
||||||
input_dtype = _find_tensor_dtype(args, kwargs)
|
input_dtype = _find_tensor_dtype(args, kwargs)
|
||||||
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
|
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
|
||||||
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
|
dtype_override = None # persistent tensors stay in stored dtype; per-op casting only
|
||||||
input_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
input_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||||
if getattr(module, "comfy_patched_weights", False):
|
if getattr(module, "comfy_patched_weights", False):
|
||||||
target_device = input_device
|
target_device = input_device
|
||||||
@ -1221,9 +1218,10 @@ def load_module_tensor(
|
|||||||
return None
|
return None
|
||||||
if current is None:
|
if current is None:
|
||||||
return None
|
return None
|
||||||
target_dtype = dtype_override or _get_future_dtype(module, name)
|
# Persistent loads must not mix storage with dtype casting.
|
||||||
if dtype_override is not None:
|
if not temporary:
|
||||||
_set_future_dtype(module, name, dtype_override)
|
dtype_override = None
|
||||||
|
target_dtype = dtype_override
|
||||||
if current.device.type != "meta":
|
if current.device.type != "meta":
|
||||||
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
|
if current.device != device or (target_dtype is not None and current.dtype != target_dtype):
|
||||||
from . import model_management
|
from . import model_management
|
||||||
@ -1316,9 +1314,7 @@ def _materialize_module_from_state_dict(
|
|||||||
metadata = getattr(lazy_state.state_dict, "_metadata", None)
|
metadata = getattr(lazy_state.state_dict, "_metadata", None)
|
||||||
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
|
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
|
||||||
refs = REGISTRY.get(module) or {}
|
refs = REGISTRY.get(module) or {}
|
||||||
if dtype_override is not None:
|
# Do not persist dtype overrides into storage.
|
||||||
for name in refs.keys():
|
|
||||||
_set_future_dtype(module, name, dtype_override)
|
|
||||||
state = _get_materialization_state(module)
|
state = _get_materialization_state(module)
|
||||||
_rebuild_materialization_state(module, refs, state)
|
_rebuild_materialization_state(module, refs, state)
|
||||||
keys = sorted(lazy_state.state_dict.keys())
|
keys = sorted(lazy_state.state_dict.keys())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user