mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
276 lines
9.3 KiB
Python
276 lines
9.3 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Comfy
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import weakref
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
|
|
ALLOW_GDS = False
|
|
PIN_IF_CPU = False
|
|
DISK_WEIGHTS_ENABLED = False
|
|
|
|
|
|
@dataclass
|
|
class DiskTensorRef:
|
|
state_dict: object
|
|
key: str
|
|
meta: object
|
|
requires_grad: bool
|
|
is_buffer: bool
|
|
|
|
def load(self, device: torch.device, allow_gds: bool, pin_if_cpu: bool) -> torch.Tensor:
|
|
dtype = getattr(self.meta, "dtype", None)
|
|
if hasattr(self.state_dict, "get_tensor"):
|
|
return self.state_dict.get_tensor(
|
|
self.key,
|
|
device=device,
|
|
dtype=dtype,
|
|
allow_gds=allow_gds,
|
|
pin_if_cpu=pin_if_cpu,
|
|
)
|
|
tensor = self.state_dict[self.key]
|
|
if device is not None and tensor.device != device:
|
|
tensor = tensor.to(device=device)
|
|
if dtype is not None and tensor.dtype != dtype:
|
|
tensor = tensor.to(dtype=dtype)
|
|
return tensor
|
|
|
|
|
|
class DiskWeightRegistry:
|
|
def __init__(self):
|
|
self._registry = weakref.WeakKeyDictionary()
|
|
|
|
def register(self, module: torch.nn.Module, name: str, ref: DiskTensorRef):
|
|
module_refs = self._registry.setdefault(module, {})
|
|
module_refs[name] = ref
|
|
|
|
def get(self, module: torch.nn.Module) -> Optional[Dict[str, DiskTensorRef]]:
|
|
return self._registry.get(module)
|
|
|
|
def has(self, module: torch.nn.Module) -> bool:
|
|
return module in self._registry
|
|
|
|
|
|
@dataclass
|
|
class CacheEntry:
|
|
module_ref: weakref.ReferenceType
|
|
name: str
|
|
size_bytes: int
|
|
is_buffer: bool
|
|
|
|
|
|
class DiskWeightCache:
|
|
def __init__(self, max_bytes: int = 0):
|
|
self.max_bytes = max_bytes
|
|
self.current_bytes = 0
|
|
self._entries: "collections.OrderedDict[tuple[int, str], CacheEntry]" = collections.OrderedDict()
|
|
|
|
def set_limit(self, max_bytes: int):
|
|
self.max_bytes = max_bytes
|
|
self._evict_if_needed()
|
|
|
|
def _entry_key(self, module: torch.nn.Module, name: str) -> tuple[int, str]:
|
|
return (id(module), name)
|
|
|
|
def record(self, module: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool):
|
|
if tensor.device.type != "cpu":
|
|
return
|
|
size_bytes = tensor.numel() * tensor.element_size()
|
|
key = self._entry_key(module, name)
|
|
if key in self._entries:
|
|
entry = self._entries.pop(key)
|
|
self.current_bytes -= entry.size_bytes
|
|
module_ref = weakref.ref(module, self._drop_module_entries)
|
|
self._entries[key] = CacheEntry(module_ref=module_ref, name=name, size_bytes=size_bytes, is_buffer=is_buffer)
|
|
self.current_bytes += size_bytes
|
|
self._evict_if_needed()
|
|
|
|
def touch(self, module: torch.nn.Module, name: str):
|
|
key = self._entry_key(module, name)
|
|
if key in self._entries:
|
|
entry = self._entries.pop(key)
|
|
self._entries[key] = entry
|
|
|
|
def evict_bytes(self, bytes_to_free: int):
|
|
freed = 0
|
|
while self._entries and freed < bytes_to_free:
|
|
_, entry = self._entries.popitem(last=False)
|
|
freed += entry.size_bytes
|
|
self.current_bytes -= entry.size_bytes
|
|
module = entry.module_ref()
|
|
if module is not None:
|
|
_evict_module_weight(module, entry.name, entry.is_buffer)
|
|
return freed
|
|
|
|
def _drop_module_entries(self, module_ref: weakref.ReferenceType):
|
|
to_remove = []
|
|
for key, entry in self._entries.items():
|
|
if entry.module_ref is module_ref:
|
|
to_remove.append(key)
|
|
for key in to_remove:
|
|
entry = self._entries.pop(key)
|
|
self.current_bytes -= entry.size_bytes
|
|
|
|
def _evict_if_needed(self):
|
|
while self._entries and self.current_bytes > self.max_bytes:
|
|
_, entry = self._entries.popitem(last=False)
|
|
self.current_bytes -= entry.size_bytes
|
|
module = entry.module_ref()
|
|
if module is not None:
|
|
_evict_module_weight(module, entry.name, entry.is_buffer)
|
|
|
|
|
|
REGISTRY = DiskWeightRegistry()
|
|
CACHE = DiskWeightCache(0)
|
|
|
|
|
|
def configure(cache_bytes: int, allow_gds: bool, pin_if_cpu: bool, enabled: bool = True):
|
|
global ALLOW_GDS, PIN_IF_CPU, DISK_WEIGHTS_ENABLED
|
|
ALLOW_GDS = allow_gds
|
|
PIN_IF_CPU = pin_if_cpu
|
|
DISK_WEIGHTS_ENABLED = enabled
|
|
CACHE.set_limit(cache_bytes if enabled else 0)
|
|
if not enabled:
|
|
CACHE._entries.clear()
|
|
CACHE.current_bytes = 0
|
|
|
|
|
|
def disk_weights_enabled() -> bool:
|
|
return DISK_WEIGHTS_ENABLED
|
|
|
|
|
|
def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = ""):
|
|
if not disk_weights_enabled():
|
|
return
|
|
if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"):
|
|
return
|
|
for name, param in module.named_parameters(recurse=True):
|
|
key = f"{prefix}{name}" if prefix else name
|
|
if key in state_dict:
|
|
meta = state_dict.meta(key)
|
|
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
|
|
REGISTRY.register(module, name, ref)
|
|
if param.device.type == "cpu":
|
|
CACHE.record(module, name, param, is_buffer=False)
|
|
for name, buf in module.named_buffers(recurse=True):
|
|
key = f"{prefix}{name}" if prefix else name
|
|
if key in state_dict and buf is not None:
|
|
meta = state_dict.meta(key)
|
|
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
|
|
REGISTRY.register(module, name, ref)
|
|
if buf.device.type == "cpu":
|
|
CACHE.record(module, name, buf, is_buffer=True)
|
|
|
|
|
|
def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
|
ref = REGISTRY.get(module)
|
|
if not ref or name not in ref:
|
|
return
|
|
disk_ref = ref[name]
|
|
shape = getattr(disk_ref.meta, "shape", None)
|
|
dtype = getattr(disk_ref.meta, "dtype", None)
|
|
if shape is None or dtype is None:
|
|
return
|
|
meta_tensor = torch.empty(shape, dtype=dtype, device="meta")
|
|
if is_buffer:
|
|
module._buffers[name] = meta_tensor
|
|
else:
|
|
module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
|
|
|
|
|
def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
|
|
def check(obj):
|
|
if torch.is_tensor(obj):
|
|
return obj.device
|
|
if isinstance(obj, (list, tuple)):
|
|
for item in obj:
|
|
dev = check(item)
|
|
if dev is not None:
|
|
return dev
|
|
if isinstance(obj, dict):
|
|
for item in obj.values():
|
|
dev = check(item)
|
|
if dev is not None:
|
|
return dev
|
|
return None
|
|
|
|
dev = check(args)
|
|
if dev is not None:
|
|
return dev
|
|
return check(kwargs)
|
|
|
|
|
|
def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
|
|
refs = REGISTRY.get(module)
|
|
if not refs:
|
|
return
|
|
for name, disk_ref in refs.items():
|
|
if name in module._parameters:
|
|
current = module._parameters[name]
|
|
is_buffer = False
|
|
elif name in module._buffers:
|
|
current = module._buffers[name]
|
|
is_buffer = True
|
|
else:
|
|
continue
|
|
if current is None:
|
|
continue
|
|
if current.device.type != "meta":
|
|
if current.device.type == "cpu":
|
|
CACHE.touch(module, name)
|
|
continue
|
|
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
|
|
if is_buffer:
|
|
module._buffers[name] = tensor
|
|
else:
|
|
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
|
if tensor.device.type == "cpu":
|
|
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
|
|
|
|
|
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
|
if not REGISTRY.has(module):
|
|
return
|
|
if getattr(module, "comfy_cast_weights", False):
|
|
target_device = torch.device("cpu")
|
|
else:
|
|
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
|
ensure_module_materialized(module, target_device)
|
|
|
|
|
|
def attach_disk_weight_hooks(model: torch.nn.Module):
|
|
if not disk_weights_enabled():
|
|
return
|
|
for module in model.modules():
|
|
if getattr(module, "_disk_weight_hook_attached", False):
|
|
continue
|
|
module.register_forward_pre_hook(disk_weight_pre_hook)
|
|
module._disk_weight_hook_attached = True
|
|
|
|
|
|
def evict_ram_cache(bytes_to_free: int):
|
|
if bytes_to_free <= 0:
|
|
return 0
|
|
return CACHE.evict_bytes(bytes_to_free)
|