mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-20 19:30:20 +08:00
mp: order the module list by offload expense
Calculate an approximate offloading temporary VRAM cost to offload a weight and primary order the module load list by that. In the simple case this is just the same as the module weight, but with Loras, a weight with a lora consumes considerably more VRAM to do the Lora application on-the-fly. This will slightly prioritize lora weights, but is really for proper VRAM offload accounting.
This commit is contained in:
parent
1a7b1d6846
commit
74fe85e0a8
@ -148,6 +148,15 @@ class LowVramPatch:
|
|||||||
else:
|
else:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||||
|
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||||
|
|
||||||
|
def low_vram_patch_estimate_vram(model, key):
|
||||||
|
weight, set_func, convert_func = get_key_weight(model, key)
|
||||||
|
if weight is None:
|
||||||
|
return 0
|
||||||
|
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||||
|
|
||||||
def get_key_weight(model, key):
|
def get_key_weight(model, key):
|
||||||
set_func = None
|
set_func = None
|
||||||
convert_func = None
|
convert_func = None
|
||||||
@ -662,7 +671,16 @@ class ModelPatcher:
|
|||||||
skip = True # skip random weights in non leaf modules
|
skip = True # skip random weights in non leaf modules
|
||||||
break
|
break
|
||||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
module_mem = comfy.model_management.module_size(m)
|
||||||
|
module_offload_mem = module_mem
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
if weight_key in self.patches:
|
||||||
|
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
||||||
|
if bias_key in self.patches:
|
||||||
|
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
||||||
|
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||||
return loading
|
return loading
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
@ -678,10 +696,7 @@ class ModelPatcher:
|
|||||||
offloaded = []
|
offloaded = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
n = x[1]
|
module_offload_mem, module_mem, n, m, params = x
|
||||||
m = x[2]
|
|
||||||
params = x[3]
|
|
||||||
module_mem = x[0]
|
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
@ -852,10 +867,7 @@ class ModelPatcher:
|
|||||||
for unload in unload_list:
|
for unload in unload_list:
|
||||||
if memory_to_free < memory_freed:
|
if memory_to_free < memory_freed:
|
||||||
break
|
break
|
||||||
module_mem = unload[0]
|
module_offload_mem, module_mem, n, m, params = unload
|
||||||
n = unload[1]
|
|
||||||
m = unload[2]
|
|
||||||
params = unload[3]
|
|
||||||
|
|
||||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user