From ac6513e142f881202c40eacc5e337982b777ccd0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Mar 2026 18:19:40 -0800 Subject: [PATCH] DynamicVram: Add casting / fix torch Buffer weights (#12749) * respect model dtype in non-comfy caster * utils: factor out parent and name functionality of set_attr * utils: implement set_attr_buffer for torch buffers * ModelPatcherDynamic: Implement torch Buffer loading If there is a buffer in dynamic - force load it. --- comfy/model_management.py | 2 ++ comfy/model_patcher.py | 22 ++++++++++++++++++---- comfy/utils.py | 19 +++++++++++++++---- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0e0e96672..0f5966371 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -796,6 +796,8 @@ def archive_model_dtypes(model): for name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + for buf_name, buf in module.named_buffers(recurse=False): + setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype) def cleanup_models(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e380e406b..70f78a089 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -241,6 +241,7 @@ class ModelPatcher: self.patches = {} self.backup = {} + self.backup_buffers = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} @@ -309,7 +310,7 @@ class ModelPatcher: return comfy.model_management.get_free_memory(device) def get_clone_model_override(self): - return self.model, (self.backup, self.object_patches_backup, self.pinned) + return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ @@ -336,7 +337,7 @@ class ModelPatcher: n.force_cast_weights = self.force_cast_weights - n.backup, n.object_patches_backup, n.pinned = model_override[1] + n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1] # attachments n.attachments = {} @@ -1579,11 +1580,22 @@ class ModelPatcherDynamic(ModelPatcher): weight, _, _ = get_key_weight(self.model, key) if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) - comfy.utils.set_attr_param(self.model, key, weight.to(device=device_to)) - self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() + model_dtype = getattr(m, param + "_comfy_model_dtype", None) + casted_weight = weight.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_param(self.model, key, casted_weight) + self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size() move_weight_functions(m, device_to) + for key, buf in self.model.named_buffers(recurse=True): + if key not in self.backup_buffers: + self.backup_buffers[key] = buf + module, buf_name = comfy.utils.resolve_attr(self.model, key) + model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None) + casted_buf = buf.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_buffer(self.model, key, casted_buf) + self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") @@ -1607,6 +1619,8 @@ class ModelPatcherDynamic(ModelPatcher): for key in list(self.backup.keys()): bk = self.backup.pop(key) comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) freed += self.model.model_loaded_weight_memory self.model.model_loaded_weight_memory = 0 diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef44..6e1d14419 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): ATTR_UNSET={} -def set_attr(obj, attr, value): +def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1], ATTR_UNSET) + return obj, attrs[-1] + +def set_attr(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: - delattr(obj, attrs[-1]) + delattr(obj, name) else: - setattr(obj, attrs[-1], value) + setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) +def set_attr_buffer(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) + persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) + obj.register_buffer(name, value, persistent=persistent) + return prev + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".")