From 74fe85e0a8a6cf2b4fe1da1a5a8a0b5f8e98db5b Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 12 Nov 2025 20:13:18 +1000 Subject: [PATCH] mp: order the module list by offload expense Calculate an approximate offloading temporary VRAM cost to offload a weight and primary order the module load list by that. In the simple case this is just the same as the module weight, but with Loras, a weight with a lora consumes considerably more VRAM to do the Lora application on-the-fly. This will slightly prioritize lora weights, but is really for proper VRAM offload accounting. --- comfy/model_patcher.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 73adc7f70..5241b7b33 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -148,6 +148,15 @@ class LowVramPatch: else: return out +#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 + +def low_vram_patch_estimate_vram(model, key): + weight, set_func, convert_func = get_key_weight(model, key) + if weight is None: + return 0 + return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + def get_key_weight(model, key): set_func = None convert_func = None @@ -662,7 +671,16 @@ class ModelPatcher: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + module_mem = comfy.model_management.module_size(m) + module_offload_mem = module_mem + if hasattr(m, "comfy_cast_weights"): + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if weight_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) + if bias_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + loading.append((module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -678,10 +696,7 @@ class ModelPatcher: offloaded = [] loading.sort(reverse=True) for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + module_offload_mem, module_mem, n, m, params = x lowvram_weight = False @@ -852,10 +867,7 @@ class ModelPatcher: for unload in unload_list: if memory_to_free < memory_freed: break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + module_offload_mem, module_mem, n, m, params = unload lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: