mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 02:07:32 +08:00
DynamicVram: Add casting / fix torch Buffer weights (#12749)
* respect model dtype in non-comfy caster * utils: factor out parent and name functionality of set_attr * utils: implement set_attr_buffer for torch buffers * ModelPatcherDynamic: Implement torch Buffer loading If there is a buffer in dynamic - force load it.
This commit is contained in:
parent
b6ddc590ed
commit
ac6513e142
@ -796,6 +796,8 @@ def archive_model_dtypes(model):
|
|||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
for param_name, param in module.named_parameters(recurse=False):
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
|
||||||
|
for buf_name, buf in module.named_buffers(recurse=False):
|
||||||
|
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_models():
|
def cleanup_models():
|
||||||
|
|||||||
@ -241,6 +241,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.patches = {}
|
self.patches = {}
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
self.backup_buffers = {}
|
||||||
self.object_patches = {}
|
self.object_patches = {}
|
||||||
self.object_patches_backup = {}
|
self.object_patches_backup = {}
|
||||||
self.weight_wrapper_patches = {}
|
self.weight_wrapper_patches = {}
|
||||||
@ -309,7 +310,7 @@ class ModelPatcher:
|
|||||||
return comfy.model_management.get_free_memory(device)
|
return comfy.model_management.get_free_memory(device)
|
||||||
|
|
||||||
def get_clone_model_override(self):
|
def get_clone_model_override(self):
|
||||||
return self.model, (self.backup, self.object_patches_backup, self.pinned)
|
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
|
||||||
|
|
||||||
def clone(self, disable_dynamic=False, model_override=None):
|
def clone(self, disable_dynamic=False, model_override=None):
|
||||||
class_ = self.__class__
|
class_ = self.__class__
|
||||||
@ -336,7 +337,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
n.force_cast_weights = self.force_cast_weights
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
n.backup, n.object_patches_backup, n.pinned = model_override[1]
|
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
|
||||||
|
|
||||||
# attachments
|
# attachments
|
||||||
n.attachments = {}
|
n.attachments = {}
|
||||||
@ -1579,11 +1580,22 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
weight, _, _ = get_key_weight(self.model, key)
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
|
||||||
comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to))
|
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
|
||||||
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
|
casted_weight = weight.to(dtype=model_dtype, device=device_to)
|
||||||
|
comfy.utils.set_attr_param(self.model, key, casted_weight)
|
||||||
|
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
|
||||||
|
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
|
for key, buf in self.model.named_buffers(recurse=True):
|
||||||
|
if key not in self.backup_buffers:
|
||||||
|
self.backup_buffers[key] = buf
|
||||||
|
module, buf_name = comfy.utils.resolve_attr(self.model, key)
|
||||||
|
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
|
||||||
|
casted_buf = buf.to(dtype=model_dtype, device=device_to)
|
||||||
|
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
|
||||||
|
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
||||||
|
|
||||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||||
|
|
||||||
@ -1607,6 +1619,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
for key in list(self.backup.keys()):
|
for key in list(self.backup.keys()):
|
||||||
bk = self.backup.pop(key)
|
bk = self.backup.pop(key)
|
||||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
|
for key in list(self.backup_buffers.keys()):
|
||||||
|
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
|
||||||
freed += self.model.model_loaded_weight_memory
|
freed += self.model.model_loaded_weight_memory
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
|
|||||||
@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
|||||||
|
|
||||||
ATTR_UNSET={}
|
ATTR_UNSET={}
|
||||||
|
|
||||||
def set_attr(obj, attr, value):
|
def resolve_attr(obj, attr):
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
for name in attrs[:-1]:
|
for name in attrs[:-1]:
|
||||||
obj = getattr(obj, name)
|
obj = getattr(obj, name)
|
||||||
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
return obj, attrs[-1]
|
||||||
|
|
||||||
|
def set_attr(obj, attr, value):
|
||||||
|
obj, name = resolve_attr(obj, attr)
|
||||||
|
prev = getattr(obj, name, ATTR_UNSET)
|
||||||
if value is ATTR_UNSET:
|
if value is ATTR_UNSET:
|
||||||
delattr(obj, attrs[-1])
|
delattr(obj, name)
|
||||||
else:
|
else:
|
||||||
setattr(obj, attrs[-1], value)
|
setattr(obj, name, value)
|
||||||
return prev
|
return prev
|
||||||
|
|
||||||
def set_attr_param(obj, attr, value):
|
def set_attr_param(obj, attr, value):
|
||||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||||
|
|
||||||
|
def set_attr_buffer(obj, attr, value):
|
||||||
|
obj, name = resolve_attr(obj, attr)
|
||||||
|
prev = getattr(obj, name, ATTR_UNSET)
|
||||||
|
persistent = name not in getattr(obj, "_non_persistent_buffers_set", set())
|
||||||
|
obj.register_buffer(name, value, persistent=persistent)
|
||||||
|
return prev
|
||||||
|
|
||||||
def copy_to_param(obj, attr, value):
|
def copy_to_param(obj, attr, value):
|
||||||
# inplace update tensor instead of replacing it
|
# inplace update tensor instead of replacing it
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user