mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 16:50:17 +08:00
Fix disk weight movement and pinned inflight tracking
This commit is contained in:
parent
c3eaea0429
commit
82e70aa3c2
@ -96,6 +96,7 @@ class CacheEntry:
|
||||
name: str
|
||||
size_bytes: int
|
||||
is_buffer: bool
|
||||
device_type: str
|
||||
|
||||
|
||||
class DiskWeightCache:
|
||||
@ -112,16 +113,25 @@ class DiskWeightCache:
|
||||
return (id(module), name)
|
||||
|
||||
def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool):
|
||||
if tensor.device.type != "cpu":
|
||||
if tensor.device.type == "meta":
|
||||
return
|
||||
size_bytes = tensor.numel() * tensor.element_size()
|
||||
key = self._entry_key(module, name)
|
||||
if key in self._entries:
|
||||
entry = self._entries.pop(key)
|
||||
self.current_bytes -= entry.size_bytes
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
module_ref = weakref.ref(module, self._drop_module_entries)
|
||||
self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer)
|
||||
self.current_bytes += size_bytes
|
||||
device_type = tensor.device.type
|
||||
self._entries[key] = CacheEntry(
|
||||
module_ref=module_ref,
|
||||
name=name,
|
||||
size_bytes=size_bytes,
|
||||
is_buffer=is_buffer,
|
||||
device_type=device_type,
|
||||
)
|
||||
if device_type == "cpu":
|
||||
self.current_bytes += size_bytes
|
||||
self._evict_if_needed()
|
||||
|
||||
def touch(self, module: torch.nn.Module, name: str):
|
||||
@ -133,9 +143,10 @@ class DiskWeightCache:
|
||||
def evict_bytes(self, bytes_to_free: int):
|
||||
freed = 0
|
||||
while self._entries and freed < bytes_to_free:
|
||||
_, entry = self._entries.popitem(last=False)
|
||||
entry = self.pop_lru(torch.device("cpu"))
|
||||
if entry is None:
|
||||
break
|
||||
freed += entry.size_bytes
|
||||
self.current_bytes -= entry.size_bytes
|
||||
module = entry.module_ref()
|
||||
if module is not None:
|
||||
_evict_module_weight(module, entry.name, entry.is_buffer)
|
||||
@ -148,8 +159,26 @@ class DiskWeightCache:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
entry = self._entries.pop(key)
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
|
||||
def remove_entry(self, module: torch.nn.Module, name: str):
|
||||
key = self._entry_key(module, name)
|
||||
entry = self._entries.pop(key, None)
|
||||
if entry is None:
|
||||
return
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
|
||||
def pop_lru(self, device: torch.device) -> Optional[CacheEntry]:
|
||||
for key, entry in self._entries.items():
|
||||
if entry.device_type == device.type:
|
||||
self._entries.pop(key)
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
return entry
|
||||
return None
|
||||
|
||||
def _drop_module_entries(self, module_ref: weakref.ReferenceType):
|
||||
to_remove = []
|
||||
for key, entry in self._entries.items():
|
||||
@ -157,12 +186,14 @@ class DiskWeightCache:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
entry = self._entries.pop(key)
|
||||
self.current_bytes -= entry.size_bytes
|
||||
if entry.device_type == "cpu":
|
||||
self.current_bytes -= entry.size_bytes
|
||||
|
||||
def _evict_if_needed(self):
|
||||
while self._entries and self.current_bytes > self.max_bytes:
|
||||
_, entry = self._entries.popitem(last=False)
|
||||
self.current_bytes -= entry.size_bytes
|
||||
entry = self.pop_lru(torch.device("cpu"))
|
||||
if entry is None:
|
||||
break
|
||||
module = entry.module_ref()
|
||||
if module is not None:
|
||||
_evict_module_weight(module, entry.name, entry.is_buffer)
|
||||
@ -282,7 +313,7 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
|
||||
REGISTRY.register(submodule, name, ref)
|
||||
if param.device.type == "cpu":
|
||||
if param.device.type != "meta":
|
||||
CACHE.record(submodule, name, param, is_buffer=False)
|
||||
for name, buf in submodule.named_buffers(recurse=False):
|
||||
key = f"{module_prefix}{name}" if module_prefix else name
|
||||
@ -290,7 +321,7 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
|
||||
REGISTRY.register(submodule, name, ref)
|
||||
if buf.device.type == "cpu":
|
||||
if buf.device.type != "meta":
|
||||
CACHE.record(submodule, name, buf, is_buffer=True)
|
||||
|
||||
|
||||
@ -354,6 +385,23 @@ def _meta_tensor(meta, dtype_override: Optional[torch.dtype] = None) -> torch.Te
|
||||
return torch.empty(shape, dtype=dtype, device="meta")
|
||||
|
||||
|
||||
def _attach_disk_identity(tensor: torch.Tensor, module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
tensor._disk_weights_module_ref = weakref.ref(module)
|
||||
tensor._disk_weights_name = name
|
||||
tensor._disk_weights_is_buffer = is_buffer
|
||||
|
||||
|
||||
def materialize_meta_tensor(tensor: torch.Tensor, target_device: torch.device, dtype_override: Optional[torch.dtype]):
|
||||
module_ref = getattr(tensor, "_disk_weights_module_ref", None)
|
||||
name = getattr(tensor, "_disk_weights_name", None)
|
||||
if module_ref is None or name is None:
|
||||
raise RuntimeError("Meta tensor missing disk weight identity")
|
||||
module = module_ref()
|
||||
if module is None:
|
||||
raise RuntimeError("Disk weight module reference expired")
|
||||
return load_module_tensor(module, name, target_device, dtype_override=dtype_override, temporary=False)
|
||||
|
||||
|
||||
def _state_dict_meta(state_dict: MutableMapping, key: str):
|
||||
if hasattr(state_dict, "meta"):
|
||||
return state_dict.meta(key)
|
||||
@ -466,7 +514,6 @@ def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_byte
|
||||
)
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
from . import model_management
|
||||
model_management._reap_pinned_inflight()
|
||||
model_management.free_memory(required_bytes + headroom_bytes, device)
|
||||
free_after = _device_free_memory(device)
|
||||
freed = max(0, free_after - free_before)
|
||||
@ -650,6 +697,7 @@ def register_lazy_modules(model: torch.nn.Module, state_dict):
|
||||
|
||||
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
from . import model_management
|
||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||
if lazy_state is not None:
|
||||
CACHE.remove_module(module)
|
||||
@ -657,6 +705,19 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
if refs:
|
||||
state = _get_materialization_state(module)
|
||||
for ref_name, disk_ref in refs.items():
|
||||
if ref_name in module._parameters:
|
||||
current = module._parameters[ref_name]
|
||||
elif ref_name in module._buffers:
|
||||
current = module._buffers[ref_name]
|
||||
else:
|
||||
current = None
|
||||
if (
|
||||
current is not None
|
||||
and current.device.type == "cpu"
|
||||
and current.data_ptr() in model_management.PINNED_MEMORY
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = _get_future_dtype(module, ref_name) or getattr(disk_ref.meta, "dtype", None)
|
||||
if shape is None or dtype is None:
|
||||
@ -664,8 +725,11 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||
if disk_ref.is_buffer:
|
||||
module._buffers[ref_name] = meta_tensor
|
||||
_attach_disk_identity(meta_tensor, module, ref_name, True)
|
||||
else:
|
||||
module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
param = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
module._parameters[ref_name] = param
|
||||
_attach_disk_identity(param, module, ref_name, False)
|
||||
nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if nbytes is not None:
|
||||
state.loaded_keys.discard(ref_name)
|
||||
@ -679,7 +743,19 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
ref = REGISTRY.get(module)
|
||||
if not ref or name not in ref:
|
||||
return
|
||||
CACHE.remove_entry(module, name)
|
||||
disk_ref = ref[name]
|
||||
if is_buffer:
|
||||
current = module._buffers.get(name)
|
||||
else:
|
||||
current = module._parameters.get(name)
|
||||
if (
|
||||
current is not None
|
||||
and current.device.type == "cpu"
|
||||
and current.data_ptr() in model_management.PINNED_MEMORY
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = _get_future_dtype(module, name) or getattr(disk_ref.meta, "dtype", None)
|
||||
if shape is None or dtype is None:
|
||||
@ -687,8 +763,11 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
||||
if is_buffer:
|
||||
module._buffers[name] = meta_tensor
|
||||
_attach_disk_identity(meta_tensor, module, name, True)
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
param = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
module._parameters[name] = param
|
||||
_attach_disk_identity(param, module, name, False)
|
||||
state = _get_materialization_state(module)
|
||||
nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if nbytes is not None:
|
||||
@ -777,6 +856,9 @@ def ensure_module_materialized(
|
||||
_set_future_dtype(module, name, dtype_override)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
free_mem_start = _device_free_memory(target_device)
|
||||
from . import model_management
|
||||
non_blocking = model_management.device_supports_non_blocking(target_device)
|
||||
offload_stream = model_management.get_offload_stream(target_device) if non_blocking else None
|
||||
for name in sorted(refs.keys()):
|
||||
disk_ref = refs[name]
|
||||
if name in module._parameters:
|
||||
@ -793,8 +875,7 @@ def ensure_module_materialized(
|
||||
if current.device.type != "meta" and current.device == target_device and (
|
||||
target_dtype is None or current.dtype == target_dtype
|
||||
):
|
||||
if current.device.type == "cpu":
|
||||
CACHE.touch(module, name)
|
||||
CACHE.touch(module, name)
|
||||
continue
|
||||
meta_nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if meta_nbytes is None:
|
||||
@ -803,7 +884,6 @@ def ensure_module_materialized(
|
||||
if target_device.type == "cpu":
|
||||
_ensure_free_memory(target_device, required_bytes, RAM_HEADROOM_BYTES)
|
||||
else:
|
||||
from . import model_management
|
||||
_ensure_free_memory(target_device, required_bytes, model_management.extra_reserved_memory())
|
||||
target_for_load = target_device
|
||||
if current.device.type == "meta":
|
||||
@ -813,16 +893,37 @@ def ensure_module_materialized(
|
||||
PIN_IF_CPU,
|
||||
dtype_override=target_dtype,
|
||||
)
|
||||
if tensor.device != target_for_load or (target_dtype is not None and tensor.dtype != target_dtype):
|
||||
tensor = model_management.cast_to(
|
||||
tensor,
|
||||
device=target_for_load,
|
||||
dtype=target_dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(target_for_load, offload_stream)
|
||||
else:
|
||||
if target_dtype is not None and current.dtype != target_dtype:
|
||||
tensor = current.to(device=target_for_load, dtype=target_dtype)
|
||||
else:
|
||||
tensor = current.to(device=target_for_load)
|
||||
if (
|
||||
current.device.type == "cpu"
|
||||
and current.data_ptr() in model_management.PINNED_MEMORY
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
tensor = model_management.cast_to(
|
||||
current,
|
||||
device=target_for_load,
|
||||
dtype=target_dtype if target_dtype is not None else current.dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(target_for_load, offload_stream)
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu":
|
||||
if tensor.device.type != "meta":
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight materialized")
|
||||
@ -834,10 +935,13 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs={}):
|
||||
input_dtype = _find_tensor_dtype(args, kwargs)
|
||||
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
|
||||
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
|
||||
if getattr(module, "comfy_cast_weights", False):
|
||||
input_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||
if getattr(module, "comfy_patched_weights", False):
|
||||
target_device = input_device
|
||||
elif getattr(module, "comfy_cast_weights", False):
|
||||
target_device = torch.device("cpu")
|
||||
else:
|
||||
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||
target_device = input_device
|
||||
ensure_module_materialized(
|
||||
module,
|
||||
target_device,
|
||||
@ -859,11 +963,73 @@ def evict_ram_cache(bytes_to_free: int):
|
||||
if bytes_to_free <= 0:
|
||||
return 0
|
||||
safetensors_stream._reap_pinned_inflight()
|
||||
from . import model_management
|
||||
model_management._reap_pinned_inflight()
|
||||
return CACHE.evict_bytes(bytes_to_free)
|
||||
|
||||
|
||||
def _move_cache_entry_to_cpu(entry: CacheEntry):
|
||||
module = entry.module_ref()
|
||||
if module is None:
|
||||
return
|
||||
if entry.is_buffer:
|
||||
current = module._buffers.get(entry.name)
|
||||
else:
|
||||
current = module._parameters.get(entry.name)
|
||||
if current is None or current.device.type == "meta":
|
||||
return
|
||||
from . import model_management
|
||||
non_blocking = model_management.device_supports_non_blocking(torch.device("cpu"))
|
||||
offload_stream = model_management.get_offload_stream(torch.device("cpu")) if non_blocking else None
|
||||
tensor = model_management.cast_to(
|
||||
current,
|
||||
device=torch.device("cpu"),
|
||||
dtype=current.dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(current.device, offload_stream)
|
||||
if entry.is_buffer:
|
||||
module._buffers[entry.name] = tensor
|
||||
else:
|
||||
module._parameters[entry.name] = torch.nn.Parameter(tensor, requires_grad=current.requires_grad)
|
||||
CACHE.record(module, entry.name, tensor, is_buffer=entry.is_buffer)
|
||||
|
||||
|
||||
def _evict_cpu_entry_to_meta(entry: CacheEntry):
|
||||
module = entry.module_ref()
|
||||
if module is None:
|
||||
return
|
||||
_evict_module_weight(module, entry.name, entry.is_buffer)
|
||||
CACHE.remove_entry(module, entry.name)
|
||||
|
||||
|
||||
def evict_for_budget(target_device: torch.device, required_bytes: int):
|
||||
if not disk_weights_enabled() or required_bytes <= 0:
|
||||
return
|
||||
from . import model_management
|
||||
free = model_management.get_free_memory(target_device)
|
||||
if free >= required_bytes:
|
||||
return
|
||||
cpu_device = torch.device("cpu")
|
||||
if target_device.type != "cpu":
|
||||
while free < required_bytes:
|
||||
entry = CACHE.pop_lru(target_device)
|
||||
if entry is None:
|
||||
break
|
||||
free_cpu = model_management.get_free_memory(cpu_device)
|
||||
if free_cpu < RAM_HEADROOM_BYTES:
|
||||
CACHE.evict_bytes(RAM_HEADROOM_BYTES - free_cpu)
|
||||
_move_cache_entry_to_cpu(entry)
|
||||
free = model_management.get_free_memory(target_device)
|
||||
else:
|
||||
while free < required_bytes:
|
||||
entry = CACHE.pop_lru(cpu_device)
|
||||
if entry is None:
|
||||
break
|
||||
_evict_cpu_entry_to_meta(entry)
|
||||
free = model_management.get_free_memory(target_device)
|
||||
|
||||
|
||||
def materialize_module_tree(module: torch.nn.Module, target_device: torch.device):
|
||||
if not disk_weights_enabled():
|
||||
return
|
||||
@ -901,8 +1067,34 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
return None
|
||||
|
||||
|
||||
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
|
||||
ensure_module_materialized(module, device_to, dtype_override=dtype_override)
|
||||
def move_module_tensors(
|
||||
module: torch.nn.Module,
|
||||
device_to: torch.device,
|
||||
dtype_override: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
):
|
||||
from . import model_management
|
||||
offload_stream = None
|
||||
if non_blocking and model_management.device_supports_non_blocking(device_to):
|
||||
offload_stream = model_management.get_offload_stream(device_to)
|
||||
|
||||
def apply_fn(tensor):
|
||||
if tensor is None or tensor.device.type == "meta":
|
||||
return tensor
|
||||
target_dtype = dtype_override or tensor.dtype
|
||||
if tensor.device == device_to and tensor.dtype == target_dtype:
|
||||
return tensor
|
||||
return model_management.cast_to(
|
||||
tensor,
|
||||
device=device_to,
|
||||
dtype=target_dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
|
||||
module._apply(apply_fn)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(device_to, offload_stream)
|
||||
return module
|
||||
|
||||
|
||||
@ -946,41 +1138,34 @@ def module_to(
|
||||
target_device = device or arg_device
|
||||
if target_device is None:
|
||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||
dtype_override = dtype or arg_dtype
|
||||
if target_device.type == "meta":
|
||||
cpu_device = torch.device("cpu")
|
||||
for submodule in module.modules():
|
||||
offload_module_weights(submodule)
|
||||
move_module_tensors(
|
||||
submodule,
|
||||
cpu_device,
|
||||
dtype_override=dtype_override,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
return module
|
||||
if not allow_materialize:
|
||||
move_module_tensors(
|
||||
module,
|
||||
target_device,
|
||||
dtype_override=dtype_override,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
return module
|
||||
dtype_override = dtype or arg_dtype
|
||||
to_kwargs = {}
|
||||
if non_blocking:
|
||||
to_kwargs["non_blocking"] = non_blocking
|
||||
if memory_format is not None:
|
||||
to_kwargs["memory_format"] = memory_format
|
||||
for submodule in module.modules():
|
||||
ensure_module_materialized(submodule, target_device, dtype_override=dtype_override)
|
||||
refs = REGISTRY.get(submodule) or {}
|
||||
for name, param in submodule.named_parameters(recurse=False):
|
||||
if name in refs:
|
||||
continue
|
||||
if param is None or param.device.type == "meta":
|
||||
continue
|
||||
if param.device != target_device or (dtype_override is not None and param.dtype != dtype_override):
|
||||
if dtype_override is not None:
|
||||
tensor = param.to(device=target_device, dtype=dtype_override, **to_kwargs)
|
||||
else:
|
||||
tensor = param.to(device=target_device, **to_kwargs)
|
||||
submodule._parameters[name] = torch.nn.Parameter(tensor, requires_grad=param.requires_grad)
|
||||
for name, buf in submodule.named_buffers(recurse=False):
|
||||
if name in refs:
|
||||
continue
|
||||
if buf is None or buf.device.type == "meta":
|
||||
continue
|
||||
if buf.device != target_device or (dtype_override is not None and buf.dtype != dtype_override):
|
||||
if dtype_override is not None:
|
||||
tensor = buf.to(device=target_device, dtype=dtype_override, **to_kwargs)
|
||||
else:
|
||||
tensor = buf.to(device=target_device, **to_kwargs)
|
||||
submodule._buffers[name] = tensor
|
||||
move_module_tensors(
|
||||
module,
|
||||
target_device,
|
||||
dtype_override=dtype_override,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
return module
|
||||
base_kwargs = dict(kwargs)
|
||||
if device is not None and arg_device is None:
|
||||
@ -1024,11 +1209,24 @@ def load_module_tensor(
|
||||
from . import model_management
|
||||
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
|
||||
_ensure_free_memory(device, _tensor_nbytes(current), headroom)
|
||||
if target_dtype is not None and current.dtype != target_dtype:
|
||||
tensor = current.to(device=device, dtype=target_dtype)
|
||||
else:
|
||||
tensor = current.to(device=device)
|
||||
non_blocking = model_management.device_supports_non_blocking(device)
|
||||
offload_stream = model_management.get_offload_stream(device) if non_blocking else None
|
||||
tensor = model_management.cast_to(
|
||||
current,
|
||||
device=device,
|
||||
dtype=target_dtype if target_dtype is not None else current.dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(device, offload_stream)
|
||||
if not temporary:
|
||||
if (
|
||||
current.device.type == "cpu"
|
||||
and current.data_ptr() in model_management.PINNED_MEMORY
|
||||
):
|
||||
model_management.wait_for_pinned_tensor(current)
|
||||
model_management.unpin_memory(current)
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
@ -1044,15 +1242,26 @@ def load_module_tensor(
|
||||
from . import model_management
|
||||
headroom = RAM_HEADROOM_BYTES if device.type == "cpu" else model_management.extra_reserved_memory()
|
||||
_ensure_free_memory(device, required_bytes, headroom)
|
||||
|
||||
non_blocking = model_management.device_supports_non_blocking(device)
|
||||
offload_stream = model_management.get_offload_stream(device) if non_blocking else None
|
||||
tensor = disk_ref.load(device, ALLOW_GDS, PIN_IF_CPU, dtype_override=target_dtype)
|
||||
if tensor.device != device or (target_dtype is not None and tensor.dtype != target_dtype):
|
||||
tensor = model_management.cast_to(
|
||||
tensor,
|
||||
device=device,
|
||||
dtype=target_dtype if target_dtype is not None else tensor.dtype,
|
||||
non_blocking=non_blocking,
|
||||
stream=offload_stream,
|
||||
)
|
||||
if non_blocking and offload_stream is not None:
|
||||
model_management.sync_stream(device, offload_stream)
|
||||
if temporary:
|
||||
return tensor
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu" and record_cache:
|
||||
if tensor.device.type != "meta" and record_cache:
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
@ -1068,8 +1277,11 @@ def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_
|
||||
attr = parts[-1]
|
||||
if is_buffer:
|
||||
module._buffers[attr] = tensor
|
||||
return tensor
|
||||
else:
|
||||
module._parameters[attr] = torch.nn.Parameter(tensor, requires_grad=requires_grad)
|
||||
param = torch.nn.Parameter(tensor, requires_grad=requires_grad)
|
||||
module._parameters[attr] = param
|
||||
return param
|
||||
|
||||
|
||||
def _materialize_module_from_state_dict(
|
||||
@ -1142,14 +1354,25 @@ def _materialize_module_from_state_dict(
|
||||
module.factory_kwargs["device"] = factory_device
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
for name, disk_ref in refs.items():
|
||||
if name in module._parameters:
|
||||
tensor = module._parameters[name]
|
||||
is_buffer = False
|
||||
elif name in module._buffers:
|
||||
tensor = module._buffers[name]
|
||||
is_buffer = True
|
||||
else:
|
||||
continue
|
||||
if tensor is not None and tensor.device.type == "meta":
|
||||
_attach_disk_identity(tensor, module, name, is_buffer)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
lazy_state.loaded = True
|
||||
_log_materialization(module, target_device, free_mem_start, refs, state, "Disk weight streamed")
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.device.type == "cpu":
|
||||
if param.device.type != "meta":
|
||||
CACHE.record(module, name, param, is_buffer=False)
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if buf is not None and buf.device.type == "cpu":
|
||||
if buf is not None and buf.device.type != "meta":
|
||||
CACHE.record(module, name, buf, is_buffer=True)
|
||||
|
||||
|
||||
@ -1178,14 +1401,16 @@ def lazy_load_state_dict(model: torch.nn.Module, state_dict, strict: bool = Fals
|
||||
continue
|
||||
meta = state_dict.meta(name)
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||
_replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
|
||||
stored = _replace_tensor(model, name, meta_tensor, is_buffer=False, requires_grad=param.requires_grad)
|
||||
_attach_disk_identity(stored, model, name, False)
|
||||
|
||||
for name, buf in model.named_buffers(recurse=True):
|
||||
if buf is None or name not in state_keys:
|
||||
continue
|
||||
meta = state_dict.meta(name)
|
||||
meta_tensor = torch.empty(meta.shape, dtype=meta.dtype, device="meta")
|
||||
_replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
|
||||
stored = _replace_tensor(model, name, meta_tensor, is_buffer=True, requires_grad=False)
|
||||
_attach_disk_identity(stored, model, name, True)
|
||||
|
||||
register_module_weights(model, state_dict)
|
||||
register_lazy_modules(model, state_dict)
|
||||
|
||||
@ -609,7 +609,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
free_before,
|
||||
device,
|
||||
)
|
||||
_reap_pinned_inflight()
|
||||
comfy.disk_weights.evict_for_budget(device, memory_required)
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -1159,6 +1159,10 @@ def sync_stream(device, stream):
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if comfy.disk_weights.disk_weights_enabled() and weight.device.type == "meta":
|
||||
target_device = device if device is not None else torch.device("cpu")
|
||||
target_dtype = dtype if dtype is not None else weight.dtype
|
||||
weight = comfy.disk_weights.materialize_meta_tensor(weight, target_device, target_dtype)
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
@ -1171,7 +1175,6 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
@ -1182,15 +1185,13 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
else:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
if (
|
||||
non_blocking
|
||||
and is_device_cuda(device)
|
||||
and is_device_cpu(weight.device)
|
||||
and weight.is_pinned()
|
||||
):
|
||||
record_stream = stream if stream is not None else current_stream(device)
|
||||
if non_blocking and (is_device_cuda(device) or is_device_cuda(weight.device)):
|
||||
record_stream = stream if stream is not None else current_stream(device if is_device_cuda(device) else weight.device)
|
||||
if record_stream is not None:
|
||||
_track_pinned_inflight(record_stream, weight)
|
||||
if is_device_cpu(weight.device) and weight.is_pinned():
|
||||
_record_pinned_event(weight.data_ptr(), record_stream)
|
||||
if is_device_cpu(r.device) and r.is_pinned():
|
||||
_record_pinned_event(r.data_ptr(), record_stream)
|
||||
return r
|
||||
|
||||
def cast_to_device(tensor, device, dtype, copy=False):
|
||||
@ -1201,8 +1202,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
PINNED_MEMORY = {}
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
MAX_PINNED_MEMORY = -1
|
||||
PINNED_INFLIGHT = collections.deque()
|
||||
DEFERRED_UNPIN = collections.deque()
|
||||
PINNED_IN_FLIGHT = {}
|
||||
if not args.disable_pinned_memory:
|
||||
if is_nvidia() or is_amd():
|
||||
if WINDOWS:
|
||||
@ -1270,17 +1270,17 @@ def pin_memory(tensor):
|
||||
|
||||
return False
|
||||
|
||||
def _track_pinned_inflight(stream, tensor):
|
||||
def _record_pinned_event(ptr, stream):
|
||||
events = PINNED_IN_FLIGHT.setdefault(ptr, [])
|
||||
event = torch.cuda.Event()
|
||||
event.record(stream)
|
||||
PINNED_INFLIGHT.append((event, tensor))
|
||||
events.append(event)
|
||||
|
||||
def _tensor_inflight(tensor):
|
||||
def wait_for_pinned_tensor(tensor):
|
||||
ptr = tensor.data_ptr()
|
||||
for _, inflight_tensor in PINNED_INFLIGHT:
|
||||
if inflight_tensor.data_ptr() == ptr:
|
||||
return True
|
||||
return False
|
||||
events = PINNED_IN_FLIGHT.pop(ptr, [])
|
||||
for event in events:
|
||||
event.synchronize()
|
||||
|
||||
def _unpin_memory_now(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
@ -1313,32 +1313,6 @@ def _unpin_memory_now(tensor):
|
||||
|
||||
return False
|
||||
|
||||
def _retry_deferred_unpins():
|
||||
if not DEFERRED_UNPIN:
|
||||
return
|
||||
remaining = collections.deque()
|
||||
while DEFERRED_UNPIN:
|
||||
tensor = DEFERRED_UNPIN.popleft()
|
||||
if _tensor_inflight(tensor):
|
||||
remaining.append(tensor)
|
||||
continue
|
||||
if not _unpin_memory_now(tensor):
|
||||
remaining.append(tensor)
|
||||
DEFERRED_UNPIN.extend(remaining)
|
||||
|
||||
def _reap_pinned_inflight():
|
||||
if not PINNED_INFLIGHT:
|
||||
_retry_deferred_unpins()
|
||||
return
|
||||
remaining = collections.deque()
|
||||
while PINNED_INFLIGHT:
|
||||
event, tensor = PINNED_INFLIGHT.popleft()
|
||||
if event.query():
|
||||
continue
|
||||
remaining.append((event, tensor))
|
||||
PINNED_INFLIGHT.extend(remaining)
|
||||
_retry_deferred_unpins()
|
||||
|
||||
def unpin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
@ -1347,10 +1321,7 @@ def unpin_memory(tensor):
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
_reap_pinned_inflight()
|
||||
if _tensor_inflight(tensor):
|
||||
DEFERRED_UNPIN.append(tensor)
|
||||
return False
|
||||
wait_for_pinned_tensor(tensor)
|
||||
return _unpin_memory_now(tensor)
|
||||
|
||||
def sage_attention_enabled():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user