mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-04 16:57:31 +08:00
DynamicVram: Add casting / fix torch Buffer weights (#12749)
* respect model dtype in non-comfy caster * utils: factor out parent and name functionality of set_attr * utils: implement set_attr_buffer for torch buffers * ModelPatcherDynamic: Implement torch Buffer loading If there is a buffer in dynamic - force load it.
This commit is contained in:
parent
b6ddc590ed
commit
ac6513e142
@ -796,6 +796,8 @@ def archive_model_dtypes(model):
|
||||
for name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||
for buf_name, buf in module.named_buffers(recurse=False):
|
||||
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
|
||||
@ -241,6 +241,7 @@ class ModelPatcher:
|
||||
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.backup_buffers = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
@ -309,7 +310,7 @@ class ModelPatcher:
|
||||
return comfy.model_management.get_free_memory(device)
|
||||
|
||||
def get_clone_model_override(self):
|
||||
return self.model, (self.backup, self.object_patches_backup, self.pinned)
|
||||
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
|
||||
|
||||
def clone(self, disable_dynamic=False, model_override=None):
|
||||
class_ = self.__class__
|
||||
@ -336,7 +337,7 @@ class ModelPatcher:
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
n.backup, n.object_patches_backup, n.pinned = model_override[1]
|
||||
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
|
||||
|
||||
# attachments
|
||||
n.attachments = {}
|
||||
@ -1579,11 +1580,22 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
|
||||
comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to))
|
||||
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
|
||||
casted_weight = weight.to(dtype=model_dtype, device=device_to)
|
||||
comfy.utils.set_attr_param(self.model, key, casted_weight)
|
||||
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
|
||||
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
for key, buf in self.model.named_buffers(recurse=True):
|
||||
if key not in self.backup_buffers:
|
||||
self.backup_buffers[key] = buf
|
||||
module, buf_name = comfy.utils.resolve_attr(self.model, key)
|
||||
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
|
||||
casted_buf = buf.to(dtype=model_dtype, device=device_to)
|
||||
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
|
||||
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
||||
|
||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
|
||||
@ -1607,6 +1619,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
for key in list(self.backup.keys()):
|
||||
bk = self.backup.pop(key)
|
||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||
for key in list(self.backup_buffers.keys()):
|
||||
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||
freed += self.model.model_loaded_weight_memory
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
|
||||
|
||||
@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
|
||||
ATTR_UNSET={}
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
def resolve_attr(obj, attr):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
||||
return obj, attrs[-1]
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
obj, name = resolve_attr(obj, attr)
|
||||
prev = getattr(obj, name, ATTR_UNSET)
|
||||
if value is ATTR_UNSET:
|
||||
delattr(obj, attrs[-1])
|
||||
delattr(obj, name)
|
||||
else:
|
||||
setattr(obj, attrs[-1], value)
|
||||
setattr(obj, name, value)
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def set_attr_buffer(obj, attr, value):
|
||||
obj, name = resolve_attr(obj, attr)
|
||||
prev = getattr(obj, name, ATTR_UNSET)
|
||||
persistent = name not in getattr(obj, "_non_persistent_buffers_set", set())
|
||||
obj.register_buffer(name, value, persistent=persistent)
|
||||
return prev
|
||||
|
||||
def copy_to_param(obj, attr, value):
|
||||
# inplace update tensor instead of replacing it
|
||||
attrs = attr.split(".")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user