diff --git a/.gitignore b/.gitignore index 8380a2f7c..38d2ba11b 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,8 @@ custom_nodes/ !custom_nodes/example_node.py.example extra_model_paths.yaml /.vs -.idea/ \ No newline at end of file +.idea/ +venv/ +web/extensions/* +!web/extensions/logging.js.example +!web/extensions/core/ \ No newline at end of file diff --git a/README.md b/README.md index 78f34a9bb..1de9d4c3b 100644 --- a/README.md +++ b/README.md @@ -119,12 +119,22 @@ After this you should have everything installed and can proceed to running Comfy ### Others: -[Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476) +#### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476) -Mac/MPS: There is basic support in the code but until someone makes some install instruction you are on your own. +#### Apple Mac silicon -Directml: ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` +You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version. +1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide. +1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. +1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). +1. Launch ComfyUI by running `python main.py`. + +> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). + +#### DirectML (AMD Cards on Windows) + +```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` ### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies? @@ -162,6 +172,8 @@ You can use () to change emphasis of a word or phrase like: (good code:1.2) or ( You can use {day|night}, for wildcard/dynamic prompts. With this syntax "{wild|card|test}" will be randomly replaced by either "wild", "card" or "test" by the frontend every time you queue the prompt. To use {} characters in your actual prompt escape them like: \\{ or \\}. +Dynamic prompts also support C-style comments, like `// comment` or `/* comment */`. + To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension): ```embedding:embedding_filename.pt``` diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index 43877fb83..f494f1d30 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -4,7 +4,7 @@ import yaml import folder_paths from comfy.ldm.util import instantiate_from_config -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint import os.path as osp import re import torch @@ -84,28 +84,4 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb # Put together new checkpoint sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - w = WeightsLoader() - load_state_dict_to = [] - if output_vae: - vae = VAE(scale_factor=scale_factor, config=vae_config) - w.first_stage_model = vae.first_stage_model - load_state_dict_to = [w] - - if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_state_dict_to = [w] - - model = instantiate_from_config(config["model"]) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - - if fp16: - model = model.half() - - return ModelPatcher(model), clip, vae + return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config) diff --git a/comfy/model_base.py b/comfy/model_base.py new file mode 100644 index 000000000..9adea9a5d --- /dev/null +++ b/comfy/model_base.py @@ -0,0 +1,97 @@ +import torch +from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel +from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation +from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule +import numpy as np + +class BaseModel(torch.nn.Module): + def __init__(self, unet_config, v_prediction=False): + super().__init__() + + self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + self.diffusion_model = UNetModel(**unet_config) + self.v_prediction = v_prediction + if self.v_prediction: + self.parameterization = "v" + else: + self.parameterization = "eps" + if "adm_in_channels" in unet_config: + self.adm_channels = unet_config["adm_in_channels"] + else: + self.adm_channels = 0 + print("v_prediction", v_prediction) + print("adm", self.adm_channels) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + + self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) + self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) + self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + + def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + context = torch.cat(c_crossattn, 1) + return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options) + + def get_dtype(self): + return self.diffusion_model.dtype + + def is_adm(self): + return self.adm_channels > 0 + +class SD21UNCLIP(BaseModel): + def __init__(self, unet_config, noise_aug_config, v_prediction=True): + super().__init__(unet_config, v_prediction) + self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) + + def encode_adm(self, **kwargs): + unclip_conditioning = kwargs.get("unclip_conditioning", None) + device = kwargs["device"] + + if unclip_conditioning is not None: + adm_inputs = [] + weights = [] + noise_aug = [] + for unclip_cond in unclip_conditioning: + adm_cond = unclip_cond["clip_vision_output"].image_embeds + weight = unclip_cond["strength"] + noise_augment = unclip_cond["noise_augmentation"] + noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + noise_aug.append(noise_augment) + adm_inputs.append(adm_out) + + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: add a way to control this + noise_augment = 0.05 + noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) + else: + adm_out = torch.zeros((1, self.adm_channels)) + + return adm_out + +class SDInpaint(BaseModel): + def __init__(self, unet_config, v_prediction=False): + super().__init__(unet_config, v_prediction) + self.concat_keys = ("mask", "masked_image") diff --git a/comfy/samplers.py b/comfy/samplers.py index 1fb928f8d..d3cd901e7 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c['transformer_options'] = transformer_options - output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) + output = model_function(input_x, timestep_, **c).chunk(batch_chunks) del input_x model_management.throw_exception_if_processing_interrupted() @@ -460,36 +460,18 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): uncond[temp[1]] = [o[0], n] -def encode_adm(noise_augmentor, conds, batch_size, device): +def encode_adm(model, conds, batch_size, device): for t in range(len(conds)): x = conds[t] + adm_out = None if 'adm' in x[1]: - adm_inputs = [] - weights = [] - noise_aug = [] - adm_in = x[1]["adm"] - for adm_c in adm_in: - adm_cond = adm_c[0].image_embeds - weight = adm_c[1] - noise_augment = adm_c[2] - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight - weights.append(weight) - noise_aug.append(noise_augment) - adm_inputs.append(adm_out) - - if len(noise_aug) > 1: - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: add a way to control this - noise_augment = 0.05 - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) + adm_out = x[1]["adm"] else: - adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) - x[1] = x[1].copy() - x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) + params = x[1].copy() + adm_out = model.encode_adm(device=device, **params) + if adm_out is not None: + x[1] = x[1].copy() + x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device) return conds @@ -591,14 +573,14 @@ class KSampler: apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if self.model.model.diffusion_model.dtype == torch.float16: + if self.model.get_dtype() == torch.float16: precision_scope = torch.autocast else: precision_scope = contextlib.nullcontext - if hasattr(self.model, 'noise_augmentor'): #unclip - positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device) - negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device) + if self.model.is_adm(): + positive = encode_adm(self.model, positive, noise.shape[0], self.device) + negative = encode_adm(self.model, negative, noise.shape[0], self.device) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} diff --git a/comfy/sd.py b/comfy/sd.py index 04eaaa9fe..d898d0197 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -15,8 +15,15 @@ from . import utils from . import clip_vision from . import gligen from . import diffusers_convert +from . import model_base def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): + replace_prefix = {"model.diffusion_model.": "diffusion_model."} + for rp in replace_prefix: + replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys()))) + for x in replace: + sd[x[1]] = sd.pop(x[0]) + m, u = model.load_state_dict(sd, strict=False) k = list(sd.keys()) @@ -182,7 +189,7 @@ def model_lora_keys(model, key_map={}): counter = 0 for b in range(12): - tk = "model.diffusion_model.input_blocks.{}.1".format(b) + tk = "diffusion_model.input_blocks.{}.1".format(b) up_counter = 0 for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) @@ -193,13 +200,13 @@ def model_lora_keys(model, key_map={}): if up_counter >= 4: counter += 1 for c in LORA_UNET_MAP_ATTENTIONS: - k = "model.diffusion_model.middle_block.1.{}.weight".format(c) + k = "diffusion_model.middle_block.1.{}.weight".format(c) if k in sdk: lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c]) key_map[lora_key] = k counter = 3 for b in range(12): - tk = "model.diffusion_model.output_blocks.{}.1".format(b) + tk = "diffusion_model.output_blocks.{}.1".format(b) up_counter = 0 for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) @@ -223,7 +230,7 @@ def model_lora_keys(model, key_map={}): ds_counter = 0 counter = 0 for b in range(12): - tk = "model.diffusion_model.input_blocks.{}.0".format(b) + tk = "diffusion_model.input_blocks.{}.0".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -242,7 +249,7 @@ def model_lora_keys(model, key_map={}): counter = 0 for b in range(3): - tk = "model.diffusion_model.middle_block.{}".format(b) + tk = "diffusion_model.middle_block.{}".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -256,7 +263,7 @@ def model_lora_keys(model, key_map={}): counter = 0 us_counter = 0 for b in range(12): - tk = "model.diffusion_model.output_blocks.{}.0".format(b) + tk = "diffusion_model.output_blocks.{}.0".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -332,7 +339,7 @@ class ModelPatcher: patch_list[i] = patch_list[i].to(device) def model_dtype(self): - return self.model.diffusion_model.dtype + return self.model.get_dtype() def add_patches(self, patches, strength=1.0): p = {} @@ -537,6 +544,19 @@ class VAE: / 3.0) / 2.0, min=0.0, max=1.0) return output + def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = utils.ProgressBar(steps) + + encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.device) - 1.).sample() * self.scale_factor + samples = utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples /= 3.0 + return samples + def decode(self, samples_in): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) @@ -567,28 +587,29 @@ class VAE: def encode(self, pixel_samples): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) - pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor + pixel_samples = pixel_samples.movedim(-1,1) + try: + free_memory = model_management.get_free_memory(self.device) + batch_number = int((free_memory * 0.7) / (2078 * pixel_samples.shape[2] * pixel_samples.shape[3])) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + batch_number = max(1, batch_number) + samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") + for x in range(0, pixel_samples.shape[0], batch_number): + pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.device) + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).sample().cpu() * self.scale_factor + + except model_management.OOM_EXCEPTION as e: + print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + samples = self.encode_tiled_(pixel_samples) + self.first_stage_model = self.first_stage_model.cpu() - samples = samples.cpu() return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) - pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - - steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) - steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) - pbar = utils.ProgressBar(steps) - - samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples /= 3.0 + pixel_samples = pixel_samples.movedim(-1,1) + samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) self.first_stage_model = self.first_stage_model.cpu() - samples = samples.cpu() return samples def broadcast_image_to(tensor, target_batch_size, batched_number): @@ -764,7 +785,7 @@ def load_controlnet(ckpt_path, model=None): for x in controlnet_data: c_m = "control_model." if x.startswith(c_m): - sd_key = "model.diffusion_model.{}".format(x[len(c_m):]) + sd_key = "diffusion_model.{}".format(x[len(c_m):]) if sd_key in model_sd: cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) @@ -931,9 +952,10 @@ def load_gligen(ckpt_path): model = model.half() return model -def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) +def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): + if config is None: + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) model_config_params = config['model']['params'] clip_config = model_config_params['cond_stage_config'] scale_factor = model_config_params['scale_factor'] @@ -942,8 +964,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e fp16 = False if "unet_config" in model_config_params: if "params" in model_config_params["unet_config"]: - if "use_fp16" in model_config_params["unet_config"]["params"]: - fp16 = model_config_params["unet_config"]["params"]["use_fp16"] + unet_config = model_config_params["unet_config"]["params"] + if "use_fp16" in unet_config: + fp16 = unet_config["use_fp16"] + + noise_aug_config = None + if "noise_aug_config" in model_config_params: + noise_aug_config = model_config_params["noise_aug_config"] + + v_prediction = False + + if "parameterization" in model_config_params: + if model_config_params["parameterization"] == "v": + v_prediction = True clip = None vae = None @@ -963,9 +996,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] - model = instantiate_from_config(config["model"]) - sd = utils.load_torch_file(ckpt_path) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + if config['model']["target"].endswith("LatentInpaintDiffusion"): + model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): + model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + else: + model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + + if state_dict is None: + state_dict = utils.load_torch_file(ckpt_path) + model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: model = model.half() @@ -1073,16 +1113,20 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} + unclip_model = False + inpaint_model = False if noise_aug_config is not None: #SD2.x unclip model sd_config["noise_aug_config"] = noise_aug_config sd_config["image_size"] = 96 sd_config["embedding_dropout"] = 0.25 sd_config["conditioning_key"] = 'crossattn-adm' + unclip_model = True model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" elif unet_config["in_channels"] > 4: #inpainting model sd_config["conditioning_key"] = "hybrid" sd_config["finetune_keys"] = None model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + inpaint_model = True else: sd_config["conditioning_key"] = "crossattn" @@ -1096,13 +1140,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o unet_config["num_classes"] = "sequential" unet_config["adm_in_channels"] = sd[unclip].shape[1] + v_prediction = False if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" out = sd[k] if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. + v_prediction = True sd_config["parameterization"] = 'v' - model = instantiate_from_config(model_config) + if inpaint_model: + model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + elif unclip_model: + model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + else: + model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index b1a392736..91fb4ff27 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -82,6 +82,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): next_new_token += 1 else: print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) + while len(tokens_temp) < len(x): + tokens_temp += [self.empty_tokens[0][-1]] out_tokens += [tokens_temp] if len(embedding_weights) > 0: diff --git a/execution.py b/execution.py index 218a84c36..fc9578bc8 100644 --- a/execution.py +++ b/execution.py @@ -728,9 +728,14 @@ class PromptQueue: return True return False - def get_history(self): + def get_history(self, prompt_id=None): with self.mutex: - return copy.deepcopy(self.history) + if prompt_id is None: + return copy.deepcopy(self.history) + elif prompt_id in self.history: + return {prompt_id: copy.deepcopy(self.history[prompt_id])} + else: + return {} def wipe_history(self): with self.mutex: diff --git a/nodes.py b/nodes.py index b057504ed..658e32dad 100644 --- a/nodes.py +++ b/nodes.py @@ -623,11 +623,11 @@ class unCLIPConditioning: c = [] for t in conditioning: o = t[1].copy() - x = (clip_vision_output, strength, noise_augmentation) - if "adm" in o: - o["adm"] = o["adm"][:] + [x] + x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation} + if "unclip_conditioning" in o: + o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x] else: - o["adm"] = [x] + o["unclip_conditioning"] = [x] n = [t[0], o] c.append(n) return (c, ) @@ -1192,6 +1192,26 @@ class ImageScale: s = s.movedim(1,-1) return (s,) +class ImageScaleBy: + upscale_methods = ["nearest-exact", "bilinear", "area"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image/upscaling" + + def upscale(self, image, upscale_method, scale_by): + samples = image.movedim(-1,1) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = s.movedim(1,-1) + return (s,) + class ImageInvert: @classmethod @@ -1290,6 +1310,7 @@ NODE_CLASS_MAPPINGS = { "LoadImage": LoadImage, "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, + "ImageScaleBy": ImageScaleBy, "ImageInvert": ImageInvert, "ImagePadForOutpaint": ImagePadForOutpaint, "ConditioningAverage ": ConditioningAverage , @@ -1371,6 +1392,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadImage": "Load Image", "LoadImageMask": "Load Image (as Mask)", "ImageScale": "Upscale Image", + "ImageScaleBy": "Upscale Image By", "ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageInvert": "Invert Image", "ImagePadForOutpaint": "Pad Image for Outpainting", diff --git a/script_examples/websockets_api_example.py b/script_examples/websockets_api_example.py new file mode 100644 index 000000000..25882d821 --- /dev/null +++ b/script_examples/websockets_api_example.py @@ -0,0 +1,164 @@ +#This is an example that uses the websockets api to know when a prompt execution is done +#Once the prompt execution is done it downloads the images using the /history endpoint + +import websocket +import uuid +import json +import urllib.request +import urllib.parse + +server_address = "127.0.0.1:8188" +client_id = str(uuid.uuid4()) + +def queue_prompt(prompt): + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + +def get_image(filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response: + return response.read() + +def get_history(prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response: + return json.loads(response.read()) + +def get_images(ws, prompt): + prompt_id = queue_prompt(prompt)['prompt_id'] + output_images = {} + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: + break #Execution is done + else: + continue #previews are binary data + + history = get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output + + return output_images + +prompt_text = """ +{ + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": 8, + "denoise": 1, + "latent_image": [ + "5", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "7", + 0 + ], + "positive": [ + "6", + 0 + ], + "sampler_name": "euler", + "scheduler": "normal", + "seed": 8566257, + "steps": 20 + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "v1-5-pruned-emaonly.ckpt" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": 512, + "width": 512 + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "masterpiece best quality girl" + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "bad hands" + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "9": { + "class_type": "SaveImage", + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "8", + 0 + ] + } + } +} +""" + +prompt = json.loads(prompt_text) +#set the text prompt for our positive CLIPTextEncode +prompt["6"]["inputs"]["text"] = "masterpiece best quality man" + +#set the seed for our KSampler node +prompt["3"]["inputs"]["seed"] = 5 + +ws = websocket.WebSocket() +ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id)) +images = get_images(ws, prompt) + +#Commented out code to display the output images: + +# for node_id in images: +# for image_data in images[node_id]: +# from PIL import Image +# import io +# image = Image.open(io.BytesIO(image_data)) +# image.show() + diff --git a/server.py b/server.py index 174d38af1..300221f6c 100644 --- a/server.py +++ b/server.py @@ -372,6 +372,11 @@ class PromptServer(): async def get_history(request): return web.json_response(self.prompt_queue.get_history()) + @routes.get("/history/{prompt_id}") + async def get_history(request): + prompt_id = request.match_info.get("prompt_id", None) + return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) + @routes.get("/queue") async def get_queue(request): queue_info = {} diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index 51e66f924..662d87e74 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -1,132 +1,138 @@ -import { app } from "/scripts/app.js"; +import {app} from "/scripts/app.js"; // Adds filtering to combo context menus -const id = "Comfy.ContextMenuFilter"; -app.registerExtension({ - name: id, +const ext = { + name: "Comfy.ContextMenuFilter", init() { const ctxMenu = LiteGraph.ContextMenu; + LiteGraph.ContextMenu = function (values, options) { const ctx = ctxMenu.call(this, values, options); // If we are a dark menu (only used for combo boxes) then add a filter input if (options?.className === "dark" && values?.length > 10) { const filter = document.createElement("input"); - Object.assign(filter.style, { - width: "calc(100% - 10px)", - border: "0", - boxSizing: "border-box", - background: "#333", - border: "1px solid #999", - margin: "0 0 5px 5px", - color: "#fff", - }); + filter.classList.add("comfy-context-menu-filter"); filter.placeholder = "Filter list"; this.root.prepend(filter); - let selectedIndex = 0; - let items = this.root.querySelectorAll(".litemenu-entry"); - let itemCount = items.length; - let selectedItem; + const items = Array.from(this.root.querySelectorAll(".litemenu-entry")); + let displayedItems = [...items]; + let itemCount = displayedItems.length; - // Apply highlighting to the selected item - function updateSelected() { - if (selectedItem) { - selectedItem.style.setProperty("background-color", ""); - selectedItem.style.setProperty("color", ""); - } - selectedItem = items[selectedIndex]; - if (selectedItem) { - selectedItem.style.setProperty("background-color", "#ccc", "important"); - selectedItem.style.setProperty("color", "#000", "important"); - } - } + // We must request an animation frame for the current node of the active canvas to update. + requestAnimationFrame(() => { + const currentNode = LGraphCanvas.active_canvas.current_node; + const clickedComboValue = currentNode.widgets + .filter(w => w.type === "combo" && w.options.values.length === values.length) + .find(w => w.options.values.every((v, i) => v === values[i])) + .value; - const positionList = () => { - const rect = this.root.getBoundingClientRect(); - - // If the top is off screen then shift the element with scaling applied - if (rect.top < 0) { - const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; - const shift = (this.root.clientHeight * scale) / 2; - this.root.style.top = -shift + "px"; - } - } - - updateSelected(); - - // Arrow up/down to select items - filter.addEventListener("keydown", (e) => { - if (e.key === "ArrowUp") { - if (selectedIndex === 0) { - selectedIndex = itemCount - 1; - } else { - selectedIndex--; - } - updateSelected(); - e.preventDefault(); - } else if (e.key === "ArrowDown") { - if (selectedIndex === itemCount - 1) { - selectedIndex = 0; - } else { - selectedIndex++; - } - updateSelected(); - e.preventDefault(); - } else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) { - selectedItem.click(); - } else if(e.key === "Escape") { - this.close(); - } - }); - - filter.addEventListener("input", () => { - // Hide all items that dont match our filter - const term = filter.value.toLocaleLowerCase(); - items = this.root.querySelectorAll(".litemenu-entry"); - // When filtering recompute which items are visible for arrow up/down - // Try and maintain selection - let visibleItems = []; - for (const item of items) { - const visible = !term || item.textContent.toLocaleLowerCase().includes(term); - if (visible) { - item.style.display = "block"; - if (item === selectedItem) { - selectedIndex = visibleItems.length; - } - visibleItems.push(item); - } else { - item.style.display = "none"; - if (item === selectedItem) { - selectedIndex = 0; - } - } - } - items = visibleItems; + let selectedIndex = values.findIndex(v => v === clickedComboValue); + let selectedItem = displayedItems?.[selectedIndex]; updateSelected(); - // If we have an event then we can try and position the list under the source - if (options.event) { - let top = options.event.clientY - 10; - - const bodyRect = document.body.getBoundingClientRect(); - const rootRect = this.root.getBoundingClientRect(); - if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) { - top = Math.max(0, bodyRect.height - rootRect.height - 10); - } - - this.root.style.top = top + "px"; - positionList(); + // Apply highlighting to the selected item + function updateSelected() { + selectedItem?.style.setProperty("background-color", ""); + selectedItem?.style.setProperty("color", ""); + selectedItem = displayedItems[selectedIndex]; + selectedItem?.style.setProperty("background-color", "#ccc", "important"); + selectedItem?.style.setProperty("color", "#000", "important"); } - }); - requestAnimationFrame(() => { - // Focus the filter box when opening - filter.focus(); + const positionList = () => { + const rect = this.root.getBoundingClientRect(); - positionList(); - }); + // If the top is off-screen then shift the element with scaling applied + if (rect.top < 0) { + const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; + const shift = (this.root.clientHeight * scale) / 2; + this.root.style.top = -shift + "px"; + } + } + + // Arrow up/down to select items + filter.addEventListener("keydown", (event) => { + switch (event.key) { + case "ArrowUp": + event.preventDefault(); + if (selectedIndex === 0) { + selectedIndex = itemCount - 1; + } else { + selectedIndex--; + } + updateSelected(); + break; + case "ArrowRight": + event.preventDefault(); + selectedIndex = itemCount - 1; + updateSelected(); + break; + case "ArrowDown": + event.preventDefault(); + if (selectedIndex === itemCount - 1) { + selectedIndex = 0; + } else { + selectedIndex++; + } + updateSelected(); + break; + case "ArrowLeft": + event.preventDefault(); + selectedIndex = 0; + updateSelected(); + break; + case "Enter": + selectedItem?.click(); + break; + case "Escape": + this.close(); + break; + } + }); + + filter.addEventListener("input", () => { + // Hide all items that don't match our filter + const term = filter.value.toLocaleLowerCase(); + // When filtering, recompute which items are visible for arrow up/down and maintain selection. + displayedItems = items.filter(item => { + const isVisible = !term || item.textContent.toLocaleLowerCase().includes(term); + item.style.display = isVisible ? "block" : "none"; + return isVisible; + }); + + selectedIndex = 0; + if (displayedItems.includes(selectedItem)) { + selectedIndex = displayedItems.findIndex(d => d === selectedItem); + } + itemCount = displayedItems.length; + + updateSelected(); + + // If we have an event then we can try and position the list under the source + if (options.event) { + let top = options.event.clientY - 10; + + const bodyRect = document.body.getBoundingClientRect(); + const rootRect = this.root.getBoundingClientRect(); + if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) { + top = Math.max(0, bodyRect.height - rootRect.height - 10); + } + + this.root.style.top = top + "px"; + positionList(); + } + }); + + requestAnimationFrame(() => { + // Focus the filter box when opening + filter.focus(); + + positionList(); + }); + }) } return ctx; @@ -134,4 +140,6 @@ app.registerExtension({ LiteGraph.ContextMenu.prototype = ctxMenu.prototype; }, -}); +} + +app.registerExtension(ext); diff --git a/web/extensions/core/dynamicPrompts.js b/web/extensions/core/dynamicPrompts.js index 7dae07f4d..599a9e685 100644 --- a/web/extensions/core/dynamicPrompts.js +++ b/web/extensions/core/dynamicPrompts.js @@ -3,6 +3,13 @@ import { app } from "../../scripts/app.js"; // Allows for simple dynamic prompt replacement // Inputs in the format {a|b} will have a random value of a or b chosen when the prompt is queued. +/* + * Strips C-style line and block comments from a string + */ +function stripComments(str) { + return str.replace(/\/\*[\s\S]*?\*\/|\/\/.*/g,''); +} + app.registerExtension({ name: "Comfy.DynamicPrompts", nodeCreated(node) { @@ -15,7 +22,7 @@ app.registerExtension({ for (const widget of widgets) { // Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node widget.serializeValue = (workflowNode, widgetIndex) => { - let prompt = widget.value; + let prompt = stripComments(widget.value); while (prompt.replace("\\{", "").includes("{") && prompt.replace("\\}", "").includes("}")) { const startIndex = prompt.replace("\\{", "00").indexOf("{"); const endIndex = prompt.replace("\\}", "00").indexOf("}"); diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 4fe0a6013..c356655b0 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -200,8 +200,23 @@ app.registerExtension({ applyToGraph() { if (!this.outputs[0].links?.length) return; + function get_links(node) { + let links = []; + for (const l of node.outputs[0].links) { + const linkInfo = app.graph.links[l]; + const n = node.graph.getNodeById(linkInfo.target_id); + if (n.type == "Reroute") { + links = links.concat(get_links(n)); + } else { + links.push(l); + } + } + return links; + } + + let links = get_links(this); // For each output link copy our value over the original widget value - for (const l of this.outputs[0].links) { + for (const l of links) { const linkInfo = app.graph.links[l]; const node = this.graph.getNodeById(linkInfo.target_id); const input = node.inputs[linkInfo.target_slot]; diff --git a/web/style.css b/web/style.css index 47571a16e..5fea5bba8 100644 --- a/web/style.css +++ b/web/style.css @@ -50,7 +50,7 @@ body { padding: 30px 30px 10px 30px; background-color: var(--comfy-menu-bg); /* Modal background */ color: var(--error-text); - box-shadow: 0px 0px 20px #888888; + box-shadow: 0 0 20px #888888; border-radius: 10px; top: 50%; left: 50%; @@ -84,7 +84,7 @@ body { font-size: 15px; position: absolute; top: 50%; - right: 0%; + right: 0; text-align: center; z-index: 100; width: 170px; @@ -252,7 +252,7 @@ button.comfy-queue-btn { bottom: 0 !important; left: auto !important; right: 0 !important; - border-radius: 0px; + border-radius: 0; } .comfy-menu span.drag-handle { visibility:hidden @@ -291,7 +291,7 @@ button.comfy-queue-btn { .litegraph .dialog { z-index: 1; - font-family: Arial; + font-family: Arial, sans-serif; } .litegraph .litemenu-entry.has_submenu { @@ -330,6 +330,13 @@ button.comfy-queue-btn { color: var(--input-text) !important; } +.comfy-context-menu-filter { + box-sizing: border-box; + border: 1px solid #999; + margin: 0 0 5px 5px; + width: calc(100% - 10px); +} + /* Search box */ .litegraph.litesearchbox {