Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-08-20 00:31:41 +03:00 committed by GitHub
commit b20f5b1e32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 51 additions and 45 deletions

View File

@ -116,6 +116,9 @@ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet i
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")

View File

@ -329,10 +329,7 @@ class LoadedModel:
self.model_use_more_vram(use_more_vram)
else:
try:
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
@ -384,6 +381,17 @@ def offloaded_memory(loaded_models, device):
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
if any(platform.win32_ver()):
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 #Windows is higher because of the shared vram issue
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
def extra_reserved_memory():
return EXTRA_RESERVED_VRAM
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = []
for i in range(len(current_loaded_models)):
@ -453,11 +461,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
global vram_state
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models = set(models)

View File

@ -96,7 +96,7 @@ class LowVramPatch:
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
class ModelPatcher:
@ -336,33 +336,7 @@ class ModelPatcher:
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
if patch_weights:
model_sd = self.model_state_dict()
keys_sort = []
for key in self.patches:
if key not in model_sd:
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
keys_sort.append((math.prod(model_sd[key].shape), key))
keys_sort.sort(reverse=True)
for ks in keys_sort:
self.patch_weight_to_device(ks[1], device_to)
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = self.model_size()
return self.model
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
@ -413,15 +387,14 @@ class ModelPatcher:
m = x[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
param = list(m.parameters())
if len(param) > 0:
weight = param[0]
if weight.device == device_to:
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
self.patch_weight_to_device(weight_key, device_to=device_to)
self.patch_weight_to_device(bias_key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
for x in load_completely:
x[2].to(device_to)
@ -430,16 +403,29 @@ class ModelPatcher:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)
mem_counter = self.model_size()
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches:
old = comfy.utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
if lowvram_model_memory == 0:
full_load = True
else:
full_load = False
if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
return self.model
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
@ -635,6 +621,10 @@ class ModelPatcher:
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
del m.comfy_patched_weights
keys = list(self.object_patches_backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
@ -662,7 +652,7 @@ class ModelPatcher:
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to:
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
@ -682,6 +672,7 @@ class ModelPatcher:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
@ -692,14 +683,14 @@ class ModelPatcher:
def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
self.patch_model(patch_weights=False)
self.patch_model(load_weights=False)
full_load = False
if self.model.model_lowvram == False:
return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
current_used = self.model.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self.model.model_loaded_weight_memory - current_used
def current_loaded_device(self):

View File

@ -365,3 +365,7 @@ NODE_CLASS_MAPPINGS = {
"VAESave": VAESave,
"ModelSave": ModelSave,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointSave": "Save Checkpoint",
}