From f17251bec65b5760cfedec29eace7d77f4b35130 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 27 Nov 2025 16:03:03 +1000 Subject: [PATCH] Account for the VRAM cost of weight offloading (#10733) * mm: default to 0 for NUM_STREAMS Dont count the compute stream as an offload stream. This makes async offload accounting easier. * mm: remove 128MB minimum This is from a previous offloading system requirement. Remove it to make behaviour of the loader and partial unloader consistent. * mp: order the module list by offload expense Calculate an approximate offloading temporary VRAM cost to offload a weight and primary order the module load list by that. In the simple case this is just the same as the module weight, but with Loras, a weight with a lora consumes considerably more VRAM to do the Lora application on-the-fly. This will slightly prioritize lora weights, but is really for proper VRAM offload accounting. * mp: Account for the VRAM cost of weight offloading when checking the VRAM headroom, assume that the weight needs to be offloaded, and only load if it has space for both the load and offload * the number of streams. As the weights are ordered from largest to smallest by offload cost this is guaranteed to fit in VRAM (tm), as all weights that follow will be smaller. Make the partial unload aware of this system as well by saving the budget for offload VRAM to the model state and accounting accordingly. Its possible that partial unload increases the size of the largest offloaded weights, and thus needs to unload a little bit more than asked to accomodate the bigger temp buffers. Honor the existing codes floor on model weight loading of 128MB by having the patcher honor this separately withough regard to offloading. Otherwise when MM specifies its 128MB minimum, MP will see the biggest weights, and budget that 128MB to only offload buffer and load nothing which isnt the intent of these minimums. The same clamp applies in case of partial offload of the currently loading model. --- comfy/model_management.py | 6 ++-- comfy/model_patcher.py | 59 +++++++++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a9327ac80..9c403d580 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = lowvram_model_memory - loaded_memory if lowvram_model_memory == 0: @@ -1012,7 +1012,7 @@ def force_channels_last(): STREAMS = {} -NUM_STREAMS = 1 +NUM_STREAMS = 0 if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) @@ -1030,7 +1030,7 @@ def current_stream(device): stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) - if NUM_STREAMS <= 1: + if NUM_STREAMS == 0: return None if device in STREAMS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 73adc7f70..3eac77275 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -148,6 +148,15 @@ class LowVramPatch: else: return out +#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 + +def low_vram_patch_estimate_vram(model, key): + weight, set_func, convert_func = get_key_weight(model, key) + if weight is None: + return 0 + return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + def get_key_weight(model, key): set_func = None convert_func = None @@ -269,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'current_weight_patches_uuid'): self.model.current_weight_patches_uuid = None + if not hasattr(self.model, 'model_offload_buffer_memory'): + self.model.model_offload_buffer_memory = 0 + def model_size(self): if self.size > 0: return self.size @@ -662,7 +674,16 @@ class ModelPatcher: skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append((comfy.model_management.module_size(m), n, m, params)) + module_mem = comfy.model_management.module_size(m) + module_offload_mem = module_mem + if hasattr(m, "comfy_cast_weights"): + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if weight_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) + if bias_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + loading.append((module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -676,20 +697,22 @@ class ModelPatcher: load_completely = [] offloaded = [] + offload_buffer = 0 loading.sort(reverse=True) for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + module_offload_mem, module_mem, n, m, params = x lowvram_weight = False + potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: + if not lowvram_fits: + offload_buffer = potential_offload lowvram_weight = True lowvram_counter += 1 lowvram_mem_counter += module_mem @@ -723,9 +746,11 @@ class ModelPatcher: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if full_load or lowvram_fits: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + else: + offload_buffer = potential_offload if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -766,7 +791,7 @@ class ModelPatcher: self.pin_weight_to_device("{}.{}".format(n, param)) if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) @@ -778,6 +803,7 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + self.model.model_offload_buffer_memory = offload_buffer self.model.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): @@ -831,6 +857,7 @@ class ModelPatcher: self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 + self.model.model_offload_buffer_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): @@ -849,13 +876,14 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + for unload in unload_list: - if memory_to_free < memory_freed: + if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + module_offload_mem, module_mem, n, m, params = unload + + potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -906,15 +934,18 @@ class ModelPatcher: m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem + offload_buffer = max(offload_buffer, potential_offload) logging.debug("freed {}".format(n)) for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed - logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) + self.model.model_offload_buffer_memory = offload_buffer + logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):