diff --git a/README.md b/README.md index a94a212ad..ba1e844b3 100644 --- a/README.md +++ b/README.md @@ -142,7 +142,7 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve 1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly). 1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. 1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). -1. Launch ComfyUI by running `python main.py --force-fp16`. Note that --force-fp16 will only work if you installed the latest pytorch nightly. +1. Launch ComfyUI by running `python main.py` > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). diff --git a/comfy/controlnet.py b/comfy/controlnet.py index b6941d8c4..8cf4a61a6 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -138,11 +138,13 @@ class ControlBase: return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): + def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model self.load_device = load_device - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + if control_model is not None: + self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.global_average_pooling = global_average_pooling self.model_sampling_current = None self.manual_cast_dtype = manual_cast_dtype @@ -183,7 +185,9 @@ class ControlNet(ControlBase): return self.control_merge(None, control, control_prev, output_dtype) def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped self.copy_to(c) return c diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index 08018c54d..ed2a45fea 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys())) # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp code2idx = {"q": 0, "k": 1, "v": 2} +# This function exists because at the time of writing torch.cat can't do fp8 with cuda +def cat_tensors(tensors): + x = 0 + for t in tensors: + x += t.shape[0] + + shape = [x] + list(tensors[0].shape)[1:] + out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype) + + x = 0 + for t in tensors: + out[x:x + t.shape[0]] = t + x += t.shape[0] + + return out def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): new_state_dict = {} @@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""): if None in tensors: raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) - new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors) + new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors) for k_pre, tensors in capture_qkv_bias.items(): if None in tensors: raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing") relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre) - new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors) + new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors) return new_state_dict diff --git a/comfy/lora.py b/comfy/lora.py index 637380d54..096285bba 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -21,6 +21,12 @@ def load_lora(lora, to_load): alpha = lora[alpha_name].item() loaded_keys.add(alpha_name) + dora_scale_name = "{}.dora_scale".format(x) + dora_scale = None + if dora_scale_name in lora.keys(): + dora_scale = lora[dora_scale_name] + loaded_keys.add(dora_scale_name) + regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) @@ -44,7 +50,7 @@ def load_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -65,7 +71,7 @@ def load_lora(lora, to_load): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -117,7 +123,7 @@ def load_lora(lora, to_load): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) #glora a1_name = "{}.a1.weight".format(x) @@ -125,7 +131,7 @@ def load_lora(lora, to_load): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) diff --git a/comfy/model_base.py b/comfy/model_base.py index bc019de53..8c89adf5e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -66,7 +66,8 @@ class BaseModel(torch.nn.Module): self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 - self.inpaint_model = False + + self.concat_keys = () logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) @@ -107,8 +108,7 @@ class BaseModel(torch.nn.Module): def extra_conds(self, **kwargs): out = {} - if self.inpaint_model: - concat_keys = ("mask", "masked_image") + if len(self.concat_keys) > 0: cond_concat = [] denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) concat_latent_image = kwargs.get("concat_latent_image", None) @@ -125,24 +125,16 @@ class BaseModel(torch.nn.Module): concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) - if len(denoise_mask.shape) == len(noise.shape): - denoise_mask = denoise_mask[:,:1] + if denoise_mask is not None: + if len(denoise_mask.shape) == len(noise.shape): + denoise_mask = denoise_mask[:,:1] - denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) - if denoise_mask.shape[-2:] != noise.shape[-2:]: - denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") - denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) + denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + if denoise_mask.shape[-2:] != noise.shape[-2:]: + denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") + denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) - def blank_inpaint_image_like(latent_image): - blank_image = torch.ones_like(latent_image) - # these are the values for "zero" in pixel space translated to latent space - blank_image[:,0] *= 0.8223 - blank_image[:,1] *= -0.6876 - blank_image[:,2] *= 0.6364 - blank_image[:,3] *= 0.1380 - return blank_image - - for ck in concat_keys: + for ck in self.concat_keys: if denoise_mask is not None: if ck == "mask": cond_concat.append(denoise_mask.to(device)) @@ -152,7 +144,7 @@ class BaseModel(torch.nn.Module): if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) elif ck == "masked_image": - cond_concat.append(blank_inpaint_image_like(noise)) + cond_concat.append(self.blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = comfy.conds.CONDNoiseShape(data) @@ -221,7 +213,16 @@ class BaseModel(torch.nn.Module): return unet_state_dict def set_inpaint(self): - self.inpaint_model = True + self.concat_keys = ("mask", "masked_image") + def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + self.blank_inpaint_image_like = blank_inpaint_image_like def memory_required(self, input_shape): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): @@ -472,6 +473,42 @@ class SD_X4Upscaler(BaseModel): out['y'] = comfy.conds.CONDRegular(noise_level) return out +class IP2P: + def extra_conds(self, **kwargs): + out = {} + + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if image is None: + image = torch.zeros_like(noise) + + if image.shape[1:] != noise.shape[1:]: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + image = utils.resize_to_batch_size(image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image)) + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = comfy.conds.CONDRegular(adm) + return out + +class SD15_instructpix2pix(IP2P, BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) + self.process_ip2p_image_in = lambda image: image + +class SDXL_instructpix2pix(IP2P, SDXL): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device) + if model_type == ModelType.V_PREDICTION_EDM: + self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) #cosxl ip2p + else: + self.process_ip2p_image_in = lambda image: image #diffusers ip2p + + class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageC) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b7c3be309..23358a2c0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix): return unet_config -def model_config_from_unet_config(unet_config): +def model_config_from_unet_config(unet_config, state_dict=None): for model_config in comfy.supported_models.models: - if model_config.matches(unet_config): + if model_config.matches(unet_config, state_dict): return model_config(unet_config) logging.error("no match {}".format(unet_config)) @@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): unet_config = detect_unet_config(state_dict, unet_key_prefix) - model_config = model_config_from_unet_config(unet_config) + model_config = model_config_from_unet_config(unet_config, state_dict) if model_config is None and use_base_if_no_match: return comfy.supported_models_base.BASE(unet_config) else: @@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], 'use_temporal_attention': False, 'use_temporal_resblock': False} + SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} + SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], @@ -345,7 +351,20 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B] + SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1], + 'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, + 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1], + 'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]} + + SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1], + 'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False, + 'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1], + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p] for unet_config in supported_models: matches = True diff --git a/comfy/model_management.py b/comfy/model_management.py index 11c97f290..537df41ed 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -274,6 +274,7 @@ class LoadedModel: self.model = model self.device = model.load_device self.weights_loaded = False + self.real_model = None def model_memory(self): return self.model.model_size() @@ -312,6 +313,7 @@ class LoadedModel: self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.model_patches_to(self.model.offload_device) self.weights_loaded = self.weights_loaded and not unpatch_weights + self.real_model = None def __eq__(self, other): return self.model is other.model @@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True): to_unload = [i] + to_unload if len(to_unload) == 0: - return None + return True same_weights = 0 for i in to_unload: @@ -349,20 +351,27 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True): return unload_weight def free_memory(memory_required, device, keep_loaded=[]): - unloaded_model = False + unloaded_model = [] + can_unload = [] + for i in range(len(current_loaded_models) -1, -1, -1): - if not DISABLE_SMART_MEMORY: - if get_free_memory(device) > memory_required: - break shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded: - m = current_loaded_models.pop(i) - m.model_unload() - del m - unloaded_model = True + can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) - if unloaded_model: + for x in sorted(can_unload): + i = x[-1] + if not DISABLE_SMART_MEMORY: + if get_free_memory(device) > memory_required: + break + current_loaded_models[i].model_unload() + unloaded_model.append(i) + + for i in sorted(unloaded_model, reverse=True): + current_loaded_models.pop(i) + + if len(unloaded_model) > 0: soft_empty_cache() else: if vram_state != VRAMState.HIGH_VRAM: @@ -376,6 +385,8 @@ def load_models_gpu(models, memory_required=0): inference_memory = minimum_inference_memory() extra_mem = max(inference_memory, memory_required) + models = set(models) + models_to_load = [] models_already_loaded = [] for x in models: @@ -401,8 +412,8 @@ def load_models_gpu(models, memory_required=0): total_memory_required = {} for loaded_model in models_to_load: - unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) for device in total_memory_required: if device != torch.device("cpu"): @@ -441,11 +452,15 @@ def load_models_gpu(models, memory_required=0): def load_model_gpu(model): return load_models_gpu([model]) -def cleanup_models(): +def cleanup_models(keep_clone_weights_loaded=False): to_delete = [] for i in range(len(current_loaded_models)): if sys.getrefcount(current_loaded_models[i].model) <= 2: - to_delete = [i] + to_delete + if not keep_clone_weights_loaded: + to_delete = [i] + to_delete + #TODO: find a less fragile way to do this. + elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model + to_delete = [i] + to_delete for i in to_delete: x = current_loaded_models.pop(i) @@ -602,7 +617,8 @@ def supports_dtype(device, dtype): #TODO def device_supports_non_blocking(device): if is_device_mps(device): return False #pytorch bug? mps doesn't support non blocking - return True + return False + # return True #TODO: figure out why this causes issues def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aa78302d2..cf51c4ad8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -7,6 +7,38 @@ import uuid import comfy.utils import comfy.model_management +def apply_weight_decompose(dora_scale, weight): + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * (weight.dim() - 1)) + .transpose(0, 1) + ) + + return weight * (dora_scale / weight_norm) + +def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): + to = model_options["transformer_options"].copy() + + if "patches_replace" not in to: + to["patches_replace"] = {} + else: + to["patches_replace"] = to["patches_replace"].copy() + + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + else: + to["patches_replace"][name] = to["patches_replace"][name].copy() + + if transformer_index is not None: + block = (block_name, number, transformer_index) + else: + block = (block_name, number) + to["patches_replace"][name][block] = patch + model_options["transformer_options"] = to + return model_options + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): self.size = size @@ -97,16 +129,7 @@ class ModelPatcher: to["patches"][name] = to["patches"].get(name, []) + [patch] def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): - to = self.model_options["transformer_options"] - if "patches_replace" not in to: - to["patches_replace"] = {} - if name not in to["patches_replace"]: - to["patches_replace"][name] = {} - if transformer_index is not None: - block = (block_name, number, transformer_index) - else: - block = (block_name, number) - to["patches_replace"][name][block] = patch + self.model_options = set_model_options_patch_replace(self.model_options, patch, name, block_name, number, transformer_index=transformer_index) def set_model_attn1_patch(self, patch): self.set_model_patch(patch, "attn1_patch") @@ -138,6 +161,15 @@ class ModelPatcher: def add_object_patch(self, name, obj): self.object_patches[name] = obj + def get_model_object(self, name): + if name in self.object_patches: + return self.object_patches[name] + else: + if name in self.object_patches_backup: + return self.object_patches_backup[name] + else: + return comfy.utils.get_attr(self.model, name) + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: @@ -266,7 +298,7 @@ class ModelPatcher: if weight_key in self.patches: m.weight_function = LowVramPatch(weight_key, self) if bias_key in self.patches: - m.bias_function = LowVramPatch(weight_key, self) + m.bias_function = LowVramPatch(bias_key, self) m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True @@ -309,6 +341,7 @@ class ModelPatcher: elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) + dora_scale = v[4] if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: @@ -318,6 +351,8 @@ class ModelPatcher: mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "lokr": @@ -328,6 +363,7 @@ class ModelPatcher: w2_a = v[5] w2_b = v[6] t2 = v[7] + dora_scale = v[8] dim = None if w1 is None: @@ -357,6 +393,8 @@ class ModelPatcher: try: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "loha": @@ -366,6 +404,7 @@ class ModelPatcher: alpha *= v[2] / w1b.shape[0] w2a = v[3] w2b = v[4] + dora_scale = v[7] if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] @@ -386,12 +425,16 @@ class ModelPatcher: try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "glora": if v[4] is not None: alpha *= v[4] / v[0].shape[0] + dora_scale = v[5] + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) @@ -399,6 +442,8 @@ class ModelPatcher: try: weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: @@ -437,4 +482,4 @@ class ModelPatcher: for k in keys: comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) - self.object_patches_backup = {} + self.object_patches_backup.clear() diff --git a/comfy/sample.py b/comfy/sample.py index 5c8a7d130..e51bd67d6 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,10 +1,9 @@ import torch import comfy.model_management import comfy.samplers -import comfy.conds import comfy.utils -import math import numpy as np +import logging def prepare_noise(latent_image, seed, noise_inds=None): """ @@ -25,94 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None): noises = torch.cat(noises, axis=0) return noises -def prepare_mask(noise_mask, shape, device): - """ensures noise mask is of proper dimensions""" - noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") - noise_mask = torch.cat([noise_mask] * shape[1], dim=1) - noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) - noise_mask = noise_mask.to(device) - return noise_mask - -def get_models_from_cond(cond, model_type): - models = [] - for c in cond: - if model_type in c: - models += [c[model_type]] - return models - -def convert_cond(cond): - out = [] - for c in cond: - temp = c[1].copy() - model_conds = temp.get("model_conds", {}) - if c[0] is not None: - model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove - temp["cross_attn"] = c[0] - temp["model_conds"] = model_conds - out.append(temp) - return out - -def get_additional_models(positive, negative, dtype): - """loads additional models in positive and negative conditioning""" - control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) - - inference_memory = 0 - control_models = [] - for m in control_nets: - control_models += m.get_models() - inference_memory += m.inference_memory_requirements(dtype) - - gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") - gligen = [x[1] for x in gligen] - models = control_models + gligen - return models, inference_memory +def prepare_sampling(model, noise_shape, positive, negative, noise_mask): + logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed") + return model, positive, negative, noise_mask, [] def cleanup_additional_models(models): - """cleanup additional models that were loaded""" - for m in models: - if hasattr(m, 'cleanup'): - m.cleanup() - -def prepare_sampling(model, noise_shape, positive, negative, noise_mask): - device = model.load_device - positive = convert_cond(positive) - negative = convert_cond(negative) - - if noise_mask is not None: - noise_mask = prepare_mask(noise_mask, noise_shape, device) - - real_model = None - models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) - real_model = model.model - - return real_model, positive, negative, noise_mask, models - + logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed") def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): - real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) + sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - noise = noise.to(model.load_device) - latent_image = latent_image.to(model.load_device) - - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(comfy.model_management.intermediate_device()) - - cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): - real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) - noise = noise.to(model.load_device) - latent_image = latent_image.to(model.load_device) - sigmas = sigmas.to(model.load_device) - - samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(comfy.model_management.intermediate_device()) - cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples - diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py new file mode 100644 index 000000000..a18abd9e9 --- /dev/null +++ b/comfy/sampler_helpers.py @@ -0,0 +1,76 @@ +import torch +import comfy.model_management +import comfy.conds + +def prepare_mask(noise_mask, shape, device): + """ensures noise mask is of proper dimensions""" + noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") + noise_mask = torch.cat([noise_mask] * shape[1], dim=1) + noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) + noise_mask = noise_mask.to(device) + return noise_mask + +def get_models_from_cond(cond, model_type): + models = [] + for c in cond: + if model_type in c: + models += [c[model_type]] + return models + +def convert_cond(cond): + out = [] + for c in cond: + temp = c[1].copy() + model_conds = temp.get("model_conds", {}) + if c[0] is not None: + model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove + temp["cross_attn"] = c[0] + temp["model_conds"] = model_conds + out.append(temp) + return out + +def get_additional_models(conds, dtype): + """loads additional models in conditioning""" + cnets = [] + gligen = [] + + for k in conds: + cnets += get_models_from_cond(conds[k], "control") + gligen += get_models_from_cond(conds[k], "gligen") + + control_nets = set(cnets) + + inference_memory = 0 + control_models = [] + for m in control_nets: + control_models += m.get_models() + inference_memory += m.inference_memory_requirements(dtype) + + gligen = [x[1] for x in gligen] + models = control_models + gligen + return models, inference_memory + +def cleanup_additional_models(models): + """cleanup additional models that were loaded""" + for m in models: + if hasattr(m, 'cleanup'): + m.cleanup() + + +def prepare_sampling(model, noise_shape, conds): + device = model.load_device + real_model = None + models, inference_memory = get_additional_models(conds, model.model_dtype()) + comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) + real_model = model.model + + return real_model, conds, models + +def cleanup_models(conds, models): + cleanup_additional_models(models) + + control_cleanup = [] + for k in conds: + control_cleanup += get_models_from_cond(conds[k], "control") + + cleanup_additional_models(set(control_cleanup)) diff --git a/comfy/samplers.py b/comfy/samplers.py index 3678dc818..415a35cc3 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -5,6 +5,7 @@ import collections from comfy import model_management import math import logging +import comfy.sampler_helpers def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) @@ -127,30 +128,23 @@ def cond_cat(c_list): return out -def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - +def calc_cond_batch(model, conds, x_in, timestep, model_options): + out_conds = [] + out_counts = [] to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) - to_run += [(p, UNCOND)] + cond = conds[i] + if cond is not None: + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, i)] while len(to_run) > 0: first = to_run[0] @@ -222,74 +216,66 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult + cond_index = cond_or_uncond[o] + out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + +def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove + logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") + return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) + +def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options={}, cond=None, uncond=None): + if "sampler_cfg_function" in model_options: + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} + cfg_result = x - model_options["sampler_cfg_function"](args) + else: + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) + + return cfg_result #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: - uncond_ = None - else: - uncond_ = uncond + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: + uncond_ = None + else: + uncond_ = uncond - cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) - if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} - cfg_result = x - model_options["sampler_cfg_function"](args) - else: - cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + conds = [cond, uncond_] + out = calc_cond_batch(model, conds, x, timestep, model_options) + return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_) - for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, - "sigma": timestep, "model_options": model_options, "input": x} - cfg_result = fn(args) - return cfg_result - -class CFGNoisePredictor(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): - out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) - return out - def forward(self, *args, **kwargs): - return self.apply_model(*args, **kwargs) - -class KSamplerX0Inpaint(torch.nn.Module): +class KSamplerX0Inpaint: def __init__(self, model, sigmas): - super().__init__() self.inner_model = model self.sigmas = sigmas - def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): + def __call__(self, x, sigma, denoise_mask, model_options={}, seed=None): if denoise_mask is not None: if "denoise_mask_function" in model_options: denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas}) latent_mask = 1. - denoise_mask x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask - out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) + out = self.inner_model(x, sigma, model_options=model_options, seed=seed) if denoise_mask is not None: out = out * denoise_mask + self.latent_image * latent_mask return out -def simple_scheduler(model, steps): - s = model.model_sampling +def simple_scheduler(model_sampling, steps): + s = model_sampling sigs = [] ss = len(s.sigmas) / steps for x in range(steps): @@ -297,8 +283,8 @@ def simple_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def ddim_scheduler(model, steps): - s = model.model_sampling +def ddim_scheduler(model_sampling, steps): + s = model_sampling sigs = [] ss = max(len(s.sigmas) // steps, 1) x = 1 @@ -309,8 +295,8 @@ def ddim_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) -def normal_scheduler(model, steps, sgm=False, floor=False): - s = model.model_sampling +def normal_scheduler(model_sampling, steps, sgm=False, floor=False): + s = model_sampling start = s.timestep(s.sigma_max) end = s.timestep(s.sigma_min) @@ -571,61 +557,122 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return KSAMPLER(sampler_function, extra_options, inpaint_options) -def wrap_model(model): - model_denoise = CFGNoisePredictor(model) - return model_denoise -def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - positive = positive[:] - negative = negative[:] +def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): + for k in conds: + conds[k] = conds[k][:] + resolve_areas_and_cond_masks(conds[k], noise.shape[2], noise.shape[3], device) - resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) - resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) - - model_wrap = wrap_model(model) - - calculate_start_end_timesteps(model, negative) - calculate_start_end_timesteps(model, positive) - - if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. - latent_image = model.process_latent_in(latent_image) + for k in conds: + calculate_start_end_timesteps(model, conds[k]) if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + for k in conds: + conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) #make sure each cond area has an opposite one with the same area - for c in positive: - create_cond_with_same_area_if_none(negative, c) - for c in negative: - create_cond_with_same_area_if_none(positive, c) + for k in conds: + for c in conds[k]: + for kk in conds: + if k != kk: + create_cond_with_same_area_if_none(conds[kk], c) - pre_run_control(model, negative + positive) + for k in conds: + pre_run_control(model, conds[k]) - apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) - apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) + if "positive" in conds: + positive = conds["positive"] + for k in conds: + if k != "positive": + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), conds[k], 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, conds[k], 'gligen', lambda cond_cnets, x: cond_cnets[x]) - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} + return conds + +class CFGGuider: + def __init__(self, model_patcher): + self.model_patcher = model_patcher + self.model_options = model_patcher.model_options + self.original_conds = {} + self.cfg = 1.0 + + def set_conds(self, positive, negative): + self.inner_set_conds({"positive": positive, "negative": negative}) + + def set_cfg(self, cfg): + self.cfg = cfg + + def inner_set_conds(self, conds): + for k in conds: + self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) + + def __call__(self, *args, **kwargs): + return self.predict_noise(*args, **kwargs) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) + + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): + if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = self.inner_model.process_latent_in(latent_image) + + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + + extra_args = {"model_options": self.model_options, "seed":seed} + + samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) + return self.inner_model.process_latent_out(samples.to(torch.float32)) + + def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + if sigmas.shape[-1] == 0: + return latent_image + + self.conds = {} + for k in self.original_conds: + self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) + + self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + device = self.model_patcher.load_device + + if denoise_mask is not None: + denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) + + noise = noise.to(device) + latent_image = latent_image.to(device) + sigmas = sigmas.to(device) + + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + + comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) + del self.inner_model + del self.conds + del self.loaded_models + return output + + +def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + cfg_guider = CFGGuider(model) + cfg_guider.set_conds(positive, negative) + cfg_guider.set_cfg(cfg) + return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) - samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) - return model.process_latent_out(samples.to(torch.float32)) SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] -def calculate_sigmas_scheduler(model, scheduler_name, steps): +def calculate_sigmas(model_sampling, scheduler_name, steps): if scheduler_name == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max)) elif scheduler_name == "normal": - sigmas = normal_scheduler(model, steps) + sigmas = normal_scheduler(model_sampling, steps) elif scheduler_name == "simple": - sigmas = simple_scheduler(model, steps) + sigmas = simple_scheduler(model_sampling, steps) elif scheduler_name == "ddim_uniform": - sigmas = ddim_scheduler(model, steps) + sigmas = ddim_scheduler(model_sampling, steps) elif scheduler_name == "sgm_uniform": - sigmas = normal_scheduler(model, steps, sgm=True) + sigmas = normal_scheduler(model_sampling, steps, sgm=True) else: logging.error("error invalid scheduler {}".format(scheduler_name)) return sigmas @@ -667,7 +714,7 @@ class KSampler: steps += 1 discard_penultimate_sigma = True - sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) + sigmas = calculate_sigmas(self.model.get_model_object("model_sampling"), self.scheduler, steps) if discard_penultimate_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) @@ -678,9 +725,12 @@ class KSampler: if denoise is None or denoise > 0.9999: self.sigmas = self.calculate_sigmas(steps).to(self.device) else: - new_steps = int(steps/denoise) - sigmas = self.calculate_sigmas(new_steps).to(self.device) - self.sigmas = sigmas[-(steps + 1):] + if denoise <= 0.0: + self.sigmas = torch.FloatTensor([]) + else: + new_steps = int(steps/denoise) + sigmas = self.calculate_sigmas(new_steps).to(self.device) + self.sigmas = sigmas[-(steps + 1):] def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): if sigmas is None: diff --git a/comfy/sd.py b/comfy/sd.py index 85821120e..57dba0b44 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -214,12 +214,18 @@ class VAE: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + if 'encoder.down.2.downsample.conv.weight' not in sd and 'decoder.up.3.upsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE ddconfig['ch_mult'] = [1, 2, 4] self.downscale_ratio = 4 self.upscale_ratio = 4 - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'quant_conv.weight' in sd: + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) + else: + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) else: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() @@ -600,7 +606,7 @@ def load_unet(unet_path): raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model -def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): +def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): clip_sd = None load_models = [model] if clip is not None: @@ -610,4 +616,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) + for k in extra_keys: + sd[k] = extra_keys[k] + comfy.utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2ce9736b7..b3b69e05b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -70,8 +70,8 @@ class SD20(supported_models_base.BASE): def model_type(self, state_dict, prefix=""): if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix) - out = state_dict[k] - if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. + out = state_dict.get(k, None) + if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. return model_base.ModelType.V_PREDICTION return model_base.ModelType.EPS @@ -174,6 +174,11 @@ class SDXL(supported_models_base.BASE): self.sampling_settings["sigma_max"] = 80.0 self.sampling_settings["sigma_min"] = 0.002 return model_base.ModelType.EDM + elif "edm_vpred.sigma_max" in state_dict: + self.sampling_settings["sigma_max"] = float(state_dict["edm_vpred.sigma_max"].item()) + if "edm_vpred.sigma_min" in state_dict: + self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item()) + return model_base.ModelType.V_PREDICTION_EDM elif "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION else: @@ -334,6 +339,11 @@ class Stable_Zero123(supported_models_base.BASE): "num_head_channels": -1, } + required_keys = { + "cc_projection.weight": None, + "cc_projection.bias": None, + } + clip_vision_prefix = "cond_stage_model.model.visual." latent_format = latent_formats.SD15 @@ -439,6 +449,33 @@ class Stable_Cascade_B(Stable_Cascade_C): out = model_base.StableCascade_B(self, device=device) return out +class SD15_instructpix2pix(SD15): + unet_config = { + "context_dim": 768, + "model_channels": 320, + "use_linear_in_transformer": False, + "adm_in_channels": None, + "use_temporal_attention": False, + "in_channels": 8, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SD15_instructpix2pix(self, device=device) + +class SDXL_instructpix2pix(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 2, 2, 10, 10], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + "in_channels": 8, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 4d7e25936..cf7cdff34 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -16,6 +16,8 @@ class BASE: "num_head_channels": 64, } + required_keys = {} + clip_prefix = [] clip_vision_prefix = None noise_aug_config = None @@ -28,10 +30,14 @@ class BASE: manual_cast_dtype = None @classmethod - def matches(s, unet_config): + def matches(s, unet_config, state_dict=None): for k in s.unet_config: if k not in unet_config or s.unet_config[k] != unet_config[k]: return False + if state_dict is not None: + for k in s.required_keys: + if k not in state_dict: + return False return True def model_type(self, state_dict, prefix=""): @@ -41,7 +47,8 @@ class BASE: return self.unet_config["in_channels"] > 4 def __init__(self, unet_config): - self.unet_config = unet_config + self.unet_config = unet_config.copy() + self.sampling_settings = self.sampling_settings.copy() self.latent_format = self.latent_format() for x in self.unet_extra_config: self.unet_config[x] = self.unet_extra_config[x] diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py new file mode 100644 index 000000000..b59f6945b --- /dev/null +++ b/comfy_extras/nodes_align_your_steps.py @@ -0,0 +1,45 @@ +#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html +import numpy as np +import torch + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + +NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582], + "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], + "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]} + +class AlignYourStepsScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model_type": (["SD1", "SDXL", "SVD"], ), + "steps": ("INT", {"default": 10, "min": 10, "max": 10000}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model_type, steps): + sigmas = NOISE_LEVELS[model_type][:] + if (steps + 1) != len(sigmas): + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas[-1] = 0 + return (torch.FloatTensor(sigmas), ) + +NODE_CLASS_MAPPINGS = { + "AlignYourStepsScheduler": AlignYourStepsScheduler, +} diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 8138b5f73..fab2ab7ac 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -1,4 +1,3 @@ -#From https://github.com/kornia/kornia import math import torch diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index dcf8859fa..3087b917b 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -8,7 +8,7 @@ class CLIPTextEncodeSDXLRefiner: "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text": ("STRING", {"multiline": True}), "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" @@ -30,8 +30,8 @@ class CLIPTextEncodeSDXL: "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), - "text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ), - "text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ), + "text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), + "text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 646fefa17..4c3a1d5bf 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -3,7 +3,7 @@ class CLIPTextEncodeControlnet: @classmethod def INPUT_TYPES(s): - return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}} + return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 72ff7957f..06238f892 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -4,6 +4,7 @@ from comfy.k_diffusion import sampling as k_diffusion_sampling import latent_preview import torch import comfy.utils +import node_helpers class BasicScheduler: @@ -24,10 +25,11 @@ class BasicScheduler: def get_sigmas(self, model, scheduler, steps, denoise): total_steps = steps if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) total_steps = int(steps/denoise) - comfy.model_management.load_models_gpu([model]) - sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() + sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] return (sigmas, ) @@ -160,6 +162,9 @@ class FlipSigmas: FUNCTION = "get_sigmas" def get_sigmas(self, sigmas): + if len(sigmas) == 0: + return (sigmas,) + sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 @@ -310,6 +315,24 @@ class SamplerDPMAdaptative: "s_noise":s_noise }) return (sampler, ) +class Noise_EmptyNoise: + def __init__(self): + self.seed = 0 + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + + +class Noise_RandomNoise: + def __init__(self, seed): + self.seed = seed + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None + return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -337,10 +360,9 @@ class SamplerCustom: latent = latent_image latent_image = latent["samples"] if not add_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + noise = Noise_EmptyNoise().generate_noise(latent) else: - batch_inds = latent["batch_index"] if "batch_index" in latent else None - noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds) + noise = Noise_RandomNoise(noise_seed).generate_noise(latent) noise_mask = None if "noise_mask" in latent: @@ -361,6 +383,207 @@ class SamplerCustom: out_denoised = out return (out, out_denoised) +class Guider_Basic(comfy.samplers.CFGGuider): + def set_conds(self, positive): + self.inner_set_conds({"positive": positive}) + +class BasicGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "conditioning": ("CONDITIONING", ), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, conditioning): + guider = Guider_Basic(model) + guider.set_conds(conditioning) + return (guider,) + +class CFGGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, positive, negative, cfg): + guider = comfy.samplers.CFGGuider(model) + guider.set_conds(positive, negative) + guider.set_cfg(cfg) + return (guider,) + +class Guider_DualCFG(comfy.samplers.CFGGuider): + def set_cfg(self, cfg1, cfg2): + self.cfg1 = cfg1 + self.cfg2 = cfg2 + + def set_conds(self, positive, middle, negative): + middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative}) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + negative_cond = self.conds.get("negative", None) + middle_cond = self.conds.get("middle", None) + + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options) + return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 + +class DualCFGGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "cond1": ("CONDITIONING", ), + "cond2": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative): + guider = Guider_DualCFG(model) + guider.set_conds(cond1, cond2, negative) + guider.set_cfg(cfg_conds, cfg_cond2_negative) + return (guider,) + +class DisableNoise: + @classmethod + def INPUT_TYPES(s): + return {"required":{ + } + } + + RETURN_TYPES = ("NOISE",) + FUNCTION = "get_noise" + CATEGORY = "sampling/custom_sampling/noise" + + def get_noise(self): + return (Noise_EmptyNoise(),) + + +class RandomNoise(DisableNoise): + @classmethod + def INPUT_TYPES(s): + return {"required":{ + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + } + } + + def get_noise(self, noise_seed): + return (Noise_RandomNoise(noise_seed),) + + +class SamplerCustomAdvanced: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"noise": ("NOISE", ), + "guider": ("GUIDER", ), + "sampler": ("SAMPLER", ), + "sigmas": ("SIGMAS", ), + "latent_image": ("LATENT", ), + } + } + + RETURN_TYPES = ("LATENT","LATENT") + RETURN_NAMES = ("output", "denoised_output") + + FUNCTION = "sample" + + CATEGORY = "sampling/custom_sampling" + + def sample(self, noise, guider, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) + samples = samples.to(comfy.model_management.intermediate_device()) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return (out, out_denoised) + +class AddNoise: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "noise": ("NOISE", ), + "sigmas": ("SIGMAS", ), + "latent_image": ("LATENT", ), + } + } + + RETURN_TYPES = ("LATENT",) + + FUNCTION = "add_noise" + + CATEGORY = "_for_testing/custom_sampling/noise" + + def add_noise(self, model, noise, sigmas, latent_image): + if len(sigmas) == 0: + return latent_image + + latent = latent_image + latent_image = latent["samples"] + + noisy = noise.generate_noise(latent) + + model_sampling = model.get_model_object("model_sampling") + process_latent_out = model.get_model_object("process_latent_out") + process_latent_in = model.get_model_object("process_latent_in") + + if len(sigmas) > 1: + scale = torch.abs(sigmas[0] - sigmas[-1]) + else: + scale = sigmas[0] + + if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = process_latent_in(latent_image) + noisy = model_sampling.noise_scaling(scale, noisy, latent_image) + noisy = process_latent_out(noisy) + noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0) + + out = latent.copy() + out["samples"] = noisy + return (out,) + + NODE_CLASS_MAPPINGS = { "SamplerCustom": SamplerCustom, "BasicScheduler": BasicScheduler, @@ -378,4 +601,12 @@ NODE_CLASS_MAPPINGS = { "SamplerDPMAdaptative": SamplerDPMAdaptative, "SplitSigmas": SplitSigmas, "FlipSigmas": FlipSigmas, + + "CFGGuider": CFGGuider, + "DualCFGGuider": DualCFGGuider, + "BasicGuider": BasicGuider, + "RandomNoise": RandomNoise, + "DisableNoise": DisableNoise, + "AddNoise": AddNoise, + "SamplerCustomAdvanced": SamplerCustomAdvanced, } diff --git a/comfy_extras/nodes_ip2p.py b/comfy_extras/nodes_ip2p.py new file mode 100644 index 000000000..c2e70a84c --- /dev/null +++ b/comfy_extras/nodes_ip2p.py @@ -0,0 +1,45 @@ +import torch + +class InstructPixToPixConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "pixels": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/instructpix2pix" + + def encode(self, positive, negative, pixels, vae): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 + + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + + concat_latent = vae.encode(pixels) + + out_latent = {} + out_latent["samples"] = torch.zeros_like(concat_latent) + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1], out_latent) + +NODE_CLASS_MAPPINGS = { + "InstructPixToPixConditioning": InstructPixToPixConditioning, +} diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index a25b73ca7..2a431f65d 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -2,7 +2,9 @@ import comfy.sd import comfy.utils import comfy.model_base import comfy.model_management +import comfy.model_sampling +import torch import folder_paths import json import os @@ -189,6 +191,13 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", # "v2-inpainting" + extra_keys = {} + model_sampling = model.get_model_object("model_sampling") + if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM): + if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION): + extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float() + extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float() + if model.model.model_type == comfy.model_base.ModelType.EPS: metadata["modelspec.predict_key"] = "epsilon" elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: @@ -203,7 +212,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) class CheckpointSave: def __init__(self): diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py new file mode 100644 index 000000000..f2d008d8b --- /dev/null +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -0,0 +1,60 @@ +import comfy_extras.nodes_model_merging + +class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["time_embed."] = argument + arg_dict["label_emb."] = argument + + for i in range(12): + arg_dict["input_blocks.{}.".format(i)] = argument + + for i in range(3): + arg_dict["middle_block.{}.".format(i)] = argument + + for i in range(12): + arg_dict["output_blocks.{}.".format(i)] = argument + + arg_dict["out."] = argument + + return {"required": arg_dict} + + +class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["time_embed."] = argument + arg_dict["label_emb."] = argument + + for i in range(9): + arg_dict["input_blocks.{}".format(i)] = argument + + for i in range(3): + arg_dict["middle_block.{}".format(i)] = argument + + for i in range(9): + arg_dict["output_blocks.{}".format(i)] = argument + + arg_dict["out."] = argument + + return {"required": arg_dict} + + +NODE_CLASS_MAPPINGS = { + "ModelMergeSD1": ModelMergeSD1, + "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks + "ModelMergeSDXL": ModelMergeSDXL, +} diff --git a/comfy_extras/nodes_pag.py b/comfy_extras/nodes_pag.py new file mode 100644 index 000000000..63f43fd62 --- /dev/null +++ b/comfy_extras/nodes_pag.py @@ -0,0 +1,56 @@ +#Modified/simplified version of the node from: https://github.com/pamparamm/sd-perturbed-attention +#If you want the one with more options see the above repo. + +#My modified one here is more basic but has less chances of breaking with ComfyUI updates. + +import comfy.model_patcher +import comfy.samplers + +class PerturbedAttentionGuidance: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, scale): + unet_block = "middle" + unet_block_id = 0 + m = model.clone() + + def perturbed_attention(q, k, v, extra_options, mask=None): + return v + + def post_cfg_function(args): + model = args["model"] + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"].copy() + x = args["input"] + + if scale == 0: + return cfg_result + + # Replace Self-attention with PAG + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, perturbed_attention, "attn1", unet_block, unet_block_id) + (pag,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + + return cfg_result + (cond_pred - pag) * scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return (m,) + +NODE_CLASS_MAPPINGS = { + "PerturbedAttentionGuidance": PerturbedAttentionGuidance, +} diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index dc73c5528..306cf9cd0 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -1,10 +1,20 @@ import torch import comfy.model_management -import comfy.sample +import comfy.sampler_helpers import comfy.samplers import comfy.utils +import node_helpers +def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale): + pos = noise_pred_pos - noise_pred_nocond + neg = noise_pred_neg - noise_pred_nocond + perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos + perp_neg = perp * neg_scale + cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) + return cfg_result + +#TODO: This node should be removed, it has been replaced with PerpNegGuider class PerpNeg: @classmethod def INPUT_TYPES(s): @@ -19,7 +29,7 @@ class PerpNeg: def patch(self, model, empty_conditioning, neg_scale): m = model.clone() - nocond = comfy.sample.convert_cond(empty_conditioning) + nocond = comfy.sampler_helpers.convert_cond(empty_conditioning) def cfg_function(args): model = args["model"] @@ -31,14 +41,9 @@ class PerpNeg: model_options = args["model_options"] nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") - (noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) + (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options) - pos = noise_pred_pos - noise_pred_nocond - neg = noise_pred_neg - noise_pred_nocond - perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos - perp_neg = perp * neg_scale - cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) - cfg_result = x - cfg_result + cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale) return cfg_result m.set_model_sampler_cfg_function(cfg_function) @@ -46,10 +51,52 @@ class PerpNeg: return (m, ) +class Guider_PerpNeg(comfy.samplers.CFGGuider): + def set_conds(self, positive, negative, empty_negative_prompt): + empty_negative_prompt = node_helpers.conditioning_set_values(empty_negative_prompt, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "empty_negative_prompt": empty_negative_prompt, "negative": negative}) + + def set_cfg(self, cfg, neg_scale): + self.cfg = cfg + self.neg_scale = neg_scale + + def predict_noise(self, x, timestep, model_options={}, seed=None): + positive_cond = self.conds.get("positive", None) + negative_cond = self.conds.get("negative", None) + empty_cond = self.conds.get("empty_negative_prompt", None) + + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, positive_cond, empty_cond], x, timestep, model_options) + return perp_neg(x, out[1], out[0], out[2], self.neg_scale, self.cfg) + +class PerpNegGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "empty_conditioning": ("CONDITIONING", ), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "_for_testing" + + def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale): + guider = Guider_PerpNeg(model) + guider.set_conds(positive, negative, empty_conditioning) + guider.set_cfg(cfg, neg_scale) + return (guider,) + NODE_CLASS_MAPPINGS = { "PerpNeg": PerpNeg, + "PerpNegGuider": PerpNegGuider, } NODE_DISPLAY_NAME_MAPPINGS = { - "PerpNeg": "Perp-Neg", + "PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)", } diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py index 90130142b..29d127d74 100644 --- a/comfy_extras/nodes_photomaker.py +++ b/comfy_extras/nodes_photomaker.py @@ -141,7 +141,7 @@ class PhotoMakerEncode: return {"required": { "photomaker": ("PHOTOMAKER",), "image": ("IMAGE",), "clip": ("CLIP", ), - "text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}), }} RETURN_TYPES = ("CONDITIONING",) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index cb5c7d228..68f6ef51e 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -5,6 +5,7 @@ from PIL import Image import math import comfy.utils +import comfy.model_management class Blend: @@ -102,6 +103,7 @@ class Blur: if blur_radius == 0: return (image,) + image = image.to(comfy.model_management.get_torch_device()) batch_size, height, width, channels = image.shape kernel_size = blur_radius * 2 + 1 @@ -112,7 +114,7 @@ class Blur: blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) - return (blurred,) + return (blurred.to(comfy.model_management.intermediate_device()),) class Quantize: def __init__(self): @@ -204,13 +206,13 @@ class Sharpen: "default": 1.0, "min": 0.1, "max": 10.0, - "step": 0.1 + "step": 0.01 }), "alpha": ("FLOAT", { "default": 1.0, "min": 0.0, "max": 5.0, - "step": 0.1 + "step": 0.01 }), }, } @@ -225,6 +227,7 @@ class Sharpen: return (image,) batch_size, height, width, channels = image.shape + image = image.to(comfy.model_management.get_torch_device()) kernel_size = sharpen_radius * 2 + 1 kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) @@ -239,7 +242,7 @@ class Sharpen: result = torch.clamp(sharpened, 0, 1) - return (result,) + return (result.to(comfy.model_management.intermediate_device()),) class ImageScaleToTotalPixels: upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index bbd380807..69084e91d 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -150,7 +150,7 @@ class SelfAttentionGuidance: degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) degraded_noised = degraded + x - uncond_pred # call into the UNet - (sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) + (sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options) return cfg_result + (degraded - sag) * sag_scale m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) diff --git a/cuda_malloc.py b/cuda_malloc.py index 70e7ecf9a..eb2857c5f 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -47,7 +47,7 @@ blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeFor "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000", "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", - "GeForce GTX 1650", "GeForce GTX 1630" + "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60" } def cuda_malloc_supported(): diff --git a/execution.py b/execution.py index 4688be2ad..4abc10d9e 100644 --- a/execution.py +++ b/execution.py @@ -393,7 +393,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] - logging.error("!!! Exception during processing !!!") + logging.error(f"!!! Exception during processing !!! {ex}") logging.error(traceback.format_exc()) error_details = { @@ -473,6 +473,7 @@ class PromptExecutor: current_outputs = self.caches.outputs.all_node_ids() + comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) self.add_message("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, broadcast=False) diff --git a/folder_paths.py b/folder_paths.py index f1bf40f8c..489795002 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,7 +1,8 @@ import os import time +import logging -supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) +supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl']) folder_names_and_paths = {} @@ -44,7 +45,7 @@ if not os.path.exists(input_directory): try: os.makedirs(input_directory) except: - print("Failed to create input directory") + logging.error("Failed to create input directory") def set_output_directory(output_dir): global output_directory @@ -146,21 +147,23 @@ def recursive_search(directory, excluded_dir_names=None): try: dirs[directory] = os.path.getmtime(directory) except FileNotFoundError: - print(f"Warning: Unable to access {directory}. Skipping this path.") - + logging.warning(f"Warning: Unable to access {directory}. Skipping this path.") + + logging.debug("recursive file list on directory {}".format(directory)) for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] for file_name in filenames: relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) result.append(relative_path) - + for d in subdirs: path = os.path.join(dirpath, d) try: dirs[path] = os.path.getmtime(path) except FileNotFoundError: - print(f"Warning: Unable to access {path}. Skipping this path.") + logging.warning(f"Warning: Unable to access {path}. Skipping this path.") continue + logging.debug("found {} files".format(len(result))) return result, dirs def filter_files_extensions(files, extensions): @@ -178,6 +181,8 @@ def get_full_path(folder_name, filename): full_path = os.path.join(x, filename) if os.path.isfile(full_path): return full_path + elif os.path.islink(full_path): + logging.warning("WARNING path {} exists but doesn't link anywhere, skipping.".format(full_path)) return None @@ -248,8 +253,8 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height err = "**** ERROR: Saving image outside the output folder is not allowed." + \ "\n full_output_folder: " + os.path.abspath(full_output_folder) + \ "\n output_dir: " + output_dir + \ - "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) - print(err) + "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) + logging.error(err) raise Exception(err) try: diff --git a/node_helpers.py b/node_helpers.py new file mode 100644 index 000000000..8828a4ec9 --- /dev/null +++ b/node_helpers.py @@ -0,0 +1,10 @@ + +def conditioning_set_values(conditioning, values={}): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + for k in values: + n[1][k] = values[k] + c.append(n) + + return c diff --git a/nodes.py b/nodes.py index 453f6e606..28359e939 100644 --- a/nodes.py +++ b/nodes.py @@ -34,6 +34,7 @@ import importlib import folder_paths import latent_preview +import node_helpers def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -41,12 +42,12 @@ def before_node_execution(): def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) -MAX_RESOLUTION=8192 +MAX_RESOLUTION=16384 class CLIPTextEncode: @classmethod def INPUT_TYPES(s): - return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}} + return {"required": {"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", )}} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "encode" @@ -151,13 +152,9 @@ class ConditioningSetArea: CATEGORY = "conditioning" def append(self, conditioning, width, height, x, y, strength): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) - n[1]['strength'] = strength - n[1]['set_area_to_bounds'] = False - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8), + "strength": strength, + "set_area_to_bounds": False}) return (c, ) class ConditioningSetAreaPercentage: @@ -176,13 +173,9 @@ class ConditioningSetAreaPercentage: CATEGORY = "conditioning" def append(self, conditioning, width, height, x, y, strength): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]['area'] = ("percentage", height, width, y, x) - n[1]['strength'] = strength - n[1]['set_area_to_bounds'] = False - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x), + "strength": strength, + "set_area_to_bounds": False}) return (c, ) class ConditioningSetAreaStrength: @@ -197,11 +190,7 @@ class ConditioningSetAreaStrength: CATEGORY = "conditioning" def append(self, conditioning, strength): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]['strength'] = strength - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"strength": strength}) return (c, ) @@ -219,19 +208,15 @@ class ConditioningSetMask: CATEGORY = "conditioning" def append(self, conditioning, mask, set_cond_area, strength): - c = [] set_area_to_bounds = False if set_cond_area != "default": set_area_to_bounds = True if len(mask.shape) < 3: mask = mask.unsqueeze(0) - for t in conditioning: - n = [t[0], t[1].copy()] - _, h, w = mask.shape - n[1]['mask'] = mask - n[1]['set_area_to_bounds'] = set_area_to_bounds - n[1]['mask_strength'] = strength - c.append(n) + + c = node_helpers.conditioning_set_values(conditioning, {"mask": mask, + "set_area_to_bounds": set_area_to_bounds, + "mask_strength": strength}) return (c, ) class ConditioningZeroOut: @@ -266,13 +251,8 @@ class ConditioningSetTimestepRange: CATEGORY = "advanced/conditioning" def set_range(self, conditioning, start, end): - c = [] - for t in conditioning: - d = t[1].copy() - d['start_percent'] = start - d['end_percent'] = end - n = [t[0], d] - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"start_percent": start, + "end_percent": end}) return (c, ) class VAEDecode: @@ -413,13 +393,8 @@ class InpaintModelConditioning: out = [] for conditioning in [positive, negative]: - c = [] - for t in conditioning: - d = t[1].copy() - d["concat_latent_image"] = concat_latent - d["concat_mask"] = mask - n = [t[0], d] - c.append(n) + c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent, + "concat_mask": mask}) out.append(c) return (out[0], out[1], out_latent) @@ -991,7 +966,7 @@ class GLIGENTextBoxApply: return {"required": {"conditioning_to": ("CONDITIONING", ), "clip": ("CLIP", ), "gligen_textbox_model": ("GLIGEN", ), - "text": ("STRING", {"multiline": True}), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), @@ -1876,6 +1851,7 @@ def load_custom_node(module_path, ignore=set()): sp = os.path.splitext(module_path) module_name = sp[0] try: + logging.debug("Trying to load custom node {}".format(module_path)) if os.path.isfile(module_path): module_spec = importlib.util.spec_from_file_location(module_name, module_path) module_dir = os.path.split(module_path)[0] @@ -1964,6 +1940,10 @@ def init_custom_nodes(): "nodes_morphology.py", "nodes_stable_cascade.py", "nodes_differential_diffusion.py", + "nodes_ip2p.py", + "nodes_model_merging_model_specific.py", + "nodes_pag.py", + "nodes_align_your_steps.py", ] import_failed = [] diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index 15b784d67..8e58195ac 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -949,7 +949,7 @@ describe("group node", () => { expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min - expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max + expect(p2.widgets.value.widget.options?.max).toBe(16384); // width/height max expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 expect(p1.widgets.value.value).toBe(128); diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 8a55246ee..97be7aa72 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -204,13 +204,17 @@ export class EzWidget { convertToWidget() { if (!this.isConvertedToInput) throw new Error(`Widget ${this.widget.name} cannot be converted as it is already a widget.`); - this.node.menu[`Convert ${this.widget.name} to widget`].call(); + var menu = this.node.menu["Convert Input to Widget"].item.submenu.options; + var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to widget`); + menu[index].callback.call(); } convertToInput() { if (this.isConvertedToInput) throw new Error(`Widget ${this.widget.name} cannot be converted as it is already an input.`); - this.node.menu[`Convert ${this.widget.name} to input`].call(); + var menu = this.node.menu["Convert Widget to Input"].item.submenu.options; + var index = menu.findIndex(a => a.content == `Convert ${this.widget.name} to input`); + menu[index].callback.call(); } } diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index b8d83613d..02546782f 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -20,6 +20,10 @@ const colorPalettes = { "MODEL": "#B39DDB", // light lavender-purple "STYLE_MODEL": "#C2FFAE", // light green-yellow "VAE": "#FF6E6E", // bright red + "NOISE": "#B0B0B0", // gray + "GUIDER": "#66FFFF", // cyan + "SAMPLER": "#ECB4B4", // very soft red + "SIGMAS": "#CDFFCD", // soft lime green "TAESD": "#DCC274", // cheesecake }, "litegraph_base": { diff --git a/web/extensions/core/dynamicPrompts.js b/web/extensions/core/dynamicPrompts.js index 599a9e685..7417361ba 100644 --- a/web/extensions/core/dynamicPrompts.js +++ b/web/extensions/core/dynamicPrompts.js @@ -17,7 +17,7 @@ app.registerExtension({ // Locate dynamic prompt text widgets // Include any widgets with dynamicPrompts set to true, and customtext const widgets = node.widgets.filter( - (n) => (n.type === "customtext" && n.dynamicPrompts !== false) || n.dynamicPrompts + (n) => n.dynamicPrompts ); for (const widget of widgets) { // Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 0f41d10f8..9075e72bb 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -256,8 +256,18 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi return { customConfig }; } +let useConversionSubmenusSetting; app.registerExtension({ name: "Comfy.WidgetInputs", + init() { + useConversionSubmenusSetting = app.ui.settings.addSetting({ + id: "Comfy.NodeInputConversionSubmenus", + name: "Node widget/input conversion sub-menus", + tooltip: "In the node context menu, place the entries that convert between input/widget in sub-menus.", + type: "boolean", + defaultValue: true, + }); + }, async beforeRegisterNodeDef(nodeType, nodeData, app) { // Add menu options to conver to/from widgets const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; @@ -292,12 +302,31 @@ app.registerExtension({ } } } + + //Convert.. main menu if (toInput.length) { - options.push(...toInput, null); + if (useConversionSubmenusSetting.value) { + options.push({ + content: "Convert Widget to Input", + submenu: { + options: toInput, + }, + }); + } else { + options.push(...toInput, null); + } } - if (toWidget.length) { - options.push(...toWidget, null); + if (useConversionSubmenusSetting.value) { + options.push({ + content: "Convert Input to Widget", + submenu: { + options: toWidget, + }, + }); + } else { + options.push(...toWidget, null); + } } } diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 4ff05ae81..427a62b59 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -7247,7 +7247,7 @@ LGraphNode.prototype.executeAction = function(action) //create links for (var i = 0; i < clipboard_info.links.length; ++i) { var link_info = clipboard_info.links[i]; - var origin_node; + var origin_node = undefined; var origin_node_relative_id = link_info[0]; if (origin_node_relative_id != null) { origin_node = nodes[origin_node_relative_id]; diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 169609209..7132fb60f 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -170,9 +170,12 @@ export async function importA1111(graph, parameters) { const opts = parameters .substr(p) .split("\n")[1] - .split(",") + .match(new RegExp("\\s*([^:]+:\\s*([^\"\\{].*?|\".*?\"|\\{.*?\\}))\\s*(,|$)", "g")) .reduce((p, n) => { const s = n.split(":"); + if (s[1].endsWith(',')) { + s[1] = s[1].substr(0, s[1].length -1); + } p[s[0].trim().toLowerCase()] = s[1].trim(); return p; }, {}); @@ -191,6 +194,7 @@ export async function importA1111(graph, parameters) { const vaeLoaderNode = LiteGraph.createNode("VAELoader"); const saveNode = LiteGraph.createNode("SaveImage"); let hrSamplerNode = null; + let hrSteps = null; const ceil64 = (v) => Math.ceil(v / 64) * 64; @@ -290,6 +294,9 @@ export async function importA1111(graph, parameters) { model(v) { setWidgetValue(ckptNode, "ckpt_name", v, true); }, + "vae"(v) { + setWidgetValue(vaeLoaderNode, "vae_name", v, true); + }, "cfg scale"(v) { setWidgetValue(samplerNode, "cfg", +v); }, @@ -316,6 +323,7 @@ export async function importA1111(graph, parameters) { const h = ceil64(+wxh[1]); const hrUp = popOpt("hires upscale"); const hrSz = popOpt("hires resize"); + hrSteps = popOpt("hires steps"); let hrMethod = popOpt("hires upscaler"); setWidgetValue(imageNode, "width", w); @@ -398,7 +406,7 @@ export async function importA1111(graph, parameters) { } if (hrSamplerNode) { - setWidgetValue(hrSamplerNode, "steps", getWidget(samplerNode, "steps").value); + setWidgetValue(hrSamplerNode, "steps", hrSteps? +hrSteps : getWidget(samplerNode, "steps").value); setWidgetValue(hrSamplerNode, "cfg", getWidget(samplerNode, "cfg").value); setWidgetValue(hrSamplerNode, "scheduler", getWidget(samplerNode, "scheduler").value); setWidgetValue(hrSamplerNode, "sampler_name", getWidget(samplerNode, "sampler_name").value); @@ -415,7 +423,7 @@ export async function importA1111(graph, parameters) { graph.arrange(); - for (const opt of ["model hash", "ensd"]) { + for (const opt of ["model hash", "ensd", "version", "vae hash", "ti hashes", "lora hashes", "hashes"]) { delete opts[opt]; } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 027bf4a3f..7003ddb6b 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -90,12 +90,15 @@ function dragElement(dragEl, settings) { }).observe(dragEl); function ensureInBounds() { - if (dragEl.classList.contains("comfy-menu-manual-pos")) { + try { newPosX = Math.min(document.body.clientWidth - dragEl.clientWidth, Math.max(0, dragEl.offsetLeft)); newPosY = Math.min(document.body.clientHeight - dragEl.clientHeight, Math.max(0, dragEl.offsetTop)); positionElement(); } + catch(exception){ + // robust + } } function positionElement() { diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 678b1b8ec..00c91914d 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -307,7 +307,9 @@ export const ComfyWidgets = { return { widget: node.addWidget(widgetType, inputName, val, function (v) { if (config.round) { - this.value = Math.round(v/config.round)*config.round; + this.value = Math.round((v + Number.EPSILON)/config.round)*config.round; + if (this.value > config.max) this.value = config.max; + if (this.value < config.min) this.value = config.min; } else { this.value = v; }