mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Add lowvram hint node
This commit is contained in:
parent
4977f203fa
commit
f17ac1f03d
@ -252,6 +252,9 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'model_lowvram'):
|
||||
self.model.model_lowvram = False
|
||||
|
||||
if not hasattr(self.model, 'lowvram_hints'):
|
||||
self.model.lowvram_hints = []
|
||||
|
||||
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
||||
self.model.current_weight_patches_uuid = None
|
||||
|
||||
@ -596,6 +599,7 @@ class ModelPatcher:
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
loading = self._load_list()
|
||||
hints = self.get_model_object("lowvram_hints")
|
||||
|
||||
load_completely = []
|
||||
loading.sort(reverse=True)
|
||||
@ -611,7 +615,7 @@ class ModelPatcher:
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
if (mem_counter + module_mem) >= lowvram_model_memory or any(x in n for x in hints):
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
@ -676,10 +680,10 @@ class ModelPatcher:
|
||||
x[2].to(device_to)
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially {}MB {}MB {}".format(round(lowvram_model_memory / (1024 * 1024)), round(mem_counter / (1024 * 1024)), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info("loaded completely {}MB {}MB {}".format(round(lowvram_model_memory / (1024 * 1024)), round(mem_counter / (1024 * 1024)), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
|
||||
@ -316,6 +316,32 @@ class ModelComputeDtype:
|
||||
return (m, )
|
||||
|
||||
|
||||
class ModelLowvramHint:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", ),
|
||||
"hints": ("STRING", {"default": "img_mlp.\ntxt_mlp.\n", "multiline": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "set_hints"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
DESCRIPTION = "Force some weights to always use lowvram. One rule per line."
|
||||
CATEGORY = "advanced/debug/model"
|
||||
|
||||
def set_hints(self, model, hints=""):
|
||||
hints = [x.strip() for x in hints.split("\n") if x.strip()]
|
||||
if not hints:
|
||||
return model
|
||||
|
||||
m = model.clone()
|
||||
m.add_object_patch("lowvram_hints", hints)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||
@ -326,4 +352,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingFlux": ModelSamplingFlux,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
"ModelComputeDtype": ModelComputeDtype,
|
||||
"ModelLowvramHint": ModelLowvramHint,
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user