mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
dequantization offload accounting (fixes Flux2 OOMs - incl TEs) (#11171)
* make setattr safe for non existent attributes Handle the case where the attribute doesnt exist by returning a static sentinel (distinct from None). If the sentinel is passed in as the set value, del the attr. * Account for dequantization and type-casts in offload costs When measuring the cost of offload, identify weights that need a type change or dequantization and add the size of the conversion result to the offload cost. This is mutually exclusive with lowvram patches which already has a large conservative estimate and wont overlap the dequant cost so\ dont double count. * Set the compute type on CLIP MPs So that the loader can know the size of weights for dequant accounting.
This commit is contained in:
parent
d50f342c90
commit
e136b6dbb0
@ -35,6 +35,7 @@ import comfy.model_management
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy.comfy_types import UnetWrapperFunction
|
from comfy.comfy_types import UnetWrapperFunction
|
||||||
|
from comfy.quant_ops import QuantizedTensor
|
||||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||||
|
|
||||||
|
|
||||||
@ -665,12 +666,18 @@ class ModelPatcher:
|
|||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
module_offload_mem = module_mem
|
module_offload_mem = module_mem
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
weight_key = "{}.weight".format(n)
|
def check_module_offload_mem(key):
|
||||||
bias_key = "{}.bias".format(n)
|
if key in self.patches:
|
||||||
if weight_key in self.patches:
|
return low_vram_patch_estimate_vram(self.model, key)
|
||||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
model_dtype = getattr(self.model, "manual_cast_dtype", None)
|
||||||
if bias_key in self.patches:
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
if model_dtype is None or weight is None:
|
||||||
|
return 0
|
||||||
|
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
|
||||||
|
return weight.numel() * model_dtype.itemsize
|
||||||
|
return 0
|
||||||
|
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||||
|
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||||
return loading
|
return loading
|
||||||
|
|
||||||
|
|||||||
@ -127,6 +127,8 @@ class CLIP:
|
|||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
#Match torch.float32 hardcode upcast in TE implemention
|
||||||
|
self.patcher.set_model_compute_dtype(torch.float32)
|
||||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
self.patcher.is_clip = True
|
self.patcher.is_clip = True
|
||||||
self.apply_hooks_to_conds = None
|
self.apply_hooks_to_conds = None
|
||||||
|
|||||||
@ -803,11 +803,16 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
|||||||
return None
|
return None
|
||||||
return f.read(length_of_header)
|
return f.read(length_of_header)
|
||||||
|
|
||||||
|
ATTR_UNSET={}
|
||||||
|
|
||||||
def set_attr(obj, attr, value):
|
def set_attr(obj, attr, value):
|
||||||
attrs = attr.split(".")
|
attrs = attr.split(".")
|
||||||
for name in attrs[:-1]:
|
for name in attrs[:-1]:
|
||||||
obj = getattr(obj, name)
|
obj = getattr(obj, name)
|
||||||
prev = getattr(obj, attrs[-1])
|
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
||||||
|
if value is ATTR_UNSET:
|
||||||
|
delattr(obj, attrs[-1])
|
||||||
|
else:
|
||||||
setattr(obj, attrs[-1], value)
|
setattr(obj, attrs[-1], value)
|
||||||
return prev
|
return prev
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user