Add lowvram hint node

This commit is contained in:
City 2025-08-20 01:24:17 +02:00
parent 4977f203fa
commit f17ac1f03d
2 changed files with 34 additions and 3 deletions

View File

@ -252,6 +252,9 @@ class ModelPatcher:
if not hasattr(self.model, 'model_lowvram'):
self.model.model_lowvram = False
if not hasattr(self.model, 'lowvram_hints'):
self.model.lowvram_hints = []
if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None
@ -596,6 +599,7 @@ class ModelPatcher:
patch_counter = 0
lowvram_counter = 0
loading = self._load_list()
hints = self.get_model_object("lowvram_hints")
load_completely = []
loading.sort(reverse=True)
@ -611,7 +615,7 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory:
if (mem_counter + module_mem) >= lowvram_model_memory or any(x in n for x in hints):
lowvram_weight = True
lowvram_counter += 1
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
@ -676,10 +680,10 @@ class ModelPatcher:
x[2].to(device_to)
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
logging.info("loaded partially {}MB {}MB {}".format(round(lowvram_model_memory / (1024 * 1024)), round(mem_counter / (1024 * 1024)), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely {}MB {}MB {}".format(round(lowvram_model_memory / (1024 * 1024)), round(mem_counter / (1024 * 1024)), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)

View File

@ -316,6 +316,32 @@ class ModelComputeDtype:
return (m, )
class ModelLowvramHint:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"hints": ("STRING", {"default": "img_mlp.\ntxt_mlp.\n", "multiline": True}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "set_hints"
EXPERIMENTAL = True
DESCRIPTION = "Force some weights to always use lowvram. One rule per line."
CATEGORY = "advanced/debug/model"
def set_hints(self, model, hints=""):
hints = [x.strip() for x in hints.split("\n") if x.strip()]
if not hints:
return model
m = model.clone()
m.add_object_patch("lowvram_hints", hints)
return (m, )
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
@ -326,4 +352,5 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingFlux": ModelSamplingFlux,
"RescaleCFG": RescaleCFG,
"ModelComputeDtype": ModelComputeDtype,
"ModelLowvramHint": ModelLowvramHint,
}