From 0782ac2a96fab2c436f78379db1de0df9737aa1d Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 8 Sep 2023 14:53:29 +1000 Subject: [PATCH 01/39] defaultInput --- web/extensions/core/widgetInputs.js | 2 +- web/scripts/app.js | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index f9a5b7278..606605f0a 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -142,7 +142,7 @@ app.registerExtension({ const r = origOnNodeCreated ? origOnNodeCreated.apply(this) : undefined; if (this.widgets) { for (const w of this.widgets) { - if (w?.options?.forceInput) { + if (w?.options?.forceInput || w?.options?.defaultInput) { const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}]; convertToInput(this, w, config); } diff --git a/web/scripts/app.js b/web/scripts/app.js index a3661da64..40295b350 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1248,6 +1248,10 @@ export class ComfyApp { if (!config.widget.options) config.widget.options = {}; config.widget.options.forceInput = inputData[1].forceInput; } + if(widgetCreated && inputData[1]?.defaultInput && config?.widget) { + if (!config.widget.options) config.widget.options = {}; + config.widget.options.defaultInput = inputData[1].defaultInput; + } } for (const o in nodeData["output"]) { From 264867bf87c37abdf794c9e1bab1bc512c2f5ff4 Mon Sep 17 00:00:00 2001 From: Michael Abrahams Date: Fri, 8 Sep 2023 11:17:45 -0400 Subject: [PATCH 02/39] Clear clipboard on copy --- web/scripts/app.js | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index a3661da64..72844a92b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -735,9 +735,17 @@ export class ComfyApp { */ #addCopyHandler() { document.addEventListener("copy", (e) => { - // copy + if (e.target.type === "text" || e.target.type === "textarea") { + // Default system copy + return; + } + // copy nodes and clear clipboard if (this.canvas.selected_nodes) { - this.canvas.copyToClipboard(); + this.canvas.copyToClipboard(); + e.clipboardData.clearData(); + e.preventDefault(); + e.stopImmediatePropagation(); + return false; } }); } @@ -842,10 +850,13 @@ export class ComfyApp { if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) { if (e.shiftKey) { this.copyToClipboard(true); + e.clipboardData.clearData(); block_default = true; } - // Trigger default onCopy - return true; + else { + // Trigger onCopy + return true; + } } // Ctrl+V Paste @@ -855,7 +866,7 @@ export class ComfyApp { block_default = true; } else { - // Trigger default onPaste + // Trigger onPaste return true; } } From 7df822212fb2da45c8523155086456c2cd119062 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 10 Sep 2023 02:36:04 -0400 Subject: [PATCH 03/39] Allow checkpoints with .pt and .bin extensions. --- folder_paths.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 82aedd43f..4a10c68e7 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,13 @@ import os import time -supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} base_path = os.path.dirname(os.path.realpath(__file__)) models_dir = os.path.join(base_path, "models") -folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) +folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) From 9562a6b49e63e63a16f3e45ff4965f72385f51fa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 10 Sep 2023 11:19:31 -0400 Subject: [PATCH 04/39] Fix a few clipboard issues. --- web/scripts/app.js | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 7ef2fc4e3..9db4e9230 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -742,7 +742,7 @@ export class ComfyApp { // copy nodes and clear clipboard if (this.canvas.selected_nodes) { this.canvas.copyToClipboard(); - e.clipboardData.clearData(); + e.clipboardData.setData('text', ' '); //clearData doesn't remove images from clipboard e.preventDefault(); e.stopImmediatePropagation(); return false; @@ -848,27 +848,14 @@ export class ComfyApp { // Ctrl+C Copy if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) { - if (e.shiftKey) { - this.copyToClipboard(true); - e.clipboardData.clearData(); - block_default = true; - } - else { - // Trigger onCopy - return true; - } + // Trigger onCopy + return true; } // Ctrl+V Paste - if ((e.key === 'v') && (e.metaKey || e.ctrlKey)) { - if (e.shiftKey) { - this.pasteFromClipboard(true); - block_default = true; - } - else { - // Trigger onPaste - return true; - } + if ((e.key === 'v' || e.key == 'V') && (e.metaKey || e.ctrlKey)) { + // Trigger onPaste + return true; } } From 7d401ed1d0fcc78b14d61d9f585ace40b9de0ddb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Sep 2023 16:36:50 -0400 Subject: [PATCH 05/39] Add ldm format support to UNETLoader. --- comfy/sd.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 8be0bcbc8..9bdb2ad64 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -454,20 +454,26 @@ def load_unet(unet_path): #load unet in diffusers format sd = comfy.utils.load_torch_file(unet_path) parameters = comfy.utils.calculate_parameters(sd) fp16 = model_management.should_use_fp16(model_params=parameters) + if "input_blocks.0.0.weight" in sd: #ldm + model_config = model_detection.model_config_from_unet(sd, "", fp16) + if model_config is None: + raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + new_sd = sd - model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) - if model_config is None: - print("ERROR UNSUPPORTED UNET", unet_path) - return None + else: #diffusers + model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) + if model_config is None: + print("ERROR UNSUPPORTED UNET", unet_path) + return None - diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) + diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) - new_sd = {} - for k in diffusers_keys: - if k in sd: - new_sd[diffusers_keys[k]] = sd.pop(k) - else: - print(diffusers_keys[k], k) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + print(diffusers_keys[k], k) offload_device = model_management.unet_offload_device() model = model_config.get_model(new_sd, "") model = model.to(offload_device) From fb3b7282034a37dbed377055f843c9a9302fdd8c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Sep 2023 21:49:56 -0400 Subject: [PATCH 06/39] Fix issue where autocast fp32 CLIP gave different results from regular. --- comfy/sd1_clip.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 477d5c309..b84a38490 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -60,6 +60,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): if dtype is not None: self.transformer.to(dtype) + self.transformer.text_model.embeddings.token_embedding.to(torch.float32) + self.transformer.text_model.embeddings.position_embedding.to(torch.float32) + self.max_length = max_length if freeze: self.freeze() @@ -138,7 +141,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if backup_embeds.weight.dtype != torch.float32: + if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: precision_scope = torch.autocast else: precision_scope = lambda a, b: contextlib.nullcontext(a) From ed58730658d0213600b64849d721a6bb92c675bf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Sep 2023 15:09:10 -0400 Subject: [PATCH 07/39] Don't leave very large hidden states in the clip vision output. --- comfy/clip_vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 9b95ae003..1206c680d 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -56,6 +56,7 @@ class ClipVisionModel(): if t is not None: if k == 'hidden_states': outputs["penultimate_hidden_states"] = t[-2].cpu() + outputs["hidden_states"] = None else: outputs[k] = t.cpu() From 0b829fe35b3ae626494735eb149c43345e5c55a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Sep 2023 18:44:05 -0400 Subject: [PATCH 08/39] .gitignore refactor. --- .gitignore | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 0177e1d7d..98d91318d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,16 @@ __pycache__/ *.py[cod] -output/ -input/ -!input/example.png -models/ -temp/ -custom_nodes/ +/output/ +/input/ +!/input/example.png +/models/ +/temp/ +/custom_nodes/ !custom_nodes/example_node.py.example extra_model_paths.yaml /.vs .idea/ venv/ -web/extensions/* -!web/extensions/logging.js.example -!web/extensions/core/ +/web/extensions/* +!/web/extensions/logging.js.example +!/web/extensions/core/ From 30de95e4b420aa02d25d151271dca9867492288f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Sep 2023 01:10:31 -0400 Subject: [PATCH 09/39] Add some nodes to subtract and add model weights. --- comfy_extras/nodes_model_merging.py | 40 +++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index bce4b3dd0..ebcbd4be9 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -27,6 +27,44 @@ class ModelMergeSimple: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) +class ModelSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2, multiplier): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return (m, ) + +class ModelAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, 1.0, 1.0) + return (m, ) + + class CLIPMergeSimple: @classmethod def INPUT_TYPES(s): @@ -144,6 +182,8 @@ class CheckpointSave: NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, "ModelMergeBlocks": ModelMergeBlocks, + "ModelMergeSubtract": ModelSubtract, + "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, } From 3039b08eb16777431946ed9ae4a63c5466336bff Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Sep 2023 11:38:20 -0400 Subject: [PATCH 10/39] Only parse command line args when main.py is called. --- comfy/cli_args.py | 7 +++++-- comfy/options.py | 6 ++++++ main.py | 3 +++ 3 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 comfy/options.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index fda245433..ffae81c49 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,6 +1,6 @@ import argparse import enum - +import comfy.options class EnumAction(argparse.Action): """ @@ -94,7 +94,10 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") -args = parser.parse_args() +if comfy.options.args_parsing: + args = parser.parse_args() +else: + args = parser.parse_args([]) if args.windows_standalone_build: args.auto_launch = True diff --git a/comfy/options.py b/comfy/options.py new file mode 100644 index 000000000..f7f8af41e --- /dev/null +++ b/comfy/options.py @@ -0,0 +1,6 @@ + +args_parsing = False + +def enable_args_parsing(enable=True): + global args_parsing + args_parsing = enable diff --git a/main.py b/main.py index 9f0f80458..7c5eaee0a 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,6 @@ +import comfy.options +comfy.options.enable_args_parsing() + import os import importlib.util import folder_paths From 0e4395a8a3f7b5da15c46308eee9721ce3f4f475 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 13 Sep 2023 18:42:44 +0100 Subject: [PATCH 11/39] Allow pasting nodes with connections in firefox --- web/scripts/app.js | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 9db4e9230..6dd1f3edd 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -671,6 +671,10 @@ export class ComfyApp { */ #addPasteHandler() { document.addEventListener("paste", (e) => { + // ctrl+shift+v is used to paste nodes with connections + // this is handled by litegraph + if(this.shiftDown) return; + let data = (e.clipboardData || window.clipboardData); const items = data.items; @@ -853,7 +857,7 @@ export class ComfyApp { } // Ctrl+V Paste - if ((e.key === 'v' || e.key == 'V') && (e.metaKey || e.ctrlKey)) { + if ((e.key === 'v' || e.key == 'V') && (e.metaKey || e.ctrlKey) && !e.shiftKey) { // Trigger onPaste return true; } From 0966d3ce823dd9e0d668bd0f4049fb5b879c6672 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 14 Sep 2023 12:16:07 -0400 Subject: [PATCH 12/39] Don't run text encoders on xpu because there are issues. --- comfy/model_management.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index b663e8f59..e38ef4eea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -451,6 +451,8 @@ def text_encoder_device(): if args.gpu_only: return get_torch_device() elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + if is_intel_xpu(): + return torch.device("cpu") if should_use_fp16(prioritize_performance=False): return get_torch_device() else: From 0d8f3764468999bc34700799553919ded9b34ef8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 14 Sep 2023 18:12:36 -0400 Subject: [PATCH 13/39] Set last layer on SD2.x models uses the proper indexes now. Before I had made the last layer the penultimate layer because some checkpoints don't have them but it's not consistent with the others models. TLDR: for SD2.x models only: CLIPSetLastLayer -1 is now -2. --- comfy/sd2_clip.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 818c9711e..05e50a005 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -12,16 +12,6 @@ class SD2ClipModel(sd1_clip.SD1ClipModel): super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] - def clip_layer(self, layer_idx): - if layer_idx < 0: - layer_idx -= 1 #The real last layer of SD2.x clip is the penultimate one. The last one might contain garbage. - if abs(layer_idx) >= 24: - self.layer = "hidden" - self.layer_idx = -2 - else: - self.layer = "hidden" - self.layer_idx = layer_idx - class SD2Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) From 44361f6344f53c32b1cd902515b9071f6d08ecc7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Sep 2023 01:56:07 -0400 Subject: [PATCH 14/39] Support for text encoder models that need attention_mask. --- comfy/sd1_clip.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index b84a38490..9978b6c35 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -71,6 +71,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.empty_tokens = [[49406] + [49407] * 76] self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.enable_attention_masks = False self.layer_norm_hidden_state = True if layer == "hidden": @@ -147,7 +148,17 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): precision_scope = lambda a, b: contextlib.nullcontext(a) with precision_scope(model_management.get_autocast_device(device), torch.float32): - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden") self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": From 076f3e63107fac1a9f4da705dfd18b428cb1340c Mon Sep 17 00:00:00 2001 From: karrycharon Date: Fri, 15 Sep 2023 16:37:58 +0800 Subject: [PATCH 15/39] fix structuredClone undefined error; --- web/scripts/app.js | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 6dd1f3edd..4beaf03ae 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1297,7 +1297,13 @@ export class ComfyApp { let reset_invalid_values = false; if (!graphData) { - graphData = structuredClone(defaultGraph); + if (typeof structuredClone === "undefined") + { + graphData = JSON.parse(JSON.stringify(defaultGraph)); + }else + { + graphData = structuredClone(defaultGraph); + } reset_invalid_values = true; } From 94e4fe39d868a0bb939c2f91746de09680e4657d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Sep 2023 12:03:03 -0400 Subject: [PATCH 16/39] This isn't used anywhere. --- comfy/ldm/models/diffusion/ddim.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index 139c8e01e..befab0075 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -33,7 +33,6 @@ class DDIMSampler(object): assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) - self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) @@ -195,7 +194,7 @@ class DDIMSampler(object): temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False): - device = self.model.betas.device + device = self.model.alphas_cumprod.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) From 415abb275f8ef74615cbb3c5ebc90b20d1a713b7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Sep 2023 19:22:47 -0400 Subject: [PATCH 17/39] Add DDPM sampler. --- comfy/k_diffusion/sampling.py | 31 +++++++++++++++++++++++++++++++ comfy/samplers.py | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index eb088d92b..937c5a388 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -706,3 +706,34 @@ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disab noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) + +def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): + alpha_cumprod = 1 / ((sigma * sigma) + 1) + alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1) + alpha = (alpha_cumprod / alpha_cumprod_prev) + + mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) + if sigma_prev > 0: + mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) + return mu + + +def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler) + if sigmas[i + 1] != 0: + x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0) + return x + + +@torch.no_grad() +def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) + diff --git a/comfy/samplers.py b/comfy/samplers.py index c60288fd1..7f1987167 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -546,7 +546,7 @@ class KSampler: SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model From 43d4935a1da0b78dac101a28cc98de0b7d556729 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Sep 2023 22:21:14 -0400 Subject: [PATCH 18/39] Add cond_or_uncond array to transformer_options so hooks can check what is cond and what is uncond. --- comfy/samplers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index 7f1987167..57673a029 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -255,6 +255,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con else: transformer_options["patches"] = patches + transformer_options["cond_or_uncond"] = cond_or_uncond[:] c['transformer_options'] = transformer_options if 'model_function_wrapper' in model_options: From 69680fede7de62f503a59efbbd8aa058b8e50395 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sat, 16 Sep 2023 20:36:00 +0900 Subject: [PATCH 19/39] fix: thumbnail ratio fix for mixed ratio images --- web/scripts/app.js | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 4beaf03ae..84090764a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -532,7 +532,17 @@ export class ComfyApp { } } this.imageRects.push([x, y, cellWidth, cellHeight]); - ctx.drawImage(img, x, y, cellWidth, cellHeight); + + let wratio = cellWidth/img.width; + let hratio = cellHeight/img.height; + var ratio = Math.min(wratio, hratio); + + let imgHeight = ratio * img.height; + let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight)/2; + let imgWidth = ratio * img.width; + let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2; + + ctx.drawImage(img, imgX, imgY, imgWidth, imgHeight); ctx.filter = "none"; } From 4d5e057bb2e32117c945cc9dfe8039dad2329297 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sat, 16 Sep 2023 20:37:27 +0900 Subject: [PATCH 20/39] fix indent --- web/scripts/app.js | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 84090764a..f0bb8640c 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -533,14 +533,14 @@ export class ComfyApp { } this.imageRects.push([x, y, cellWidth, cellHeight]); - let wratio = cellWidth/img.width; - let hratio = cellHeight/img.height; - var ratio = Math.min(wratio, hratio); + let wratio = cellWidth/img.width; + let hratio = cellHeight/img.height; + var ratio = Math.min(wratio, hratio); - let imgHeight = ratio * img.height; - let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight)/2; - let imgWidth = ratio * img.width; - let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2; + let imgHeight = ratio * img.height; + let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight)/2; + let imgWidth = ratio * img.width; + let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2; ctx.drawImage(img, imgX, imgY, imgWidth, imgHeight); ctx.filter = "none"; From 61b1f67734f445aabdbd941537c22bfe6f9237aa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 16 Sep 2023 12:59:54 -0400 Subject: [PATCH 21/39] Support models without previews. --- comfy/latent_formats.py | 4 ++++ latent_preview.py | 7 +++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 8b59cfbdc..fadc0eec7 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -1,5 +1,9 @@ class LatentFormat: + scale_factor = 1.0 + latent_rgb_factors = None + taesd_decoder_name = None + def process_in(self, latent): return latent * self.scale_factor diff --git a/latent_preview.py b/latent_preview.py index 30c1d1317..87240a582 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -53,7 +53,9 @@ def get_previewer(device, latent_format): method = args.preview_method if method != LatentPreviewMethod.NoPreviews: # TODO previewer methods - taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) + taesd_decoder_path = None + if latent_format.taesd_decoder_name is not None: + taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB @@ -68,7 +70,8 @@ def get_previewer(device, latent_format): print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) + if latent_format.latent_rgb_factors is not None: + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) return previewer From 0665749b1a13f149f3c1770db7f366643acafdd7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Sep 2023 02:10:06 -0400 Subject: [PATCH 22/39] Move ModelSubtract and ModelAdd to advanced/model_merging --- comfy_extras/nodes_model_merging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index ebcbd4be9..3d42d7806 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -37,7 +37,7 @@ class ModelSubtract: RETURN_TYPES = ("MODEL",) FUNCTION = "merge" - CATEGORY = "_for_testing/model_merging" + CATEGORY = "advanced/model_merging" def merge(self, model1, model2, multiplier): m = model1.clone() @@ -55,7 +55,7 @@ class ModelAdd: RETURN_TYPES = ("MODEL",) FUNCTION = "merge" - CATEGORY = "_for_testing/model_merging" + CATEGORY = "advanced/model_merging" def merge(self, model1, model2): m = model1.clone() From 321c5fa2958a2cdb05a08f6792fd2f72336e8c90 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Sep 2023 04:09:19 -0400 Subject: [PATCH 23/39] Enable pytorch attention by default on xpu. --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e38ef4eea..d8bc3bfea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -165,6 +165,9 @@ try: ENABLE_PYTORCH_ATTENTION = True if torch.cuda.is_bf16_supported(): VAE_DTYPE = torch.bfloat16 + if is_intel_xpu(): + if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + ENABLE_PYTORCH_ATTENTION = True except: pass From db63aa7e53c459b016cfa4159be004e59af84da9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Sep 2023 12:49:06 -0400 Subject: [PATCH 24/39] Nodes can now control the rounding in the UI. --- custom_nodes/example_node.py.example | 8 +++++++- nodes.py | 4 ++-- web/scripts/widgets.js | 23 ++++++++++++++++------- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index e37808b03..733014f3c 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -54,7 +54,13 @@ class Example: "step": 64, #Slider's step "display": "number" # Cosmetic only: display as "number" or "slider" }), - "float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}), + "float_field": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 10.0, + "step": 0.01, + "round": 0.001, #The value represeting the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. + "display": "number"}), "print_to_screen": (["enable", "disable"],), "string_field": ("STRING", { "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node diff --git a/nodes.py b/nodes.py index 77d180526..3bc08663e 100644 --- a/nodes.py +++ b/nodes.py @@ -1217,7 +1217,7 @@ class KSampler: {"model": ("MODEL",), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ), @@ -1243,7 +1243,7 @@ class KSamplerAdvanced: "add_noise": (["enable", "disable"], ), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ), diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 30caa6a8c..40b3067b7 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -2,17 +2,22 @@ import { api } from "./api.js" function getNumberDefaults(inputData, defaultStep) { let defaultVal = inputData[1]["default"]; - let { min, max, step } = inputData[1]; + let { min, max, step, round} = inputData[1]; if (defaultVal == undefined) defaultVal = 0; if (min == undefined) min = 0; if (max == undefined) max = 2048; if (step == undefined) step = defaultStep; -// precision is the number of decimal places to show. -// by default, display the the smallest number of decimal places such that changes of size step are visible. - let precision = Math.max(-Math.floor(Math.log10(step)),0) -// by default, round the value to those decimal places shown. - let round = Math.round(1000000*Math.pow(0.1,precision))/1000000; + + // precision is the number of decimal places to show. + // by default, display the the smallest number of decimal places such that changes of size step are visible. + let precision = Math.max(-Math.floor(Math.log10(step)),0); + + if (round == undefined || round === true) { + // by default, round the value to those decimal places shown. + round = Math.round(1000000*Math.pow(0.1,precision))/1000000; + } + return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; } @@ -271,7 +276,11 @@ export const ComfyWidgets = { const { val, config } = getNumberDefaults(inputData, 0.5); return { widget: node.addWidget(widgetType, inputName, val, function (v) { - this.value = Math.round(v/config.round)*config.round; + if (config.round) { + this.value = Math.round(v/config.round)*config.round; + } else { + this.value = v; + } }, config) }; }, INT(node, inputName, inputData, app) { From 01094316268cab9ed5cd53b825b359a7becb9d6c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Sep 2023 16:20:03 -0400 Subject: [PATCH 25/39] Lower the minimum resolution of EmptyLatentImage. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 3bc08663e..9ccf179ce 100644 --- a/nodes.py +++ b/nodes.py @@ -889,8 +889,8 @@ class EmptyLatentImage: @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" From b92bf8196e0d3158b3e981d056a2be15ce5ab1cd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Sep 2023 23:04:49 -0400 Subject: [PATCH 26/39] Do lora cast on GPU instead of CPU for higher performance. --- comfy/model_patcher.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a6ee0bae1..85bf5bd2a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -187,13 +187,13 @@ class ModelPatcher: else: weight += alpha * w1.type(weight.dtype).to(weight.device) elif len(v) == 4: #lora/locon - mat1 = v[0].float().to(weight.device) - mat2 = v[1].float().to(weight.device) + mat1 = v[0].to(weight.device).float() + mat2 = v[1].to(weight.device).float() if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = v[3].float().to(weight.device) + mat3 = v[3].to(weight.device).float() final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: @@ -212,18 +212,18 @@ class ModelPatcher: if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(w1_a.float(), w1_b.float()) + w1 = torch.mm(w1_a.to(weight.device).float(), w1_b.to(weight.device).float()) else: - w1 = w1.float().to(weight.device) + w1 = w1.to(weight.device).float() if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) + w2 = torch.mm(w2_a.to(weight.device).float(), w2_b.to(weight.device).float()) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2_b.to(weight.device).float(), w2_a.to(weight.device).float()) else: - w2 = w2.float().to(weight.device) + w2 = w2.to(weight.device).float() if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -244,11 +244,11 @@ class ModelPatcher: if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.to(weight.device).float(), w1b.to(weight.device).float(), w1a.to(weight.device).float()) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2b.to(weight.device).float(), w2a.to(weight.device).float()) else: - m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) - m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) + m1 = torch.mm(w1a.to(weight.device).float(), w1b.to(weight.device).float()) + m2 = torch.mm(w2a.to(weight.device).float(), w2b.to(weight.device).float()) try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) From 26cd8405ddb32216d02b5eed23f6481cb36873c7 Mon Sep 17 00:00:00 2001 From: enzymezoo-code <103286087+enzymezoo-code@users.noreply.github.com> Date: Mon, 18 Sep 2023 22:18:06 -0500 Subject: [PATCH 27/39] Ci quality workflows (#1423) * Add inference tests * Clean up * Rename test graph file * Add readme for tests * Separate server fixture * test file name change * Assert images are generated * Clean up comments * Add __init__.py so tests can run with command line `pytest` * Fix command line args for pytest * Loop all samplers/schedulers in test_inference.py * Ci quality workflows compare (#1) * Add image comparison tests * Comparison tests do not pass with empty metadata * Ensure tests are run in correct order * Save image files with test name * Update tests readme * Reduce step counts in tests to ~halve runtime * Ci quality workflows build (#2) * Add build test github workflow --- .github/workflows/test-build.yml | 31 +++ pytest.ini | 5 + tests/README.md | 29 ++ tests/__init__.py | 0 tests/compare/conftest.py | 41 +++ tests/compare/test_quality.py | 195 ++++++++++++++ tests/conftest.py | 36 +++ tests/inference/__init__.py | 0 .../graphs/default_graph_sdxl1_0.json | 144 ++++++++++ tests/inference/test_inference.py | 247 ++++++++++++++++++ 10 files changed, 728 insertions(+) create mode 100644 .github/workflows/test-build.yml create mode 100644 pytest.ini create mode 100644 tests/README.md create mode 100644 tests/__init__.py create mode 100644 tests/compare/conftest.py create mode 100644 tests/compare/test_quality.py create mode 100644 tests/conftest.py create mode 100644 tests/inference/__init__.py create mode 100644 tests/inference/graphs/default_graph_sdxl1_0.json create mode 100644 tests/inference/test_inference.py diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml new file mode 100644 index 000000000..421dd5ee4 --- /dev/null +++ b/.github/workflows/test-build.yml @@ -0,0 +1,31 @@ +name: Build package + +# +# This workflow is a test of the python package build. +# Install Python dependencies across different Python versions. +# + +on: + push: + paths: + - "requirements.txt" + - ".github/workflows/test-build.yml" + +jobs: + build: + name: Build Test + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..b5a68e0f1 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +markers = + inference: mark as inference test (deselect with '-m "not inference"') +testpaths = tests +addopts = -s \ No newline at end of file diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..2005fd45b --- /dev/null +++ b/tests/README.md @@ -0,0 +1,29 @@ +# Automated Testing + +## Running tests locally + +Additional requirements for running tests: +``` +pip install pytest +pip install websocket-client==1.6.1 +opencv-python==4.6.0.66 +scikit-image==0.21.0 +``` +Run inference tests: +``` +pytest tests/inference +``` + +## Quality regression test +Compares images in 2 directories to ensure they are the same + +1) Run an inference test to save a directory of "ground truth" images +``` + pytest tests/inference --output_dir tests/inference/baseline +``` +2) Make code edits + +3) Run inference and quality comparison tests +``` +pytest +``` \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/compare/conftest.py b/tests/compare/conftest.py new file mode 100644 index 000000000..dd5078c9e --- /dev/null +++ b/tests/compare/conftest.py @@ -0,0 +1,41 @@ +import os +import pytest + +# Command line arguments for pytest +def pytest_addoption(parser): + parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images') + parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test') + parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics') + parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images') + +# This initializes args at the beginning of the test session +@pytest.fixture(scope="session", autouse=True) +def args_pytest(pytestconfig): + args = {} + args['baseline_dir'] = pytestconfig.getoption('baseline_dir') + args['test_dir'] = pytestconfig.getoption('test_dir') + args['metrics_file'] = pytestconfig.getoption('metrics_file') + args['img_output_dir'] = pytestconfig.getoption('img_output_dir') + + # Initialize metrics file + with open(args['metrics_file'], 'a') as f: + # if file is empty, write header + if os.stat(args['metrics_file']).st_size == 0: + f.write("| date | run | file | status | value | \n") + f.write("| --- | --- | --- | --- | --- | \n") + + return args + + +def gather_file_basenames(directory: str): + files = [] + for file in os.listdir(directory): + if file.endswith(".png"): + files.append(file) + return files + +# Creates the list of baseline file names to use as a fixture +def pytest_generate_tests(metafunc): + if "baseline_fname" in metafunc.fixturenames: + baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir")) + metafunc.parametrize("baseline_fname", baseline_fnames) diff --git a/tests/compare/test_quality.py b/tests/compare/test_quality.py new file mode 100644 index 000000000..92a2d5a8b --- /dev/null +++ b/tests/compare/test_quality.py @@ -0,0 +1,195 @@ +import datetime +import numpy as np +import os +from PIL import Image +import pytest +from pytest import fixture +from typing import Tuple, List + +from cv2 import imread, cvtColor, COLOR_BGR2RGB +from skimage.metrics import structural_similarity as ssim + + +""" +This test suite compares images in 2 directories by file name +The directories are specified by the command line arguments --baseline_dir and --test_dir + +""" +# ssim: Structural Similarity Index +# Returns a tuple of (ssim, diff_image) +def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]: + score, diff = ssim(img0, img1, channel_axis=-1, full=True) + # rescale the difference image to 0-255 range + diff = (diff * 255).astype("uint8") + return score, diff + +# Metrics must return a tuple of (score, diff_image) +METRICS = {"ssim": ssim_score} +METRICS_PASS_THRESHOLD = {"ssim": 0.95} + + +class TestCompareImageMetrics: + @fixture(scope="class") + def test_file_names(self, args_pytest): + test_dir = args_pytest['test_dir'] + fnames = self.gather_file_basenames(test_dir) + yield fnames + del fnames + + @fixture(scope="class", autouse=True) + def teardown(self, args_pytest): + yield + # Runs after all tests are complete + # Aggregate output files into a grid of images + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + img_output_dir = args_pytest['img_output_dir'] + metrics_file = args_pytest['metrics_file'] + + grid_dir = os.path.join(img_output_dir, "grid") + os.makedirs(grid_dir, exist_ok=True) + + for metric_dir in METRICS.keys(): + metric_path = os.path.join(img_output_dir, metric_dir) + for file in os.listdir(metric_path): + if file.endswith(".png"): + score = self.lookup_score_from_fname(file, metrics_file) + image_file_list = [] + image_file_list.append([ + os.path.join(baseline_dir, file), + os.path.join(test_dir, file), + os.path.join(metric_path, file) + ]) + # Create grid + image_list = [[Image.open(file) for file in files] for files in image_file_list] + grid = self.image_grid(image_list) + grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}")) + + # Tests run for each baseline file name + @fixture() + def fname(self, baseline_fname): + yield baseline_fname + del baseline_fname + + def test_directories_not_empty(self, args_pytest): + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty" + assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty" + + def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest): + # Check that all files in baseline_dir have a file in test_dir with matching metadata + baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname) + file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names] + file_match = self.find_file_match(baseline_file_path, file_paths) + assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}" + + # For a baseline image file, finds the corresponding file name in test_dir and + # compares the images using the metrics in METRICS + @pytest.mark.parametrize("metric", METRICS.keys()) + def test_pipeline_compare( + self, + args_pytest, + fname, + test_file_names, + metric, + ): + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + metrics_output_file = args_pytest['metrics_file'] + img_output_dir = args_pytest['img_output_dir'] + + baseline_file_path = os.path.join(baseline_dir, fname) + + # Find file match + file_paths = [os.path.join(test_dir, f) for f in test_file_names] + test_file = self.find_file_match(baseline_file_path, file_paths) + + # Run metrics + sample_baseline = self.read_img(baseline_file_path) + sample_secondary = self.read_img(test_file) + + score, metric_img = METRICS[metric](sample_baseline, sample_secondary) + metric_status = score > METRICS_PASS_THRESHOLD[metric] + + # Save metric values + with open(metrics_output_file, 'a') as f: + run_info = os.path.splitext(fname)[0] + metric_status_str = "PASS ✅" if metric_status else "FAIL ❌" + date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n") + + # Save metric image + metric_img_dir = os.path.join(img_output_dir, metric) + os.makedirs(metric_img_dir, exist_ok=True) + output_filename = f'{fname}' + Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename)) + + assert score > METRICS_PASS_THRESHOLD[metric] + + def read_img(self, filename: str) -> np.ndarray: + cvImg = imread(filename) + cvImg = cvtColor(cvImg, COLOR_BGR2RGB) + return cvImg + + def image_grid(self, img_list: list[list[Image.Image]]): + # imgs is a 2D list of images + # Assumes the input images are a rectangular grid of equal sized images + rows = len(img_list) + cols = len(img_list[0]) + + w, h = img_list[0][0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + + for i, row in enumerate(img_list): + for j, img in enumerate(row): + grid.paste(img, box=(j*w, i*h)) + return grid + + def lookup_score_from_fname(self, + fname: str, + metrics_output_file: str + ) -> float: + fname_basestr = os.path.splitext(fname)[0] + with open(metrics_output_file, 'r') as f: + for line in f: + if fname_basestr in line: + score = float(line.split('|')[5]) + return score + raise ValueError(f"Could not find score for {fname} in {metrics_output_file}") + + def gather_file_basenames(self, directory: str): + files = [] + for file in os.listdir(directory): + if file.endswith(".png"): + files.append(file) + return files + + def read_file_prompt(self, fname:str) -> str: + # Read prompt from image file metadata + img = Image.open(fname) + img.load() + return img.info['prompt'] + + def find_file_match(self, baseline_file: str, file_paths: List[str]): + # Find a file in file_paths with matching metadata to baseline_file + baseline_prompt = self.read_file_prompt(baseline_file) + + # Do not match empty prompts + if baseline_prompt is None or baseline_prompt == "": + return None + + # Find file match + # Reorder test_file_names so that the file with matching name is first + # This is an optimization because matching file names are more likely + # to have matching metadata if they were generated with the same script + basename = os.path.basename(baseline_file) + file_path_basenames = [os.path.basename(f) for f in file_paths] + if basename in file_path_basenames: + match_index = file_path_basenames.index(basename) + file_paths.insert(0, file_paths.pop(match_index)) + + for f in file_paths: + test_file_prompt = self.read_file_prompt(f) + if baseline_prompt == test_file_prompt: + return f \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..1a35880af --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,36 @@ +import os +import pytest + +# Command line arguments for pytest +def pytest_addoption(parser): + parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images') + parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") + parser.addoption("--port", type=int, default=8188, help="Set the listen port.") + +# This initializes args at the beginning of the test session +@pytest.fixture(scope="session", autouse=True) +def args_pytest(pytestconfig): + args = {} + args['output_dir'] = pytestconfig.getoption('output_dir') + args['listen'] = pytestconfig.getoption('listen') + args['port'] = pytestconfig.getoption('port') + + os.makedirs(args['output_dir'], exist_ok=True) + + return args + +def pytest_collection_modifyitems(items): + # Modifies items so tests run in the correct order + + LAST_TESTS = ['test_quality'] + + # Move the last items to the end + last_items = [] + for test_name in LAST_TESTS: + for item in items.copy(): + print(item.module.__name__, item) + if item.module.__name__ == test_name: + last_items.append(item) + items.remove(item) + + items.extend(last_items) diff --git a/tests/inference/__init__.py b/tests/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/graphs/default_graph_sdxl1_0.json b/tests/inference/graphs/default_graph_sdxl1_0.json new file mode 100644 index 000000000..c06c6829c --- /dev/null +++ b/tests/inference/graphs/default_graph_sdxl1_0.json @@ -0,0 +1,144 @@ +{ + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage" + }, + "6": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "10": { + "inputs": { + "add_noise": "enable", + "noise_seed": 42, + "steps": 20, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 0, + "end_at_step": 32, + "return_with_leftover_noise": "enable", + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "15", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "12": { + "inputs": { + "samples": [ + "14", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode" + }, + "13": { + "inputs": { + "filename_prefix": "test_inference", + "images": [ + "12", + 0 + ] + }, + "class_type": "SaveImage" + }, + "14": { + "inputs": { + "add_noise": "disable", + "noise_seed": 42, + "steps": 20, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 32, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": [ + "16", + 0 + ], + "positive": [ + "17", + 0 + ], + "negative": [ + "20", + 0 + ], + "latent_image": [ + "10", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "15": { + "inputs": { + "conditioning": [ + "6", + 0 + ] + }, + "class_type": "ConditioningZeroOut" + }, + "16": { + "inputs": { + "ckpt_name": "sd_xl_refiner_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "17": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "20": { + "inputs": { + "text": "", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + } + } \ No newline at end of file diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py new file mode 100644 index 000000000..a96f94550 --- /dev/null +++ b/tests/inference/test_inference.py @@ -0,0 +1,247 @@ +from copy import deepcopy +from io import BytesIO +from urllib import request +import numpy +import os +from PIL import Image +import pytest +from pytest import fixture +import time +import torch +from typing import Union +import json +import subprocess +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import urllib.request +import urllib.parse + +# Currently causes an error when running pytest with built-in pytest args +# TODO: modify cli_args.py to not parse args on import +# We will hard-code sampler and scheduler lists for now +# from comfy.samplers import KSampler + +""" +These tests generate and save images through a range of parameters +""" + +class ComfyGraph: + def __init__(self, + graph: dict, + sampler_nodes: list[str], + ): + self.graph = graph + self.sampler_nodes = sampler_nodes + + def set_prompt(self, prompt, negative_prompt=None): + # Sets the prompt for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + prompt_node = self.graph[node]['inputs']['positive'][0] + self.graph[prompt_node]['inputs']['text'] = prompt + if negative_prompt: + negative_prompt_node = self.graph[node]['inputs']['negative'][0] + self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt + + def set_sampler_name(self, sampler_name:str, ): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['sampler_name'] = sampler_name + + def set_scheduler(self, scheduler:str): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['scheduler'] = scheduler + + def set_filename_prefix(self, prefix:str): + # sets the filename prefix for the save nodes + for node in self.graph: + if self.graph[node]['class_type'] == 'SaveImage': + self.graph[node]['inputs']['filename_prefix'] = prefix + + +class ComfyClient: + # From examples/websockets_api_example.py + + def connect(self, + listen:str = '127.0.0.1', + port:Union[str,int] = 8188, + client_id: str = str(uuid.uuid4()) + ): + self.client_id = client_id + self.server_address = f"{listen}:{port}" + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) + self.ws = ws + + def queue_prompt(self, prompt): + p = {"prompt": prompt, "client_id": self.client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + def get_image(self, filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: + return response.read() + + def get_history(self, prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: + return json.loads(response.read()) + + def get_images(self, graph, save=True): + prompt = graph + if not save: + # Replace save nodes with preview nodes + prompt_str = json.dumps(prompt) + prompt_str = prompt_str.replace('SaveImage', 'PreviewImage') + prompt = json.loads(prompt_str) + + prompt_id = self.queue_prompt(prompt)['prompt_id'] + output_images = {} + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: + break #Execution is done + else: + continue #previews are binary data + + history = self.get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output + + return output_images + +# +# Initialize graphs +# +default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json' +with open(default_graph_file, 'r') as file: + default_graph = json.loads(file.read()) +DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14']) +DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0] + +# +# Loop through these variables +# +comfy_graph_list = [DEFAULT_COMFY_GRAPH] +comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID] +prompt_list = [ + 'a painting of a cat', +] +#TODO use sampler and scheduler list from comfy.samplers.KSampler +# sampler_list = KSampler.SAMPLERS +# scheduler_list = KSampler.SCHEDULERS +# Hard coded sampler and scheduler lists for now +SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] +SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", + "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] +sampler_list = SAMPLERS +scheduler_list = SCHEDULERS +@pytest.mark.inference +@pytest.mark.parametrize("sampler", sampler_list) +@pytest.mark.parametrize("scheduler", scheduler_list) +@pytest.mark.parametrize("prompt", prompt_list) +class TestInference: + # + # Initialize server and client + # + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + # Start server + p = subprocess.Popen([ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + ]) + yield + p.kill() + torch.cuda.empty_cache() + + def start_client(self, listen:str, port:int): + # Start client + comfy_client = ComfyClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + comfy_client.connect(listen=listen, port=port) + except ConnectionRefusedError as e: + print(e) + print(f"({i+1}/{n_tries}) Retrying...") + else: + break + return comfy_client + + # + # Client and graph fixtures with server warmup + # + # Returns a "_client_graph", which is client-graph pair corresponding to an initialized server + # The "graph" is the default graph + @fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True) + def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph): + comfy_graph = request.param + + # Start client + comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"]) + + # Warm up pipeline + comfy_client.get_images(graph=comfy_graph.graph, save=False) + + yield comfy_client, comfy_graph + del comfy_client + del comfy_graph + torch.cuda.empty_cache() + + @fixture + def client(self, _client_graph): + client = _client_graph[0] + yield client + + @fixture + def comfy_graph(self, _client_graph): + # avoid mutating the graph + graph = deepcopy(_client_graph[1]) + yield graph + + def test_comfy( + self, + client, + comfy_graph, + sampler, + scheduler, + prompt, + request + ): + test_info = request.node.name + comfy_graph.set_filename_prefix(test_info) + # Settings for comfy graph + comfy_graph.set_sampler_name(sampler) + comfy_graph.set_scheduler(scheduler) + comfy_graph.set_prompt(prompt) + + # Generate + images = client.get_images(comfy_graph.graph) + + assert len(images) != 0, "No images generated" + # assert all images are not blank + for images_output in images.values(): + for image_data in images_output: + pil_image = Image.open(BytesIO(image_data)) + assert numpy.array(pil_image).any() != 0, "Image is blank" + + From 7c93afd2cd826aea7b49e49f42502b5ac03b647d Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Tue, 19 Sep 2023 05:20:00 +0200 Subject: [PATCH 28/39] Manual float precision, toggle for old behavior (#1541) * Add toggle for float rounding * Add manual precision override --- web/scripts/ui.js | 19 +++++++++++++++++++ web/scripts/widgets.js | 12 +++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index f39939bf3..1e7920167 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -577,6 +577,25 @@ export class ComfyUI { defaultValue: false, }); + this.settings.addSetting({ + id: "Comfy.DisableFloatRounding", + name: "Disable rounding floats (requires page reload).", + type: "boolean", + defaultValue: false, + }); + + this.settings.addSetting({ + id: "Comfy.FloatRoundingPrecision", + name: "Decimal places [0 = auto] (requires page reload).", + type: "slider", + attrs: { + min: 0, + max: 6, + step: 1, + }, + defaultValue: 0, + }); + const fileInput = $el("input", { id: "comfy-file-input", type: "file", diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 40b3067b7..942be8f36 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,6 +1,6 @@ import { api } from "./api.js" -function getNumberDefaults(inputData, defaultStep) { +function getNumberDefaults(inputData, defaultStep, app) { let defaultVal = inputData[1]["default"]; let { min, max, step, round} = inputData[1]; @@ -8,12 +8,14 @@ function getNumberDefaults(inputData, defaultStep) { if (min == undefined) min = 0; if (max == undefined) max = 2048; if (step == undefined) step = defaultStep; - // precision is the number of decimal places to show. // by default, display the the smallest number of decimal places such that changes of size step are visible. let precision = Math.max(-Math.floor(Math.log10(step)),0); + if (app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision") > 0) { + precision = app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision"); + } - if (round == undefined || round === true) { + if (!app.ui.settings.getSettingValue("Comfy.DisableFloatRounding") && (round == undefined || round === true)) { // by default, round the value to those decimal places shown. round = Math.round(1000000*Math.pow(0.1,precision))/1000000; } @@ -273,7 +275,7 @@ export const ComfyWidgets = { "INT:noise_seed": seedWidget, FLOAT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 0.5); + const { val, config } = getNumberDefaults(inputData, 0.5, app); return { widget: node.addWidget(widgetType, inputName, val, function (v) { if (config.round) { @@ -285,7 +287,7 @@ export const ComfyWidgets = { }, INT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 1); + const { val, config } = getNumberDefaults(inputData, 1, app); Object.assign(config, { precision: 0 }); return { widget: node.addWidget( From f32463936d3b8205df7b66dbd9c3f9a2fd69668a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Sep 2023 23:23:25 -0400 Subject: [PATCH 29/39] Unhardcode sampler and scheduler list in test. --- tests/inference/test_inference.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index a96f94550..141cc5c7e 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -16,10 +16,8 @@ import uuid import urllib.request import urllib.parse -# Currently causes an error when running pytest with built-in pytest args -# TODO: modify cli_args.py to not parse args on import -# We will hard-code sampler and scheduler lists for now -# from comfy.samplers import KSampler + +from comfy.samplers import KSampler """ These tests generate and save images through a range of parameters @@ -140,16 +138,10 @@ comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID] prompt_list = [ 'a painting of a cat', ] -#TODO use sampler and scheduler list from comfy.samplers.KSampler -# sampler_list = KSampler.SAMPLERS -# scheduler_list = KSampler.SCHEDULERS -# Hard coded sampler and scheduler lists for now -SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] -SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", - "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] -sampler_list = SAMPLERS -scheduler_list = SCHEDULERS + +sampler_list = KSampler.SAMPLERS +scheduler_list = KSampler.SCHEDULERS + @pytest.mark.inference @pytest.mark.parametrize("sampler", sampler_list) @pytest.mark.parametrize("scheduler", scheduler_list) From 6d3dee9d16254979592b95399835a54428b3cea6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Sep 2023 23:33:19 -0400 Subject: [PATCH 30/39] Clean up #1541. --- web/scripts/widgets.js | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 942be8f36..2b0239374 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,6 +1,6 @@ import { api } from "./api.js" -function getNumberDefaults(inputData, defaultStep, app) { +function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { let defaultVal = inputData[1]["default"]; let { min, max, step, round} = inputData[1]; @@ -10,17 +10,15 @@ function getNumberDefaults(inputData, defaultStep, app) { if (step == undefined) step = defaultStep; // precision is the number of decimal places to show. // by default, display the the smallest number of decimal places such that changes of size step are visible. - let precision = Math.max(-Math.floor(Math.log10(step)),0); - if (app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision") > 0) { - precision = app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision"); + if (precision == undefined) { + precision = Math.max(-Math.floor(Math.log10(step)),0); } - if (!app.ui.settings.getSettingValue("Comfy.DisableFloatRounding") && (round == undefined || round === true)) { + if (enable_rounding && (round == undefined || round === true)) { // by default, round the value to those decimal places shown. round = Math.round(1000000*Math.pow(0.1,precision))/1000000; } - return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; } @@ -275,7 +273,10 @@ export const ComfyWidgets = { "INT:noise_seed": seedWidget, FLOAT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 0.5, app); + let precision = app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision"); + let disable_rounding = app.ui.settings.getSettingValue("Comfy.DisableFloatRounding") + if (precision == 0) precision = undefined; + const { val, config } = getNumberDefaults(inputData, 0.5, precision, !disable_rounding); return { widget: node.addWidget(widgetType, inputName, val, function (v) { if (config.round) { @@ -287,7 +288,7 @@ export const ComfyWidgets = { }, INT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 1, app); + const { val, config } = getNumberDefaults(inputData, 1, 0, true); Object.assign(config, { precision: 0 }); return { widget: node.addWidget( From 2b6b17817331a24afc7106bfe9ec3e2f9b03fab1 Mon Sep 17 00:00:00 2001 From: MoonRide303 Date: Tue, 19 Sep 2023 10:40:38 +0200 Subject: [PATCH 31/39] Added support for lanczos scaling --- comfy/utils.py | 11 +++++++++++ comfy_extras/nodes_post_processing.py | 2 +- nodes.py | 4 ++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 3ed32e372..4e08bcb80 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,8 +1,10 @@ import torch +import torchvision import math import struct import comfy.checkpoint_pickle import safetensors.torch +from PIL import Image def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -346,6 +348,13 @@ def bislerp(samples, width, height): result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result +def lanczos(samples, width, height): + images = [torchvision.transforms.functional.to_pil_image(image) for image in samples] + images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] + images = [torchvision.transforms.functional.to_tensor(image) for image in images] + result = torch.stack(images) + return result + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] @@ -364,6 +373,8 @@ def common_upscale(samples, width, height, upscale_method, crop): if upscale_method == "bislerp": return bislerp(s, width, height) + elif upscale_method == "lanczos": + return lanczos(s, width, height) else: return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 51bdb24fa..3f651e594 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -211,7 +211,7 @@ class Sharpen: return (result,) class ImageScaleToTotalPixels: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod diff --git a/nodes.py b/nodes.py index 9ccf179ce..59c50a161 100644 --- a/nodes.py +++ b/nodes.py @@ -1423,7 +1423,7 @@ class LoadImageMask: return True class ImageScale: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod @@ -1444,7 +1444,7 @@ class ImageScale: return (s,) class ImageScaleBy: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] @classmethod def INPUT_TYPES(s): From 83215924081198b7b4cd95a89046a4527951fc68 Mon Sep 17 00:00:00 2001 From: Sean Lynch Date: Tue, 19 Sep 2023 08:18:29 -0400 Subject: [PATCH 32/39] Escape paths when passing them to globs Try to prevent JS search from breaking on pathnames with square brackets. --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index d04060499..b2e16716b 100644 --- a/server.py +++ b/server.py @@ -132,12 +132,12 @@ class PromptServer(): @routes.get("/extensions") async def get_extensions(request): files = glob.glob(os.path.join( - self.web_root, 'extensions/**/*.js'), recursive=True) + glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True) extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) for name, dir in nodes.EXTENSION_WEB_DIRS.items(): - files = glob.glob(os.path.join(dir, '**/*.js'), recursive=True) + files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True) extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote( name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) From 7c9a92f552552cb51c9230d80d05ee42ebd8be90 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 19 Sep 2023 13:12:47 -0400 Subject: [PATCH 33/39] Don't depend on torchvision. --- comfy/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 4e08bcb80..7843b58cc 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,9 +1,9 @@ import torch -import torchvision import math import struct import comfy.checkpoint_pickle import safetensors.torch +import numpy as np from PIL import Image def load_torch_file(ckpt, safe_load=False, device=None): @@ -349,9 +349,9 @@ def bislerp(samples, width, height): return result def lanczos(samples, width, height): - images = [torchvision.transforms.functional.to_pil_image(image) for image in samples] + images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] - images = [torchvision.transforms.functional.to_tensor(image) for image in images] + images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) return result From b92a86d7370b28af6777c3859f7d486191f6379a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 13:24:08 -0400 Subject: [PATCH 34/39] Update litegraph to upstream. --- web/lib/litegraph.core.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 4a21a1b34..8fb5d07a8 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -4928,7 +4928,7 @@ LGraphNode.prototype.executeAction = function(action) this.title = o.title; this._bounding.set(o.bounding); this.color = o.color; - this.font = o.font; + this.font_size = o.font_size; }; LGraphGroup.prototype.serialize = function() { @@ -4942,7 +4942,7 @@ LGraphNode.prototype.executeAction = function(action) Math.round(b[3]) ], color: this.color, - font: this.font + font_size: this.font_size }; }; From 1cdfb3dba4e7af11e2e05dc6a6276ba84eb1adf2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 17:52:41 -0400 Subject: [PATCH 35/39] Only do the cast on the device if the device supports it. --- comfy/model_management.py | 17 ++++++++++++++++ comfy/model_patcher.py | 43 ++++++++++++++++++++++++++------------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d8bc3bfea..1050c13a4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -481,6 +481,23 @@ def get_autocast_device(dev): return dev.type return "cuda" +def cast_to_device(tensor, device, dtype, copy=False): + device_supports_cast = False + if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: + device_supports_cast = True + elif tensor.dtype == torch.bfloat16: + if hasattr(device, 'type') and device.type.startswith("cuda"): + device_supports_cast = True + + if device_supports_cast: + if copy: + if tensor.device == device: + return tensor.to(dtype, copy=copy) + return tensor.to(device, copy=copy).to(dtype) + else: + return tensor.to(device).to(dtype) + else: + return tensor.to(dtype).to(device, copy=copy) def xformers_enabled(): global directml_enabled diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 85bf5bd2a..10551656e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -3,6 +3,7 @@ import copy import inspect import comfy.utils +import comfy.model_management class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None): @@ -154,7 +155,7 @@ class ModelPatcher: self.backup[key] = weight.to(self.offload_device) if device_to is not None: - temp_weight = weight.float().to(device_to, copy=True) + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) @@ -185,15 +186,15 @@ class ModelPatcher: if w1.shape != weight.shape: print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: - weight += alpha * w1.type(weight.dtype).to(weight.device) + weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) elif len(v) == 4: #lora/locon - mat1 = v[0].to(weight.device).float() - mat2 = v[1].to(weight.device).float() + mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) + mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = v[3].to(weight.device).float() + mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: @@ -212,18 +213,23 @@ class ModelPatcher: if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(w1_a.to(weight.device).float(), w1_b.to(weight.device).float()) + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32)) else: - w1 = w1.to(weight.device).float() + w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(w2_a.to(weight.device).float(), w2_b.to(weight.device).float()) + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32)) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2_b.to(weight.device).float(), w2_a.to(weight.device).float()) + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32)) else: - w2 = w2.to(weight.device).float() + w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -244,11 +250,20 @@ class ModelPatcher: if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.to(weight.device).float(), w1b.to(weight.device).float(), w1a.to(weight.device).float()) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2b.to(weight.device).float(), w2a.to(weight.device).float()) + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1a, weight.device, torch.float32)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2a, weight.device, torch.float32)) else: - m1 = torch.mm(w1a.to(weight.device).float(), w1b.to(weight.device).float()) - m2 = torch.mm(w2a.to(weight.device).float(), w2b.to(weight.device).float()) + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1b, weight.device, torch.float32)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2b, weight.device, torch.float32)) try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) From 1122df1a2018eda31605703e7b3388ad80f209e1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 17:58:54 -0400 Subject: [PATCH 36/39] Increase range of lora strengths. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 59c50a161..18d82ea80 100644 --- a/nodes.py +++ b/nodes.py @@ -543,8 +543,8 @@ class LoraLoader: return {"required": { "model": ("MODEL",), "clip": ("CLIP", ), "lora_name": (folder_paths.get_filename_list("loras"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - "strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL", "CLIP") FUNCTION = "load_lora" From 4d41bd595c1e2bf55f9e3ccee0921b1213c0184a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 21:46:41 -0400 Subject: [PATCH 37/39] Fix loading group titles. --- web/lib/litegraph.core.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 8fb5d07a8..f81c83a8a 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -4928,7 +4928,9 @@ LGraphNode.prototype.executeAction = function(action) this.title = o.title; this._bounding.set(o.bounding); this.color = o.color; - this.font_size = o.font_size; + if (o.font_size) { + this.font_size = o.font_size; + } }; LGraphGroup.prototype.serialize = function() { From 0793eb926933034997cc2383adc414d080643e77 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 23:16:01 -0400 Subject: [PATCH 38/39] Only clear clipboard when copying nodes. --- web/scripts/app.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index f0bb8640c..5efe08c00 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -753,8 +753,9 @@ export class ComfyApp { // Default system copy return; } + // copy nodes and clear clipboard - if (this.canvas.selected_nodes) { + if (e.target.className === "litegraph" && this.canvas.selected_nodes) { this.canvas.copyToClipboard(); e.clipboardData.setData('text', ' '); //clearData doesn't remove images from clipboard e.preventDefault(); From 492db2de8db7e082addf131b40adb4a1b7535821 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 21 Sep 2023 01:14:42 -0400 Subject: [PATCH 39/39] Allow having a different pooled output for each image in a batch. --- comfy/model_base.py | 4 ++-- comfy/samplers.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index ca154dba0..ed2dc83e4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -181,7 +181,7 @@ class SDXLRefiner(BaseModel): out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([aesthetic_score]))) - flat = torch.flatten(torch.cat(out))[None, ] + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SDXL(BaseModel): @@ -206,5 +206,5 @@ class SDXL(BaseModel): out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([target_height]))) out.append(self.embedder(torch.Tensor([target_width]))) - flat = torch.flatten(torch.cat(out))[None, ] + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) diff --git a/comfy/samplers.py b/comfy/samplers.py index 57673a029..e3192ca58 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -7,6 +7,7 @@ from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math from comfy import model_base +import comfy.utils def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) return abs(a*b) // math.gcd(a, b) @@ -538,7 +539,7 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): if adm_out is not None: x[1] = x[1].copy() - x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device) + x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device) return conds