mp: order the module list by offload expense

Calculate an approximate offloading temporary VRAM cost to offload a
weight and primary order the module load list by that. In the simple
case this is just the same as the module weight, but with Loras, a
weight with a lora consumes considerably more VRAM to do the Lora
application on-the-fly.

This will slightly prioritize lora weights, but is really for
proper VRAM offload accounting.
This commit is contained in:
Rattus 2025-11-12 20:13:18 +10:00
parent 1a7b1d6846
commit 74fe85e0a8

View File

@ -148,6 +148,15 @@ class LowVramPatch:
else:
return out
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
convert_func = None
@ -662,7 +671,16 @@ class ModelPatcher:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params))
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
loading.append((module_offload_mem, module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -678,10 +696,7 @@ class ModelPatcher:
offloaded = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
params = x[3]
module_mem = x[0]
module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
@ -852,10 +867,7 @@ class ModelPatcher:
for unload in unload_list:
if memory_to_free < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
params = unload[3]
module_offload_mem, module_mem, n, m, params = unload
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: