Unload weights if vram usage goes up between runs. (#10690)

This commit is contained in:
comfyanonymous 2025-11-09 15:51:33 -08:00 committed by GitHub
parent e632e5de28
commit dea899f221
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 9 deletions

View File

@ -503,7 +503,11 @@ class LoadedModel:
use_more_vram = lowvram_model_memory use_more_vram = lowvram_model_memory
if use_more_vram == 0: if use_more_vram == 0:
use_more_vram = 1e32 use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) if use_more_vram > 0:
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
else:
self.model.partially_unload(self.model.offload_device, -use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
@ -689,7 +693,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_free_mem = get_free_memory(torch_dev) + loaded_memory current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0:
lowvram_model_memory = 0.1
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1 lowvram_model_memory = 0.1

View File

@ -843,7 +843,7 @@ class ModelPatcher:
self.object_patches_backup.clear() self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected(): with self.use_ejected():
hooks_unpatched = False hooks_unpatched = False
memory_freed = 0 memory_freed = 0
@ -887,13 +887,19 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to) module_mem += move_weight_functions(m, device_to)
if lowvram_possible: if lowvram_possible:
if weight_key in self.patches: if weight_key in self.patches:
_, set_func, convert_func = get_key_weight(self.model, weight_key) if force_patch_weights:
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) self.patch_weight_to_device(weight_key)
patch_counter += 1 else:
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
_, set_func, convert_func = get_key_weight(self.model, bias_key) if force_patch_weights:
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) self.patch_weight_to_device(bias_key)
patch_counter += 1 else:
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
cast_weight = True cast_weight = True
if cast_weight: if cast_weight: