From 95e6413587e7086606c239ffa02b03dfe1cc5621 Mon Sep 17 00:00:00 2001 From: Rattus Date: Wed, 12 Nov 2025 20:16:34 +1000 Subject: [PATCH] mp: Account for the VRAM cost of weight offloading when checking the VRAM headroom, assume that the weight needs to be offloaded, and only load if it has space for both the load and offload * the number of streams. As the weights are ordered from largest to smallest by offload cost this is guaranteed to fit in VRAM (tm), as all weights that follow will be smaller. Make the partial unload aware of this system as well by saving the budget for offload VRAM to the model state and accounting accordingly. Its possible that partial unload increases the size of the largest offloaded weights, and thus needs to unload a little bit more than asked to accomodate the bigger temp buffers. Honor the existing codes floor on model weight loading of 128MB by having the patcher honor this separately withough regard to offloading. Otherwise when MM specifies its 128MB minimum, MP will see the biggest weights, and budget that 128MB to only offload buffer and load nothing which isnt the intent of these minimums. The same clamp applies in case of partial offload of the currently loading model. --- comfy/model_patcher.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5241b7b33..3eac77275 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -278,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'current_weight_patches_uuid'): self.model.current_weight_patches_uuid = None + if not hasattr(self.model, 'model_offload_buffer_memory'): + self.model.model_offload_buffer_memory = 0 + def model_size(self): if self.size > 0: return self.size @@ -694,17 +697,22 @@ class ModelPatcher: load_completely = [] offloaded = [] + offload_buffer = 0 loading.sort(reverse=True) for x in loading: module_offload_mem, module_mem, n, m, params = x lowvram_weight = False + potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: + if not lowvram_fits: + offload_buffer = potential_offload lowvram_weight = True lowvram_counter += 1 lowvram_mem_counter += module_mem @@ -738,9 +746,11 @@ class ModelPatcher: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if full_load or lowvram_fits: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + else: + offload_buffer = potential_offload if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -781,7 +791,7 @@ class ModelPatcher: self.pin_weight_to_device("{}.{}".format(n, param)) if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) @@ -793,6 +803,7 @@ class ModelPatcher: self.model.lowvram_patch_counter += patch_counter self.model.device = device_to self.model.model_loaded_weight_memory = mem_counter + self.model.model_offload_buffer_memory = offload_buffer self.model.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): @@ -846,6 +857,7 @@ class ModelPatcher: self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 + self.model.model_offload_buffer_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): @@ -864,11 +876,15 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + for unload in unload_list: - if memory_to_free < memory_freed: + if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break module_offload_mem, module_mem, n, m, params = unload + potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem + lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: move_weight = True @@ -918,15 +934,18 @@ class ModelPatcher: m.comfy_cast_weights = True m.comfy_patched_weights = False memory_freed += module_mem + offload_buffer = max(offload_buffer, potential_offload) logging.debug("freed {}".format(n)) for param in params: self.pin_weight_to_device("{}.{}".format(n, param)) + self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed - logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) + self.model.model_offload_buffer_memory = offload_buffer + logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):