mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-17 09:40:50 +08:00
Implement partial disk weight materialization
This commit is contained in:
parent
5f2188e31b
commit
45a77073ac
@ -20,8 +20,9 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, MutableMapping, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, MutableMapping, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
@ -33,6 +34,8 @@ PIN_IF_CPU = False
|
||||
DISK_WEIGHTS_ENABLED = False
|
||||
BASE_LOAD_FROM_STATE_DICT = torch.nn.Module._load_from_state_dict
|
||||
LAZY_MODULE_STATE = weakref.WeakKeyDictionary()
|
||||
DISK_MATERIALIZATION_STATE = weakref.WeakKeyDictionary()
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -178,22 +181,24 @@ def register_module_weights(module: torch.nn.Module, state_dict, prefix: str = "
|
||||
return
|
||||
if not hasattr(state_dict, "meta") or not hasattr(state_dict, "get_tensor"):
|
||||
return
|
||||
for name, param in module.named_parameters(recurse=True):
|
||||
key = f"{prefix}{name}" if prefix else name
|
||||
if key in state_dict:
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
|
||||
REGISTRY.register(module, name, ref)
|
||||
if param.device.type == "cpu":
|
||||
CACHE.record(module, name, param, is_buffer=False)
|
||||
for name, buf in module.named_buffers(recurse=True):
|
||||
key = f"{prefix}{name}" if prefix else name
|
||||
if key in state_dict and buf is not None:
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
|
||||
REGISTRY.register(module, name, ref)
|
||||
if buf.device.type == "cpu":
|
||||
CACHE.record(module, name, buf, is_buffer=True)
|
||||
for module_name, submodule in module.named_modules():
|
||||
module_prefix = f"{prefix}{module_name}." if module_name else prefix
|
||||
for name, param in submodule.named_parameters(recurse=False):
|
||||
key = f"{module_prefix}{name}" if module_prefix else name
|
||||
if key in state_dict:
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=param.requires_grad, is_buffer=False)
|
||||
REGISTRY.register(submodule, name, ref)
|
||||
if param.device.type == "cpu":
|
||||
CACHE.record(submodule, name, param, is_buffer=False)
|
||||
for name, buf in submodule.named_buffers(recurse=False):
|
||||
key = f"{module_prefix}{name}" if module_prefix else name
|
||||
if key in state_dict and buf is not None:
|
||||
meta = state_dict.meta(key)
|
||||
ref = DiskTensorRef(state_dict=state_dict, key=key, meta=meta, requires_grad=False, is_buffer=True)
|
||||
REGISTRY.register(submodule, name, ref)
|
||||
if buf.device.type == "cpu":
|
||||
CACHE.record(submodule, name, buf, is_buffer=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -203,6 +208,245 @@ class LazyModuleState:
|
||||
loaded: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiskMaterializationState:
|
||||
loaded_keys: Set[str] = field(default_factory=set)
|
||||
deferred_keys: Set[str] = field(default_factory=set)
|
||||
loaded_bytes: int = 0
|
||||
deferred_bytes: int = 0
|
||||
|
||||
|
||||
def _get_materialization_state(module: torch.nn.Module) -> DiskMaterializationState:
|
||||
state = DISK_MATERIALIZATION_STATE.get(module)
|
||||
if state is None:
|
||||
state = DiskMaterializationState()
|
||||
DISK_MATERIALIZATION_STATE[module] = state
|
||||
return state
|
||||
|
||||
|
||||
def _update_disk_state_attrs(module: torch.nn.Module, state: DiskMaterializationState):
|
||||
module.disk_loaded_weight_memory = state.loaded_bytes
|
||||
module.disk_offload_buffer_memory = state.deferred_bytes
|
||||
|
||||
|
||||
def _tensor_nbytes(tensor: torch.Tensor) -> int:
|
||||
return tensor.numel() * tensor.element_size()
|
||||
|
||||
|
||||
def _meta_nbytes(meta) -> Optional[int]:
|
||||
return getattr(meta, "nbytes", None)
|
||||
|
||||
|
||||
def _meta_tensor(meta, dtype_override: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
dtype = dtype_override or getattr(meta, "dtype", None)
|
||||
shape = getattr(meta, "shape", None)
|
||||
if dtype is None or shape is None:
|
||||
raise KeyError("Missing metadata for meta tensor")
|
||||
return torch.empty(shape, dtype=dtype, device="meta")
|
||||
|
||||
|
||||
def _state_dict_meta(state_dict: MutableMapping, key: str):
|
||||
if hasattr(state_dict, "meta"):
|
||||
return state_dict.meta(key)
|
||||
if hasattr(state_dict, "get_tensor"):
|
||||
t = state_dict.get_tensor(key, device=torch.device("meta"))
|
||||
else:
|
||||
t = state_dict[key]
|
||||
numel = t.numel()
|
||||
return SimpleNamespace(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=numel,
|
||||
nbytes=numel * t.element_size(),
|
||||
)
|
||||
|
||||
|
||||
def _rebuild_materialization_state(module: torch.nn.Module, refs: Dict[str, DiskTensorRef], state: DiskMaterializationState):
|
||||
state.loaded_keys.clear()
|
||||
state.deferred_keys.clear()
|
||||
state.loaded_bytes = 0
|
||||
state.deferred_bytes = 0
|
||||
for name, ref in refs.items():
|
||||
if name in module._parameters:
|
||||
tensor = module._parameters[name]
|
||||
elif name in module._buffers:
|
||||
tensor = module._buffers[name]
|
||||
else:
|
||||
continue
|
||||
if tensor is None:
|
||||
continue
|
||||
nbytes = _meta_nbytes(ref.meta) or _tensor_nbytes(tensor)
|
||||
if tensor.device.type == "meta":
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += nbytes
|
||||
else:
|
||||
state.loaded_keys.add(name)
|
||||
state.loaded_bytes += nbytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
|
||||
|
||||
def _device_free_memory(device: torch.device) -> int:
|
||||
from . import model_management
|
||||
return int(model_management.get_free_memory(device))
|
||||
|
||||
|
||||
def _evict_ram_for_budget(required_bytes: int) -> int:
|
||||
if required_bytes <= 0:
|
||||
return 0
|
||||
return evict_ram_cache(required_bytes)
|
||||
|
||||
|
||||
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
|
||||
free_mem = _device_free_memory(device)
|
||||
if device.type == "cpu" and free_mem < required_bytes:
|
||||
_evict_ram_for_budget(required_bytes - free_mem)
|
||||
free_mem = _device_free_memory(device)
|
||||
return free_mem
|
||||
|
||||
|
||||
def _choose_alternate_device(device: torch.device) -> Optional[torch.device]:
|
||||
from . import model_management
|
||||
if device.type == "cpu":
|
||||
alt = model_management.get_torch_device()
|
||||
if alt.type != "cpu":
|
||||
return alt
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
return None
|
||||
|
||||
|
||||
class _BudgetedStateDict(MutableMapping):
|
||||
is_stream_state_dict = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base: MutableMapping,
|
||||
allowed_keys: Set[str],
|
||||
device: torch.device,
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
overrides: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
self._base = base
|
||||
self._allowed_keys = allowed_keys
|
||||
self._device = device
|
||||
self._allow_gds = allow_gds
|
||||
self._pin_if_cpu = pin_if_cpu
|
||||
self._overrides = overrides or {}
|
||||
self._deleted: Set[str] = set()
|
||||
|
||||
def _get_meta(self, key: str):
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
return safetensors_stream.TensorMeta(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=t.numel(),
|
||||
nbytes=_tensor_nbytes(t),
|
||||
data_offsets=(0, _tensor_nbytes(t)),
|
||||
filename="<override>",
|
||||
fst_dtype=None,
|
||||
strides=tuple(t.stride()),
|
||||
)
|
||||
if hasattr(self._base, "meta"):
|
||||
return self._base.meta(key)
|
||||
if hasattr(self._base, "get_tensor"):
|
||||
t = self._base.get_tensor(key, device=torch.device("meta"))
|
||||
else:
|
||||
t = self._base[key]
|
||||
return safetensors_stream.TensorMeta(
|
||||
dtype=t.dtype,
|
||||
shape=tuple(t.shape),
|
||||
numel=t.numel(),
|
||||
nbytes=_tensor_nbytes(t),
|
||||
data_offsets=(0, _tensor_nbytes(t)),
|
||||
filename="<tensor>",
|
||||
fst_dtype=None,
|
||||
strides=tuple(t.stride()),
|
||||
)
|
||||
|
||||
def get_tensor(
|
||||
self,
|
||||
key: str,
|
||||
*,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
allow_gds: Optional[bool] = None,
|
||||
pin_if_cpu: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
t = self._overrides[key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
t = t.to(dtype=dtype)
|
||||
return t
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
if key not in self._allowed_keys:
|
||||
meta = self._get_meta(key)
|
||||
target_dtype = dtype or meta.dtype
|
||||
return _meta_tensor(meta, dtype_override=target_dtype)
|
||||
if hasattr(self._base, "get_tensor"):
|
||||
return self._base.get_tensor(
|
||||
key,
|
||||
device=self._device if device is None else device,
|
||||
dtype=dtype,
|
||||
allow_gds=self._allow_gds if allow_gds is None else allow_gds,
|
||||
pin_if_cpu=self._pin_if_cpu if not pin_if_cpu else pin_if_cpu,
|
||||
)
|
||||
t = self._base[key]
|
||||
if device is not None and t.device != device:
|
||||
t = t.to(device=device)
|
||||
if dtype is not None and t.dtype != dtype:
|
||||
t = t.to(dtype=dtype)
|
||||
return t
|
||||
|
||||
def __getitem__(self, key: str) -> torch.Tensor:
|
||||
return self.get_tensor(key)
|
||||
|
||||
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
||||
self._overrides[key] = value
|
||||
self._deleted.discard(key)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
if key in self._overrides:
|
||||
del self._overrides[key]
|
||||
return
|
||||
if key in self._deleted:
|
||||
raise KeyError(key)
|
||||
self._deleted.add(key)
|
||||
|
||||
def __iter__(self):
|
||||
for k in self._base.keys():
|
||||
if k in self._deleted:
|
||||
continue
|
||||
yield k
|
||||
for k in self._overrides.keys():
|
||||
if k not in self._deleted:
|
||||
yield k
|
||||
|
||||
def __len__(self) -> int:
|
||||
base_keys = list(self._base.keys())
|
||||
return len(base_keys) - len(self._deleted) + len(self._overrides)
|
||||
|
||||
def pop(self, key: str, default: object = _MISSING) -> torch.Tensor:
|
||||
if key in self._overrides:
|
||||
return self._overrides.pop(key)
|
||||
if key in self._deleted:
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
if key not in self._base:
|
||||
if default is _MISSING:
|
||||
raise KeyError(key)
|
||||
return default
|
||||
self._deleted.add(key)
|
||||
return self.get_tensor(key)
|
||||
|
||||
def meta(self, key: str):
|
||||
return self._get_meta(key)
|
||||
|
||||
def _has_custom_load(module: torch.nn.Module) -> bool:
|
||||
return module.__class__._load_from_state_dict is not BASE_LOAD_FROM_STATE_DICT
|
||||
|
||||
@ -239,6 +483,7 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
CACHE.remove_module(module)
|
||||
refs = REGISTRY.get(module)
|
||||
if refs:
|
||||
state = _get_materialization_state(module)
|
||||
for ref_name, disk_ref in refs.items():
|
||||
shape = getattr(disk_ref.meta, "shape", None)
|
||||
dtype = getattr(disk_ref.meta, "dtype", None)
|
||||
@ -249,6 +494,14 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
module._buffers[ref_name] = meta_tensor
|
||||
else:
|
||||
module._parameters[ref_name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if nbytes is not None:
|
||||
state.loaded_keys.discard(ref_name)
|
||||
if ref_name not in state.deferred_keys:
|
||||
state.deferred_keys.add(ref_name)
|
||||
state.deferred_bytes += nbytes
|
||||
state.loaded_bytes = max(0, state.loaded_bytes - nbytes)
|
||||
_update_disk_state_attrs(module, state)
|
||||
lazy_state.loaded = False
|
||||
return
|
||||
ref = REGISTRY.get(module)
|
||||
@ -264,6 +517,15 @@ def _evict_module_weight(module: torch.nn.Module, name: str, is_buffer: bool):
|
||||
module._buffers[name] = meta_tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(meta_tensor, requires_grad=disk_ref.requires_grad)
|
||||
state = _get_materialization_state(module)
|
||||
nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if nbytes is not None:
|
||||
state.loaded_keys.discard(name)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += nbytes
|
||||
state.loaded_bytes = max(0, state.loaded_bytes - nbytes)
|
||||
_update_disk_state_attrs(module, state)
|
||||
|
||||
|
||||
def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
|
||||
@ -288,15 +550,23 @@ def _find_tensor_device(args, kwargs) -> Optional[torch.device]:
|
||||
return check(kwargs)
|
||||
|
||||
|
||||
def ensure_module_materialized(module: torch.nn.Module, target_device: torch.device):
|
||||
def ensure_module_materialized(
|
||||
module: torch.nn.Module,
|
||||
target_device: torch.device,
|
||||
fallback_device: Optional[torch.device] = None,
|
||||
):
|
||||
lazy_state = LAZY_MODULE_STATE.get(module)
|
||||
if lazy_state is not None and not lazy_state.loaded:
|
||||
if lazy_state is not None:
|
||||
_materialize_module_from_state_dict(module, lazy_state, target_device)
|
||||
return
|
||||
refs = REGISTRY.get(module)
|
||||
if not refs:
|
||||
return
|
||||
for name, disk_ref in refs.items():
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
remaining_budget = _device_free_memory(target_device)
|
||||
for name in sorted(refs.keys()):
|
||||
disk_ref = refs[name]
|
||||
if name in module._parameters:
|
||||
current = module._parameters[name]
|
||||
is_buffer = False
|
||||
@ -307,30 +577,52 @@ def ensure_module_materialized(module: torch.nn.Module, target_device: torch.dev
|
||||
continue
|
||||
if current is None:
|
||||
continue
|
||||
if current.device.type == "meta":
|
||||
tensor = disk_ref.load(target_device, ALLOW_GDS, PIN_IF_CPU)
|
||||
elif current.device != target_device:
|
||||
tensor = current.to(device=target_device)
|
||||
else:
|
||||
if current.device.type != "meta" and current.device == target_device:
|
||||
if current.device.type == "cpu":
|
||||
CACHE.touch(module, name)
|
||||
continue
|
||||
meta_nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if meta_nbytes is None:
|
||||
continue
|
||||
required_bytes = meta_nbytes
|
||||
if target_device.type == "cpu":
|
||||
free_mem = _maybe_free_ram_budget(target_device, required_bytes)
|
||||
remaining_budget = min(remaining_budget, free_mem)
|
||||
if required_bytes > remaining_budget:
|
||||
if fallback_device is not None and fallback_device != target_device:
|
||||
fallback_free = _maybe_free_ram_budget(fallback_device, required_bytes)
|
||||
if fallback_free >= required_bytes:
|
||||
target_for_load = fallback_device
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
target_for_load = target_device
|
||||
if current.device.type == "meta":
|
||||
tensor = disk_ref.load(target_for_load, ALLOW_GDS, PIN_IF_CPU)
|
||||
else:
|
||||
tensor = current.to(device=target_for_load)
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu":
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
remaining_budget = max(0, remaining_budget - required_bytes)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
|
||||
|
||||
def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
||||
if not REGISTRY.has(module):
|
||||
if not REGISTRY.has(module) and module not in LAZY_MODULE_STATE:
|
||||
return
|
||||
if getattr(module, "comfy_cast_weights", False):
|
||||
target_device = torch.device("cpu")
|
||||
fallback_device = _find_tensor_device(args, kwargs)
|
||||
else:
|
||||
target_device = _find_tensor_device(args, kwargs) or torch.device("cpu")
|
||||
ensure_module_materialized(module, target_device)
|
||||
fallback_device = None
|
||||
ensure_module_materialized(module, target_device, fallback_device=fallback_device)
|
||||
|
||||
|
||||
def attach_disk_weight_hooks(model: torch.nn.Module):
|
||||
@ -386,6 +678,87 @@ def module_to(module: torch.nn.Module, *args, **kwargs):
|
||||
return module.to(*args, **kwargs)
|
||||
|
||||
|
||||
def load_module_tensor(
|
||||
module: torch.nn.Module,
|
||||
name: str,
|
||||
device: torch.device,
|
||||
*,
|
||||
allow_alternate: bool = True,
|
||||
record_cache: bool = True,
|
||||
temporary: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
refs = REGISTRY.get(module)
|
||||
if not refs or name not in refs:
|
||||
return None
|
||||
if name in module._parameters:
|
||||
current = module._parameters[name]
|
||||
is_buffer = False
|
||||
elif name in module._buffers:
|
||||
current = module._buffers[name]
|
||||
is_buffer = True
|
||||
else:
|
||||
return None
|
||||
if current is None:
|
||||
return None
|
||||
if current.device.type != "meta":
|
||||
if current.device != device:
|
||||
tensor = current.to(device=device)
|
||||
if not temporary:
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=refs[name].requires_grad)
|
||||
_rebuild_materialization_state(module, refs, _get_materialization_state(module))
|
||||
return tensor
|
||||
return current
|
||||
|
||||
disk_ref = refs[name]
|
||||
required_bytes = _meta_nbytes(disk_ref.meta)
|
||||
if required_bytes is None:
|
||||
return current
|
||||
free_mem = _maybe_free_ram_budget(device, required_bytes)
|
||||
load_device = device
|
||||
if free_mem < required_bytes and allow_alternate:
|
||||
alt = _choose_alternate_device(device)
|
||||
if alt is not None:
|
||||
alt_free = _maybe_free_ram_budget(alt, required_bytes)
|
||||
if alt_free >= required_bytes:
|
||||
load_device = alt
|
||||
else:
|
||||
state = _get_materialization_state(module)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
return current
|
||||
else:
|
||||
state = _get_materialization_state(module)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
return current
|
||||
elif free_mem < required_bytes:
|
||||
state = _get_materialization_state(module)
|
||||
if name not in state.deferred_keys:
|
||||
state.deferred_keys.add(name)
|
||||
state.deferred_bytes += required_bytes
|
||||
_update_disk_state_attrs(module, state)
|
||||
return current
|
||||
|
||||
tensor = disk_ref.load(load_device, ALLOW_GDS, PIN_IF_CPU)
|
||||
if temporary:
|
||||
return tensor
|
||||
if is_buffer:
|
||||
module._buffers[name] = tensor
|
||||
else:
|
||||
module._parameters[name] = torch.nn.Parameter(tensor, requires_grad=disk_ref.requires_grad)
|
||||
if tensor.device.type == "cpu" and record_cache:
|
||||
CACHE.record(module, name, tensor, is_buffer=is_buffer)
|
||||
_rebuild_materialization_state(module, refs, _get_materialization_state(module))
|
||||
return tensor
|
||||
|
||||
|
||||
def _replace_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor, is_buffer: bool, requires_grad: bool):
|
||||
parts = name.split(".")
|
||||
module = model
|
||||
@ -404,12 +777,42 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
|
||||
error_msgs = []
|
||||
metadata = getattr(lazy_state.state_dict, "_metadata", None)
|
||||
local_metadata = {} if metadata is None else metadata.get(lazy_state.prefix[:-1], {})
|
||||
state_dict = safetensors_stream.DeviceViewStateDict(
|
||||
refs = REGISTRY.get(module) or {}
|
||||
state = _get_materialization_state(module)
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
keys = sorted(lazy_state.state_dict.keys())
|
||||
existing = {}
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
key = f"{lazy_state.prefix}{name}"
|
||||
if key in lazy_state.state_dict and param is not None and param.device.type != "meta":
|
||||
existing[key] = param
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
key = f"{lazy_state.prefix}{name}"
|
||||
if key in lazy_state.state_dict and buf is not None and buf.device.type != "meta":
|
||||
existing[key] = buf
|
||||
remaining_budget = _device_free_memory(target_device)
|
||||
allowed = set(existing.keys())
|
||||
for key in keys:
|
||||
if key in allowed:
|
||||
continue
|
||||
meta = _state_dict_meta(lazy_state.state_dict, key)
|
||||
required = _meta_nbytes(meta)
|
||||
if required is None:
|
||||
continue
|
||||
if target_device.type == "cpu":
|
||||
free_mem = _maybe_free_ram_budget(target_device, required)
|
||||
remaining_budget = min(remaining_budget, free_mem)
|
||||
if required <= remaining_budget:
|
||||
allowed.add(key)
|
||||
remaining_budget = max(0, remaining_budget - required)
|
||||
deferred_state_dict_keys = {key for key in keys if key not in allowed}
|
||||
state_dict = _BudgetedStateDict(
|
||||
lazy_state.state_dict,
|
||||
allowed_keys=allowed,
|
||||
device=target_device,
|
||||
allow_gds=ALLOW_GDS,
|
||||
pin_if_cpu=PIN_IF_CPU,
|
||||
mutate_base=False,
|
||||
overrides=existing,
|
||||
)
|
||||
factory_device = None
|
||||
if hasattr(module, "factory_kwargs") and "device" in module.factory_kwargs:
|
||||
@ -435,7 +838,8 @@ def _materialize_module_from_state_dict(module: torch.nn.Module, lazy_state: Laz
|
||||
module.factory_kwargs["device"] = factory_device
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(module.__class__.__name__, "\n\t".join(error_msgs)))
|
||||
lazy_state.loaded = True
|
||||
_rebuild_materialization_state(module, refs, state)
|
||||
lazy_state.loaded = len(deferred_state_dict_keys) == 0
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.device.type == "cpu":
|
||||
CACHE.record(module, name, param, is_buffer=False)
|
||||
|
||||
51
comfy/ops.py
51
comfy/ops.py
@ -19,6 +19,7 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.disk_weights
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
@ -98,11 +99,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
weight_source = s.weight
|
||||
bias_source = s.bias
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
if weight_source.device.type == "meta":
|
||||
loaded = comfy.disk_weights.load_module_tensor(s, "weight", device, temporary=True)
|
||||
if loaded is not None:
|
||||
weight_source = loaded
|
||||
if bias_source is not None and bias_source.device.type == "meta":
|
||||
loaded_bias = comfy.disk_weights.load_module_tensor(s, "bias", device, temporary=True)
|
||||
if loaded_bias is not None:
|
||||
bias_source = loaded_bias
|
||||
|
||||
weight = comfy.model_management.cast_to(weight_source, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
if bias_source is not None:
|
||||
bias = comfy.model_management.cast_to(bias_source, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
@ -532,9 +545,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
if value.device.type != "meta":
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
@ -551,11 +565,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
if layer_conf is not None and layer_conf.device.type != "meta":
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
elif layer_conf is not None:
|
||||
layer_conf = None
|
||||
|
||||
if layer_conf is None:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
if weight.device.type == "meta":
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
@ -601,10 +620,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
if weight.device.type == "meta":
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
@ -614,7 +636,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
if _v.device.type == "meta":
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
|
||||
else:
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user