diff --git a/comfy/cmd/cuda_malloc.py b/comfy/cmd/cuda_malloc.py index f41647785..2e3a4be9a 100644 --- a/comfy/cmd/cuda_malloc.py +++ b/comfy/cmd/cuda_malloc.py @@ -50,7 +50,7 @@ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeFor "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000", "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", - "GeForce GTX 1650", "GeForce GTX 1630" + "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60" } def cuda_malloc_supported(): diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 4cbe59e1a..e78e7dda1 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -409,7 +409,7 @@ class PromptExecutor: d = self.outputs_ui.pop(x) del d - model_management.cleanup_models() + model_management.cleanup_models(keep_clone_weights_loaded=True) self.add_message("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id}, broadcast=False) diff --git a/comfy/lora.py b/comfy/lora.py index 49c71d321..2580ff5fc 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -21,6 +21,12 @@ def load_lora(lora, to_load): alpha = lora[alpha_name].item() loaded_keys.add(alpha_name) + dora_scale_name = "{}.dora_scale".format(x) + dora_scale = None + if dora_scale_name in lora.keys(): + dora_scale = lora[dora_scale_name] + loaded_keys.add(dora_scale_name) + regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) @@ -44,7 +50,7 @@ def load_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -65,7 +71,7 @@ def load_lora(lora, to_load): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -117,7 +123,7 @@ def load_lora(lora, to_load): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) #glora a1_name = "{}.a1.weight".format(x) @@ -125,7 +131,7 @@ def load_lora(lora, to_load): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 1861fadb6..65fd41abd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -345,7 +345,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B] + SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1], + 'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, + 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1], + 'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS] for unet_config in supported_models: matches = True diff --git a/comfy/model_management.py b/comfy/model_management.py index 8954973cb..94849a7b7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import Literal + import psutil import logging from enum import Enum @@ -278,6 +282,7 @@ class LoadedModel: self.model = model self.device = model.load_device self.weights_loaded = False + self.real_model = None def model_memory(self): return self.model.model_size() @@ -316,6 +321,7 @@ class LoadedModel: self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) self.weights_loaded = self.weights_loaded and not unpatch_weights + self.real_model = None def __eq__(self, other): return self.model is other.model @@ -323,7 +329,7 @@ class LoadedModel: def minimum_inference_memory(): return (1024 * 1024 * 1024) -def unload_model_clones(model, unload_weights_only=True, force_unload=True): +def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]: with model_management_lock: to_unload = [] for i in range(len(current_loaded_models)): @@ -331,7 +337,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True): to_unload = [i] + to_unload if len(to_unload) == 0: - return None + return True same_weights = 0 for i in to_unload: @@ -355,20 +361,27 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True): def free_memory(memory_required, device, keep_loaded=[]): with model_management_lock: - unloaded_model = False + unloaded_model = [] + can_unload = [] + for i in range(len(current_loaded_models) -1, -1, -1): - if not DISABLE_SMART_MEMORY: - if get_free_memory(device) > memory_required: - break shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded: - m = current_loaded_models.pop(i) - m.model_unload() - del m - unloaded_model = True + can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) - if unloaded_model: + for x in sorted(can_unload): + i = x[-1] + if not DISABLE_SMART_MEMORY: + if get_free_memory(device) > memory_required: + break + current_loaded_models[i].model_unload() + unloaded_model.append(i) + + for i in sorted(unloaded_model, reverse=True): + current_loaded_models.pop(i) + + if len(unloaded_model) > 0: soft_empty_cache() else: if vram_state != VRAMState.HIGH_VRAM: @@ -408,8 +421,8 @@ def load_models_gpu(models, memory_required=0): total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False):#unload clones where the weights are different + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: if device != torch.device("cpu"): @@ -449,12 +462,16 @@ def load_model_gpu(model): with model_management_lock: return load_models_gpu([model]) -def cleanup_models(): +def cleanup_models(keep_clone_weights_loaded=False): with model_management_lock: to_delete = [] for i in range(len(current_loaded_models)): if sys.getrefcount(current_loaded_models[i].model) <= 2: - to_delete = [i] + to_delete + if not keep_clone_weights_loaded: + to_delete = [i] + to_delete + #TODO: find a less fragile way to do this. + elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model + to_delete = [i] + to_delete for i in to_delete: x = current_loaded_models.pop(i) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6d679aa85..bc51743af 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -7,6 +7,18 @@ import uuid from . import utils from . import model_management +def apply_weight_decompose(dora_scale, weight): + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * (weight.dim() - 1)) + .transpose(0, 1) + ) + + return weight * (dora_scale / weight_norm) + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): self.size = size @@ -309,6 +321,7 @@ class ModelPatcher: elif patch_type == "lora": #lora/locon mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32) + dora_scale = v[4] if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: @@ -318,6 +331,8 @@ class ModelPatcher: mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "lokr": @@ -328,6 +343,7 @@ class ModelPatcher: w2_a = v[5] w2_b = v[6] t2 = v[7] + dora_scale = v[8] dim = None if w1 is None: @@ -357,6 +373,8 @@ class ModelPatcher: try: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "loha": @@ -366,6 +384,7 @@ class ModelPatcher: alpha *= v[2] / w1b.shape[0] w2a = v[3] w2b = v[4] + dora_scale = v[7] if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] @@ -386,12 +405,16 @@ class ModelPatcher: try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "glora": if v[4] is not None: alpha *= v[4] / v[0].shape[0] + dora_scale = v[5] + a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) @@ -399,6 +422,8 @@ class ModelPatcher: try: weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: diff --git a/comfy/nodes/common.py b/comfy/nodes/common.py index 84474a17d..1ace3b961 100644 --- a/comfy/nodes/common.py +++ b/comfy/nodes/common.py @@ -1 +1 @@ -MAX_RESOLUTION=8192 +MAX_RESOLUTION=16384 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2ce9736b7..5b2eb73fd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -70,8 +70,8 @@ class SD20(supported_models_base.BASE): def model_type(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) - out = state_dict[k] - if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. + out = state_dict.get(k, None) + if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. return model_base.ModelType.V_PREDICTION return model_base.ModelType.EPS diff --git a/comfy_extras/nodes/nodes_post_processing.py b/comfy_extras/nodes/nodes_post_processing.py index 7012c2702..0760ceb46 100644 --- a/comfy_extras/nodes/nodes_post_processing.py +++ b/comfy_extras/nodes/nodes_post_processing.py @@ -204,13 +204,13 @@ class Sharpen: "default": 1.0, "min": 0.1, "max": 10.0, - "step": 0.1 + "step": 0.01 }), "alpha": ("FLOAT", { "default": 1.0, "min": 0.0, "max": 5.0, - "step": 0.1 + "step": 0.01 }), }, } diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index 1afde42c6..e114e5f93 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -947,7 +947,7 @@ describe("group node", () => { expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min - expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max + expect(p2.widgets.value.widget.options?.max).toBe(16384); // width/height max expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p1.widgets.value.value).toBe(128);