diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2e8ce2613..a486c2723 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -665,12 +666,18 @@ class ModelPatcher: module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if weight_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) - if bias_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + def check_module_offload_mem(key): + if key in self.patches: + return low_vram_patch_estimate_vram(self.model, key) + model_dtype = getattr(self.model, "manual_cast_dtype", None) + weight, _, _ = get_key_weight(self.model, key) + if model_dtype is None or weight is None: + return 0 + if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)): + return weight.numel() * model_dtype.itemsize + return 0 + module_offload_mem += check_module_offload_mem("{}.weight".format(n)) + module_offload_mem += check_module_offload_mem("{}.bias".format(n)) loading.append((module_offload_mem, module_mem, n, m, params)) return loading diff --git a/comfy/sd.py b/comfy/sd.py index 754b1703d..a16f2d14f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -127,6 +127,8 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + #Match torch.float32 hardcode upcast in TE implemention + self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.is_clip = True self.apply_hooks_to_conds = None diff --git a/comfy/utils.py b/comfy/utils.py index 89846bc95..9dc0d76ac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -803,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): return None return f.read(length_of_header) +ATTR_UNSET={} + def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], value) + prev = getattr(obj, attrs[-1], ATTR_UNSET) + if value is ATTR_UNSET: + delattr(obj, attrs[-1]) + else: + setattr(obj, attrs[-1], value) return prev def set_attr_param(obj, attr, value):