Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-03-29 13:35:41 -07:00
commit 8f548d4d19
10 changed files with 82 additions and 28 deletions

View File

@ -50,7 +50,7 @@ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeFor
"Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000",
"Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000",
"GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M",
"GeForce GTX 1650", "GeForce GTX 1630"
"GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60"
}
def cuda_malloc_supported():

View File

@ -409,7 +409,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x)
del d
model_management.cleanup_models()
model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",
{"nodes": list(current_outputs), "prompt_id": prompt_id},
broadcast=False)

View File

@ -21,6 +21,12 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
@ -44,7 +50,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@ -65,7 +71,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
@ -117,7 +123,7 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
#glora
a1_name = "{}.a1.weight".format(x)
@ -125,7 +131,7 @@ def load_lora(lora, to_load):
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)

View File

@ -345,7 +345,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B]
SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS]
for unet_config in supported_models:
matches = True

View File

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Literal
import psutil
import logging
from enum import Enum
@ -278,6 +282,7 @@ class LoadedModel:
self.model = model
self.device = model.load_device
self.weights_loaded = False
self.real_model = None
def model_memory(self):
return self.model.model_size()
@ -316,6 +321,7 @@ class LoadedModel:
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
def __eq__(self, other):
return self.model is other.model
@ -323,7 +329,7 @@ class LoadedModel:
def minimum_inference_memory():
return (1024 * 1024 * 1024)
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
with model_management_lock:
to_unload = []
for i in range(len(current_loaded_models)):
@ -331,7 +337,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return None
return True
same_weights = 0
for i in to_unload:
@ -355,20 +361,27 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
def free_memory(memory_required, device, keep_loaded=[]):
with model_management_lock:
unloaded_model = False
unloaded_model = []
can_unload = []
for i in range(len(current_loaded_models) -1, -1, -1):
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
m = current_loaded_models.pop(i)
m.model_unload()
del m
unloaded_model = True
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
if unloaded_model:
for x in sorted(can_unload):
i = x[-1]
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
current_loaded_models.pop(i)
if len(unloaded_model) > 0:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
@ -408,8 +421,8 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False):#unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
@ -449,12 +462,16 @@ def load_model_gpu(model):
with model_management_lock:
return load_models_gpu([model])
def cleanup_models():
def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock:
to_delete = []
for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
for i in to_delete:
x = current_loaded_models.pop(i)

View File

@ -7,6 +7,18 @@ import uuid
from . import utils
from . import model_management
def apply_weight_decompose(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
.transpose(0, 1)
)
return weight * (dora_scale / weight_norm)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
@ -309,6 +321,7 @@ class ModelPatcher:
elif patch_type == "lora": #lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
@ -318,6 +331,8 @@ class ModelPatcher:
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
@ -328,6 +343,7 @@ class ModelPatcher:
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
@ -357,6 +373,8 @@ class ModelPatcher:
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
@ -366,6 +384,7 @@ class ModelPatcher:
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
@ -386,12 +405,16 @@ class ModelPatcher:
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
dora_scale = v[5]
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
@ -399,6 +422,8 @@ class ModelPatcher:
try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:

View File

@ -1 +1 @@
MAX_RESOLUTION=8192
MAX_RESOLUTION=16384

View File

@ -70,8 +70,8 @@ class SD20(supported_models_base.BASE):
def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
out = state_dict.get(k, None)
if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return model_base.ModelType.V_PREDICTION
return model_base.ModelType.EPS

View File

@ -204,13 +204,13 @@ class Sharpen:
"default": 1.0,
"min": 0.1,
"max": 10.0,
"step": 0.1
"step": 0.01
}),
"alpha": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 5.0,
"step": 0.1
"step": 0.01
}),
},
}

View File

@ -947,7 +947,7 @@ describe("group node", () => {
expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min
expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max
expect(p2.widgets.value.widget.options?.max).toBe(16384); // width/height max
expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10
expect(p1.widgets.value.value).toBe(128);