Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-03-29 13:35:41 -07:00
commit 8f548d4d19
10 changed files with 82 additions and 28 deletions

View File

@ -50,7 +50,7 @@ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeFor
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000", "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630" "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
} }
def cuda_malloc_supported(): def cuda_malloc_supported():

View File

@ -409,7 +409,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x) d = self.outputs_ui.pop(x)
del d del d
model_management.cleanup_models() model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached", self.add_message("execution_cached",
{"nodes": list(current_outputs), "prompt_id": prompt_id}, {"nodes": list(current_outputs), "prompt_id": prompt_id},
broadcast=False) broadcast=False)

View File

@ -21,6 +21,12 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item() alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name) loaded_keys.add(alpha_name)
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
regular_lora = "{}.lora_up.weight".format(x) regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
@ -44,7 +50,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys(): if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name] mid = lora[mid_name]
loaded_keys.add(mid_name) loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
@ -65,7 +71,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name) loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
@ -117,7 +123,7 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name) loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
#glora #glora
a1_name = "{}.a1.weight".format(x) a1_name = "{}.a1.weight".format(x)
@ -125,7 +131,7 @@ def load_lora(lora, to_load):
b1_name = "{}.b1.weight".format(x) b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x) b2_name = "{}.b2.weight".format(x)
if a1_name in lora: if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name) loaded_keys.add(a1_name)
loaded_keys.add(a2_name) loaded_keys.add(a2_name)
loaded_keys.add(b1_name) loaded_keys.add(b1_name)

View File

@ -345,7 +345,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B] SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True

View File

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Literal
import psutil import psutil
import logging import logging
from enum import Enum from enum import Enum
@ -278,6 +282,7 @@ class LoadedModel:
self.model = model self.model = model
self.device = model.load_device self.device = model.load_device
self.weights_loaded = False self.weights_loaded = False
self.real_model = None
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
@ -316,6 +321,7 @@ class LoadedModel:
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
@ -323,7 +329,7 @@ class LoadedModel:
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) return (1024 * 1024 * 1024)
def unload_model_clones(model, unload_weights_only=True, force_unload=True): def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
with model_management_lock: with model_management_lock:
to_unload = [] to_unload = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
@ -331,7 +337,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [i] + to_unload to_unload = [i] + to_unload
if len(to_unload) == 0: if len(to_unload) == 0:
return None return True
same_weights = 0 same_weights = 0
for i in to_unload: for i in to_unload:
@ -355,20 +361,27 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
with model_management_lock: with model_management_lock:
unloaded_model = False unloaded_model = []
can_unload = []
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if shift_model.device == device:
if shift_model not in keep_loaded: if shift_model not in keep_loaded:
m = current_loaded_models.pop(i) can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
m.model_unload()
del m
unloaded_model = True
if unloaded_model: for x in sorted(can_unload):
i = x[-1]
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
current_loaded_models.pop(i)
if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
else: else:
if vram_state != VRAMState.HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
@ -408,8 +421,8 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False):#unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
@ -449,12 +462,16 @@ def load_model_gpu(model):
with model_management_lock: with model_management_lock:
return load_models_gpu([model]) return load_models_gpu([model])
def cleanup_models(): def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock: with model_management_lock:
to_delete = [] to_delete = []
for i in range(len(current_loaded_models)): for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2: if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
for i in to_delete: for i in to_delete:
x = current_loaded_models.pop(i) x = current_loaded_models.pop(i)

View File

@ -7,6 +7,18 @@ import uuid
from . import utils from . import utils
from . import model_management from . import model_management
def apply_weight_decompose(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
.transpose(0, 1)
)
return weight * (dora_scale / weight_norm)
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size self.size = size
@ -309,6 +321,7 @@ class ModelPatcher:
elif patch_type == "lora": #lora/locon elif patch_type == "lora": #lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32) mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32) mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None: if v[2] is not None:
alpha *= v[2] / mat2.shape[0] alpha *= v[2] / mat2.shape[0]
if v[3] is not None: if v[3] is not None:
@ -318,6 +331,8 @@ class ModelPatcher:
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try: try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr": elif patch_type == "lokr":
@ -328,6 +343,7 @@ class ModelPatcher:
w2_a = v[5] w2_a = v[5]
w2_b = v[6] w2_b = v[6]
t2 = v[7] t2 = v[7]
dora_scale = v[8]
dim = None dim = None
if w1 is None: if w1 is None:
@ -357,6 +373,8 @@ class ModelPatcher:
try: try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha": elif patch_type == "loha":
@ -366,6 +384,7 @@ class ModelPatcher:
alpha *= v[2] / w1b.shape[0] alpha *= v[2] / w1b.shape[0]
w2a = v[3] w2a = v[3]
w2b = v[4] w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition if v[5] is not None: #cp decomposition
t1 = v[5] t1 = v[5]
t2 = v[6] t2 = v[6]
@ -386,12 +405,16 @@ class ModelPatcher:
try: try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora": elif patch_type == "glora":
if v[4] is not None: if v[4] is not None:
alpha *= v[4] / v[0].shape[0] alpha *= v[4] / v[0].shape[0]
dora_scale = v[5]
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
@ -399,6 +422,8 @@ class ModelPatcher:
try: try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e: except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e)) logging.error("ERROR {} {} {}".format(patch_type, key, e))
else: else:

View File

@ -1 +1 @@
MAX_RESOLUTION=8192 MAX_RESOLUTION=16384

View File

@ -70,8 +70,8 @@ class SD20(supported_models_base.BASE):
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k] out = state_dict.get(k, None)
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return model_base.ModelType.V_PREDICTION return model_base.ModelType.V_PREDICTION
return model_base.ModelType.EPS return model_base.ModelType.EPS

View File

@ -204,13 +204,13 @@ class Sharpen:
"default": 1.0, "default": 1.0,
"min": 0.1, "min": 0.1,
"max": 10.0, "max": 10.0,
"step": 0.1 "step": 0.01
}), }),
"alpha": ("FLOAT", { "alpha": ("FLOAT", {
"default": 1.0, "default": 1.0,
"min": 0.0, "min": 0.0,
"max": 5.0, "max": 5.0,
"step": 0.1 "step": 0.01
}), }),
}, },
} }

View File

@ -947,7 +947,7 @@ describe("group node", () => {
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max expect(p2.widgets.value.widget.options?.max).toBe(16384); // width/height max
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p1.widgets.value.value).toBe(128); expect(p1.widgets.value.value).toBe(128);