mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Add lowvram hint node
This commit is contained in:
parent
4977f203fa
commit
f17ac1f03d
@ -252,6 +252,9 @@ class ModelPatcher:
|
|||||||
if not hasattr(self.model, 'model_lowvram'):
|
if not hasattr(self.model, 'model_lowvram'):
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
|
|
||||||
|
if not hasattr(self.model, 'lowvram_hints'):
|
||||||
|
self.model.lowvram_hints = []
|
||||||
|
|
||||||
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
||||||
self.model.current_weight_patches_uuid = None
|
self.model.current_weight_patches_uuid = None
|
||||||
|
|
||||||
@ -596,6 +599,7 @@ class ModelPatcher:
|
|||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
loading = self._load_list()
|
loading = self._load_list()
|
||||||
|
hints = self.get_model_object("lowvram_hints")
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
@ -611,7 +615,7 @@ class ModelPatcher:
|
|||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if (mem_counter + module_mem) >= lowvram_model_memory or any(x in n for x in hints):
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
@ -676,10 +680,10 @@ class ModelPatcher:
|
|||||||
x[2].to(device_to)
|
x[2].to(device_to)
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially {}MB {}MB {}".format(round(lowvram_model_memory / (1024 * 1024)), round(mem_counter / (1024 * 1024)), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logging.info("loaded completely {}MB {}MB {}".format(round(lowvram_model_memory / (1024 * 1024)), round(mem_counter / (1024 * 1024)), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
if full_load:
|
if full_load:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
|
|||||||
@ -316,6 +316,32 @@ class ModelComputeDtype:
|
|||||||
return (m, )
|
return (m, )
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLowvramHint:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL", ),
|
||||||
|
"hints": ("STRING", {"default": "img_mlp.\ntxt_mlp.\n", "multiline": True}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "set_hints"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
DESCRIPTION = "Force some weights to always use lowvram. One rule per line."
|
||||||
|
CATEGORY = "advanced/debug/model"
|
||||||
|
|
||||||
|
def set_hints(self, model, hints=""):
|
||||||
|
hints = [x.strip() for x in hints.split("\n") if x.strip()]
|
||||||
|
if not hints:
|
||||||
|
return model
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.add_object_patch("lowvram_hints", hints)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||||
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
"ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
|
||||||
@ -326,4 +352,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelSamplingFlux": ModelSamplingFlux,
|
"ModelSamplingFlux": ModelSamplingFlux,
|
||||||
"RescaleCFG": RescaleCFG,
|
"RescaleCFG": RescaleCFG,
|
||||||
"ModelComputeDtype": ModelComputeDtype,
|
"ModelComputeDtype": ModelComputeDtype,
|
||||||
|
"ModelLowvramHint": ModelLowvramHint,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user