mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 13:50:15 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
b20f5b1e32
@ -116,6 +116,9 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
|
|||||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
|
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
|
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
|
|||||||
@ -329,10 +329,7 @@ class LoadedModel:
|
|||||||
self.model_use_more_vram(use_more_vram)
|
self.model_use_more_vram(use_more_vram)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
if lowvram_model_memory > 0 and load_weights:
|
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
|
||||||
else:
|
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
self.model_unload()
|
self.model_unload()
|
||||||
@ -384,6 +381,17 @@ def offloaded_memory(loaded_models, device):
|
|||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
return (1024 * 1024 * 1024) * 1.2
|
||||||
|
|
||||||
|
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
|
||||||
|
if any(platform.win32_ver()):
|
||||||
|
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||||
|
|
||||||
|
if args.reserve_vram is not None:
|
||||||
|
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||||
|
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
|
||||||
|
|
||||||
|
def extra_reserved_memory():
|
||||||
|
return EXTRA_RESERVED_VRAM
|
||||||
|
|
||||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
||||||
to_unload = []
|
to_unload = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
@ -453,11 +461,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
|
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
||||||
if minimum_memory_required is None:
|
if minimum_memory_required is None:
|
||||||
minimum_memory_required = extra_mem
|
minimum_memory_required = extra_mem
|
||||||
else:
|
else:
|
||||||
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
|
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||||
|
|
||||||
models = set(models)
|
models = set(models)
|
||||||
|
|
||||||
|
|||||||
@ -96,7 +96,7 @@ class LowVramPatch:
|
|||||||
self.key = key
|
self.key = key
|
||||||
self.model_patcher = model_patcher
|
self.model_patcher = model_patcher
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
@ -336,33 +336,7 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, out_weight)
|
comfy.utils.set_attr_param(self.model, key, out_weight)
|
||||||
|
|
||||||
def patch_model(self, device_to=None, patch_weights=True):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
for k in self.object_patches:
|
|
||||||
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
|
||||||
if k not in self.object_patches_backup:
|
|
||||||
self.object_patches_backup[k] = old
|
|
||||||
|
|
||||||
if patch_weights:
|
|
||||||
model_sd = self.model_state_dict()
|
|
||||||
keys_sort = []
|
|
||||||
for key in self.patches:
|
|
||||||
if key not in model_sd:
|
|
||||||
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
|
||||||
continue
|
|
||||||
keys_sort.append((math.prod(model_sd[key].shape), key))
|
|
||||||
|
|
||||||
keys_sort.sort(reverse=True)
|
|
||||||
for ks in keys_sort:
|
|
||||||
self.patch_weight_to_device(ks[1], device_to)
|
|
||||||
|
|
||||||
if device_to is not None:
|
|
||||||
self.model.to(device_to)
|
|
||||||
self.model.device = device_to
|
|
||||||
self.model.model_loaded_weight_memory = self.model_size()
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
@ -413,15 +387,14 @@ class ModelPatcher:
|
|||||||
m = x[2]
|
m = x[2]
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
param = list(m.parameters())
|
if hasattr(m, "comfy_patched_weights"):
|
||||||
if len(param) > 0:
|
if m.comfy_patched_weights == True:
|
||||||
weight = param[0]
|
|
||||||
if weight.device == device_to:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.patch_weight_to_device(weight_key, device_to=device_to)
|
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||||
self.patch_weight_to_device(bias_key, device_to=device_to)
|
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
m.comfy_patched_weights = True
|
||||||
|
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
x[2].to(device_to)
|
x[2].to(device_to)
|
||||||
@ -430,16 +403,29 @@ class ModelPatcher:
|
|||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
|
if full_load:
|
||||||
|
self.model.to(device_to)
|
||||||
|
mem_counter = self.model_size()
|
||||||
|
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = mem_counter
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
|
|
||||||
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
|
for k in self.object_patches:
|
||||||
|
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
|
||||||
|
if k not in self.object_patches_backup:
|
||||||
|
self.object_patches_backup[k] = old
|
||||||
|
|
||||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
if lowvram_model_memory == 0:
|
||||||
self.patch_model(device_to, patch_weights=False)
|
full_load = True
|
||||||
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
else:
|
||||||
|
full_load = False
|
||||||
|
|
||||||
|
if load_weights:
|
||||||
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
@ -635,6 +621,10 @@ class ModelPatcher:
|
|||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
|
for m in self.model.modules():
|
||||||
|
if hasattr(m, "comfy_patched_weights"):
|
||||||
|
del m.comfy_patched_weights
|
||||||
|
|
||||||
keys = list(self.object_patches_backup.keys())
|
keys = list(self.object_patches_backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||||
@ -662,7 +652,7 @@ class ModelPatcher:
|
|||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
if m.weight is not None and m.weight.device != device_to:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
for key in [weight_key, bias_key]:
|
for key in [weight_key, bias_key]:
|
||||||
bk = self.backup.get(key, None)
|
bk = self.backup.get(key, None)
|
||||||
if bk is not None:
|
if bk is not None:
|
||||||
@ -682,6 +672,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
@ -692,14 +683,14 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0):
|
def partially_load(self, device_to, extra_memory=0):
|
||||||
self.unpatch_model(unpatch_weights=False)
|
self.unpatch_model(unpatch_weights=False)
|
||||||
self.patch_model(patch_weights=False)
|
self.patch_model(load_weights=False)
|
||||||
full_load = False
|
full_load = False
|
||||||
if self.model.model_lowvram == False:
|
if self.model.model_lowvram == False:
|
||||||
return 0
|
return 0
|
||||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
full_load = True
|
full_load = True
|
||||||
current_used = self.model.model_loaded_weight_memory
|
current_used = self.model.model_loaded_weight_memory
|
||||||
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
|
|||||||
@ -365,3 +365,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"VAESave": VAESave,
|
"VAESave": VAESave,
|
||||||
"ModelSave": ModelSave,
|
"ModelSave": ModelSave,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"CheckpointSave": "Save Checkpoint",
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user