From af3219706788586323b8fa5eb6f904f54f71c7ca Mon Sep 17 00:00:00 2001 From: Gregor Adams <1148334+pixelass@users.noreply.github.com> Date: Wed, 9 Aug 2023 13:03:30 +0200 Subject: [PATCH 001/150] feat(extensions): Allow hiding link connectors Thank you for adding this feature (linksRenderMode) to core. I would like to add the "Hidden" option (invalid number 3 will just hide the connector lines), so that I can remove that extension from my extension pack to prevent conflicts https://github.com/failfa-st/failfast-comfyui-extensions --- web/extensions/core/linkRenderMode.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/linkRenderMode.js b/web/extensions/core/linkRenderMode.js index 1e9091ec1..fb4df4234 100644 --- a/web/extensions/core/linkRenderMode.js +++ b/web/extensions/core/linkRenderMode.js @@ -9,7 +9,7 @@ const ext = { name: "Link Render Mode", defaultValue: 2, type: "combo", - options: LiteGraph.LINK_RENDER_MODES.map((m, i) => ({ + options: [...LiteGraph.LINK_RENDER_MODES, "Hidden"].map((m, i) => ({ value: i, text: m, selected: i == app.canvas.links_render_mode, From 15adc3699f74e0f3fcd3c29e62f8256825098e88 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 29 Aug 2023 14:22:53 -0400 Subject: [PATCH 002/150] Move beta_schedule to model_config and allow disabling unet creation. --- comfy/model_base.py | 5 +++-- comfy/supported_models_base.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index d654f56f6..acd4169a8 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -19,8 +19,9 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config - self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) - self.diffusion_model = UNetModel(**unet_config, device=device) + self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + if not unet_config.get("disable_unet_model_creation", False): + self.diffusion_model = UNetModel(**unet_config, device=device) self.model_type = model_type self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index d0088bbd5..c72838008 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -33,6 +33,7 @@ class BASE: clip_prefix = [] clip_vision_prefix = None noise_aug_config = None + beta_schedule = "linear" @classmethod def matches(s, unet_config): From f2f5e5dcbb9a39d15514a00de008b50bb4cba8e0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 29 Aug 2023 16:44:57 -0400 Subject: [PATCH 003/150] Support SDXL t2i adapters with 3 channel input. --- comfy/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 7098186f9..83e1be058 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -465,7 +465,7 @@ def load_t2i_adapter(t2i_data): if len(down_opts) > 0: use_conv = True xl = False - if cin == 256: + if cin == 256 or cin == 768: xl = True model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) else: From 81d9200e1851f160dacf9ad30e55444a2e42241e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 29 Aug 2023 17:55:42 -0400 Subject: [PATCH 004/150] Add node to convert a specific colour in an image to a mask. --- comfy_extras/nodes_mask.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 5adb468ac..43f623a62 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -125,6 +125,27 @@ class ImageToMask: mask = image[0, :, :, channels.index(channel)] return (mask,) +class ImageColorToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, color): + temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int) + temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2] + mask = torch.where(temp == color, 255, 0).float() + return (mask,) + class SolidMask: @classmethod def INPUT_TYPES(cls): @@ -315,6 +336,7 @@ NODE_CLASS_MAPPINGS = { "ImageCompositeMasked": ImageCompositeMasked, "MaskToImage": MaskToImage, "ImageToMask": ImageToMask, + "ImageColorToMask": ImageColorToMask, "SolidMask": SolidMask, "InvertMask": InvertMask, "CropMask": CropMask, From d70b0bc43c2d1d0a2212e9493926ddf231c80a41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 29 Aug 2023 17:58:40 -0400 Subject: [PATCH 005/150] Use the GPU for the canny preprocessor when available. --- comfy_extras/nodes_canny.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index d7c3f132f..94d453f2c 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -3,7 +3,7 @@ import math import torch import torch.nn.functional as F - +import comfy.model_management def get_canny_nms_kernel(device=None, dtype=None): """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" @@ -290,8 +290,8 @@ class Canny: CATEGORY = "image/preprocessors" def detect_edge(self, image, low_threshold, high_threshold): - output = canny(image.movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].repeat(1, 3, 1, 1).movedim(1, -1) + output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) + img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1) return (img_out,) NODE_CLASS_MAPPINGS = { From fe4c07400c792ff8c5247cc4697235e33d96fcc9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 29 Aug 2023 23:58:32 -0400 Subject: [PATCH 006/150] Fix "Load Checkpoint with config" node. --- comfy/sd.py | 6 ++++-- comfy/supported_models_base.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index a63a0d1de..e98dabe88 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -22,6 +22,7 @@ from . import sdxl_clip import comfy.model_patcher import comfy.lora import comfy.t2i_adapter.adapter +import comfy.supported_models_base def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -348,10 +349,11 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl class EmptyClass: pass - model_config = EmptyClass() - model_config.unet_config = unet_config + model_config = comfy.supported_models_base.BASE({}) + from . import latent_formats model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) + model_config.unet_config = unet_config if config['model']["target"].endswith("LatentInpaintDiffusion"): model = model_base.SDInpaint(model_config, model_type=model_type) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index c72838008..c9cd54d0e 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -1,6 +1,7 @@ import torch from . import model_base from . import utils +from . import latent_formats def state_dict_key_replace(state_dict, keys_to_replace): @@ -34,6 +35,7 @@ class BASE: clip_vision_prefix = None noise_aug_config = None beta_schedule = "linear" + latent_format = latent_formats.LatentFormat @classmethod def matches(s, unet_config): From 18617967e5be09d4d24ff0bb337ab4468dd80e6c Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Wed, 30 Aug 2023 00:25:04 -0700 Subject: [PATCH 007/150] Fix error message in model_patcher.py Found while tinkering. --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2f087a600..a6ee0bae1 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -145,7 +145,7 @@ class ModelPatcher: model_sd = self.model_state_dict() for key in self.patches: if key not in model_sd: - print("could not patch. key doesn't exist in model:", k) + print("could not patch. key doesn't exist in model:", key) continue weight = model_sd[key] From 7e941f9f247f9b013a33c2e7d117466108414e99 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 30 Aug 2023 12:55:07 -0400 Subject: [PATCH 008/150] Clean up DiffusersLoader node. --- comfy/diffusers_load.py | 101 ++++++++++------------------------------ nodes.py | 2 +- 2 files changed, 26 insertions(+), 77 deletions(-) diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index 11d94c340..a52e0102b 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -1,87 +1,36 @@ import json import os -import yaml -import folder_paths -from comfy.sd import load_checkpoint -import os.path as osp -import re -import torch -from safetensors.torch import load_file, save_file -from . import diffusers_convert +import comfy.sd +def first_file(path, filenames): + for f in filenames: + p = os.path.join(path, f) + if os.path.exists(p): + return p + return None -def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): - diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) - diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) +def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None): + diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"] + unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names) + vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names) - # magic - v2 = diffusers_unet_conf["sample_size"] == 96 - if 'prediction_type' in diffusers_scheduler_conf: - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"] + text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names) + text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names) - if v2: - if v_pred: - config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + text_encoder_paths = [text_encoder1_path] + if text_encoder2_path is not None: + text_encoder_paths.append(text_encoder2_path) - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) + unet = comfy.sd.load_unet(unet_path) - model_config_params = config['model']['params'] - clip_config = model_config_params['cond_stage_config'] - scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - vae_config['scale_factor'] = scale_factor - model_config_params["unet_config"]["params"]["use_fp16"] = fp16 + clip = None + if output_clip: + clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory) - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") - text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + vae = None + if output_vae: + vae = comfy.sd.VAE(ckpt_path=vae_path) - # Load models from safetensors if it exists, if it doesn't pytorch - if osp.exists(unet_path): - unet_state_dict = load_file(unet_path, device="cpu") - else: - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") - unet_state_dict = torch.load(unet_path, map_location="cpu") - - if osp.exists(vae_path): - vae_state_dict = load_file(vae_path, device="cpu") - else: - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") - vae_state_dict = torch.load(vae_path, map_location="cpu") - - if osp.exists(text_enc_path): - text_enc_dict = load_file(text_enc_path, device="cpu") - else: - text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - - # Convert the UNet model - unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict) - unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} - - # Convert the VAE model - vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper - is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict - - if is_v20_model: - # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm - text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} - text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict) - text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} - else: - text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict) - text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} - - # Put together new checkpoint - sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - - return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config) + return (unet, clip, vae) diff --git a/nodes.py b/nodes.py index 3e4d5240b..5e755f149 100644 --- a/nodes.py +++ b/nodes.py @@ -475,7 +475,7 @@ class DiffusersLoader: model_path = path break - return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: From 2cd3980199ea1769ea3007009c516683b472337b Mon Sep 17 00:00:00 2001 From: Ridan Vandenbergh Date: Wed, 30 Aug 2023 20:46:53 +0200 Subject: [PATCH 009/150] Remove forced lowercase on embeddings endpoint --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 0d7d28a0b..57d5a65df 100644 --- a/server.py +++ b/server.py @@ -127,7 +127,7 @@ class PromptServer(): @routes.get("/embeddings") def get_embeddings(self): embeddings = folder_paths.get_filename_list("embeddings") - return web.json_response(list(map(lambda a: os.path.splitext(a)[0].lower(), embeddings))) + return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) @routes.get("/extensions") async def get_extensions(request): From 5f101f4da14e0b4a360ca1d0c380fab174d301bf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 31 Aug 2023 02:25:21 -0400 Subject: [PATCH 010/150] Update litegraph with upstream: middle mouse dragging. --- web/lib/litegraph.core.js | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 356c71ac2..4bb2f0d99 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -6233,11 +6233,17 @@ LGraphNode.prototype.executeAction = function(action) ,posAdd:[!mClikSlot_isOut?-30:30, -alphaPosY*130] //-alphaPosY*30] ,posSizeFix:[!mClikSlot_isOut?-1:0, 0] //-alphaPosY*2*/ }); - + skip_action = true; } } } } + + if (!skip_action && this.allow_dragcanvas) { + //console.log("pointerevents: dragging_canvas start from middle button"); + this.dragging_canvas = true; + } + } else if (e.which == 3 || this.pointer_is_double) { From 1c012d69afa8bd92a007a3e468e2a1f874365d39 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 31 Aug 2023 13:25:00 -0400 Subject: [PATCH 011/150] It doesn't make sense for c_crossattn and c_concat to be lists. --- comfy/model_base.py | 4 ++-- comfy/samplers.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index acd4169a8..677a23de7 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -50,10 +50,10 @@ class BaseModel(torch.nn.Module): def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): if c_concat is not None: - xc = torch.cat([x] + c_concat, dim=1) + xc = torch.cat([x] + [c_concat], dim=1) else: xc = x - context = torch.cat(c_crossattn, 1) + context = c_crossattn dtype = self.get_dtype() xc = xc.to(dtype) t = t.to(dtype) diff --git a/comfy/samplers.py b/comfy/samplers.py index 134336de6..103ac33ff 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -165,9 +165,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c_crossattn_out.append(c) if len(c_crossattn_out) > 0: - out['c_crossattn'] = [torch.cat(c_crossattn_out)] + out['c_crossattn'] = torch.cat(c_crossattn_out) if len(c_concat) > 0: - out['c_concat'] = [torch.cat(c_concat)] + out['c_concat'] = torch.cat(c_concat) if len(c_adm) > 0: out['c_adm'] = torch.cat(c_adm) return out From 57beace324b49f4b6b45291e3940b99c84387e89 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 31 Aug 2023 14:26:16 -0400 Subject: [PATCH 012/150] Fix VAEDecodeTiled minimum. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 5e755f149..38d947d65 100644 --- a/nodes.py +++ b/nodes.py @@ -245,7 +245,7 @@ class VAEDecodeTiled: @classmethod def INPUT_TYPES(s): return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), - "tile_size": ("INT", {"default": 512, "min": 192, "max": 4096, "step": 64}) + "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64}) }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" From cfe1c54de88e7525ec7e4189a8a3294dfc3cd4c0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 31 Aug 2023 15:16:58 -0400 Subject: [PATCH 013/150] Fix controlnet issue. --- comfy/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 83e1be058..f62dd9c88 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -155,7 +155,7 @@ class ControlNet(ControlBase): self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = torch.cat(cond['c_crossattn'], 1) + context = cond['c_crossattn'] y = cond.get('c_adm', None) if y is not None: y = y.to(self.control_model.dtype) From 9a7a52f8b5321e9a67fab18d4f256d2cd6bc338f Mon Sep 17 00:00:00 2001 From: Michael Poutre Date: Thu, 3 Aug 2023 19:49:52 -0700 Subject: [PATCH 014/150] refactor/fix: Treat forceInput widgets as standard widgets --- web/extensions/core/widgetInputs.js | 23 ++++++++++++++++++++--- web/scripts/app.js | 29 ++++++++++++++--------------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index d9eaf8a0c..a6b1a1dc2 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -5,7 +5,7 @@ const CONVERTED_TYPE = "converted-widget"; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; function isConvertableWidget(widget, config) { - return VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0]); + return (VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0])) && !widget.options?.forceInput; } function hideWidget(node, widget, suffix = "") { @@ -103,6 +103,9 @@ app.registerExtension({ let toInput = []; let toWidget = []; for (const w of this.widgets) { + if (w.options?.forceInput) { + continue; + } if (w.type === CONVERTED_TYPE) { toWidget.push({ content: `Convert ${w.name} to widget`, @@ -130,6 +133,20 @@ app.registerExtension({ return r; }; + const origOnNodeCreated = nodeType.prototype.onNodeCreated + nodeType.prototype.onNodeCreated = function () { + const r = origOnNodeCreated ? origOnNodeCreated.apply(this) : undefined; + if (this.widgets) { + for (const w of this.widgets) { + if (w?.options?.forceInput) { + const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}]; + convertToInput(this, w, config); + } + } + } + return r; + } + // On initial configure of nodes hide all converted widgets const origOnConfigure = nodeType.prototype.onConfigure; nodeType.prototype.onConfigure = function () { @@ -137,7 +154,7 @@ app.registerExtension({ if (this.inputs) { for (const input of this.inputs) { - if (input.widget) { + if (input.widget && !input.widget.config[1]?.forceInput) { const w = this.widgets.find((w) => w.name === input.widget.name); if (w) { hideWidget(this, w); @@ -374,7 +391,7 @@ app.registerExtension({ } for (const k in config1[1]) { - if (k !== "default") { + if (k !== "default" && k !== 'forceInput') { if (config1[1][k] !== config2[1][k]) { return false; } diff --git a/web/scripts/app.js b/web/scripts/app.js index 6a2c63290..42adfde8f 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1151,22 +1151,21 @@ export class ComfyApp { const inputData = inputs[inputName]; const type = inputData[0]; - if(inputData[1]?.forceInput) { - this.addInput(inputName, type); + if (Array.isArray(type)) { + // Enums + Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {}); + } else if (`${type}:${inputName}` in widgets) { + // Support custom widgets by Type:Name + Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {}); + } else if (type in widgets) { + // Standard type widgets + Object.assign(config, widgets[type](this, inputName, inputData, app) || {}); } else { - if (Array.isArray(type)) { - // Enums - Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {}); - } else if (`${type}:${inputName}` in widgets) { - // Support custom widgets by Type:Name - Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {}); - } else if (type in widgets) { - // Standard type widgets - Object.assign(config, widgets[type](this, inputName, inputData, app) || {}); - } else { - // Node connection inputs - this.addInput(inputName, type); - } + // Node connection inputs + this.addInput(inputName, type); + } + if(inputData[1]?.forceInput && config?.widget) { + config.widget.options.forceInput = inputData[1].forceInput; } } From 69c5e6de85cdb46e0869704abff51b483caea967 Mon Sep 17 00:00:00 2001 From: Michael Poutre Date: Thu, 31 Aug 2023 17:55:24 -0700 Subject: [PATCH 015/150] fix(widgets): Add options object if not present when forceInput: true --- web/scripts/app.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 42adfde8f..3b7483cdf 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1165,6 +1165,7 @@ export class ComfyApp { this.addInput(inputName, type); } if(inputData[1]?.forceInput && config?.widget) { + if (!config.widget.options) config.widget.options = {}; config.widget.options.forceInput = inputData[1].forceInput; } } From 5c363a9d86827d194e3a8e5dd6085a67f65c7ee6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 1 Sep 2023 02:01:08 -0400 Subject: [PATCH 016/150] Fix controlnet bug. --- comfy/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index f62dd9c88..490be6bbc 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -143,7 +143,7 @@ class ControlNet(ControlBase): if control_prev is not None: return control_prev else: - return {} + return None output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: From 0e3b64117218c50a554b492269f5f35779839695 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 1 Sep 2023 02:12:03 -0400 Subject: [PATCH 017/150] Remove xformers related print. --- comfy/ldm/modules/attention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 973619bf2..9fdfbd217 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -402,8 +402,6 @@ class MemoryEfficientCrossAttention(nn.Module): # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops): super().__init__() - print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " - f"{heads} heads.") inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) From 7931ff0fd95c1842b0c8e7f5cc3a2ce5d3b88b3b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 1 Sep 2023 15:18:25 -0400 Subject: [PATCH 018/150] Support SDXL inpaint models. --- comfy/model_base.py | 9 +++------ comfy/model_detection.py | 6 +++++- comfy/sd.py | 7 ++++--- comfy/supported_models.py | 5 ++++- comfy/supported_models_base.py | 11 ++++++----- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 677a23de7..ca154dba0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -111,6 +111,9 @@ class BaseModel(torch.nn.Module): return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + def set_inpaint(self): + self.concat_keys = ("mask", "masked_image") + def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): adm_inputs = [] weights = [] @@ -148,12 +151,6 @@ class SD21UNCLIP(BaseModel): else: return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) - -class SDInpaint(BaseModel): - def __init__(self, model_config, model_type=ModelType.EPS, device=None): - super().__init__(model_config, model_type, device=device) - self.concat_keys = ("mask", "masked_image") - def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280] diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0edc4f180..372d5a2df 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -183,8 +183,12 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16): 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} + SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320, + 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], + 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet] + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint] for unet_config in supported_models: matches = True diff --git a/comfy/sd.py b/comfy/sd.py index e98dabe88..8be0bcbc8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -355,13 +355,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) model_config.unet_config = unet_config - if config['model']["target"].endswith("LatentInpaintDiffusion"): - model = model_base.SDInpaint(model_config, model_type=model_type) - elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): + if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) else: model = model_base.BaseModel(model_config, model_type=model_type) + if config['model']["target"].endswith("LatentInpaintDiffusion"): + model.set_inpaint() + if fp16: model = model.half() diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 95fc8f3f5..0b3e4bcbd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -153,7 +153,10 @@ class SDXL(supported_models_base.BASE): return model_base.ModelType.EPS def get_model(self, state_dict, prefix="", device=None): - return model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device) + out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device) + if self.inpaint_model(): + out.set_inpaint() + return out def process_clip_state_dict(self, state_dict): keys_to_replace = {} diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index c9cd54d0e..395a90ab4 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -57,12 +57,13 @@ class BASE: self.unet_config[x] = self.unet_extra_config[x] def get_model(self, state_dict, prefix="", device=None): - if self.inpaint_model(): - return model_base.SDInpaint(self, model_type=self.model_type(state_dict, prefix), device=device) - elif self.noise_aug_config is not None: - return model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device) + if self.noise_aug_config is not None: + out = model_base.SD21UNCLIP(self, self.noise_aug_config, model_type=self.model_type(state_dict, prefix), device=device) else: - return model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device) + out = model_base.BaseModel(self, model_type=self.model_type(state_dict, prefix), device=device) + if self.inpaint_model(): + out.set_inpaint() + return out def process_clip_state_dict(self, state_dict): return state_dict From 7891d13329c13943c1d90a5a0973262a85da97d0 Mon Sep 17 00:00:00 2001 From: Muhammed Yusuf <32941435+myusf01@users.noreply.github.com> Date: Sat, 2 Sep 2023 09:58:23 +0300 Subject: [PATCH 019/150] Added label for autoQueueCheckbox. (#1295) * Added label for autoQueueCheckbox. * Menu gets behind of some custom nodes. * Edited extraOptions. Options divided in to different divs to manage them with ease. --- web/scripts/ui.js | 18 +++++++++++++++--- web/style.css | 2 +- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 86e2a1c41..8611c2482 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -617,7 +617,10 @@ export class ComfyUI { ]), ]), $el("div", {id: "extraOptions", style: {width: "100%", display: "none"}}, [ - $el("label", {innerHTML: "Batch count"}, [ + $el("div",[ + + $el("label", {innerHTML: "Batch count"}), + $el("input", { id: "batchCountInputNumber", type: "number", @@ -639,14 +642,23 @@ export class ComfyUI { this.batchCount = i.srcElement.value; document.getElementById("batchCountInputNumber").value = i.srcElement.value; }, + }), + ]), + + $el("div",[ + $el("label",{ + for:"autoQueueCheckbox", + innerHTML: "Auto Queue" + // textContent: "Auto Queue" }), $el("input", { id: "autoQueueCheckbox", type: "checkbox", checked: false, - title: "automatically queue prompt when the queue size hits 0", + title: "Automatically queue prompt when the queue size hits 0", + }), - ]), + ]) ]), $el("div.comfy-menu-btns", [ $el("button", { diff --git a/web/style.css b/web/style.css index 5b6b9ec57..692fa31d6 100644 --- a/web/style.css +++ b/web/style.css @@ -88,7 +88,7 @@ body { top: 50%; right: 0; text-align: center; - z-index: 100; + z-index: 999; width: 170px; display: flex; flex-direction: column; From 36ea8784a875bde21c88f84dfb99475b6e8187e8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 2 Sep 2023 03:34:57 -0400 Subject: [PATCH 020/150] Only return tuple of 3 args in CheckpointLoaderSimple. --- nodes.py | 2 +- web/scripts/ui.js | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 38d947d65..fa26e5939 100644 --- a/nodes.py +++ b/nodes.py @@ -449,7 +449,7 @@ class CheckpointLoaderSimple: def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) - return out + return out[:3] class DiffusersLoader: @classmethod diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 8611c2482..ce3f4fcee 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -620,7 +620,6 @@ export class ComfyUI { $el("div",[ $el("label", {innerHTML: "Batch count"}), - $el("input", { id: "batchCountInputNumber", type: "number", From 77a176f9e0f4777363a414fbb006cb133d31e034 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 2 Sep 2023 03:42:49 -0400 Subject: [PATCH 021/150] Use common function to reshape batch to. --- comfy/sample.py | 8 +++----- comfy/utils.py | 7 +++++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 79ea37e0d..e4730b189 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,6 +1,7 @@ import torch import comfy.model_management import comfy.samplers +import comfy.utils import math import numpy as np @@ -28,8 +29,7 @@ def prepare_mask(noise_mask, shape, device): noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * shape[1], dim=1) - if noise_mask.shape[0] < shape[0]: - noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]] + noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = noise_mask.to(device) return noise_mask @@ -37,9 +37,7 @@ def broadcast_cond(cond, batch, device): """broadcasts conditioning to the batch size""" copy = [] for p in cond: - t = p[0] - if t.shape[0] < batch: - t = torch.cat([t] * batch) + t = comfy.utils.repeat_to_batch_size(p[0], batch) t = t.to(device) copy += [[t] + p[1:]] return copy diff --git a/comfy/utils.py b/comfy/utils.py index 693e2612d..47f4b9709 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -223,6 +223,13 @@ def unet_to_diffusers(unet_config): return diffusers_unet_map +def repeat_to_batch_size(tensor, batch_size): + if tensor.shape[0] > batch_size: + return tensor[:batch_size] + elif tensor.shape[0] < batch_size: + return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] + return tensor + def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: From 7291e303f662f35bc545a8aaa0020558e82a8ca9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 2 Sep 2023 11:48:44 -0400 Subject: [PATCH 022/150] Fix issue with some workflows not getting serialized. --- web/extensions/core/widgetInputs.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index a6b1a1dc2..34c656de1 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -16,6 +16,9 @@ function hideWidget(node, widget, suffix = "") { widget.type = CONVERTED_TYPE + suffix; widget.serializeValue = () => { // Prevent serializing the widget if we have no input linked + if (!node.inputs) { + return undefined; + } const { link } = node.inputs.find((i) => i.widget?.name === widget.name); if (link == null) { return undefined; From 6962cb46a99ed0a9895dc06ca293b9e48e3eabc8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 2 Sep 2023 12:17:30 -0400 Subject: [PATCH 023/150] Fix issue when node_input is undefined. --- web/extensions/core/widgetInputs.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 34c656de1..f9a5b7278 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -19,8 +19,9 @@ function hideWidget(node, widget, suffix = "") { if (!node.inputs) { return undefined; } - const { link } = node.inputs.find((i) => i.widget?.name === widget.name); - if (link == null) { + let node_input = node.inputs.find((i) => i.widget?.name === widget.name); + + if (!node_input || !node_input.link) { return undefined; } return widget.origSerializeValue ? widget.origSerializeValue() : widget.value; From 62efc78a4b13b87ef0df51323fe1bd71b433fa11 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 2 Sep 2023 15:36:45 -0400 Subject: [PATCH 024/150] Display history in reverse order to make it easier to load last gen. --- web/scripts/ui.js | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index ce3f4fcee..f39939bf3 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -431,10 +431,12 @@ class ComfySettingsDialog extends ComfyDialog { class ComfyList { #type; #text; + #reverse; - constructor(text, type) { + constructor(text, type, reverse) { this.#text = text; this.#type = type || text.toLowerCase(); + this.#reverse = reverse || false; this.element = $el("div.comfy-list"); this.element.style.display = "none"; } @@ -451,7 +453,7 @@ class ComfyList { textContent: section, }), $el("div.comfy-list-items", [ - ...items[section].map((item) => { + ...(this.#reverse ? items[section].reverse() : items[section]).map((item) => { // Allow items to specify a custom remove action (e.g. for interrupt current prompt) const removeAction = item.remove || { name: "Delete", @@ -529,7 +531,7 @@ export class ComfyUI { this.batchCount = 1; this.lastQueueSize = 0; this.queue = new ComfyList("Queue"); - this.history = new ComfyList("History"); + this.history = new ComfyList("History", "history", true); api.addEventListener("status", () => { this.queue.update(); From dfd6489c9622fdf48728e336fc263df283c84903 Mon Sep 17 00:00:00 2001 From: Chris Date: Sun, 3 Sep 2023 07:53:02 +1000 Subject: [PATCH 025/150] onExecutionStart --- web/scripts/app.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 3b7483cdf..ce5e27d0c 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -994,6 +994,10 @@ export class ComfyApp { api.addEventListener("execution_start", ({ detail }) => { this.runningNodeId = null; this.lastExecutionError = null + this.graph._nodes.forEach((node) => { + if (node.onExecutionStart) + node.onExecutionStart() + }) }); api.addEventListener("execution_error", ({ detail }) => { From 4a0c4ce4ef3c1e0f2b777dcd20a8864be1420f19 Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Sat, 2 Sep 2023 18:22:10 -0700 Subject: [PATCH 026/150] Some fixes to generalize CUDA specific functionality to Intel or other GPUs. --- comfy/ldm/modules/attention.py | 3 +- comfy/ldm/modules/diffusionmodules/util.py | 24 ++++++++++---- comfy/model_management.py | 37 ++++++++++++---------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 9fdfbd217..8f953d337 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -323,8 +323,7 @@ class CrossAttentionDoggettx(nn.Module): break except model_management.OOM_EXCEPTION as e: if first_op_done == False: - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + model_management.soft_empty_cache() if cleared_cache == False: cleared_cache = True print("out of memory error, emptying cache and trying again") diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index d890c8044..9d07d9359 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -15,6 +15,7 @@ import torch.nn as nn import numpy as np from einops import repeat +from comfy import model_management from comfy.ldm.util import instantiate_from_config import comfy.ops @@ -139,13 +140,22 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) + if model_management.is_nvidia(): + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + elif model_management.is_intel_xpu(): + with torch.enable_grad(), \ + torch.xpu.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params, diff --git a/comfy/model_management.py b/comfy/model_management.py index aca8af999..bdbbbd843 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -58,8 +58,15 @@ except: if args.cpu: cpu_state = CPUState.CPU -def get_torch_device(): +def is_intel_xpu(): + global cpu_state global xpu_available + if cpu_state == CPUState.GPU: + if xpu_available: + return True + return False + +def get_torch_device(): global directml_enabled global cpu_state if directml_enabled: @@ -70,13 +77,12 @@ def get_torch_device(): if cpu_state == CPUState.CPU: return torch.device("cpu") else: - if xpu_available: + if is_intel_xpu(): return torch.device("xpu") else: return torch.device(torch.cuda.current_device()) def get_total_memory(dev=None, torch_total_too=False): - global xpu_available global directml_enabled if dev is None: dev = get_torch_device() @@ -88,7 +94,7 @@ def get_total_memory(dev=None, torch_total_too=False): if directml_enabled: mem_total = 1024 * 1024 * 1024 #TODO mem_total_torch = mem_total - elif xpu_available: + elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] mem_total = torch.xpu.get_device_properties(dev).total_memory @@ -146,11 +152,11 @@ def is_nvidia(): if cpu_state == CPUState.GPU: if torch.version.cuda: return True + return False ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention VAE_DTYPE = torch.float32 - try: if is_nvidia(): torch_version = torch.version.__version__ @@ -162,6 +168,9 @@ try: except: pass +if is_intel_xpu(): + VAE_DTYPE = torch.bfloat16 + if args.fp16_vae: VAE_DTYPE = torch.float16 elif args.bf16_vae: @@ -220,7 +229,6 @@ if DISABLE_SMART_MEMORY: print("Disabling smart memory management") def get_torch_device_name(device): - global xpu_available if hasattr(device, 'type'): if device.type == "cuda": try: @@ -230,7 +238,7 @@ def get_torch_device_name(device): return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend) else: return "{}".format(device.type) - elif xpu_available: + elif is_intel_xpu(): return "{} {}".format(device, torch.xpu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -260,7 +268,6 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0): - global xpu_available patch_model_to = None if lowvram_model_memory == 0: patch_model_to = self.device @@ -281,7 +288,7 @@ class LoadedModel: accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) self.model_accelerated = True - if xpu_available and not args.disable_ipex_optimize: + if is_intel_xpu() and not args.disable_ipex_optimize: self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True) return self.real_model @@ -471,12 +478,11 @@ def get_autocast_device(dev): def xformers_enabled(): - global xpu_available global directml_enabled global cpu_state if cpu_state != CPUState.GPU: return False - if xpu_available: + if is_intel_xpu(): return False if directml_enabled: return False @@ -503,7 +509,6 @@ def pytorch_attention_flash_attention(): return False def get_free_memory(dev=None, torch_free_too=False): - global xpu_available global directml_enabled if dev is None: dev = get_torch_device() @@ -515,7 +520,7 @@ def get_free_memory(dev=None, torch_free_too=False): if directml_enabled: mem_free_total = 1024 * 1024 * 1024 #TODO mem_free_torch = mem_free_total - elif xpu_available: + elif is_intel_xpu(): stats = torch.xpu.memory_stats(dev) mem_active = stats['active_bytes.all.current'] mem_allocated = stats['allocated_bytes.all.current'] @@ -577,7 +582,6 @@ def is_device_mps(device): return False def should_use_fp16(device=None, model_params=0, prioritize_performance=True): - global xpu_available global directml_enabled if device is not None: @@ -600,7 +604,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if cpu_mode() or mps_mode(): return False #TODO ? - if xpu_available: + if is_intel_xpu(): return True if torch.cuda.is_bf16_supported(): @@ -636,11 +640,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return True def soft_empty_cache(): - global xpu_available global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() - elif xpu_available: + elif is_intel_xpu(): torch.xpu.empty_cache() elif torch.cuda.is_available(): if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda From 766c7b3815c0203d98200772fd7fe1b908cfaa0c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 2 Sep 2023 22:25:12 -0400 Subject: [PATCH 027/150] Update upscale model code to latest Chainner model code. Don't add SRFormer because the code license is incompatible with the GPL. Remove MAT because it's unused and the license is incompatible with GPL. --- .../chainner_models/architecture/DAT.py | 1182 ++++++++++++ .../chainner_models/architecture/LICENSE-DAT | 201 ++ .../architecture/LICENSE-SCUNet | 201 ++ .../chainner_models/architecture/LICENSE-mat | 161 -- .../chainner_models/architecture/MAT.py | 1636 ----------------- .../architecture/OmniSR/OmniSR.py | 12 +- .../chainner_models/architecture/SCUNet.py | 455 +++++ .../chainner_models/architecture/SPSR.py | 1 - .../chainner_models/architecture/SwinIR.py | 1 + .../chainner_models/architecture/mat/utils.py | 698 ------- comfy_extras/chainner_models/model_loading.py | 25 +- comfy_extras/chainner_models/types.py | 22 +- 12 files changed, 2084 insertions(+), 2511 deletions(-) create mode 100644 comfy_extras/chainner_models/architecture/DAT.py create mode 100644 comfy_extras/chainner_models/architecture/LICENSE-DAT create mode 100644 comfy_extras/chainner_models/architecture/LICENSE-SCUNet delete mode 100644 comfy_extras/chainner_models/architecture/LICENSE-mat delete mode 100644 comfy_extras/chainner_models/architecture/MAT.py create mode 100644 comfy_extras/chainner_models/architecture/SCUNet.py delete mode 100644 comfy_extras/chainner_models/architecture/mat/utils.py diff --git a/comfy_extras/chainner_models/architecture/DAT.py b/comfy_extras/chainner_models/architecture/DAT.py new file mode 100644 index 000000000..0bcc26ef4 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/DAT.py @@ -0,0 +1,1182 @@ +# pylint: skip-file +import math +import re + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from einops.layers.torch import Rearrange +from torch import Tensor +from torch.nn import functional as F + +from .timm.drop import DropPath +from .timm.weight_init import trunc_normal_ + + +def img2windows(img, H_sp, W_sp): + """ + Input: Image (B, C, H, W) + Output: Window Partition (B', N, C) + """ + B, C, H, W = img.shape + img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) + img_perm = ( + img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp * W_sp, C) + ) + return img_perm + + +def windows2img(img_splits_hw, H_sp, W_sp, H, W): + """ + Input: Window Partition (B', N, C) + Output: Image (B, H, W, C) + """ + B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) + + img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) + img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return img + + +class SpatialGate(nn.Module): + """Spatial-Gate. + Args: + dim (int): Half of input channels. + """ + + def __init__(self, dim): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.conv = nn.Conv2d( + dim, dim, kernel_size=3, stride=1, padding=1, groups=dim + ) # DW Conv + + def forward(self, x, H, W): + # Split + x1, x2 = x.chunk(2, dim=-1) + B, N, C = x.shape + x2 = ( + self.conv(self.norm(x2).transpose(1, 2).contiguous().view(B, C // 2, H, W)) + .flatten(2) + .transpose(-1, -2) + .contiguous() + ) + + return x1 * x2 + + +class SGFN(nn.Module): + """Spatial-Gate Feed-Forward Network. + Args: + in_features (int): Number of input channels. + hidden_features (int | None): Number of hidden channels. Default: None + out_features (int | None): Number of output channels. Default: None + act_layer (nn.Module): Activation layer. Default: nn.GELU + drop (float): Dropout rate. Default: 0.0 + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.sg = SpatialGate(hidden_features // 2) + self.fc2 = nn.Linear(hidden_features // 2, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + """ + Input: x: (B, H*W, C), H, W + Output: x: (B, H*W, C) + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + + x = self.sg(x, H, W) + x = self.drop(x) + + x = self.fc2(x) + x = self.drop(x) + return x + + +class DynamicPosBias(nn.Module): + # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py + """Dynamic Relative Position Bias. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + residual (bool): If True, use residual strage to connect conv. + """ + + def __init__(self, dim, num_heads, residual): + super().__init__() + self.residual = residual + self.num_heads = num_heads + self.pos_dim = dim // 4 + self.pos_proj = nn.Linear(2, self.pos_dim) + self.pos1 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos2 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.pos_dim), + ) + self.pos3 = nn.Sequential( + nn.LayerNorm(self.pos_dim), + nn.ReLU(inplace=True), + nn.Linear(self.pos_dim, self.num_heads), + ) + + def forward(self, biases): + if self.residual: + pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads + pos = pos + self.pos1(pos) + pos = pos + self.pos2(pos) + pos = self.pos3(pos) + else: + pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) + return pos + + +class Spatial_Attention(nn.Module): + """Spatial Window Self-Attention. + It supports rectangle window (containing square window). + Args: + dim (int): Number of input channels. + idx (int): The indentix of window. (0/1) + split_size (tuple(int)): Height and Width of spatial window. + dim_out (int | None): The dimension of the attention output. Default: None + num_heads (int): Number of attention heads. Default: 6 + attn_drop (float): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float): Dropout ratio of output. Default: 0.0 + qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set + position_bias (bool): The dynamic relative position bias. Default: True + """ + + def __init__( + self, + dim, + idx, + split_size=[8, 8], + dim_out=None, + num_heads=6, + attn_drop=0.0, + proj_drop=0.0, + qk_scale=None, + position_bias=True, + ): + super().__init__() + self.dim = dim + self.dim_out = dim_out or dim + self.split_size = split_size + self.num_heads = num_heads + self.idx = idx + self.position_bias = position_bias + + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + if idx == 0: + H_sp, W_sp = self.split_size[0], self.split_size[1] + elif idx == 1: + W_sp, H_sp = self.split_size[0], self.split_size[1] + else: + print("ERROR MODE", idx) + exit(0) + self.H_sp = H_sp + self.W_sp = W_sp + + if self.position_bias: + self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False) + # generate mother-set + position_bias_h = torch.arange(1 - self.H_sp, self.H_sp) + position_bias_w = torch.arange(1 - self.W_sp, self.W_sp) + biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w])) + biases = biases.flatten(1).transpose(0, 1).contiguous().float() + self.register_buffer("rpe_biases", biases) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.H_sp) + coords_w = torch.arange(self.W_sp) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.H_sp - 1 + relative_coords[:, :, 1] += self.W_sp - 1 + relative_coords[:, :, 0] *= 2 * self.W_sp - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.attn_drop = nn.Dropout(attn_drop) + + def im2win(self, x, H, W): + B, N, C = x.shape + x = x.transpose(-2, -1).contiguous().view(B, C, H, W) + x = img2windows(x, self.H_sp, self.W_sp) + x = ( + x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + return x + + def forward(self, qkv, H, W, mask=None): + """ + Input: qkv: (B, 3*L, C), H, W, mask: (B, N, N), N is the window size + Output: x (B, H, W, C) + """ + q, k, v = qkv[0], qkv[1], qkv[2] + + B, L, C = q.shape + assert L == H * W, "flatten img_tokens has wrong size" + + # partition the q,k,v, image to window + q = self.im2win(q, H, W) + k = self.im2win(k, H, W) + v = self.im2win(v, H, W) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) # B head N C @ B head C N --> B head N N + + # calculate drpe + if self.position_bias: + pos = self.pos(self.rpe_biases) + # select position bias + relative_position_bias = pos[self.relative_position_index.view(-1)].view( + self.H_sp * self.W_sp, self.H_sp * self.W_sp, -1 + ) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + + N = attn.shape[3] + + # use mask for shift window + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0 + ) + attn = attn.view(-1, self.num_heads, N, N) + + attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype) + attn = self.attn_drop(attn) + + x = attn @ v + x = x.transpose(1, 2).reshape( + -1, self.H_sp * self.W_sp, C + ) # B head N N @ B head N C + + # merge the window, window to image + x = windows2img(x, self.H_sp, self.W_sp, H, W) # B H' W' C + + return x + + +class Adaptive_Spatial_Attention(nn.Module): + # The implementation builds on CAT code https://github.com/Zhengchen1999/CAT + """Adaptive Spatial Self-Attention + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 6 + split_size (tuple(int)): Height and Width of spatial window. + shift_size (tuple(int)): Shift size for spatial window. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. + drop (float): Dropout rate. Default: 0.0 + attn_drop (float): Attention dropout rate. Default: 0.0 + rg_idx (int): The indentix of Residual Group (RG) + b_idx (int): The indentix of Block in each RG + """ + + def __init__( + self, + dim, + num_heads, + reso=64, + split_size=[8, 8], + shift_size=[1, 2], + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + rg_idx=0, + b_idx=0, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.split_size = split_size + self.shift_size = shift_size + self.b_idx = b_idx + self.rg_idx = rg_idx + self.patches_resolution = reso + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + assert ( + 0 <= self.shift_size[0] < self.split_size[0] + ), "shift_size must in 0-split_size0" + assert ( + 0 <= self.shift_size[1] < self.split_size[1] + ), "shift_size must in 0-split_size1" + + self.branch_num = 2 + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(drop) + + self.attns = nn.ModuleList( + [ + Spatial_Attention( + dim // 2, + idx=i, + split_size=split_size, + num_heads=num_heads // 2, + dim_out=dim // 2, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + position_bias=True, + ) + for i in range(self.branch_num) + ] + ) + + if (self.rg_idx % 2 == 0 and self.b_idx > 0 and (self.b_idx - 2) % 4 == 0) or ( + self.rg_idx % 2 != 0 and self.b_idx % 4 == 0 + ): + attn_mask = self.calculate_mask( + self.patches_resolution, self.patches_resolution + ) + self.register_buffer("attn_mask_0", attn_mask[0]) + self.register_buffer("attn_mask_1", attn_mask[1]) + else: + attn_mask = None + self.register_buffer("attn_mask_0", None) + self.register_buffer("attn_mask_1", None) + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.BatchNorm2d(dim), + nn.GELU(), + ) + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + nn.BatchNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1), + ) + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.BatchNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1), + ) + + def calculate_mask(self, H, W): + # The implementation builds on Swin Transformer code https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for shift window + img_mask_0 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=0 + img_mask_1 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=1 + h_slices_0 = ( + slice(0, -self.split_size[0]), + slice(-self.split_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None), + ) + w_slices_0 = ( + slice(0, -self.split_size[1]), + slice(-self.split_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None), + ) + + h_slices_1 = ( + slice(0, -self.split_size[1]), + slice(-self.split_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None), + ) + w_slices_1 = ( + slice(0, -self.split_size[0]), + slice(-self.split_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None), + ) + cnt = 0 + for h in h_slices_0: + for w in w_slices_0: + img_mask_0[:, h, w, :] = cnt + cnt += 1 + cnt = 0 + for h in h_slices_1: + for w in w_slices_1: + img_mask_1[:, h, w, :] = cnt + cnt += 1 + + # calculate mask for window-0 + img_mask_0 = img_mask_0.view( + 1, + H // self.split_size[0], + self.split_size[0], + W // self.split_size[1], + self.split_size[1], + 1, + ) + img_mask_0 = ( + img_mask_0.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(-1, self.split_size[0], self.split_size[1], 1) + ) # nW, sw[0], sw[1], 1 + mask_windows_0 = img_mask_0.view(-1, self.split_size[0] * self.split_size[1]) + attn_mask_0 = mask_windows_0.unsqueeze(1) - mask_windows_0.unsqueeze(2) + attn_mask_0 = attn_mask_0.masked_fill( + attn_mask_0 != 0, float(-100.0) + ).masked_fill(attn_mask_0 == 0, float(0.0)) + + # calculate mask for window-1 + img_mask_1 = img_mask_1.view( + 1, + H // self.split_size[1], + self.split_size[1], + W // self.split_size[0], + self.split_size[0], + 1, + ) + img_mask_1 = ( + img_mask_1.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(-1, self.split_size[1], self.split_size[0], 1) + ) # nW, sw[1], sw[0], 1 + mask_windows_1 = img_mask_1.view(-1, self.split_size[1] * self.split_size[0]) + attn_mask_1 = mask_windows_1.unsqueeze(1) - mask_windows_1.unsqueeze(2) + attn_mask_1 = attn_mask_1.masked_fill( + attn_mask_1 != 0, float(-100.0) + ).masked_fill(attn_mask_1 == 0, float(0.0)) + + return attn_mask_0, attn_mask_1 + + def forward(self, x, H, W): + """ + Input: x: (B, H*W, C), H, W + Output: x: (B, H*W, C) + """ + B, L, C = x.shape + assert L == H * W, "flatten img_tokens has wrong size" + + qkv = self.qkv(x).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # 3, B, HW, C + # V without partition + v = qkv[2].transpose(-2, -1).contiguous().view(B, C, H, W) + + # image padding + max_split_size = max(self.split_size[0], self.split_size[1]) + pad_l = pad_t = 0 + pad_r = (max_split_size - W % max_split_size) % max_split_size + pad_b = (max_split_size - H % max_split_size) % max_split_size + + qkv = qkv.reshape(3 * B, H, W, C).permute(0, 3, 1, 2) # 3B C H W + qkv = ( + F.pad(qkv, (pad_l, pad_r, pad_t, pad_b)) + .reshape(3, B, C, -1) + .transpose(-2, -1) + ) # l r t b + _H = pad_b + H + _W = pad_r + W + _L = _H * _W + + # window-0 and window-1 on split channels [C/2, C/2]; for square windows (e.g., 8x8), window-0 and window-1 can be merged + # shift in block: (0, 4, 8, ...), (2, 6, 10, ...), (0, 4, 8, ...), (2, 6, 10, ...), ... + if (self.rg_idx % 2 == 0 and self.b_idx > 0 and (self.b_idx - 2) % 4 == 0) or ( + self.rg_idx % 2 != 0 and self.b_idx % 4 == 0 + ): + qkv = qkv.view(3, B, _H, _W, C) + qkv_0 = torch.roll( + qkv[:, :, :, :, : C // 2], + shifts=(-self.shift_size[0], -self.shift_size[1]), + dims=(2, 3), + ) + qkv_0 = qkv_0.view(3, B, _L, C // 2) + qkv_1 = torch.roll( + qkv[:, :, :, :, C // 2 :], + shifts=(-self.shift_size[1], -self.shift_size[0]), + dims=(2, 3), + ) + qkv_1 = qkv_1.view(3, B, _L, C // 2) + + if self.patches_resolution != _H or self.patches_resolution != _W: + mask_tmp = self.calculate_mask(_H, _W) + x1_shift = self.attns[0](qkv_0, _H, _W, mask=mask_tmp[0].to(x.device)) + x2_shift = self.attns[1](qkv_1, _H, _W, mask=mask_tmp[1].to(x.device)) + else: + x1_shift = self.attns[0](qkv_0, _H, _W, mask=self.attn_mask_0) + x2_shift = self.attns[1](qkv_1, _H, _W, mask=self.attn_mask_1) + + x1 = torch.roll( + x1_shift, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2) + ) + x2 = torch.roll( + x2_shift, shifts=(self.shift_size[1], self.shift_size[0]), dims=(1, 2) + ) + x1 = x1[:, :H, :W, :].reshape(B, L, C // 2) + x2 = x2[:, :H, :W, :].reshape(B, L, C // 2) + # attention output + attened_x = torch.cat([x1, x2], dim=2) + + else: + x1 = self.attns[0](qkv[:, :, :, : C // 2], _H, _W)[:, :H, :W, :].reshape( + B, L, C // 2 + ) + x2 = self.attns[1](qkv[:, :, :, C // 2 :], _H, _W)[:, :H, :W, :].reshape( + B, L, C // 2 + ) + # attention output + attened_x = torch.cat([x1, x2], dim=2) + + # convolution output + conv_x = self.dwconv(v) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + channel_map = ( + self.channel_interaction(conv_x) + .permute(0, 2, 3, 1) + .contiguous() + .view(B, 1, C) + ) + # S-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2, -1).contiguous().view(B, C, H, W) + spatial_map = self.spatial_interaction(attention_reshape) + + # C-I + attened_x = attened_x * torch.sigmoid(channel_map) + # S-I + conv_x = torch.sigmoid(spatial_map) * conv_x + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, L, C) + + x = attened_x + conv_x + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Adaptive_Channel_Attention(nn.Module): + # The implementation builds on XCiT code https://github.com/facebookresearch/xcit + """Adaptive Channel Self-Attention + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default: 6 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. + attn_drop (float): Attention dropout rate. Default: 0.0 + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.dwconv = nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim), + nn.BatchNorm2d(dim), + nn.GELU(), + ) + self.channel_interaction = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(dim, dim // 8, kernel_size=1), + nn.BatchNorm2d(dim // 8), + nn.GELU(), + nn.Conv2d(dim // 8, dim, kernel_size=1), + ) + self.spatial_interaction = nn.Sequential( + nn.Conv2d(dim, dim // 16, kernel_size=1), + nn.BatchNorm2d(dim // 16), + nn.GELU(), + nn.Conv2d(dim // 16, 1, kernel_size=1), + ) + + def forward(self, x, H, W): + """ + Input: x: (B, H*W, C), H, W + Output: x: (B, H*W, C) + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q.transpose(-2, -1) + k = k.transpose(-2, -1) + v = v.transpose(-2, -1) + + v_ = v.reshape(B, C, N).contiguous().view(B, C, H, W) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # attention output + attened_x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + + # convolution output + conv_x = self.dwconv(v_) + + # Adaptive Interaction Module (AIM) + # C-Map (before sigmoid) + attention_reshape = attened_x.transpose(-2, -1).contiguous().view(B, C, H, W) + channel_map = self.channel_interaction(attention_reshape) + # S-Map (before sigmoid) + spatial_map = ( + self.spatial_interaction(conv_x) + .permute(0, 2, 3, 1) + .contiguous() + .view(B, N, 1) + ) + + # S-I + attened_x = attened_x * torch.sigmoid(spatial_map) + # C-I + conv_x = conv_x * torch.sigmoid(channel_map) + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, N, C) + + x = attened_x + conv_x + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class DATB(nn.Module): + def __init__( + self, + dim, + num_heads, + reso=64, + split_size=[2, 4], + shift_size=[1, 2], + expansion_factor=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + rg_idx=0, + b_idx=0, + ): + super().__init__() + + self.norm1 = norm_layer(dim) + + if b_idx % 2 == 0: + # DSTB + self.attn = Adaptive_Spatial_Attention( + dim, + num_heads=num_heads, + reso=reso, + split_size=split_size, + shift_size=shift_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + rg_idx=rg_idx, + b_idx=b_idx, + ) + else: + # DCTB + self.attn = Adaptive_Channel_Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + ffn_hidden_dim = int(dim * expansion_factor) + self.ffn = SGFN( + in_features=dim, + hidden_features=ffn_hidden_dim, + out_features=dim, + act_layer=act_layer, + ) + self.norm2 = norm_layer(dim) + + def forward(self, x, x_size): + """ + Input: x: (B, H*W, C), x_size: (H, W) + Output: x: (B, H*W, C) + """ + H, W = x_size + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.ffn(self.norm2(x), H, W)) + + return x + + +class ResidualGroup(nn.Module): + """ResidualGroup + Args: + dim (int): Number of input channels. + reso (int): Input resolution. + num_heads (int): Number of attention heads. + split_size (tuple(int)): Height and Width of spatial window. + expansion_factor (float): Ratio of ffn hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop (float): Dropout rate. Default: 0 + attn_drop(float): Attention dropout rate. Default: 0 + drop_paths (float | None): Stochastic depth rate. + act_layer (nn.Module): Activation layer. Default: nn.GELU + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm + depth (int): Number of dual aggregation Transformer blocks in residual group. + use_chk (bool): Whether to use checkpointing to save memory. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__( + self, + dim, + reso, + num_heads, + split_size=[2, 4], + expansion_factor=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_paths=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + depth=2, + use_chk=False, + resi_connection="1conv", + rg_idx=0, + ): + super().__init__() + self.use_chk = use_chk + self.reso = reso + + self.blocks = nn.ModuleList( + [ + DATB( + dim=dim, + num_heads=num_heads, + reso=reso, + split_size=split_size, + shift_size=[split_size[0] // 2, split_size[1] // 2], + expansion_factor=expansion_factor, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_paths[i], + act_layer=act_layer, + norm_layer=norm_layer, + rg_idx=rg_idx, + b_idx=i, + ) + for i in range(depth) + ] + ) + + if resi_connection == "1conv": + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == "3conv": + self.conv = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1), + ) + + def forward(self, x, x_size): + """ + Input: x: (B, H*W, C), x_size: (H, W) + Output: x: (B, H*W, C) + """ + H, W = x_size + res = x + for blk in self.blocks: + if self.use_chk: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) + x = self.conv(x) + x = rearrange(x, "b c h w -> b (h w) c") + x = res + x + + return x + + +class Upsample(nn.Sequential): + """Upsample module. + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError( + f"scale {scale} is not supported. " "Supported scales: 2^n and 3." + ) + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + h, w = self.input_resolution + flops = h * w * self.num_feat * 3 * 9 + return flops + + +class DAT(nn.Module): + """Dual Aggregation Transformer + Args: + img_size (int): Input image size. Default: 64 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 180 + depths (tuple(int)): Depth of each residual group (number of DATB in each RG). + split_size (tuple(int)): Height and Width of spatial window. + num_heads (tuple(int)): Number of attention heads in different residual groups. + expansion_factor (float): Ratio of ffn hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + act_layer (nn.Module): Activation layer. Default: nn.GELU + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm + use_chk (bool): Whether to use checkpointing to save memory. + upscale: Upscale factor. 2/3/4 for image SR + img_range: Image range. 1. or 255. + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, state_dict): + super().__init__() + + # defaults + img_size = 64 + in_chans = 3 + embed_dim = 180 + split_size = [2, 4] + depth = [2, 2, 2, 2] + num_heads = [2, 2, 2, 2] + expansion_factor = 4.0 + qkv_bias = True + qk_scale = None + drop_rate = 0.0 + attn_drop_rate = 0.0 + drop_path_rate = 0.1 + act_layer = nn.GELU + norm_layer = nn.LayerNorm + use_chk = False + upscale = 2 + img_range = 1.0 + resi_connection = "1conv" + upsampler = "pixelshuffle" + + self.model_arch = "DAT" + self.sub_type = "SR" + self.state = state_dict + + state_keys = state_dict.keys() + if "conv_before_upsample.0.weight" in state_keys: + if "conv_up1.weight" in state_keys: + upsampler = "nearest+conv" + else: + upsampler = "pixelshuffle" + supports_fp16 = False + elif "upsample.0.weight" in state_keys: + upsampler = "pixelshuffledirect" + else: + upsampler = "" + + num_feat = ( + state_dict.get("conv_before_upsample.0.weight", None).shape[1] + if state_dict.get("conv_before_upsample.weight", None) + else 64 + ) + + num_in_ch = state_dict["conv_first.weight"].shape[1] + in_chans = num_in_ch + if "conv_last.weight" in state_keys: + num_out_ch = state_dict["conv_last.weight"].shape[0] + else: + num_out_ch = num_in_ch + + upscale = 1 + if upsampler == "nearest+conv": + upsample_keys = [ + x for x in state_keys if "conv_up" in x and "bias" not in x + ] + + for upsample_key in upsample_keys: + upscale *= 2 + elif upsampler == "pixelshuffle": + upsample_keys = [ + x + for x in state_keys + if "upsample" in x and "conv" not in x and "bias" not in x + ] + for upsample_key in upsample_keys: + shape = state_dict[upsample_key].shape[0] + upscale *= math.sqrt(shape // num_feat) + upscale = int(upscale) + elif upsampler == "pixelshuffledirect": + upscale = int( + math.sqrt(state_dict["upsample.0.bias"].shape[0] // num_out_ch) + ) + + max_layer_num = 0 + max_block_num = 0 + for key in state_keys: + result = re.match(r"layers.(\d*).blocks.(\d*).norm1.weight", key) + if result: + layer_num, block_num = result.groups() + max_layer_num = max(max_layer_num, int(layer_num)) + max_block_num = max(max_block_num, int(block_num)) + + depth = [max_block_num + 1 for _ in range(max_layer_num + 1)] + + if "layers.0.blocks.1.attn.temperature" in state_keys: + num_heads_num = state_dict["layers.0.blocks.1.attn.temperature"].shape[0] + num_heads = [num_heads_num for _ in range(max_layer_num + 1)] + else: + num_heads = depth + + embed_dim = state_dict["conv_first.weight"].shape[0] + expansion_factor = float( + state_dict["layers.0.blocks.0.ffn.fc1.weight"].shape[0] / embed_dim + ) + + # TODO: could actually count the layers, but this should do + if "layers.0.conv.4.weight" in state_keys: + resi_connection = "3conv" + else: + resi_connection = "1conv" + + if "layers.0.blocks.2.attn.attn_mask_0" in state_keys: + attn_mask_0_x, attn_mask_0_y, attn_mask_0_z = state_dict[ + "layers.0.blocks.2.attn.attn_mask_0" + ].shape + + img_size = int(math.sqrt(attn_mask_0_x * attn_mask_0_y)) + + if "layers.0.blocks.0.attn.attns.0.rpe_biases" in state_keys: + split_sizes = ( + state_dict["layers.0.blocks.0.attn.attns.0.rpe_biases"][-1] + 1 + ) + split_size = [int(x) for x in split_sizes] + + self.in_nc = num_in_ch + self.out_nc = num_out_ch + self.num_feat = num_feat + self.embed_dim = embed_dim + self.num_heads = num_heads + self.depth = depth + self.scale = upscale + self.upsampler = upsampler + self.img_size = img_size + self.img_range = img_range + self.expansion_factor = expansion_factor + self.resi_connection = resi_connection + self.split_size = split_size + + self.supports_fp16 = False # Too much weirdness to support this at the moment + self.supports_bfp16 = True + self.min_size_restriction = 16 + + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + + # ------------------------- 1, Shallow Feature Extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, Deep Feature Extraction ------------------------- # + self.num_layers = len(depth) + self.use_chk = use_chk + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models + heads = num_heads + + self.before_RG = nn.Sequential( + Rearrange("b c h w -> b (h w) c"), nn.LayerNorm(embed_dim) + ) + + curr_dim = embed_dim + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth)) + ] # stochastic depth decay rule + + self.layers = nn.ModuleList() + for i in range(self.num_layers): + layer = ResidualGroup( + dim=embed_dim, + num_heads=heads[i], + reso=img_size, + split_size=split_size, + expansion_factor=expansion_factor, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_paths=dpr[sum(depth[:i]) : sum(depth[: i + 1])], + act_layer=act_layer, + norm_layer=norm_layer, + depth=depth[i], + use_chk=use_chk, + resi_connection=resi_connection, + rg_idx=i, + ) + self.layers.append(layer) + + self.norm = norm_layer(curr_dim) + # build the last conv layer in deep feature extraction + if resi_connection == "1conv": + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == "3conv": + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1), + ) + + # ------------------------- 3, Reconstruction ------------------------- # + if self.upsampler == "pixelshuffle": + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) + ) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == "pixelshuffledirect": + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep( + upscale, embed_dim, num_out_ch, (img_size, img_size) + ) + + self.apply(self._init_weights) + self.load_state_dict(state_dict, strict=True) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance( + m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d) + ): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + _, _, H, W = x.shape + x_size = [H, W] + x = self.before_RG(x) + for layer in self.layers: + x = layer(x, x_size) + x = self.norm(x) + x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W) + + return x + + def forward(self, x): + """ + Input: x: (B, C, H, W) + """ + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == "pixelshuffle": + # for image SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == "pixelshuffledirect": + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + + x = x / self.img_range + self.mean + return x diff --git a/comfy_extras/chainner_models/architecture/LICENSE-DAT b/comfy_extras/chainner_models/architecture/LICENSE-DAT new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/LICENSE-DAT @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/comfy_extras/chainner_models/architecture/LICENSE-SCUNet b/comfy_extras/chainner_models/architecture/LICENSE-SCUNet new file mode 100644 index 000000000..ff75c988f --- /dev/null +++ b/comfy_extras/chainner_models/architecture/LICENSE-SCUNet @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022 Kai Zhang (cskaizhang@gmail.com, https://cszn.github.io/). All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/comfy_extras/chainner_models/architecture/LICENSE-mat b/comfy_extras/chainner_models/architecture/LICENSE-mat deleted file mode 100644 index 593adf6c6..000000000 --- a/comfy_extras/chainner_models/architecture/LICENSE-mat +++ /dev/null @@ -1,161 +0,0 @@ -## creative commons - -# Attribution-NonCommercial 4.0 International - -Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. - -### Using Creative Commons Public Licenses - -Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. - -* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). - -* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). - -## Creative Commons Attribution-NonCommercial 4.0 International Public License - -By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. - -### Section 1 – Definitions. - -a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. - -b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. - -c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. - -d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. - -e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. - -f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. - -g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. - -h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. - -i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. - -j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. - -k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. - -l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. - -### Section 2 – Scope. - -a. ___License grant.___ - - 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: - - A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and - - B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. - - 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. - - 3. __Term.__ The term of this Public License is specified in Section 6(a). - - 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. - - 5. __Downstream recipients.__ - - A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. - - B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. - - 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). - -b. ___Other rights.___ - - 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. - - 2. Patent and trademark rights are not licensed under this Public License. - - 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. - -### Section 3 – License Conditions. - -Your exercise of the Licensed Rights is expressly made subject to the following conditions. - -a. ___Attribution.___ - - 1. If You Share the Licensed Material (including in modified form), You must: - - A. retain the following if it is supplied by the Licensor with the Licensed Material: - - i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); - - ii. a copyright notice; - - iii. a notice that refers to this Public License; - - iv. a notice that refers to the disclaimer of warranties; - - v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; - - B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and - - C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. - - 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. - - 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. - - 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. - -### Section 4 – Sui Generis Database Rights. - -Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: - -a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; - -b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and - -c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. - -For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. - -### Section 5 – Disclaimer of Warranties and Limitation of Liability. - -a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ - -b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ - -c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. - -### Section 6 – Term and Termination. - -a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. - -b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: - - 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or - - 2. upon express reinstatement by the Licensor. - - For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. - -c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. - -d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. - -### Section 7 – Other Terms and Conditions. - -a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. - -b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. - -### Section 8 – Interpretation. - -a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. - -b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. - -c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. - -d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. - -> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. -> -> Creative Commons may be contacted at creativecommons.org diff --git a/comfy_extras/chainner_models/architecture/MAT.py b/comfy_extras/chainner_models/architecture/MAT.py deleted file mode 100644 index 8fe170266..000000000 --- a/comfy_extras/chainner_models/architecture/MAT.py +++ /dev/null @@ -1,1636 +0,0 @@ -# pylint: skip-file -"""Original MAT project is copyright of fenglingwb: https://github.com/fenglinglwb/MAT -Code used for this implementation of MAT is modified from lama-cleaner, -copyright of Sanster: https://github.com/fenglinglwb/MAT""" - -import random - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint - -from .mat.utils import ( - Conv2dLayer, - FullyConnectedLayer, - activation_funcs, - bias_act, - conv2d_resample, - normalize_2nd_moment, - setup_filter, - to_2tuple, - upsample2d, -) - - -class ModulatedConv2d(nn.Module): - def __init__( - self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - kernel_size, # Width and height of the convolution kernel. - style_dim, # dimension of the style code - demodulate=True, # perfrom demodulation - up=1, # Integer upsampling factor. - down=1, # Integer downsampling factor. - resample_filter=[ - 1, - 3, - 3, - 1, - ], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output to +-X, None = disable clamping. - ): - super().__init__() - self.demodulate = demodulate - - self.weight = torch.nn.Parameter( - torch.randn([1, out_channels, in_channels, kernel_size, kernel_size]) - ) - self.out_channels = out_channels - self.kernel_size = kernel_size - self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) - self.padding = self.kernel_size // 2 - self.up = up - self.down = down - self.register_buffer("resample_filter", setup_filter(resample_filter)) - self.conv_clamp = conv_clamp - - self.affine = FullyConnectedLayer(style_dim, in_channels, bias_init=1) - - def forward(self, x, style): - batch, in_channels, height, width = x.shape - style = self.affine(style).view(batch, 1, in_channels, 1, 1).to(x.device) - weight = self.weight.to(x.device) * self.weight_gain * style - - if self.demodulate: - decoefs = (weight.pow(2).sum(dim=[2, 3, 4]) + 1e-8).rsqrt() - weight = weight * decoefs.view(batch, self.out_channels, 1, 1, 1) - - weight = weight.view( - batch * self.out_channels, in_channels, self.kernel_size, self.kernel_size - ) - x = x.view(1, batch * in_channels, height, width) - x = conv2d_resample( - x=x, - w=weight, - f=self.resample_filter, - up=self.up, - down=self.down, - padding=self.padding, - groups=batch, - ) - out = x.view(batch, self.out_channels, *x.shape[2:]) - - return out - - -class StyleConv(torch.nn.Module): - def __init__( - self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - style_dim, # Intermediate latent (W) dimensionality. - resolution, # Resolution of this layer. - kernel_size=3, # Convolution kernel size. - up=1, # Integer upsampling factor. - use_noise=False, # Enable noise input? - activation="lrelu", # Activation function: 'relu', 'lrelu', etc. - resample_filter=[ - 1, - 3, - 3, - 1, - ], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. - demodulate=True, # perform demodulation - ): - super().__init__() - - self.conv = ModulatedConv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - style_dim=style_dim, - demodulate=demodulate, - up=up, - resample_filter=resample_filter, - conv_clamp=conv_clamp, - ) - - self.use_noise = use_noise - self.resolution = resolution - if use_noise: - self.register_buffer("noise_const", torch.randn([resolution, resolution])) - self.noise_strength = torch.nn.Parameter(torch.zeros([])) - - self.bias = torch.nn.Parameter(torch.zeros([out_channels])) - self.activation = activation - self.act_gain = activation_funcs[activation].def_gain - self.conv_clamp = conv_clamp - - def forward(self, x, style, noise_mode="random", gain=1): - x = self.conv(x, style) - - assert noise_mode in ["random", "const", "none"] - - if self.use_noise: - if noise_mode == "random": - xh, xw = x.size()[-2:] - noise = ( - torch.randn([x.shape[0], 1, xh, xw], device=x.device) - * self.noise_strength - ) - if noise_mode == "const": - noise = self.noise_const * self.noise_strength - x = x + noise - - act_gain = self.act_gain * gain - act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None - out = bias_act( - x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp - ) - - return out - - -class ToRGB(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - style_dim, - kernel_size=1, - resample_filter=[1, 3, 3, 1], - conv_clamp=None, - demodulate=False, - ): - super().__init__() - - self.conv = ModulatedConv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - style_dim=style_dim, - demodulate=demodulate, - resample_filter=resample_filter, - conv_clamp=conv_clamp, - ) - self.bias = torch.nn.Parameter(torch.zeros([out_channels])) - self.register_buffer("resample_filter", setup_filter(resample_filter)) - self.conv_clamp = conv_clamp - - def forward(self, x, style, skip=None): - x = self.conv(x, style) - out = bias_act(x, self.bias, clamp=self.conv_clamp) - - if skip is not None: - if skip.shape != out.shape: - skip = upsample2d(skip, self.resample_filter) - out = out + skip - - return out - - -def get_style_code(a, b): - return torch.cat([a, b.to(a.device)], dim=1) - - -class DecBlockFirst(nn.Module): - def __init__( - self, - in_channels, - out_channels, - activation, - style_dim, - use_noise, - demodulate, - img_channels, - ): - super().__init__() - self.fc = FullyConnectedLayer( - in_features=in_channels * 2, - out_features=in_channels * 4**2, - activation=activation, - ) - self.conv = StyleConv( - in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=4, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB( - in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) - - def forward(self, x, ws, gs, E_features, noise_mode="random"): - x = self.fc(x).view(x.shape[0], -1, 4, 4) - x = x + E_features[2] - style = get_style_code(ws[:, 0], gs) - x = self.conv(x, style, noise_mode=noise_mode) - style = get_style_code(ws[:, 1], gs) - img = self.toRGB(x, style, skip=None) - - return x, img - - -class MappingNet(torch.nn.Module): - def __init__( - self, - z_dim, # Input latent (Z) dimensionality, 0 = no latent. - c_dim, # Conditioning label (C) dimensionality, 0 = no label. - w_dim, # Intermediate latent (W) dimensionality. - num_ws, # Number of intermediate latents to output, None = do not broadcast. - num_layers=8, # Number of mapping layers. - embed_features=None, # Label embedding dimensionality, None = same as w_dim. - layer_features=None, # Number of intermediate features in the mapping layers, None = same as w_dim. - activation="lrelu", # Activation function: 'relu', 'lrelu', etc. - lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. - w_avg_beta=0.995, # Decay for tracking the moving average of W during training, None = do not track. - ): - super().__init__() - self.z_dim = z_dim - self.c_dim = c_dim - self.w_dim = w_dim - self.num_ws = num_ws - self.num_layers = num_layers - self.w_avg_beta = w_avg_beta - - if embed_features is None: - embed_features = w_dim - if c_dim == 0: - embed_features = 0 - if layer_features is None: - layer_features = w_dim - features_list = ( - [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] - ) - - if c_dim > 0: - self.embed = FullyConnectedLayer(c_dim, embed_features) - for idx in range(num_layers): - in_features = features_list[idx] - out_features = features_list[idx + 1] - layer = FullyConnectedLayer( - in_features, - out_features, - activation=activation, - lr_multiplier=lr_multiplier, - ) - setattr(self, f"fc{idx}", layer) - - if num_ws is not None and w_avg_beta is not None: - self.register_buffer("w_avg", torch.zeros([w_dim])) - - def forward( - self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False - ): - # Embed, normalize, and concat inputs. - x = None - with torch.autograd.profiler.record_function("input"): - if self.z_dim > 0: - x = normalize_2nd_moment(z.to(torch.float32)) - if self.c_dim > 0: - y = normalize_2nd_moment(self.embed(c.to(torch.float32))) - x = torch.cat([x, y], dim=1) if x is not None else y - - # Main layers. - for idx in range(self.num_layers): - layer = getattr(self, f"fc{idx}") - x = layer(x) - - # Update moving average of W. - if self.w_avg_beta is not None and self.training and not skip_w_avg_update: - with torch.autograd.profiler.record_function("update_w_avg"): - self.w_avg.copy_( - x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta) - ) - - # Broadcast. - if self.num_ws is not None: - with torch.autograd.profiler.record_function("broadcast"): - x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) - - # Apply truncation. - if truncation_psi != 1: - with torch.autograd.profiler.record_function("truncate"): - assert self.w_avg_beta is not None - if self.num_ws is None or truncation_cutoff is None: - x = self.w_avg.lerp(x, truncation_psi) - else: - x[:, :truncation_cutoff] = self.w_avg.lerp( - x[:, :truncation_cutoff], truncation_psi - ) - - return x - - -class DisFromRGB(nn.Module): - def __init__( - self, in_channels, out_channels, activation - ): # res = 2, ..., resolution_log2 - super().__init__() - self.conv = Conv2dLayer( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - activation=activation, - ) - - def forward(self, x): - return self.conv(x) - - -class DisBlock(nn.Module): - def __init__( - self, in_channels, out_channels, activation - ): # res = 2, ..., resolution_log2 - super().__init__() - self.conv0 = Conv2dLayer( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=3, - activation=activation, - ) - self.conv1 = Conv2dLayer( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - down=2, - activation=activation, - ) - self.skip = Conv2dLayer( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - down=2, - bias=False, - ) - - def forward(self, x): - skip = self.skip(x, gain=np.sqrt(0.5)) - x = self.conv0(x) - x = self.conv1(x, gain=np.sqrt(0.5)) - out = skip + x - - return out - - -def nf(stage, channel_base=32768, channel_decay=1.0, channel_max=512): - NF = {512: 64, 256: 128, 128: 256, 64: 512, 32: 512, 16: 512, 8: 512, 4: 512} - return NF[2**stage] - - -class Mlp(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.0, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = FullyConnectedLayer( - in_features=in_features, out_features=hidden_features, activation="lrelu" - ) - self.fc2 = FullyConnectedLayer( - in_features=hidden_features, out_features=out_features - ) - - def forward(self, x): - x = self.fc1(x) - x = self.fc2(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = ( - x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - ) - return windows - - -def window_reverse(windows, window_size: int, H: int, W: int): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - # B = windows.shape[0] / (H * W / window_size / window_size) - x = windows.view( - B, H // window_size, W // window_size, window_size, window_size, -1 - ) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class Conv2dLayerPartial(nn.Module): - def __init__( - self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - kernel_size, # Width and height of the convolution kernel. - bias=True, # Apply additive bias before the activation function? - activation="linear", # Activation function: 'relu', 'lrelu', etc. - up=1, # Integer upsampling factor. - down=1, # Integer downsampling factor. - resample_filter=[ - 1, - 3, - 3, - 1, - ], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output to +-X, None = disable clamping. - trainable=True, # Update the weights of this layer during training? - ): - super().__init__() - self.conv = Conv2dLayer( - in_channels, - out_channels, - kernel_size, - bias, - activation, - up, - down, - resample_filter, - conv_clamp, - trainable, - ) - - self.weight_maskUpdater = torch.ones(1, 1, kernel_size, kernel_size) - self.slide_winsize = kernel_size**2 - self.stride = down - self.padding = kernel_size // 2 if kernel_size % 2 == 1 else 0 - - def forward(self, x, mask=None): - if mask is not None: - with torch.no_grad(): - if self.weight_maskUpdater.type() != x.type(): - self.weight_maskUpdater = self.weight_maskUpdater.to(x) - update_mask = F.conv2d( - mask, - self.weight_maskUpdater, - bias=None, - stride=self.stride, - padding=self.padding, - ) - mask_ratio = self.slide_winsize / (update_mask + 1e-8) - update_mask = torch.clamp(update_mask, 0, 1) # 0 or 1 - mask_ratio = torch.mul(mask_ratio, update_mask) - x = self.conv(x) - x = torch.mul(x, mask_ratio) - return x, update_mask - else: - x = self.conv(x) - return x, None - - -class WindowAttention(nn.Module): - r"""Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__( - self, - dim, - window_size, - num_heads, - down_ratio=1, - qkv_bias=True, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - ): - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.q = FullyConnectedLayer(in_features=dim, out_features=dim) - self.k = FullyConnectedLayer(in_features=dim, out_features=dim) - self.v = FullyConnectedLayer(in_features=dim, out_features=dim) - self.proj = FullyConnectedLayer(in_features=dim, out_features=dim) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask_windows=None, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - norm_x = F.normalize(x, p=2.0, dim=-1) - q = ( - self.q(norm_x) - .reshape(B_, N, self.num_heads, C // self.num_heads) - .permute(0, 2, 1, 3) - ) - k = ( - self.k(norm_x) - .view(B_, -1, self.num_heads, C // self.num_heads) - .permute(0, 2, 3, 1) - ) - v = ( - self.v(x) - .view(B_, -1, self.num_heads, C // self.num_heads) - .permute(0, 2, 1, 3) - ) - - attn = (q @ k) * self.scale - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( - 1 - ).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - - if mask_windows is not None: - attn_mask_windows = mask_windows.squeeze(-1).unsqueeze(1).unsqueeze(1) - attn = attn + attn_mask_windows.masked_fill( - attn_mask_windows == 0, float(-100.0) - ).masked_fill(attn_mask_windows == 1, float(0.0)) - with torch.no_grad(): - mask_windows = torch.clamp( - torch.sum(mask_windows, dim=1, keepdim=True), 0, 1 - ).repeat(1, N, 1) - - attn = self.softmax(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - return x, mask_windows - - -class SwinTransformerBlock(nn.Module): - r"""Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__( - self, - dim, - input_resolution, - num_heads, - down_ratio=1, - window_size=7, - shift_size=0, - mlp_ratio=4.0, - qkv_bias=True, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - ): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert ( - 0 <= self.shift_size < self.window_size - ), "shift_size must in 0-window_size" - - if self.shift_size > 0: - down_ratio = 1 - self.attn = WindowAttention( - dim, - window_size=to_2tuple(self.window_size), - num_heads=num_heads, - down_ratio=down_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - ) - - self.fuse = FullyConnectedLayer( - in_features=dim * 2, out_features=dim, activation="lrelu" - ) - - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - ) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - w_slices = ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None), - ) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition( - img_mask, self.window_size - ) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( - attn_mask == 0, float(0.0) - ) - - return attn_mask - - def forward(self, x, x_size, mask=None): - # H, W = self.input_resolution - H, W = x_size - B, _, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = x.view(B, H, W, C) - if mask is not None: - mask = mask.view(B, H, W, 1) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll( - x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) - ) - if mask is not None: - shifted_mask = torch.roll( - mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) - ) - else: - shifted_x = x - if mask is not None: - shifted_mask = mask - - # partition windows - x_windows = window_partition( - shifted_x, self.window_size - ) # nW*B, window_size, window_size, C - x_windows = x_windows.view( - -1, self.window_size * self.window_size, C - ) # nW*B, window_size*window_size, C - if mask is not None: - mask_windows = window_partition(shifted_mask, self.window_size) - mask_windows = mask_windows.view(-1, self.window_size * self.window_size, 1) - else: - mask_windows = None - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows, mask_windows = self.attn( - x_windows, mask_windows, mask=self.attn_mask - ) # nW*B, window_size*window_size, C - else: - attn_windows, mask_windows = self.attn( - x_windows, mask_windows, mask=self.calculate_mask(x_size).to(x.device) - ) # nW*B, window_size*window_size, C - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - if mask is not None: - mask_windows = mask_windows.view(-1, self.window_size, self.window_size, 1) - shifted_mask = window_reverse(mask_windows, self.window_size, H, W) - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll( - shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) - ) - if mask is not None: - mask = torch.roll( - shifted_mask, shifts=(self.shift_size, self.shift_size), dims=(1, 2) - ) - else: - x = shifted_x - if mask is not None: - mask = shifted_mask - x = x.view(B, H * W, C) - if mask is not None: - mask = mask.view(B, H * W, 1) - - # FFN - x = self.fuse(torch.cat([shortcut, x], dim=-1)) - x = self.mlp(x) - - return x, mask - - -class PatchMerging(nn.Module): - def __init__(self, in_channels, out_channels, down=2): - super().__init__() - self.conv = Conv2dLayerPartial( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - activation="lrelu", - down=down, - ) - self.down = down - - def forward(self, x, x_size, mask=None): - x = token2feature(x, x_size) - if mask is not None: - mask = token2feature(mask, x_size) - x, mask = self.conv(x, mask) - if self.down != 1: - ratio = 1 / self.down - x_size = (int(x_size[0] * ratio), int(x_size[1] * ratio)) - x = feature2token(x) - if mask is not None: - mask = feature2token(mask) - return x, x_size, mask - - -class PatchUpsampling(nn.Module): - def __init__(self, in_channels, out_channels, up=2): - super().__init__() - self.conv = Conv2dLayerPartial( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - activation="lrelu", - up=up, - ) - self.up = up - - def forward(self, x, x_size, mask=None): - x = token2feature(x, x_size) - if mask is not None: - mask = token2feature(mask, x_size) - x, mask = self.conv(x, mask) - if self.up != 1: - x_size = (int(x_size[0] * self.up), int(x_size[1] * self.up)) - x = feature2token(x) - if mask is not None: - mask = feature2token(mask) - return x, x_size, mask - - -class BasicLayer(nn.Module): - """A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__( - self, - dim, - input_resolution, - depth, - num_heads, - window_size, - down_ratio=1, - mlp_ratio=2.0, - qkv_bias=True, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - norm_layer=nn.LayerNorm, - downsample=None, - use_checkpoint=False, - ): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # patch merging layer - if downsample is not None: - # self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - self.downsample = downsample - else: - self.downsample = None - - # build blocks - self.blocks = nn.ModuleList( - [ - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - down_ratio=down_ratio, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] - if isinstance(drop_path, list) - else drop_path, - norm_layer=norm_layer, - ) - for i in range(depth) - ] - ) - - self.conv = Conv2dLayerPartial( - in_channels=dim, out_channels=dim, kernel_size=3, activation="lrelu" - ) - - def forward(self, x, x_size, mask=None): - if self.downsample is not None: - x, x_size, mask = self.downsample(x, x_size, mask) - identity = x - for blk in self.blocks: - if self.use_checkpoint: - x, mask = checkpoint.checkpoint(blk, x, x_size, mask) - else: - x, mask = blk(x, x_size, mask) - if mask is not None: - mask = token2feature(mask, x_size) - x, mask = self.conv(token2feature(x, x_size), mask) - x = feature2token(x) + identity - if mask is not None: - mask = feature2token(mask) - return x, x_size, mask - - -class ToToken(nn.Module): - def __init__(self, in_channels=3, dim=128, kernel_size=5, stride=1): - super().__init__() - - self.proj = Conv2dLayerPartial( - in_channels=in_channels, - out_channels=dim, - kernel_size=kernel_size, - activation="lrelu", - ) - - def forward(self, x, mask): - x, mask = self.proj(x, mask) - - return x, mask - - -class EncFromRGB(nn.Module): - def __init__( - self, in_channels, out_channels, activation - ): # res = 2, ..., resolution_log2 - super().__init__() - self.conv0 = Conv2dLayer( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - activation=activation, - ) - self.conv1 = Conv2dLayer( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - activation=activation, - ) - - def forward(self, x): - x = self.conv0(x) - x = self.conv1(x) - - return x - - -class ConvBlockDown(nn.Module): - def __init__( - self, in_channels, out_channels, activation - ): # res = 2, ..., resolution_log - super().__init__() - - self.conv0 = Conv2dLayer( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - activation=activation, - down=2, - ) - self.conv1 = Conv2dLayer( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - activation=activation, - ) - - def forward(self, x): - x = self.conv0(x) - x = self.conv1(x) - - return x - - -def token2feature(x, x_size): - B, _, C = x.shape - h, w = x_size - x = x.permute(0, 2, 1).reshape(B, C, h, w) - return x - - -def feature2token(x): - B, C, _, _ = x.shape - x = x.view(B, C, -1).transpose(1, 2) - return x - - -class Encoder(nn.Module): - def __init__( - self, - res_log2, - img_channels, - activation, - patch_size=5, - channels=16, - drop_path_rate=0.1, - ): - super().__init__() - - self.resolution = [] - - for i in range(res_log2, 3, -1): # from input size to 16x16 - res = 2**i - self.resolution.append(res) - if i == res_log2: - block = EncFromRGB(img_channels * 2 + 1, nf(i), activation) - else: - block = ConvBlockDown(nf(i + 1), nf(i), activation) - setattr(self, "EncConv_Block_%dx%d" % (res, res), block) - - def forward(self, x): - out = {} - for res in self.resolution: - res_log2 = int(np.log2(res)) - x = getattr(self, "EncConv_Block_%dx%d" % (res, res))(x) - out[res_log2] = x - - return out - - -class ToStyle(nn.Module): - def __init__(self, in_channels, out_channels, activation, drop_rate): - super().__init__() - self.conv = nn.Sequential( - Conv2dLayer( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=3, - activation=activation, - down=2, - ), - Conv2dLayer( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=3, - activation=activation, - down=2, - ), - Conv2dLayer( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=3, - activation=activation, - down=2, - ), - ) - - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc = FullyConnectedLayer( - in_features=in_channels, out_features=out_channels, activation=activation - ) - # self.dropout = nn.Dropout(drop_rate) - - def forward(self, x): - x = self.conv(x) - x = self.pool(x) - x = self.fc(x.flatten(start_dim=1)) - # x = self.dropout(x) - - return x - - -class DecBlockFirstV2(nn.Module): - def __init__( - self, - res, - in_channels, - out_channels, - activation, - style_dim, - use_noise, - demodulate, - img_channels, - ): - super().__init__() - self.res = res - - self.conv0 = Conv2dLayer( - in_channels=in_channels, - out_channels=in_channels, - kernel_size=3, - activation=activation, - ) - self.conv1 = StyleConv( - in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2**res, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB( - in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) - - def forward(self, x, ws, gs, E_features, noise_mode="random"): - # x = self.fc(x).view(x.shape[0], -1, 4, 4) - x = self.conv0(x) - x = x + E_features[self.res] - style = get_style_code(ws[:, 0], gs) - x = self.conv1(x, style, noise_mode=noise_mode) - style = get_style_code(ws[:, 1], gs) - img = self.toRGB(x, style, skip=None) - - return x, img - - -class DecBlock(nn.Module): - def __init__( - self, - res, - in_channels, - out_channels, - activation, - style_dim, - use_noise, - demodulate, - img_channels, - ): # res = 4, ..., resolution_log2 - super().__init__() - self.res = res - - self.conv0 = StyleConv( - in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2**res, - kernel_size=3, - up=2, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.conv1 = StyleConv( - in_channels=out_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2**res, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB( - in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) - - def forward(self, x, img, ws, gs, E_features, noise_mode="random"): - style = get_style_code(ws[:, self.res * 2 - 9], gs) - x = self.conv0(x, style, noise_mode=noise_mode) - x = x + E_features[self.res] - style = get_style_code(ws[:, self.res * 2 - 8], gs) - x = self.conv1(x, style, noise_mode=noise_mode) - style = get_style_code(ws[:, self.res * 2 - 7], gs) - img = self.toRGB(x, style, skip=img) - - return x, img - - -class Decoder(nn.Module): - def __init__( - self, res_log2, activation, style_dim, use_noise, demodulate, img_channels - ): - super().__init__() - self.Dec_16x16 = DecBlockFirstV2( - 4, nf(4), nf(4), activation, style_dim, use_noise, demodulate, img_channels - ) - for res in range(5, res_log2 + 1): - setattr( - self, - "Dec_%dx%d" % (2**res, 2**res), - DecBlock( - res, - nf(res - 1), - nf(res), - activation, - style_dim, - use_noise, - demodulate, - img_channels, - ), - ) - self.res_log2 = res_log2 - - def forward(self, x, ws, gs, E_features, noise_mode="random"): - x, img = self.Dec_16x16(x, ws, gs, E_features, noise_mode=noise_mode) - for res in range(5, self.res_log2 + 1): - block = getattr(self, "Dec_%dx%d" % (2**res, 2**res)) - x, img = block(x, img, ws, gs, E_features, noise_mode=noise_mode) - - return img - - -class DecStyleBlock(nn.Module): - def __init__( - self, - res, - in_channels, - out_channels, - activation, - style_dim, - use_noise, - demodulate, - img_channels, - ): - super().__init__() - self.res = res - - self.conv0 = StyleConv( - in_channels=in_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2**res, - kernel_size=3, - up=2, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.conv1 = StyleConv( - in_channels=out_channels, - out_channels=out_channels, - style_dim=style_dim, - resolution=2**res, - kernel_size=3, - use_noise=use_noise, - activation=activation, - demodulate=demodulate, - ) - self.toRGB = ToRGB( - in_channels=out_channels, - out_channels=img_channels, - style_dim=style_dim, - kernel_size=1, - demodulate=False, - ) - - def forward(self, x, img, style, skip, noise_mode="random"): - x = self.conv0(x, style, noise_mode=noise_mode) - x = x + skip - x = self.conv1(x, style, noise_mode=noise_mode) - img = self.toRGB(x, style, skip=img) - - return x, img - - -class FirstStage(nn.Module): - def __init__( - self, - img_channels, - img_resolution=256, - dim=180, - w_dim=512, - use_noise=False, - demodulate=True, - activation="lrelu", - ): - super().__init__() - res = 64 - - self.conv_first = Conv2dLayerPartial( - in_channels=img_channels + 1, - out_channels=dim, - kernel_size=3, - activation=activation, - ) - self.enc_conv = nn.ModuleList() - down_time = int(np.log2(img_resolution // res)) - # 根据图片尺寸构建 swim transformer 的层数 - for i in range(down_time): # from input size to 64 - self.enc_conv.append( - Conv2dLayerPartial( - in_channels=dim, - out_channels=dim, - kernel_size=3, - down=2, - activation=activation, - ) - ) - - # from 64 -> 16 -> 64 - depths = [2, 3, 4, 3, 2] - ratios = [1, 1 / 2, 1 / 2, 2, 2] - num_heads = 6 - window_sizes = [8, 16, 16, 16, 8] - drop_path_rate = 0.1 - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - - self.tran = nn.ModuleList() - for i, depth in enumerate(depths): - res = int(res * ratios[i]) - if ratios[i] < 1: - merge = PatchMerging(dim, dim, down=int(1 / ratios[i])) - elif ratios[i] > 1: - merge = PatchUpsampling(dim, dim, up=ratios[i]) - else: - merge = None - self.tran.append( - BasicLayer( - dim=dim, - input_resolution=[res, res], - depth=depth, - num_heads=num_heads, - window_size=window_sizes[i], - drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], - downsample=merge, - ) - ) - - # global style - down_conv = [] - for i in range(int(np.log2(16))): - down_conv.append( - Conv2dLayer( - in_channels=dim, - out_channels=dim, - kernel_size=3, - down=2, - activation=activation, - ) - ) - down_conv.append(nn.AdaptiveAvgPool2d((1, 1))) - self.down_conv = nn.Sequential(*down_conv) - self.to_style = FullyConnectedLayer( - in_features=dim, out_features=dim * 2, activation=activation - ) - self.ws_style = FullyConnectedLayer( - in_features=w_dim, out_features=dim, activation=activation - ) - self.to_square = FullyConnectedLayer( - in_features=dim, out_features=16 * 16, activation=activation - ) - - style_dim = dim * 3 - self.dec_conv = nn.ModuleList() - for i in range(down_time): # from 64 to input size - res = res * 2 - self.dec_conv.append( - DecStyleBlock( - res, - dim, - dim, - activation, - style_dim, - use_noise, - demodulate, - img_channels, - ) - ) - - def forward(self, images_in, masks_in, ws, noise_mode="random"): - x = torch.cat([masks_in - 0.5, images_in * masks_in], dim=1) - - skips = [] - x, mask = self.conv_first(x, masks_in) # input size - skips.append(x) - for i, block in enumerate(self.enc_conv): # input size to 64 - x, mask = block(x, mask) - if i != len(self.enc_conv) - 1: - skips.append(x) - - x_size = x.size()[-2:] - x = feature2token(x) - mask = feature2token(mask) - mid = len(self.tran) // 2 - for i, block in enumerate(self.tran): # 64 to 16 - if i < mid: - x, x_size, mask = block(x, x_size, mask) - skips.append(x) - elif i > mid: - x, x_size, mask = block(x, x_size, None) - x = x + skips[mid - i] - else: - x, x_size, mask = block(x, x_size, None) - - mul_map = torch.ones_like(x) * 0.5 - mul_map = F.dropout(mul_map, training=True).to(x.device) - ws = self.ws_style(ws[:, -1]).to(x.device) - add_n = self.to_square(ws).unsqueeze(1).to(x.device) - add_n = ( - F.interpolate( - add_n, size=x.size(1), mode="linear", align_corners=False - ) - .squeeze(1) - .unsqueeze(-1) - ).to(x.device) - x = x * mul_map + add_n * (1 - mul_map) - gs = self.to_style( - self.down_conv(token2feature(x, x_size)).flatten(start_dim=1) - ).to(x.device) - style = torch.cat([gs, ws], dim=1) - - x = token2feature(x, x_size).contiguous() - img = None - for i, block in enumerate(self.dec_conv): - x, img = block( - x, img, style, skips[len(self.dec_conv) - i - 1], noise_mode=noise_mode - ) - - # ensemble - img = img * (1 - masks_in) + images_in * masks_in - - return img - - -class SynthesisNet(nn.Module): - def __init__( - self, - w_dim, # Intermediate latent (W) dimensionality. - img_resolution, # Output image resolution. - img_channels=3, # Number of color channels. - channel_base=32768, # Overall multiplier for the number of channels. - channel_decay=1.0, - channel_max=512, # Maximum number of channels in any layer. - activation="lrelu", # Activation function: 'relu', 'lrelu', etc. - drop_rate=0.5, - use_noise=False, - demodulate=True, - ): - super().__init__() - resolution_log2 = int(np.log2(img_resolution)) - assert img_resolution == 2**resolution_log2 and img_resolution >= 4 - - self.num_layers = resolution_log2 * 2 - 3 * 2 - self.img_resolution = img_resolution - self.resolution_log2 = resolution_log2 - - # first stage - self.first_stage = FirstStage( - img_channels, - img_resolution=img_resolution, - w_dim=w_dim, - use_noise=False, - demodulate=demodulate, - ) - - # second stage - self.enc = Encoder( - resolution_log2, img_channels, activation, patch_size=5, channels=16 - ) - self.to_square = FullyConnectedLayer( - in_features=w_dim, out_features=16 * 16, activation=activation - ) - self.to_style = ToStyle( - in_channels=nf(4), - out_channels=nf(2) * 2, - activation=activation, - drop_rate=drop_rate, - ) - style_dim = w_dim + nf(2) * 2 - self.dec = Decoder( - resolution_log2, activation, style_dim, use_noise, demodulate, img_channels - ) - - def forward(self, images_in, masks_in, ws, noise_mode="random", return_stg1=False): - out_stg1 = self.first_stage(images_in, masks_in, ws, noise_mode=noise_mode) - - # encoder - x = images_in * masks_in + out_stg1 * (1 - masks_in) - x = torch.cat([masks_in - 0.5, x, images_in * masks_in], dim=1) - E_features = self.enc(x) - - fea_16 = E_features[4].to(x.device) - mul_map = torch.ones_like(fea_16) * 0.5 - mul_map = F.dropout(mul_map, training=True).to(x.device) - add_n = self.to_square(ws[:, 0]).view(-1, 16, 16).unsqueeze(1) - add_n = F.interpolate( - add_n, size=fea_16.size()[-2:], mode="bilinear", align_corners=False - ).to(x.device) - fea_16 = fea_16 * mul_map + add_n * (1 - mul_map) - E_features[4] = fea_16 - - # style - gs = self.to_style(fea_16).to(x.device) - - # decoder - img = self.dec(fea_16, ws, gs, E_features, noise_mode=noise_mode).to(x.device) - - # ensemble - img = img * (1 - masks_in) + images_in * masks_in - - if not return_stg1: - return img - else: - return img, out_stg1 - - -class Generator(nn.Module): - def __init__( - self, - z_dim, # Input latent (Z) dimensionality, 0 = no latent. - c_dim, # Conditioning label (C) dimensionality, 0 = no label. - w_dim, # Intermediate latent (W) dimensionality. - img_resolution, # resolution of generated image - img_channels, # Number of input color channels. - synthesis_kwargs={}, # Arguments for SynthesisNetwork. - mapping_kwargs={}, # Arguments for MappingNetwork. - ): - super().__init__() - self.z_dim = z_dim - self.c_dim = c_dim - self.w_dim = w_dim - self.img_resolution = img_resolution - self.img_channels = img_channels - - self.synthesis = SynthesisNet( - w_dim=w_dim, - img_resolution=img_resolution, - img_channels=img_channels, - **synthesis_kwargs, - ) - self.mapping = MappingNet( - z_dim=z_dim, - c_dim=c_dim, - w_dim=w_dim, - num_ws=self.synthesis.num_layers, - **mapping_kwargs, - ) - - def forward( - self, - images_in, - masks_in, - z, - c, - truncation_psi=1, - truncation_cutoff=None, - skip_w_avg_update=False, - noise_mode="none", - return_stg1=False, - ): - ws = self.mapping( - z, - c, - truncation_psi=truncation_psi, - truncation_cutoff=truncation_cutoff, - skip_w_avg_update=skip_w_avg_update, - ) - img = self.synthesis(images_in, masks_in, ws, noise_mode=noise_mode) - return img - - -class MAT(nn.Module): - def __init__(self, state_dict): - super(MAT, self).__init__() - self.model_arch = "MAT" - self.sub_type = "Inpaint" - self.in_nc = 3 - self.out_nc = 3 - self.scale = 1 - - self.supports_fp16 = False - self.supports_bf16 = True - - self.min_size = 512 - self.pad_mod = 512 - self.pad_to_square = True - - seed = 240 # pick up a random number - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - - self.model = Generator( - z_dim=512, c_dim=0, w_dim=512, img_resolution=512, img_channels=3 - ) - self.z = torch.from_numpy(np.random.randn(1, self.model.z_dim)) # [1., 512] - self.label = torch.zeros([1, self.model.c_dim]) - self.state = { - k.replace("synthesis", "model.synthesis").replace( - "mapping", "model.mapping" - ): v - for k, v in state_dict.items() - } - self.load_state_dict(self.state, strict=False) - - def forward(self, image, mask): - """Input images and output images have same size - images: [H, W, C] RGB - masks: [H, W] mask area == 255 - return: BGR IMAGE - """ - - image = image * 2 - 1 # [0, 1] -> [-1, 1] - mask = 1 - mask - - output = self.model( - image, mask, self.z, self.label, truncation_psi=1, noise_mode="none" - ) - - return output * 0.5 + 0.5 diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py index dec169520..1e1c3f35e 100644 --- a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py +++ b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py @@ -56,7 +56,17 @@ class OmniSR(nn.Module): residual_layer = [] self.res_num = res_num - self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer) + if ( + "residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight" + in state_dict.keys() + ): + rel_pos_bias_weight = state_dict[ + "residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight" + ].shape[0] + self.window_size = int((math.sqrt(rel_pos_bias_weight) + 1) / 2) + else: + self.window_size = 8 + self.up_scale = up_scale for _ in range(res_num): diff --git a/comfy_extras/chainner_models/architecture/SCUNet.py b/comfy_extras/chainner_models/architecture/SCUNet.py new file mode 100644 index 000000000..b8354a873 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/SCUNet.py @@ -0,0 +1,455 @@ +# pylint: skip-file +# ----------------------------------------------------------------------------------- +# SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278 +# Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc +# ----------------------------------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from einops.layers.torch import Rearrange + +from .timm.drop import DropPath +from .timm.weight_init import trunc_normal_ + + +# Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py +class WMSA(nn.Module): + """Self-attention module in Swin Transformer""" + + def __init__(self, input_dim, output_dim, head_dim, window_size, type): + super(WMSA, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.head_dim = head_dim + self.scale = self.head_dim**-0.5 + self.n_heads = input_dim // head_dim + self.window_size = window_size + self.type = type + self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) + + self.relative_position_params = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads) + ) + # TODO recover + # self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1)) + self.relative_position_params = nn.Parameter( + torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads) + ) + + self.linear = nn.Linear(self.input_dim, self.output_dim) + + trunc_normal_(self.relative_position_params, std=0.02) + self.relative_position_params = torch.nn.Parameter( + self.relative_position_params.view( + 2 * window_size - 1, 2 * window_size - 1, self.n_heads + ) + .transpose(1, 2) + .transpose(0, 1) + ) + + def generate_mask(self, h, w, p, shift): + """generating the mask of SW-MSA + Args: + shift: shift parameters in CyclicShift. + Returns: + attn_mask: should be (1 1 w p p), + """ + # supporting square. + attn_mask = torch.zeros( + h, + w, + p, + p, + p, + p, + dtype=torch.bool, + device=self.relative_position_params.device, + ) + if self.type == "W": + return attn_mask + + s = p - shift + attn_mask[-1, :, :s, :, s:, :] = True + attn_mask[-1, :, s:, :, :s, :] = True + attn_mask[:, -1, :, :s, :, s:] = True + attn_mask[:, -1, :, s:, :, :s] = True + attn_mask = rearrange( + attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)" + ) + return attn_mask + + def forward(self, x): + """Forward pass of Window Multi-head Self-attention module. + Args: + x: input tensor with shape of [b h w c]; + attn_mask: attention mask, fill -inf where the value is True; + Returns: + output: tensor shape [b h w c] + """ + if self.type != "W": + x = torch.roll( + x, + shifts=(-(self.window_size // 2), -(self.window_size // 2)), + dims=(1, 2), + ) + + x = rearrange( + x, + "b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c", + p1=self.window_size, + p2=self.window_size, + ) + h_windows = x.size(1) + w_windows = x.size(2) + # square validation + # assert h_windows == w_windows + + x = rearrange( + x, + "b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c", + p1=self.window_size, + p2=self.window_size, + ) + qkv = self.embedding_layer(x) + q, k, v = rearrange( + qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim + ).chunk(3, dim=0) + sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale + # Adding learnable relative embedding + sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q") + # Using Attn Mask to distinguish different subwindows. + if self.type != "W": + attn_mask = self.generate_mask( + h_windows, w_windows, self.window_size, shift=self.window_size // 2 + ) + sim = sim.masked_fill_(attn_mask, float("-inf")) + + probs = nn.functional.softmax(sim, dim=-1) + output = torch.einsum("hbwij,hbwjc->hbwic", probs, v) + output = rearrange(output, "h b w p c -> b w p (h c)") + output = self.linear(output) + output = rearrange( + output, + "b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c", + w1=h_windows, + p1=self.window_size, + ) + + if self.type != "W": + output = torch.roll( + output, + shifts=(self.window_size // 2, self.window_size // 2), + dims=(1, 2), + ) + + return output + + def relative_embedding(self): + cord = torch.tensor( + np.array( + [ + [i, j] + for i in range(self.window_size) + for j in range(self.window_size) + ] + ) + ) + relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 + # negative is allowed + return self.relative_position_params[ + :, relation[:, :, 0].long(), relation[:, :, 1].long() + ] + + +class Block(nn.Module): + def __init__( + self, + input_dim, + output_dim, + head_dim, + window_size, + drop_path, + type="W", + input_resolution=None, + ): + """SwinTransformer Block""" + super(Block, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + assert type in ["W", "SW"] + self.type = type + if input_resolution <= window_size: + self.type = "W" + + self.ln1 = nn.LayerNorm(input_dim) + self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.ln2 = nn.LayerNorm(input_dim) + self.mlp = nn.Sequential( + nn.Linear(input_dim, 4 * input_dim), + nn.GELU(), + nn.Linear(4 * input_dim, output_dim), + ) + + def forward(self, x): + x = x + self.drop_path(self.msa(self.ln1(x))) + x = x + self.drop_path(self.mlp(self.ln2(x))) + return x + + +class ConvTransBlock(nn.Module): + def __init__( + self, + conv_dim, + trans_dim, + head_dim, + window_size, + drop_path, + type="W", + input_resolution=None, + ): + """SwinTransformer and Conv Block""" + super(ConvTransBlock, self).__init__() + self.conv_dim = conv_dim + self.trans_dim = trans_dim + self.head_dim = head_dim + self.window_size = window_size + self.drop_path = drop_path + self.type = type + self.input_resolution = input_resolution + + assert self.type in ["W", "SW"] + if self.input_resolution <= self.window_size: + self.type = "W" + + self.trans_block = Block( + self.trans_dim, + self.trans_dim, + self.head_dim, + self.window_size, + self.drop_path, + self.type, + self.input_resolution, + ) + self.conv1_1 = nn.Conv2d( + self.conv_dim + self.trans_dim, + self.conv_dim + self.trans_dim, + 1, + 1, + 0, + bias=True, + ) + self.conv1_2 = nn.Conv2d( + self.conv_dim + self.trans_dim, + self.conv_dim + self.trans_dim, + 1, + 1, + 0, + bias=True, + ) + + self.conv_block = nn.Sequential( + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), + nn.ReLU(True), + nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), + ) + + def forward(self, x): + conv_x, trans_x = torch.split( + self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1 + ) + conv_x = self.conv_block(conv_x) + conv_x + trans_x = Rearrange("b c h w -> b h w c")(trans_x) + trans_x = self.trans_block(trans_x) + trans_x = Rearrange("b h w c -> b c h w")(trans_x) + res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) + x = x + res + + return x + + +class SCUNet(nn.Module): + def __init__( + self, + state_dict, + in_nc=3, + config=[4, 4, 4, 4, 4, 4, 4], + dim=64, + drop_path_rate=0.0, + input_resolution=256, + ): + super(SCUNet, self).__init__() + self.model_arch = "SCUNet" + self.sub_type = "SR" + + self.num_filters: int = 0 + + self.state = state_dict + self.config = config + self.dim = dim + self.head_dim = 32 + self.window_size = 8 + + self.in_nc = in_nc + self.out_nc = self.in_nc + self.scale = 1 + self.supports_fp16 = True + + # drop path rate for each layer + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] + + self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] + + begin = 0 + self.m_down1 = [ + ConvTransBlock( + dim // 2, + dim // 2, + self.head_dim, + self.window_size, + dpr[i + begin], + "W" if not i % 2 else "SW", + input_resolution, + ) + for i in range(config[0]) + ] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] + + begin += config[0] + self.m_down2 = [ + ConvTransBlock( + dim, + dim, + self.head_dim, + self.window_size, + dpr[i + begin], + "W" if not i % 2 else "SW", + input_resolution // 2, + ) + for i in range(config[1]) + ] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] + + begin += config[1] + self.m_down3 = [ + ConvTransBlock( + 2 * dim, + 2 * dim, + self.head_dim, + self.window_size, + dpr[i + begin], + "W" if not i % 2 else "SW", + input_resolution // 4, + ) + for i in range(config[2]) + ] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] + + begin += config[2] + self.m_body = [ + ConvTransBlock( + 4 * dim, + 4 * dim, + self.head_dim, + self.window_size, + dpr[i + begin], + "W" if not i % 2 else "SW", + input_resolution // 8, + ) + for i in range(config[3]) + ] + + begin += config[3] + self.m_up3 = [ + nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), + ] + [ + ConvTransBlock( + 2 * dim, + 2 * dim, + self.head_dim, + self.window_size, + dpr[i + begin], + "W" if not i % 2 else "SW", + input_resolution // 4, + ) + for i in range(config[4]) + ] + + begin += config[4] + self.m_up2 = [ + nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), + ] + [ + ConvTransBlock( + dim, + dim, + self.head_dim, + self.window_size, + dpr[i + begin], + "W" if not i % 2 else "SW", + input_resolution // 2, + ) + for i in range(config[5]) + ] + + begin += config[5] + self.m_up1 = [ + nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), + ] + [ + ConvTransBlock( + dim // 2, + dim // 2, + self.head_dim, + self.window_size, + dpr[i + begin], + "W" if not i % 2 else "SW", + input_resolution, + ) + for i in range(config[6]) + ] + + self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] + + self.m_head = nn.Sequential(*self.m_head) + self.m_down1 = nn.Sequential(*self.m_down1) + self.m_down2 = nn.Sequential(*self.m_down2) + self.m_down3 = nn.Sequential(*self.m_down3) + self.m_body = nn.Sequential(*self.m_body) + self.m_up3 = nn.Sequential(*self.m_up3) + self.m_up2 = nn.Sequential(*self.m_up2) + self.m_up1 = nn.Sequential(*self.m_up1) + self.m_tail = nn.Sequential(*self.m_tail) + # self.apply(self._init_weights) + self.load_state_dict(state_dict, strict=True) + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (64 - h % 64) % 64 + mod_pad_w = (64 - w % 64) % 64 + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") + return x + + def forward(self, x0): + h, w = x0.size()[-2:] + x0 = self.check_image_size(x0) + + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x + x4) + x = self.m_up2(x + x3) + x = self.m_up1(x + x2) + x = self.m_tail(x + x1) + + x = x[:, :, :h, :w] + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) diff --git a/comfy_extras/chainner_models/architecture/SPSR.py b/comfy_extras/chainner_models/architecture/SPSR.py index 6f5ac458c..c3cefff19 100644 --- a/comfy_extras/chainner_models/architecture/SPSR.py +++ b/comfy_extras/chainner_models/architecture/SPSR.py @@ -60,7 +60,6 @@ class SPSRNet(nn.Module): self.out_nc: int = self.state["f_HR_conv1.0.bias"].shape[0] self.scale = self.get_scale(4) - print(self.scale) self.num_filters: int = self.state["model.0.weight"].shape[0] self.supports_fp16 = True diff --git a/comfy_extras/chainner_models/architecture/SwinIR.py b/comfy_extras/chainner_models/architecture/SwinIR.py index 8cce2d0ea..1abf450bb 100644 --- a/comfy_extras/chainner_models/architecture/SwinIR.py +++ b/comfy_extras/chainner_models/architecture/SwinIR.py @@ -972,6 +972,7 @@ class SwinIR(nn.Module): self.upsampler = upsampler self.img_size = img_size self.img_range = img_range + self.resi_connection = resi_connection self.supports_fp16 = False # Too much weirdness to support this at the moment self.supports_bfp16 = True diff --git a/comfy_extras/chainner_models/architecture/mat/utils.py b/comfy_extras/chainner_models/architecture/mat/utils.py deleted file mode 100644 index 1e9445a2c..000000000 --- a/comfy_extras/chainner_models/architecture/mat/utils.py +++ /dev/null @@ -1,698 +0,0 @@ -"""Code used for this implementation of the MAT helper utils is modified from -lama-cleaner, copyright of Sanster: https://github.com/fenglinglwb/MAT""" - -import collections -from itertools import repeat -from typing import Any - -import numpy as np -import torch -from torch import conv2d, conv_transpose2d - - -def normalize_2nd_moment(x, dim=1, eps=1e-8): - return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() - - -class EasyDict(dict): - """Convenience class that behaves like a dict but allows access with the attribute syntax.""" - - def __getattr__(self, name: str) -> Any: - try: - return self[name] - except KeyError: - raise AttributeError(name) - - def __setattr__(self, name: str, value: Any) -> None: - self[name] = value - - def __delattr__(self, name: str) -> None: - del self[name] - - -activation_funcs = { - "linear": EasyDict( - func=lambda x, **_: x, - def_alpha=0, - def_gain=1, - cuda_idx=1, - ref="", - has_2nd_grad=False, - ), - "relu": EasyDict( - func=lambda x, **_: torch.nn.functional.relu(x), - def_alpha=0, - def_gain=np.sqrt(2), - cuda_idx=2, - ref="y", - has_2nd_grad=False, - ), - "lrelu": EasyDict( - func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), - def_alpha=0.2, - def_gain=np.sqrt(2), - cuda_idx=3, - ref="y", - has_2nd_grad=False, - ), - "tanh": EasyDict( - func=lambda x, **_: torch.tanh(x), - def_alpha=0, - def_gain=1, - cuda_idx=4, - ref="y", - has_2nd_grad=True, - ), - "sigmoid": EasyDict( - func=lambda x, **_: torch.sigmoid(x), - def_alpha=0, - def_gain=1, - cuda_idx=5, - ref="y", - has_2nd_grad=True, - ), - "elu": EasyDict( - func=lambda x, **_: torch.nn.functional.elu(x), - def_alpha=0, - def_gain=1, - cuda_idx=6, - ref="y", - has_2nd_grad=True, - ), - "selu": EasyDict( - func=lambda x, **_: torch.nn.functional.selu(x), - def_alpha=0, - def_gain=1, - cuda_idx=7, - ref="y", - has_2nd_grad=True, - ), - "softplus": EasyDict( - func=lambda x, **_: torch.nn.functional.softplus(x), - def_alpha=0, - def_gain=1, - cuda_idx=8, - ref="y", - has_2nd_grad=True, - ), - "swish": EasyDict( - func=lambda x, **_: torch.sigmoid(x) * x, - def_alpha=0, - def_gain=np.sqrt(2), - cuda_idx=9, - ref="x", - has_2nd_grad=True, - ), -} - - -def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None): - """Slow reference implementation of `bias_act()` using standard TensorFlow ops.""" - assert isinstance(x, torch.Tensor) - assert clamp is None or clamp >= 0 - spec = activation_funcs[act] - alpha = float(alpha if alpha is not None else spec.def_alpha) - gain = float(gain if gain is not None else spec.def_gain) - clamp = float(clamp if clamp is not None else -1) - - # Add bias. - if b is not None: - assert isinstance(b, torch.Tensor) and b.ndim == 1 - assert 0 <= dim < x.ndim - assert b.shape[0] == x.shape[dim] - x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]).to(x.device) - - # Evaluate activation function. - alpha = float(alpha) - x = spec.func(x, alpha=alpha) - - # Scale by gain. - gain = float(gain) - if gain != 1: - x = x * gain - - # Clamp. - if clamp >= 0: - x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type - return x - - -def bias_act( - x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref" -): - r"""Fused bias and activation function. - Adds bias `b` to activation tensor `x`, evaluates activation function `act`, - and scales the result by `gain`. Each of the steps is optional. In most cases, - the fused op is considerably more efficient than performing the same calculation - using standard PyTorch ops. It supports first and second order gradients, - but not third order gradients. - Args: - x: Input activation tensor. Can be of any shape. - b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type - as `x`. The shape must be known, and it must match the dimension of `x` - corresponding to `dim`. - dim: The dimension in `x` corresponding to the elements of `b`. - The value of `dim` is ignored if `b` is not specified. - act: Name of the activation function to evaluate, or `"linear"` to disable. - Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. - See `activation_funcs` for a full list. `None` is not allowed. - alpha: Shape parameter for the activation function, or `None` to use the default. - gain: Scaling factor for the output tensor, or `None` to use default. - See `activation_funcs` for the default scaling of each activation function. - If unsure, consider specifying 1. - clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable - the clamping (default). - impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). - Returns: - Tensor of the same shape and datatype as `x`. - """ - assert isinstance(x, torch.Tensor) - assert impl in ["ref", "cuda"] - return _bias_act_ref( - x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp - ) - - -def setup_filter( - f, - device=torch.device("cpu"), - normalize=True, - flip_filter=False, - gain=1, - separable=None, -): - r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. - Args: - f: Torch tensor, numpy array, or python list of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), - `[]` (impulse), or - `None` (identity). - device: Result device (default: cpu). - normalize: Normalize the filter so that it retains the magnitude - for constant input signal (DC)? (default: True). - flip_filter: Flip the filter? (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - separable: Return a separable filter? (default: select automatically). - Returns: - Float32 tensor of the shape - `[filter_height, filter_width]` (non-separable) or - `[filter_taps]` (separable). - """ - # Validate. - if f is None: - f = 1 - f = torch.as_tensor(f, dtype=torch.float32) - assert f.ndim in [0, 1, 2] - assert f.numel() > 0 - if f.ndim == 0: - f = f[np.newaxis] - - # Separable? - if separable is None: - separable = f.ndim == 1 and f.numel() >= 8 - if f.ndim == 1 and not separable: - f = f.ger(f) - assert f.ndim == (1 if separable else 2) - - # Apply normalize, flip, gain, and device. - if normalize: - f /= f.sum() - if flip_filter: - f = f.flip(list(range(f.ndim))) - f = f * (gain ** (f.ndim / 2)) - f = f.to(device=device) - return f - - -def _get_filter_size(f): - if f is None: - return 1, 1 - - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - fw = f.shape[-1] - fh = f.shape[0] - - fw = int(fw) - fh = int(fh) - assert fw >= 1 and fh >= 1 - return fw, fh - - -def _get_weight_shape(w): - shape = [int(sz) for sz in w.shape] - return shape - - -def _parse_scaling(scaling): - if isinstance(scaling, int): - scaling = [scaling, scaling] - assert isinstance(scaling, (list, tuple)) - assert all(isinstance(x, int) for x in scaling) - sx, sy = scaling - assert sx >= 1 and sy >= 1 - return sx, sy - - -def _parse_padding(padding): - if isinstance(padding, int): - padding = [padding, padding] - assert isinstance(padding, (list, tuple)) - assert all(isinstance(x, int) for x in padding) - if len(padding) == 2: - padx, pady = padding - padding = [padx, padx, pady, pady] - padx0, padx1, pady0, pady1 = padding - return padx0, padx1, pady0, pady1 - - -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -to_2tuple = _ntuple(2) - - -def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" - # Validate arguments. - assert isinstance(x, torch.Tensor) and x.ndim == 4 - if f is None: - f = torch.ones([1, 1], dtype=torch.float32, device=x.device) - assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] - assert f.dtype == torch.float32 and not f.requires_grad - batch_size, num_channels, in_height, in_width = x.shape - # upx, upy = _parse_scaling(up) - # downx, downy = _parse_scaling(down) - - upx, upy = up, up - downx, downy = down, down - - # padx0, padx1, pady0, pady1 = _parse_padding(padding) - padx0, padx1, pady0, pady1 = padding[0], padding[1], padding[2], padding[3] - - # Upsample by inserting zeros. - x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) - x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) - x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) - - # Pad or crop. - x = torch.nn.functional.pad( - x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)] - ) - x = x[ - :, - :, - max(-pady0, 0) : x.shape[2] - max(-pady1, 0), - max(-padx0, 0) : x.shape[3] - max(-padx1, 0), - ] - - # Setup filter. - f = f * (gain ** (f.ndim / 2)) - f = f.to(x.dtype) - if not flip_filter: - f = f.flip(list(range(f.ndim))) - - # Convolve with the filter. - f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) - if f.ndim == 4: - x = conv2d(input=x, weight=f, groups=num_channels) - else: - x = conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) - x = conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) - - # Downsample by throwing away pixels. - x = x[:, :, ::downy, ::downx] - return x - - -def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"): - r"""Pad, upsample, filter, and downsample a batch of 2D images. - Performs the following sequence of operations for each channel: - 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). - 2. Pad the image with the specified number of zeros on each side (`padding`). - Negative padding corresponds to cropping the image. - 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it - so that the footprint of all output pixels lies within the input image. - 4. Downsample the image by keeping every Nth pixel (`down`). - This sequence of operations bears close resemblance to scipy.signal.upfirdn(). - The fused op is considerably more efficient than performing the same calculation - using standard PyTorch ops. It supports gradients of arbitrary order. - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - up: Integer upsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - down: Integer downsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the upsampled image. Can be a single number - or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - # assert isinstance(x, torch.Tensor) - # assert impl in ['ref', 'cuda'] - return _upfirdn2d_ref( - x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain - ) - - -def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"): - r"""Upsample a batch of 2D images using the given 2D FIR filter. - By default, the result is padded so that its shape is a multiple of the input. - User-specified padding is applied on top of that, with negative values - indicating cropping. Pixels outside the image are assumed to be zero. - Args: - x: Float32/float64/float16 input tensor of the shape - `[batch_size, num_channels, in_height, in_width]`. - f: Float32 FIR filter of the shape - `[filter_height, filter_width]` (non-separable), - `[filter_taps]` (separable), or - `None` (identity). - up: Integer upsampling factor. Can be a single int or a list/tuple - `[x, y]` (default: 1). - padding: Padding with respect to the output. Can be a single number or a - list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - flip_filter: False = convolution, True = correlation (default: False). - gain: Overall scaling factor for signal magnitude (default: 1). - impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - upx, upy = _parse_scaling(up) - # upx, upy = up, up - padx0, padx1, pady0, pady1 = _parse_padding(padding) - # padx0, padx1, pady0, pady1 = padding, padding, padding, padding - fw, fh = _get_filter_size(f) - p = [ - padx0 + (fw + upx - 1) // 2, - padx1 + (fw - upx) // 2, - pady0 + (fh + upy - 1) // 2, - pady1 + (fh - upy) // 2, - ] - return upfirdn2d( - x, - f, - up=up, - padding=p, - flip_filter=flip_filter, - gain=gain * upx * upy, - impl=impl, - ) - - -class FullyConnectedLayer(torch.nn.Module): - def __init__( - self, - in_features, # Number of input features. - out_features, # Number of output features. - bias=True, # Apply additive bias before the activation function? - activation="linear", # Activation function: 'relu', 'lrelu', etc. - lr_multiplier=1, # Learning rate multiplier. - bias_init=0, # Initial value for the additive bias. - ): - super().__init__() - self.weight = torch.nn.Parameter( - torch.randn([out_features, in_features]) / lr_multiplier - ) - self.bias = ( - torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) - if bias - else None - ) - self.activation = activation - - self.weight_gain = lr_multiplier / np.sqrt(in_features) - self.bias_gain = lr_multiplier - - def forward(self, x): - w = self.weight * self.weight_gain - b = self.bias - if b is not None and self.bias_gain != 1: - b = b * self.bias_gain - - if self.activation == "linear" and b is not None: - # out = torch.addmm(b.unsqueeze(0), x, w.t()) - x = x.matmul(w.t().to(x.device)) - out = x + b.reshape( - [-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)] - ).to(x.device) - else: - x = x.matmul(w.t().to(x.device)) - out = bias_act(x, b, act=self.activation, dim=x.ndim - 1).to(x.device) - return out - - -def _conv2d_wrapper( - x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True -): - """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.""" - out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) - - # Flip weight if requested. - if ( - not flip_weight - ): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). - w = w.flip([2, 3]) - - # Workaround performance pitfall in cuDNN 8.0.5, triggered when using - # 1x1 kernel + memory_format=channels_last + less than 64 channels. - if ( - kw == 1 - and kh == 1 - and stride == 1 - and padding in [0, [0, 0], (0, 0)] - and not transpose - ): - if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: - if out_channels <= 4 and groups == 1: - in_shape = x.shape - x = w.squeeze(3).squeeze(2) @ x.reshape( - [in_shape[0], in_channels_per_group, -1] - ) - x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) - else: - x = x.to(memory_format=torch.contiguous_format) - w = w.to(memory_format=torch.contiguous_format) - x = conv2d(x, w, groups=groups) - return x.to(memory_format=torch.channels_last) - - # Otherwise => execute using conv2d_gradfix. - op = conv_transpose2d if transpose else conv2d - return op(x, w, stride=stride, padding=padding, groups=groups) - - -def conv2d_resample( - x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False -): - r"""2D convolution with optional up/downsampling. - Padding is performed only once at the beginning, not between the operations. - Args: - x: Input tensor of shape - `[batch_size, in_channels, in_height, in_width]`. - w: Weight tensor of shape - `[out_channels, in_channels//groups, kernel_height, kernel_width]`. - f: Low-pass filter for up/downsampling. Must be prepared beforehand by - calling setup_filter(). None = identity (default). - up: Integer upsampling factor (default: 1). - down: Integer downsampling factor (default: 1). - padding: Padding with respect to the upsampled image. Can be a single number - or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` - (default: 0). - groups: Split input channels into N groups (default: 1). - flip_weight: False = convolution, True = correlation (default: True). - flip_filter: False = convolution, True = correlation (default: False). - Returns: - Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. - """ - # Validate arguments. - assert isinstance(x, torch.Tensor) and (x.ndim == 4) - assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) - assert f is None or ( - isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32 - ) - assert isinstance(up, int) and (up >= 1) - assert isinstance(down, int) and (down >= 1) - # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" - out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) - fw, fh = _get_filter_size(f) - # px0, px1, py0, py1 = _parse_padding(padding) - px0, px1, py0, py1 = padding, padding, padding, padding - - # Adjust padding to account for up/downsampling. - if up > 1: - px0 += (fw + up - 1) // 2 - px1 += (fw - up) // 2 - py0 += (fh + up - 1) // 2 - py1 += (fh - up) // 2 - if down > 1: - px0 += (fw - down + 1) // 2 - px1 += (fw - down) // 2 - py0 += (fh - down + 1) // 2 - py1 += (fh - down) // 2 - - # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. - if kw == 1 and kh == 1 and (down > 1 and up == 1): - x = upfirdn2d( - x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter - ) - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - return x - - # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. - if kw == 1 and kh == 1 and (up > 1 and down == 1): - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - x = upfirdn2d( - x=x, - f=f, - up=up, - padding=[px0, px1, py0, py1], - gain=up**2, - flip_filter=flip_filter, - ) - return x - - # Fast path: downsampling only => use strided convolution. - if down > 1 and up == 1: - x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) - x = _conv2d_wrapper( - x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight - ) - return x - - # Fast path: upsampling with optional downsampling => use transpose strided convolution. - if up > 1: - if groups == 1: - w = w.transpose(0, 1) - else: - w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) - w = w.transpose(1, 2) - w = w.reshape( - groups * in_channels_per_group, out_channels // groups, kh, kw - ) - px0 -= kw - 1 - px1 -= kw - up - py0 -= kh - 1 - py1 -= kh - up - pxt = max(min(-px0, -px1), 0) - pyt = max(min(-py0, -py1), 0) - x = _conv2d_wrapper( - x=x, - w=w, - stride=up, - padding=[pyt, pxt], - groups=groups, - transpose=True, - flip_weight=(not flip_weight), - ) - x = upfirdn2d( - x=x, - f=f, - padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], - gain=up**2, - flip_filter=flip_filter, - ) - if down > 1: - x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) - return x - - # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. - if up == 1 and down == 1: - if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: - return _conv2d_wrapper( - x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight - ) - - # Fallback: Generic reference implementation. - x = upfirdn2d( - x=x, - f=(f if up > 1 else None), - up=up, - padding=[px0, px1, py0, py1], - gain=up**2, - flip_filter=flip_filter, - ) - x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - if down > 1: - x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) - return x - - -class Conv2dLayer(torch.nn.Module): - def __init__( - self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - kernel_size, # Width and height of the convolution kernel. - bias=True, # Apply additive bias before the activation function? - activation="linear", # Activation function: 'relu', 'lrelu', etc. - up=1, # Integer upsampling factor. - down=1, # Integer downsampling factor. - resample_filter=[ - 1, - 3, - 3, - 1, - ], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output to +-X, None = disable clamping. - channels_last=False, # Expect the input to have memory_format=channels_last? - trainable=True, # Update the weights of this layer during training? - ): - super().__init__() - self.activation = activation - self.up = up - self.down = down - self.register_buffer("resample_filter", setup_filter(resample_filter)) - self.conv_clamp = conv_clamp - self.padding = kernel_size // 2 - self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) - self.act_gain = activation_funcs[activation].def_gain - - memory_format = ( - torch.channels_last if channels_last else torch.contiguous_format - ) - weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( - memory_format=memory_format - ) - bias = torch.zeros([out_channels]) if bias else None - if trainable: - self.weight = torch.nn.Parameter(weight) - self.bias = torch.nn.Parameter(bias) if bias is not None else None - else: - self.register_buffer("weight", weight) - if bias is not None: - self.register_buffer("bias", bias) - else: - self.bias = None - - def forward(self, x, gain=1): - w = self.weight * self.weight_gain - x = conv2d_resample( - x=x, - w=w, - f=self.resample_filter, - up=self.up, - down=self.down, - padding=self.padding, - ) - - act_gain = self.act_gain * gain - act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None - out = bias_act( - x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp - ) - return out diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py index 2e66e6247..e000871c1 100644 --- a/comfy_extras/chainner_models/model_loading.py +++ b/comfy_extras/chainner_models/model_loading.py @@ -1,13 +1,14 @@ import logging as logger +from .architecture.DAT import DAT from .architecture.face.codeformer import CodeFormer from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa -from .architecture.MAT import MAT from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN +from .architecture.SCUNet import SCUNet from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 from .architecture.SwiftSRGAN import Generator as SwiftSRGAN @@ -33,7 +34,6 @@ def load_state_dict(state_dict) -> PyTorchModel: state_dict = state_dict["params"] state_dict_keys = list(state_dict.keys()) - # SRVGGNet Real-ESRGAN (v2) if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys: model = RealESRGANv2(state_dict) @@ -46,12 +46,14 @@ def load_state_dict(state_dict) -> PyTorchModel: and "initial.cnn.depthwise.weight" in state_dict["model"].keys() ): model = SwiftSRGAN(state_dict) - # HAT -- be sure it is above swinir - elif "layers.0.residual_group.blocks.0.conv_block.cab.0.weight" in state_dict_keys: - model = HAT(state_dict) - # SwinIR + # SwinIR, Swin2SR, HAT elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys: - if "patch_embed.proj.weight" in state_dict_keys: + if ( + "layers.0.residual_group.blocks.0.conv_block.cab.0.weight" + in state_dict_keys + ): + model = HAT(state_dict) + elif "patch_embed.proj.weight" in state_dict_keys: model = Swin2SR(state_dict) else: model = SwinIR(state_dict) @@ -78,12 +80,15 @@ def load_state_dict(state_dict) -> PyTorchModel: or "generator.model.1.bn_l.running_mean" in state_dict_keys ): model = LaMa(state_dict) - # MAT - elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys: - model = MAT(state_dict) # Omni-SR elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys: model = OmniSR(state_dict) + # SCUNet + elif "m_head.0.weight" in state_dict_keys and "m_tail.0.weight" in state_dict_keys: + model = SCUNet(state_dict) + # DAT + elif "layers.0.blocks.2.attn.attn_mask_0" in state_dict_keys: + model = DAT(state_dict) # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1 else: try: diff --git a/comfy_extras/chainner_models/types.py b/comfy_extras/chainner_models/types.py index 1906c0c7f..193333b9e 100644 --- a/comfy_extras/chainner_models/types.py +++ b/comfy_extras/chainner_models/types.py @@ -1,20 +1,32 @@ from typing import Union +from .architecture.DAT import DAT from .architecture.face.codeformer import CodeFormer from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa -from .architecture.MAT import MAT from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN +from .architecture.SCUNet import SCUNet from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 from .architecture.SwiftSRGAN import Generator as SwiftSRGAN from .architecture.Swin2SR import Swin2SR from .architecture.SwinIR import SwinIR -PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR) +PyTorchSRModels = ( + RealESRGANv2, + SPSR, + SwiftSRGAN, + ESRGAN, + SwinIR, + Swin2SR, + HAT, + OmniSR, + SCUNet, + DAT, +) PyTorchSRModel = Union[ RealESRGANv2, SPSR, @@ -24,6 +36,8 @@ PyTorchSRModel = Union[ Swin2SR, HAT, OmniSR, + SCUNet, + DAT, ] @@ -39,8 +53,8 @@ def is_pytorch_face_model(model: object): return isinstance(model, PyTorchFaceModels) -PyTorchInpaintModels = (LaMa, MAT) -PyTorchInpaintModel = Union[LaMa, MAT] +PyTorchInpaintModels = (LaMa,) +PyTorchInpaintModel = Union[LaMa] def is_pytorch_inpaint_model(model: object): From a74c5dbf3764fa598b58da8c88da823aaf8364fa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 2 Sep 2023 22:33:37 -0400 Subject: [PATCH 028/150] Move some functions to utils.py --- comfy/supported_models.py | 12 ++++++------ comfy/supported_models_base.py | 21 +++------------------ comfy/utils.py | 14 ++++++++++++++ 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 0b3e4bcbd..bb8ae2148 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -68,7 +68,7 @@ class SD20(supported_models_base.BASE): def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} replace_prefix[""] = "cond_stage_model.model." - state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix) + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict @@ -120,7 +120,7 @@ class SDXLRefiner(supported_models_base.BASE): keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection" keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" - state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) + state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -129,7 +129,7 @@ class SDXLRefiner(supported_models_base.BASE): if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g: state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids") replace_prefix["clip_g"] = "conditioner.embedders.0.model" - state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g def clip_target(self): @@ -167,8 +167,8 @@ class SDXL(supported_models_base.BASE): keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" - state_dict = supported_models_base.state_dict_prefix_replace(state_dict, replace_prefix) - state_dict = supported_models_base.state_dict_key_replace(state_dict, keys_to_replace) + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) + state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace) return state_dict def process_clip_state_dict_for_saving(self, state_dict): @@ -183,7 +183,7 @@ class SDXL(supported_models_base.BASE): replace_prefix["clip_g"] = "conditioner.embedders.1.model" replace_prefix["clip_l"] = "conditioner.embedders.0" - state_dict_g = supported_models_base.state_dict_prefix_replace(state_dict_g, replace_prefix) + state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix) return state_dict_g def clip_target(self): diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 395a90ab4..88a1d7fde 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -3,21 +3,6 @@ from . import model_base from . import utils from . import latent_formats - -def state_dict_key_replace(state_dict, keys_to_replace): - for x in keys_to_replace: - if x in state_dict: - state_dict[keys_to_replace[x]] = state_dict.pop(x) - return state_dict - -def state_dict_prefix_replace(state_dict, replace_prefix): - for rp in replace_prefix: - replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) - for x in replace: - state_dict[x[1]] = state_dict.pop(x[0]) - return state_dict - - class ClipTarget: def __init__(self, tokenizer, clip): self.clip = clip @@ -70,13 +55,13 @@ class BASE: def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} - return state_dict_prefix_replace(state_dict, replace_prefix) + return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_unet_state_dict_for_saving(self, state_dict): replace_prefix = {"": "model.diffusion_model."} - return state_dict_prefix_replace(state_dict, replace_prefix) + return utils.state_dict_prefix_replace(state_dict, replace_prefix) def process_vae_state_dict_for_saving(self, state_dict): replace_prefix = {"": "first_stage_model."} - return state_dict_prefix_replace(state_dict, replace_prefix) + return utils.state_dict_prefix_replace(state_dict, replace_prefix) diff --git a/comfy/utils.py b/comfy/utils.py index 47f4b9709..3ed32e372 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -39,6 +39,20 @@ def calculate_parameters(sd, prefix=""): params += sd[k].nelement() return params +def state_dict_key_replace(state_dict, keys_to_replace): + for x in keys_to_replace: + if x in state_dict: + state_dict[keys_to_replace[x]] = state_dict.pop(x) + return state_dict + +def state_dict_prefix_replace(state_dict, replace_prefix): + for rp in replace_prefix: + replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), state_dict.keys()))) + for x in replace: + state_dict[x[1]] = state_dict.pop(x[0]) + return state_dict + + def transformers_convert(sd, prefix_from, prefix_to, number): keys_to_replace = { "{}positional_embedding": "{}embeddings.position_embedding.weight", From 2da73b7073dc520ee480dee8ff911b9aa83ff70a Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Sat, 2 Sep 2023 20:07:52 -0700 Subject: [PATCH 029/150] Revert changes in comfy/ldm/modules/diffusionmodules/util.py, which is unused. --- comfy/ldm/modules/diffusionmodules/util.py | 24 +++++++--------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 9d07d9359..d890c8044 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -15,7 +15,6 @@ import torch.nn as nn import numpy as np from einops import repeat -from comfy import model_management from comfy.ldm.util import instantiate_from_config import comfy.ops @@ -140,22 +139,13 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - if model_management.is_nvidia(): - with torch.enable_grad(), \ - torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - elif model_management.is_intel_xpu(): - with torch.enable_grad(), \ - torch.xpu.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params, From 6f70227b8cbc61d5e228857aed991a8ae1ef1a33 Mon Sep 17 00:00:00 2001 From: Michael Abrahams Date: Sun, 3 Sep 2023 11:51:50 -0400 Subject: [PATCH 030/150] Add support for pasting images into the graph It can be useful to paste images from the clipboard directly into the node graph. This commit modifies copy and paste handling to support this. When an image file is found in the clipboard, we check whether an image node is selected. If so, paste the image into that node. Otherwise, a new node is created. If no image data are found in the clipboard, we call the original Litegraph paste. To ensure that onCopy and onPaste events are fired, we override Litegraph's ctrl+c and ctrl+v handling. Try to detect whether the pasted image is a real file on disk, or just pixel data copied from e.g. Photoshop. Pasted pixel data will be called 'image.png' and have a creation time of now. If it is simply pasted data, we store it in the subfolder /input/clipboard/. This also adds support for the subfolder property in the IMAGEUPLOAD widget. --- web/scripts/app.js | 93 ++++++++++++++++++++++++++++++++++++------ web/scripts/widgets.js | 29 +++++++++---- 2 files changed, 102 insertions(+), 20 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 3b7483cdf..b5114604a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -667,11 +667,40 @@ export class ComfyApp { } /** - * Adds a handler on paste that extracts and loads workflows from pasted JSON data + * Adds a handler on paste that extracts and loads images or workflows from pasted JSON data */ #addPasteHandler() { document.addEventListener("paste", (e) => { - let data = (e.clipboardData || window.clipboardData).getData("text/plain"); + let data = (e.clipboardData || window.clipboardData); + const items = data.items; + + // Look for image paste data + for (const item of items) { + if (item.type.startsWith('image/')) { + var imageNode = null; + + // If an image node is selected, paste into it + if (this.canvas.current_node && + this.canvas.current_node.is_selected && + ComfyApp.isImageNode(this.canvas.current_node)) { + imageNode = this.canvas.current_node; + } + + // No image node selected: add a new one + if (!imageNode) { + const newNode = LiteGraph.createNode("LoadImage"); + newNode.pos = [...this.canvas.graph_mouse]; + imageNode = this.graph.add(newNode); + this.graph.change(); + } + const blob = item.getAsFile(); + imageNode.pasteFile(blob); + return; + } + } + + // No image found. Look for node data + data = data.getData("text/plain"); let workflow; try { data = data.slice(data.indexOf("{")); @@ -687,9 +716,29 @@ export class ComfyApp { if (workflow && workflow.version && workflow.nodes && workflow.extra) { this.loadGraphData(workflow); } + else { + // Litegraph default paste + this.canvas.pasteFromClipboard(); + } + + }); } + + /** + * Adds a handler on copy that serializes selected nodes to JSON + */ + #addCopyHandler() { + document.addEventListener("copy", (e) => { + // copy + if (this.canvas.selected_nodes) { + this.canvas.copyToClipboard(); + } + }); + } + + /** * Handle mouse * @@ -745,12 +794,6 @@ export class ComfyApp { const self = this; const origProcessKey = LGraphCanvas.prototype.processKey; LGraphCanvas.prototype.processKey = function(e) { - const res = origProcessKey.apply(this, arguments); - - if (res === false) { - return res; - } - if (!this.graph) { return; } @@ -761,9 +804,10 @@ export class ComfyApp { return; } - if (e.type == "keydown") { + if (e.type == "keydown" && !e.repeat) { + // Ctrl + M mute/unmute - if (e.keyCode == 77 && e.ctrlKey) { + if (e.key === 'm' && e.ctrlKey) { if (this.selected_nodes) { for (var i in this.selected_nodes) { if (this.selected_nodes[i].mode === 2) { // never @@ -776,7 +820,8 @@ export class ComfyApp { block_default = true; } - if (e.keyCode == 66 && e.ctrlKey) { + // Ctrl + B bypass + if (e.key === 'b' && e.ctrlKey) { if (this.selected_nodes) { for (var i in this.selected_nodes) { if (this.selected_nodes[i].mode === 4) { // never @@ -788,6 +833,28 @@ export class ComfyApp { } block_default = true; } + + // Ctrl+C Copy + if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) { + if (e.shiftKey) { + this.copyToClipboard(true); + block_default = true; + } + // Trigger default onCopy + return true; + } + + // Ctrl+V Paste + if ((e.key === 'v') && (e.metaKey || e.ctrlKey)) { + if (e.shiftKey) { + this.pasteFromClipboard(true); + block_default = true; + } + else { + // Trigger default onPaste + return true; + } + } } this.graph.change(); @@ -798,7 +865,8 @@ export class ComfyApp { return false; } - return res; + // Fall through to Litegraph defaults + return origProcessKey.apply(this, arguments); }; } @@ -1110,6 +1178,7 @@ export class ComfyApp { this.#addDrawGroupsHandler(); this.#addApiUpdateHandlers(); this.#addDropHandler(); + this.#addCopyHandler(); this.#addPasteHandler(); this.#addKeyboardHandler(); diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 5a4644b13..45ac9b896 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -76,7 +76,7 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random targetWidget.value = max; } } - return valueControl; + return valueControl; }; function seedWidget(node, inputName, inputData, app) { @@ -387,11 +387,12 @@ export const ComfyWidgets = { } }); - async function uploadFile(file, updateNode) { + async function uploadFile(file, updateNode, pasted = false) { try { // Wrap file in formdata so it includes filename const body = new FormData(); body.append("image", file); + if (pasted) body.append("subfolder", "pasted"); const resp = await api.fetchApi("/upload/image", { method: "POST", body, @@ -399,15 +400,17 @@ export const ComfyWidgets = { if (resp.status === 200) { const data = await resp.json(); - // Add the file as an option and update the widget value - if (!imageWidget.options.values.includes(data.name)) { - imageWidget.options.values.push(data.name); + // Add the file to the dropdown list and update the widget value + let path = data.name; + if (data.subfolder) path = data.subfolder + "/" + path; + + if (!imageWidget.options.values.includes(path)) { + imageWidget.options.values.push(path); } if (updateNode) { - showImage(data.name); - - imageWidget.value = data.name; + showImage(path); + imageWidget.value = path; } } else { alert(resp.status + " - " + resp.statusText); @@ -460,6 +463,16 @@ export const ComfyWidgets = { return handled; }; + node.pasteFile = function(file) { + if (file.type.startsWith("image/")) { + const is_pasted = (file.name === "image.png") && + (file.lastModified - Date.now() < 2000); + uploadFile(file, true, is_pasted); + return true; + } + return false; + } + return { widget: uploadWidget }; }, }; From 1938f5c5fe479996802c46d5c2233887e3598a40 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Sep 2023 00:58:18 -0400 Subject: [PATCH 031/150] Add a force argument to soft_empty_cache to force a cache empty. --- comfy/ldm/modules/attention.py | 2 +- comfy/ldm/modules/diffusionmodules/model.py | 1 + comfy/model_management.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 8f953d337..34484b288 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -323,7 +323,7 @@ class CrossAttentionDoggettx(nn.Module): break except model_management.OOM_EXCEPTION as e: if first_op_done == False: - model_management.soft_empty_cache() + model_management.soft_empty_cache(True) if cleared_cache == False: cleared_cache = True print("out of memory error, emptying cache and trying again") diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 431548483..5f38640c3 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -186,6 +186,7 @@ def slice_attention(q, k, v): del s2 break except model_management.OOM_EXCEPTION as e: + model_management.soft_empty_cache(True) steps *= 2 if steps > 128: raise e diff --git a/comfy/model_management.py b/comfy/model_management.py index bdbbbd843..b663e8f59 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -639,14 +639,14 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return True -def soft_empty_cache(): +def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() elif is_intel_xpu(): torch.xpu.empty_cache() elif torch.cuda.is_available(): - if is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda + if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda torch.cuda.empty_cache() torch.cuda.ipc_collect() From d19684707922af9c2399307c8d9bccc1b267cc3b Mon Sep 17 00:00:00 2001 From: Michael Poutre Date: Wed, 23 Aug 2023 16:37:31 -0700 Subject: [PATCH 032/150] feat: Add support for excluded_dirs to folder_paths.recursive_search Refactored variable names to better match what they represent --- folder_paths.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index e321690dd..16de1bb66 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -121,18 +121,25 @@ def add_model_folder_path(folder_name, full_folder_path): def get_folder_paths(folder_name): return folder_names_and_paths[folder_name][0][:] -def recursive_search(directory): +def recursive_search(directory, excluded_dir_names=None): if not os.path.isdir(directory): return [], {} + + if excluded_dir_names is None: + excluded_dir_names = [] + result = [] dirs = {directory: os.path.getmtime(directory)} - for root, subdir, file in os.walk(directory, followlinks=True): - for filepath in file: - #we os.path,join directory with a blank string to generate a path separator at the end. - result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) - for d in subdir: - path = os.path.join(root, d) + for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): + print("Checking directory: " + dirpath) + subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] + for file_name in filenames: + relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) + result.append(relative_path) + for d in subdirs: + path = os.path.join(dirpath, d) dirs[path] = os.path.getmtime(path) + print("Returning from recursive_search" + repr(result)) return result, dirs def filter_files_extensions(files, extensions): From 3e00fa433216335a874bd408de421f9d65432daf Mon Sep 17 00:00:00 2001 From: Michael Poutre Date: Wed, 23 Aug 2023 16:50:41 -0700 Subject: [PATCH 033/150] feat: Exclude .git when retrieving filename lists In the future could support user provided excluded dirs via config file --- folder_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index 16de1bb66..a18052856 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -166,7 +166,7 @@ def get_filename_list_(folder_name): folders = folder_names_and_paths[folder_name] output_folders = {} for x in folders[0]: - files, folders_all = recursive_search(x) + files, folders_all = recursive_search(x, excluded_dir_names=[".git"]) output_list.update(filter_files_extensions(files, folders[1])) output_folders = {**output_folders, **folders_all} From f368e5ac7d31649b22c0c1e44bc9fa8002fcb117 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 5 Sep 2023 01:22:03 -0400 Subject: [PATCH 034/150] Don't paste nodes when target is a textarea or a text box. --- web/scripts/app.js | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 7f5073573..9c380d3fb 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -717,8 +717,12 @@ export class ComfyApp { this.loadGraphData(workflow); } else { + if (e.target.type === "text" || e.target.type === "textarea") { + return; + } + // Litegraph default paste - this.canvas.pasteFromClipboard(); + this.canvas.pasteFromClipboard(); } From bc1f6e21856f7be25db5c5c2956b89c27db93b3d Mon Sep 17 00:00:00 2001 From: Michael Poutre Date: Tue, 5 Sep 2023 15:06:46 -0700 Subject: [PATCH 035/150] fix(ui/widgets): Only set widget forceInput option if a widget is added --- web/scripts/app.js | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 9c380d3fb..a3661da64 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1228,6 +1228,7 @@ export class ComfyApp { const inputData = inputs[inputName]; const type = inputData[0]; + let widgetCreated = true; if (Array.isArray(type)) { // Enums Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {}); @@ -1240,8 +1241,10 @@ export class ComfyApp { } else { // Node connection inputs this.addInput(inputName, type); + widgetCreated = false; } - if(inputData[1]?.forceInput && config?.widget) { + + if(widgetCreated && inputData[1]?.forceInput && config?.widget) { if (!config.widget.options) config.widget.options = {}; config.widget.options.forceInput = inputData[1].forceInput; } From 21a563d385ff520e1f7fdaada722212b35fb8d95 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 5 Sep 2023 23:46:37 -0400 Subject: [PATCH 036/150] Remove prints. --- folder_paths.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index a18052856..82aedd43f 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -131,7 +131,6 @@ def recursive_search(directory, excluded_dir_names=None): result = [] dirs = {directory: os.path.getmtime(directory)} for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): - print("Checking directory: " + dirpath) subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] for file_name in filenames: relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) @@ -139,7 +138,6 @@ def recursive_search(directory, excluded_dir_names=None): for d in subdirs: path = os.path.join(dirpath, d) dirs[path] = os.path.getmtime(path) - print("Returning from recursive_search" + repr(result)) return result, dirs def filter_files_extensions(files, extensions): From f88f7f413afbe04b42c4422e9deedbaa3269ce76 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 6 Sep 2023 03:26:55 -0400 Subject: [PATCH 037/150] Add a ConditioningSetAreaPercentage node. --- comfy/samplers.py | 15 ++++++++++++--- nodes.py | 27 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 103ac33ff..3250b2edc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -390,11 +390,20 @@ def get_mask_aabb(masks): return bounding_boxes, is_empty -def resolve_cond_masks(conditions, h, w, device): +def resolve_areas_and_cond_masks(conditions, h, w, device): # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. # While we're doing this, we can also resolve the mask device and scaling for performance reasons for i in range(len(conditions)): c = conditions[i] + if 'area' in c[1]: + area = c[1]['area'] + if area[0] == "percentage": + modified = c[1].copy() + area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w)) + modified['area'] = area + c = [c[0], modified] + conditions[i] = c + if 'mask' in c[1]: mask = c[1]['mask'] mask = mask.to(device=device) @@ -622,8 +631,8 @@ class KSampler: positive = positive[:] negative = negative[:] - resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) - resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) + resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) calculate_start_end_timesteps(self.model_wrap, negative) calculate_start_end_timesteps(self.model_wrap, positive) diff --git a/nodes.py b/nodes.py index fa26e5939..77d180526 100644 --- a/nodes.py +++ b/nodes.py @@ -159,6 +159,31 @@ class ConditioningSetArea: c.append(n) return (c, ) +class ConditioningSetAreaPercentage: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), + "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), + "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), + "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, width, height, x, y, strength): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['area'] = ("percentage", height, width, y, x) + n[1]['strength'] = strength + n[1]['set_area_to_bounds'] = False + c.append(n) + return (c, ) + class ConditioningSetMask: @classmethod def INPUT_TYPES(s): @@ -1583,6 +1608,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningCombine": ConditioningCombine, "ConditioningConcat": ConditioningConcat, "ConditioningSetArea": ConditioningSetArea, + "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, @@ -1644,6 +1670,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ConditioningAverage ": "Conditioning (Average)", "ConditioningConcat": "Conditioning (Concat)", "ConditioningSetArea": "Conditioning (Set Area)", + "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)", From cb080e771e1e792e18611ef63d2d6a49aa50a524 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 6 Sep 2023 16:18:02 -0400 Subject: [PATCH 038/150] Lower refresh timeout for search in litegraph. --- web/lib/litegraph.core.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 4bb2f0d99..4a21a1b34 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -11529,7 +11529,7 @@ LGraphNode.prototype.executeAction = function(action) if (timeout) { clearInterval(timeout); } - timeout = setTimeout(refreshHelper, 250); + timeout = setTimeout(refreshHelper, 10); return; } e.preventDefault(); From adb9eb94b0d825bb904d449cf259e7da66453a17 Mon Sep 17 00:00:00 2001 From: Chris Date: Thu, 7 Sep 2023 12:10:52 +1000 Subject: [PATCH 039/150] Send class description if any --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 57d5a65df..e84e698d6 100644 --- a/server.py +++ b/server.py @@ -398,7 +398,7 @@ class PromptServer(): info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = node_class info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class - info['description'] = '' + info['description'] = obj_class.DESCRIPTION if hasattr(node_class,'DESCRIPTION') else '' info['category'] = 'sd' if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: info['output_node'] = True From 694c705f5225be458ce3cf1db34531c17925e20d Mon Sep 17 00:00:00 2001 From: Chris Date: Thu, 7 Sep 2023 12:20:37 +1000 Subject: [PATCH 040/150] get class description --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index e84e698d6..2ebf9e235 100644 --- a/server.py +++ b/server.py @@ -398,7 +398,7 @@ class PromptServer(): info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = node_class info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class - info['description'] = obj_class.DESCRIPTION if hasattr(node_class,'DESCRIPTION') else '' + info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' info['category'] = 'sd' if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: info['output_node'] = True From 8be46438be1c848e01e4085f54ae997e2e918771 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Sep 2023 03:31:43 -0400 Subject: [PATCH 041/150] Support DiffBIR SwinIR models. --- .../chainner_models/architecture/SwinIR.py | 17 ++++++++++++++++- comfy_extras/nodes_upscale_model.py | 2 ++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/comfy_extras/chainner_models/architecture/SwinIR.py b/comfy_extras/chainner_models/architecture/SwinIR.py index 1abf450bb..439dcbcb2 100644 --- a/comfy_extras/chainner_models/architecture/SwinIR.py +++ b/comfy_extras/chainner_models/architecture/SwinIR.py @@ -846,6 +846,7 @@ class SwinIR(nn.Module): num_in_ch = in_chans num_out_ch = in_chans supports_fp16 = True + self.start_unshuffle = 1 self.model_arch = "SwinIR" self.sub_type = "SR" @@ -874,6 +875,11 @@ class SwinIR(nn.Module): else 64 ) + if "conv_first.1.weight" in self.state: + self.state["conv_first.weight"] = self.state.pop("conv_first.1.weight") + self.state["conv_first.bias"] = self.state.pop("conv_first.1.bias") + self.start_unshuffle = round(math.sqrt(self.state["conv_first.weight"].shape[1] // 3)) + num_in_ch = self.state["conv_first.weight"].shape[1] in_chans = num_in_ch if "conv_last.weight" in state_keys: @@ -968,7 +974,7 @@ class SwinIR(nn.Module): self.depths = depths self.window_size = window_size self.mlp_ratio = mlp_ratio - self.scale = upscale + self.scale = upscale / self.start_unshuffle self.upsampler = upsampler self.img_size = img_size self.img_range = img_range @@ -1101,6 +1107,9 @@ class SwinIR(nn.Module): self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) if self.upscale == 4: self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + elif self.upscale == 8: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) @@ -1157,6 +1166,9 @@ class SwinIR(nn.Module): self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range + if self.start_unshuffle > 1: + x = torch.nn.functional.pixel_unshuffle(x, self.start_unshuffle) + if self.upsampler == "pixelshuffle": # for classical SR x = self.conv_first(x) @@ -1186,6 +1198,9 @@ class SwinIR(nn.Module): ) ) ) + elif self.upscale == 8: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) x = self.conv_last(self.lrelu(self.conv_hr(x))) else: # for image denoising and JPEG compression artifact reduction diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index abd182e6e..2b5e49a55 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -18,6 +18,8 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) + if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""}) out = model_loading.load_state_dict(sd).eval() return (out, ) From 62799c858575a2a69e671560efa7eb8001e0e275 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 7 Sep 2023 18:42:21 +0100 Subject: [PATCH 042/150] fix crash on node with VALIDATE_INPUTS and actual inputs --- execution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index e10fdbb60..5f5d6c738 100644 --- a/execution.py +++ b/execution.py @@ -21,7 +21,8 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - return None + input_data_all[x] = (None,) + continue obj = outputs[input_unique_id][output_index] input_data_all[x] = obj else: From d6d1a8998fa60da9265ea3e9db35d80441cac6fd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Sep 2023 18:06:22 -0400 Subject: [PATCH 043/150] Properly check upload filename for directory transversal. --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 2ebf9e235..e58a11d86 100644 --- a/server.py +++ b/server.py @@ -170,15 +170,15 @@ class PromptServer(): subfolder = post.get("subfolder", "") full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder)) + filepath = os.path.join(full_output_folder, filename) - if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir: + if os.path.commonpath((upload_dir, os.path.abspath(filepath))) != upload_dir: return web.Response(status=400) if not os.path.exists(full_output_folder): os.makedirs(full_output_folder) split = os.path.splitext(filename) - filepath = os.path.join(full_output_folder, filename) if overwrite is not None and (overwrite == "true" or overwrite == "1"): pass From 9261587d8975bb0c3f929433345e9918bf659460 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Sep 2023 18:14:30 -0400 Subject: [PATCH 044/150] Small refactor. --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index e58a11d86..be33f4100 100644 --- a/server.py +++ b/server.py @@ -170,9 +170,9 @@ class PromptServer(): subfolder = post.get("subfolder", "") full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder)) - filepath = os.path.join(full_output_folder, filename) + filepath = os.path.abspath(os.path.join(full_output_folder, filename)) - if os.path.commonpath((upload_dir, os.path.abspath(filepath))) != upload_dir: + if os.path.commonpath((upload_dir, filepath)) != upload_dir: return web.Response(status=400) if not os.path.exists(full_output_folder): From 326577d04c99590cbf91324f507fdc2c7d37832d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Sep 2023 23:37:03 -0400 Subject: [PATCH 045/150] Allow cancelling of everything with a progress bar. --- comfy/samplers.py | 2 -- main.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 3250b2edc..c60288fd1 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -263,8 +263,6 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con output = model_function(input_x, timestep_, **c).chunk(batch_chunks) del input_x - model_management.throw_exception_if_processing_interrupted() - for o in range(batch_chunks): if cond_or_uncond[o] == COND: out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] diff --git a/main.py b/main.py index a4038db4b..9f0f80458 100644 --- a/main.py +++ b/main.py @@ -104,6 +104,7 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def hijack_progress(server): def hook(value, total, preview_image): + comfy.model_management.throw_exception_if_processing_interrupted() server.send_sync("progress", {"value": value, "max": total}, server.client_id) if preview_image is not None: server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) From 0782ac2a96fab2c436f78379db1de0df9737aa1d Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 8 Sep 2023 14:53:29 +1000 Subject: [PATCH 046/150] defaultInput --- web/extensions/core/widgetInputs.js | 2 +- web/scripts/app.js | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index f9a5b7278..606605f0a 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -142,7 +142,7 @@ app.registerExtension({ const r = origOnNodeCreated ? origOnNodeCreated.apply(this) : undefined; if (this.widgets) { for (const w of this.widgets) { - if (w?.options?.forceInput) { + if (w?.options?.forceInput || w?.options?.defaultInput) { const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}]; convertToInput(this, w, config); } diff --git a/web/scripts/app.js b/web/scripts/app.js index a3661da64..40295b350 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1248,6 +1248,10 @@ export class ComfyApp { if (!config.widget.options) config.widget.options = {}; config.widget.options.forceInput = inputData[1].forceInput; } + if(widgetCreated && inputData[1]?.defaultInput && config?.widget) { + if (!config.widget.options) config.widget.options = {}; + config.widget.options.defaultInput = inputData[1].defaultInput; + } } for (const o in nodeData["output"]) { From ff962098fdf4486fa5117e268187af964aaf586d Mon Sep 17 00:00:00 2001 From: MoonRide303 Date: Fri, 8 Sep 2023 08:43:17 +0200 Subject: [PATCH 047/150] Fixed Load Image preview not displaying some files (issue #1158) --- web/scripts/widgets.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 45ac9b896..975577631 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -335,7 +335,7 @@ export const ComfyWidgets = { subfolder = name.substring(0, folder_separator); name = name.substring(folder_separator + 1); } - img.src = api.apiURL(`/view?filename=${name}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`); + img.src = api.apiURL(`/view?filename=${encodeURIComponent(name)}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`); node.setSizeForImage?.(); } From 3ebe6b539a510c59004a6bb3b4cdb833a5612431 Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 8 Sep 2023 20:37:55 +1000 Subject: [PATCH 048/150] round float widgets (by default to 0.001) --- web/scripts/widgets.js | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 45ac9b896..68862a68c 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -2,14 +2,15 @@ import { api } from "./api.js" function getNumberDefaults(inputData, defaultStep) { let defaultVal = inputData[1]["default"]; - let { min, max, step } = inputData[1]; + let { min, max, step, round } = inputData[1]; if (defaultVal == undefined) defaultVal = 0; if (min == undefined) min = 0; if (max == undefined) max = 2048; if (step == undefined) step = defaultStep; + if (round == undefined) round = 0.001; - return { val: defaultVal, config: { min, max, step: 10.0 * step } }; + return { val: defaultVal, config: { min, max, step: 10.0 * step, round } }; } export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) { @@ -264,7 +265,10 @@ export const ComfyWidgets = { FLOAT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); const { val, config } = getNumberDefaults(inputData, 0.5); - return { widget: node.addWidget(widgetType, inputName, val, () => {}, config) }; + return { widget: node.addWidget(widgetType, inputName, val, + function (v) { + this.value = Math.round(v/config.round)*config.round; + }, config) }; }, INT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); From 1e6b67101cad777319e891afce3c7120e0dc1273 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 8 Sep 2023 11:36:51 -0400 Subject: [PATCH 049/150] Support diffusers format t2i adapters. --- comfy/controlnet.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 490be6bbc..af0df103e 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -449,10 +449,18 @@ class T2IAdapter(ControlBase): return c def load_t2i_adapter(t2i_data): - keys = t2i_data.keys() - if 'adapter' in keys: + if 'adapter' in t2i_data: t2i_data = t2i_data['adapter'] - keys = t2i_data.keys() + if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format + prefix_replace = {} + for i in range(4): + for j in range(2): + prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j) + prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2) + prefix_replace["adapter."] = "" + t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace) + keys = t2i_data.keys() + if "body.0.in_conv.weight" in keys: cin = t2i_data['body.0.in_conv.weight'].shape[1] model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4) From 264867bf87c37abdf794c9e1bab1bc512c2f5ff4 Mon Sep 17 00:00:00 2001 From: Michael Abrahams Date: Fri, 8 Sep 2023 11:17:45 -0400 Subject: [PATCH 050/150] Clear clipboard on copy --- web/scripts/app.js | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index a3661da64..72844a92b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -735,9 +735,17 @@ export class ComfyApp { */ #addCopyHandler() { document.addEventListener("copy", (e) => { - // copy + if (e.target.type === "text" || e.target.type === "textarea") { + // Default system copy + return; + } + // copy nodes and clear clipboard if (this.canvas.selected_nodes) { - this.canvas.copyToClipboard(); + this.canvas.copyToClipboard(); + e.clipboardData.clearData(); + e.preventDefault(); + e.stopImmediatePropagation(); + return false; } }); } @@ -842,10 +850,13 @@ export class ComfyApp { if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) { if (e.shiftKey) { this.copyToClipboard(true); + e.clipboardData.clearData(); block_default = true; } - // Trigger default onCopy - return true; + else { + // Trigger onCopy + return true; + } } // Ctrl+V Paste @@ -855,7 +866,7 @@ export class ComfyApp { block_default = true; } else { - // Trigger default onPaste + // Trigger onPaste return true; } } From 10de64af7f1ea22e08e39267ade7ef7f8b1607fe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 8 Sep 2023 14:02:03 -0400 Subject: [PATCH 051/150] Google doesn't want people to use ComfyUI on colab anymore. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index baa8cf8b6..d83b4bdac 100644 --- a/README.md +++ b/README.md @@ -77,9 +77,9 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. -## Colab Notebook +## Jupyter Notebook -To run it on colab or paperspace you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb) +To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb) ## Manual Install (Windows, Linux) From e85be36bd2c12f335abdf75669b994c535bbb126 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 8 Sep 2023 14:06:58 -0400 Subject: [PATCH 052/150] Add a penultimate_hidden_states to the clip vision output. --- comfy/clip_vision.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index daaa2f2bf..9b95ae003 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -49,12 +49,16 @@ class ClipVisionModel(): precision_scope = lambda a, b: contextlib.nullcontext(a) with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32): - outputs = self.model(pixel_values=pixel_values) + outputs = self.model(pixel_values=pixel_values, output_hidden_states=True) for k in outputs: t = outputs[k] if t is not None: - outputs[k] = t.cpu() + if k == 'hidden_states': + outputs["penultimate_hidden_states"] = t[-2].cpu() + else: + outputs[k] = t.cpu() + return outputs def convert_to_transformers(sd, prefix): From cc2fa311ddf5e085177e219c0fd2d1fd036551db Mon Sep 17 00:00:00 2001 From: Michael Poutre Date: Fri, 8 Sep 2023 21:11:53 -0700 Subject: [PATCH 053/150] fix(server): Disable access logs --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index be33f4100..d04060499 100644 --- a/server.py +++ b/server.py @@ -603,7 +603,7 @@ class PromptServer(): await self.send(*msg) async def start(self, address, port, verbose=True, call_on_start=None): - runner = web.AppRunner(self.app) + runner = web.AppRunner(self.app, access_log=None) await runner.setup() site = web.TCPSite(runner, address, port) await site.start() From 7372255e49b88fc0bb8416faff05ebcd88c81aba Mon Sep 17 00:00:00 2001 From: Chris Date: Sat, 9 Sep 2023 15:21:38 +1000 Subject: [PATCH 054/150] Specify the precision and rounding based on step --- web/scripts/widgets.js | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 68862a68c..8f7537b73 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -2,15 +2,19 @@ import { api } from "./api.js" function getNumberDefaults(inputData, defaultStep) { let defaultVal = inputData[1]["default"]; - let { min, max, step, round } = inputData[1]; + let { min, max, step, round, precision } = inputData[1]; if (defaultVal == undefined) defaultVal = 0; if (min == undefined) min = 0; if (max == undefined) max = 2048; if (step == undefined) step = defaultStep; - if (round == undefined) round = 0.001; +// precision is the number of decimal places to show. +// by default, display the the smallest number of decimal places such that changes of size step are visible. + if (precision == undefined) precision = Math.max(-Math.floor(Math.log10(step)),0) +// by default, round the value to those decimal places shown. + if (round == undefined) round = Math.round(1000000*Math.pow(0.1,precision))/1000000; - return { val: defaultVal, config: { min, max, step: 10.0 * step, round } }; + return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; } export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) { From 07691e80c3bf9be16c629169e259105ca5327bf0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 9 Sep 2023 03:15:31 -0400 Subject: [PATCH 055/150] Does it make sense to allow configuring the round and precision? --- web/scripts/widgets.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 4dc173b8f..30caa6a8c 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -2,7 +2,7 @@ import { api } from "./api.js" function getNumberDefaults(inputData, defaultStep) { let defaultVal = inputData[1]["default"]; - let { min, max, step, round, precision } = inputData[1]; + let { min, max, step } = inputData[1]; if (defaultVal == undefined) defaultVal = 0; if (min == undefined) min = 0; @@ -10,9 +10,9 @@ function getNumberDefaults(inputData, defaultStep) { if (step == undefined) step = defaultStep; // precision is the number of decimal places to show. // by default, display the the smallest number of decimal places such that changes of size step are visible. - if (precision == undefined) precision = Math.max(-Math.floor(Math.log10(step)),0) + let precision = Math.max(-Math.floor(Math.log10(step)),0) // by default, round the value to those decimal places shown. - if (round == undefined) round = Math.round(1000000*Math.pow(0.1,precision))/1000000; + let round = Math.round(1000000*Math.pow(0.1,precision))/1000000; return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; } From 7df822212fb2da45c8523155086456c2cd119062 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 10 Sep 2023 02:36:04 -0400 Subject: [PATCH 056/150] Allow checkpoints with .pt and .bin extensions. --- folder_paths.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 82aedd43f..4a10c68e7 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,13 @@ import os import time -supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} base_path = os.path.dirname(os.path.realpath(__file__)) models_dir = os.path.join(base_path, "models") -folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_ckpt_extensions) +folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions) folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"]) folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions) From 9562a6b49e63e63a16f3e45ff4965f72385f51fa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 10 Sep 2023 11:19:31 -0400 Subject: [PATCH 057/150] Fix a few clipboard issues. --- web/scripts/app.js | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 7ef2fc4e3..9db4e9230 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -742,7 +742,7 @@ export class ComfyApp { // copy nodes and clear clipboard if (this.canvas.selected_nodes) { this.canvas.copyToClipboard(); - e.clipboardData.clearData(); + e.clipboardData.setData('text', ' '); //clearData doesn't remove images from clipboard e.preventDefault(); e.stopImmediatePropagation(); return false; @@ -848,27 +848,14 @@ export class ComfyApp { // Ctrl+C Copy if ((e.key === 'c') && (e.metaKey || e.ctrlKey)) { - if (e.shiftKey) { - this.copyToClipboard(true); - e.clipboardData.clearData(); - block_default = true; - } - else { - // Trigger onCopy - return true; - } + // Trigger onCopy + return true; } // Ctrl+V Paste - if ((e.key === 'v') && (e.metaKey || e.ctrlKey)) { - if (e.shiftKey) { - this.pasteFromClipboard(true); - block_default = true; - } - else { - // Trigger onPaste - return true; - } + if ((e.key === 'v' || e.key == 'V') && (e.metaKey || e.ctrlKey)) { + // Trigger onPaste + return true; } } From 7d401ed1d0fcc78b14d61d9f585ace40b9de0ddb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Sep 2023 16:36:50 -0400 Subject: [PATCH 058/150] Add ldm format support to UNETLoader. --- comfy/sd.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 8be0bcbc8..9bdb2ad64 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -454,20 +454,26 @@ def load_unet(unet_path): #load unet in diffusers format sd = comfy.utils.load_torch_file(unet_path) parameters = comfy.utils.calculate_parameters(sd) fp16 = model_management.should_use_fp16(model_params=parameters) + if "input_blocks.0.0.weight" in sd: #ldm + model_config = model_detection.model_config_from_unet(sd, "", fp16) + if model_config is None: + raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + new_sd = sd - model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) - if model_config is None: - print("ERROR UNSUPPORTED UNET", unet_path) - return None + else: #diffusers + model_config = model_detection.model_config_from_diffusers_unet(sd, fp16) + if model_config is None: + print("ERROR UNSUPPORTED UNET", unet_path) + return None - diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) + diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) - new_sd = {} - for k in diffusers_keys: - if k in sd: - new_sd[diffusers_keys[k]] = sd.pop(k) - else: - print(diffusers_keys[k], k) + new_sd = {} + for k in diffusers_keys: + if k in sd: + new_sd[diffusers_keys[k]] = sd.pop(k) + else: + print(diffusers_keys[k], k) offload_device = model_management.unet_offload_device() model = model_config.get_model(new_sd, "") model = model.to(offload_device) From fb3b7282034a37dbed377055f843c9a9302fdd8c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Sep 2023 21:49:56 -0400 Subject: [PATCH 059/150] Fix issue where autocast fp32 CLIP gave different results from regular. --- comfy/sd1_clip.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 477d5c309..b84a38490 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -60,6 +60,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): if dtype is not None: self.transformer.to(dtype) + self.transformer.text_model.embeddings.token_embedding.to(torch.float32) + self.transformer.text_model.embeddings.position_embedding.to(torch.float32) + self.max_length = max_length if freeze: self.freeze() @@ -138,7 +141,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if backup_embeds.weight.dtype != torch.float32: + if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: precision_scope = torch.autocast else: precision_scope = lambda a, b: contextlib.nullcontext(a) From ed58730658d0213600b64849d721a6bb92c675bf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Sep 2023 15:09:10 -0400 Subject: [PATCH 060/150] Don't leave very large hidden states in the clip vision output. --- comfy/clip_vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 9b95ae003..1206c680d 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -56,6 +56,7 @@ class ClipVisionModel(): if t is not None: if k == 'hidden_states': outputs["penultimate_hidden_states"] = t[-2].cpu() + outputs["hidden_states"] = None else: outputs[k] = t.cpu() From 0b829fe35b3ae626494735eb149c43345e5c55a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Sep 2023 18:44:05 -0400 Subject: [PATCH 061/150] .gitignore refactor. --- .gitignore | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 0177e1d7d..98d91318d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,16 +1,16 @@ __pycache__/ *.py[cod] -output/ -input/ -!input/example.png -models/ -temp/ -custom_nodes/ +/output/ +/input/ +!/input/example.png +/models/ +/temp/ +/custom_nodes/ !custom_nodes/example_node.py.example extra_model_paths.yaml /.vs .idea/ venv/ -web/extensions/* -!web/extensions/logging.js.example -!web/extensions/core/ +/web/extensions/* +!/web/extensions/logging.js.example +!/web/extensions/core/ From 30de95e4b420aa02d25d151271dca9867492288f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Sep 2023 01:10:31 -0400 Subject: [PATCH 062/150] Add some nodes to subtract and add model weights. --- comfy_extras/nodes_model_merging.py | 40 +++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index bce4b3dd0..ebcbd4be9 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -27,6 +27,44 @@ class ModelMergeSimple: m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) +class ModelSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2, multiplier): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return (m, ) + +class ModelAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, 1.0, 1.0) + return (m, ) + + class CLIPMergeSimple: @classmethod def INPUT_TYPES(s): @@ -144,6 +182,8 @@ class CheckpointSave: NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, "ModelMergeBlocks": ModelMergeBlocks, + "ModelMergeSubtract": ModelSubtract, + "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, } From 3039b08eb16777431946ed9ae4a63c5466336bff Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Sep 2023 11:38:20 -0400 Subject: [PATCH 063/150] Only parse command line args when main.py is called. --- comfy/cli_args.py | 7 +++++-- comfy/options.py | 6 ++++++ main.py | 3 +++ 3 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 comfy/options.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index fda245433..ffae81c49 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,6 +1,6 @@ import argparse import enum - +import comfy.options class EnumAction(argparse.Action): """ @@ -94,7 +94,10 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") -args = parser.parse_args() +if comfy.options.args_parsing: + args = parser.parse_args() +else: + args = parser.parse_args([]) if args.windows_standalone_build: args.auto_launch = True diff --git a/comfy/options.py b/comfy/options.py new file mode 100644 index 000000000..f7f8af41e --- /dev/null +++ b/comfy/options.py @@ -0,0 +1,6 @@ + +args_parsing = False + +def enable_args_parsing(enable=True): + global args_parsing + args_parsing = enable diff --git a/main.py b/main.py index 9f0f80458..7c5eaee0a 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,6 @@ +import comfy.options +comfy.options.enable_args_parsing() + import os import importlib.util import folder_paths From 0e4395a8a3f7b5da15c46308eee9721ce3f4f475 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 13 Sep 2023 18:42:44 +0100 Subject: [PATCH 064/150] Allow pasting nodes with connections in firefox --- web/scripts/app.js | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 9db4e9230..6dd1f3edd 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -671,6 +671,10 @@ export class ComfyApp { */ #addPasteHandler() { document.addEventListener("paste", (e) => { + // ctrl+shift+v is used to paste nodes with connections + // this is handled by litegraph + if(this.shiftDown) return; + let data = (e.clipboardData || window.clipboardData); const items = data.items; @@ -853,7 +857,7 @@ export class ComfyApp { } // Ctrl+V Paste - if ((e.key === 'v' || e.key == 'V') && (e.metaKey || e.ctrlKey)) { + if ((e.key === 'v' || e.key == 'V') && (e.metaKey || e.ctrlKey) && !e.shiftKey) { // Trigger onPaste return true; } From 0966d3ce823dd9e0d668bd0f4049fb5b879c6672 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 14 Sep 2023 12:16:07 -0400 Subject: [PATCH 065/150] Don't run text encoders on xpu because there are issues. --- comfy/model_management.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index b663e8f59..e38ef4eea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -451,6 +451,8 @@ def text_encoder_device(): if args.gpu_only: return get_torch_device() elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + if is_intel_xpu(): + return torch.device("cpu") if should_use_fp16(prioritize_performance=False): return get_torch_device() else: From 0d8f3764468999bc34700799553919ded9b34ef8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 14 Sep 2023 18:12:36 -0400 Subject: [PATCH 066/150] Set last layer on SD2.x models uses the proper indexes now. Before I had made the last layer the penultimate layer because some checkpoints don't have them but it's not consistent with the others models. TLDR: for SD2.x models only: CLIPSetLastLayer -1 is now -2. --- comfy/sd2_clip.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 818c9711e..05e50a005 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -12,16 +12,6 @@ class SD2ClipModel(sd1_clip.SD1ClipModel): super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] - def clip_layer(self, layer_idx): - if layer_idx < 0: - layer_idx -= 1 #The real last layer of SD2.x clip is the penultimate one. The last one might contain garbage. - if abs(layer_idx) >= 24: - self.layer = "hidden" - self.layer_idx = -2 - else: - self.layer = "hidden" - self.layer_idx = layer_idx - class SD2Tokenizer(sd1_clip.SD1Tokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) From 44361f6344f53c32b1cd902515b9071f6d08ecc7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Sep 2023 01:56:07 -0400 Subject: [PATCH 067/150] Support for text encoder models that need attention_mask. --- comfy/sd1_clip.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index b84a38490..9978b6c35 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -71,6 +71,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.empty_tokens = [[49406] + [49407] * 76] self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.enable_attention_masks = False self.layer_norm_hidden_state = True if layer == "hidden": @@ -147,7 +148,17 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): precision_scope = lambda a, b: contextlib.nullcontext(a) with precision_scope(model_management.get_autocast_device(device), torch.float32): - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden") self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": From 076f3e63107fac1a9f4da705dfd18b428cb1340c Mon Sep 17 00:00:00 2001 From: karrycharon Date: Fri, 15 Sep 2023 16:37:58 +0800 Subject: [PATCH 068/150] fix structuredClone undefined error; --- web/scripts/app.js | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 6dd1f3edd..4beaf03ae 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1297,7 +1297,13 @@ export class ComfyApp { let reset_invalid_values = false; if (!graphData) { - graphData = structuredClone(defaultGraph); + if (typeof structuredClone === "undefined") + { + graphData = JSON.parse(JSON.stringify(defaultGraph)); + }else + { + graphData = structuredClone(defaultGraph); + } reset_invalid_values = true; } From 94e4fe39d868a0bb939c2f91746de09680e4657d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Sep 2023 12:03:03 -0400 Subject: [PATCH 069/150] This isn't used anywhere. --- comfy/ldm/models/diffusion/ddim.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index 139c8e01e..befab0075 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -33,7 +33,6 @@ class DDIMSampler(object): assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) - self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) @@ -195,7 +194,7 @@ class DDIMSampler(object): temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False): - device = self.model.betas.device + device = self.model.alphas_cumprod.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) From 415abb275f8ef74615cbb3c5ebc90b20d1a713b7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Sep 2023 19:22:47 -0400 Subject: [PATCH 070/150] Add DDPM sampler. --- comfy/k_diffusion/sampling.py | 31 +++++++++++++++++++++++++++++++ comfy/samplers.py | 2 +- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index eb088d92b..937c5a388 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -706,3 +706,34 @@ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disab noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) + +def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): + alpha_cumprod = 1 / ((sigma * sigma) + 1) + alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1) + alpha = (alpha_cumprod / alpha_cumprod_prev) + + mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) + if sigma_prev > 0: + mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) + return mu + + +def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler) + if sigmas[i + 1] != 0: + x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0) + return x + + +@torch.no_grad() +def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): + return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) + diff --git a/comfy/samplers.py b/comfy/samplers.py index c60288fd1..7f1987167 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -546,7 +546,7 @@ class KSampler: SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model From 43d4935a1da0b78dac101a28cc98de0b7d556729 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Sep 2023 22:21:14 -0400 Subject: [PATCH 071/150] Add cond_or_uncond array to transformer_options so hooks can check what is cond and what is uncond. --- comfy/samplers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index 7f1987167..57673a029 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -255,6 +255,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con else: transformer_options["patches"] = patches + transformer_options["cond_or_uncond"] = cond_or_uncond[:] c['transformer_options'] = transformer_options if 'model_function_wrapper' in model_options: From 69680fede7de62f503a59efbbd8aa058b8e50395 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sat, 16 Sep 2023 20:36:00 +0900 Subject: [PATCH 072/150] fix: thumbnail ratio fix for mixed ratio images --- web/scripts/app.js | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 4beaf03ae..84090764a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -532,7 +532,17 @@ export class ComfyApp { } } this.imageRects.push([x, y, cellWidth, cellHeight]); - ctx.drawImage(img, x, y, cellWidth, cellHeight); + + let wratio = cellWidth/img.width; + let hratio = cellHeight/img.height; + var ratio = Math.min(wratio, hratio); + + let imgHeight = ratio * img.height; + let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight)/2; + let imgWidth = ratio * img.width; + let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2; + + ctx.drawImage(img, imgX, imgY, imgWidth, imgHeight); ctx.filter = "none"; } From 4d5e057bb2e32117c945cc9dfe8039dad2329297 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sat, 16 Sep 2023 20:37:27 +0900 Subject: [PATCH 073/150] fix indent --- web/scripts/app.js | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 84090764a..f0bb8640c 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -533,14 +533,14 @@ export class ComfyApp { } this.imageRects.push([x, y, cellWidth, cellHeight]); - let wratio = cellWidth/img.width; - let hratio = cellHeight/img.height; - var ratio = Math.min(wratio, hratio); + let wratio = cellWidth/img.width; + let hratio = cellHeight/img.height; + var ratio = Math.min(wratio, hratio); - let imgHeight = ratio * img.height; - let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight)/2; - let imgWidth = ratio * img.width; - let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2; + let imgHeight = ratio * img.height; + let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight)/2; + let imgWidth = ratio * img.width; + let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2; ctx.drawImage(img, imgX, imgY, imgWidth, imgHeight); ctx.filter = "none"; From 61b1f67734f445aabdbd941537c22bfe6f9237aa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 16 Sep 2023 12:59:54 -0400 Subject: [PATCH 074/150] Support models without previews. --- comfy/latent_formats.py | 4 ++++ latent_preview.py | 7 +++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 8b59cfbdc..fadc0eec7 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -1,5 +1,9 @@ class LatentFormat: + scale_factor = 1.0 + latent_rgb_factors = None + taesd_decoder_name = None + def process_in(self, latent): return latent * self.scale_factor diff --git a/latent_preview.py b/latent_preview.py index 30c1d1317..87240a582 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -53,7 +53,9 @@ def get_previewer(device, latent_format): method = args.preview_method if method != LatentPreviewMethod.NoPreviews: # TODO previewer methods - taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) + taesd_decoder_path = None + if latent_format.taesd_decoder_name is not None: + taesd_decoder_path = folder_paths.get_full_path("vae_approx", latent_format.taesd_decoder_name) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB @@ -68,7 +70,8 @@ def get_previewer(device, latent_format): print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) + if latent_format.latent_rgb_factors is not None: + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) return previewer From 0665749b1a13f149f3c1770db7f366643acafdd7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Sep 2023 02:10:06 -0400 Subject: [PATCH 075/150] Move ModelSubtract and ModelAdd to advanced/model_merging --- comfy_extras/nodes_model_merging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index ebcbd4be9..3d42d7806 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -37,7 +37,7 @@ class ModelSubtract: RETURN_TYPES = ("MODEL",) FUNCTION = "merge" - CATEGORY = "_for_testing/model_merging" + CATEGORY = "advanced/model_merging" def merge(self, model1, model2, multiplier): m = model1.clone() @@ -55,7 +55,7 @@ class ModelAdd: RETURN_TYPES = ("MODEL",) FUNCTION = "merge" - CATEGORY = "_for_testing/model_merging" + CATEGORY = "advanced/model_merging" def merge(self, model1, model2): m = model1.clone() From 321c5fa2958a2cdb05a08f6792fd2f72336e8c90 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Sep 2023 04:09:19 -0400 Subject: [PATCH 076/150] Enable pytorch attention by default on xpu. --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e38ef4eea..d8bc3bfea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -165,6 +165,9 @@ try: ENABLE_PYTORCH_ATTENTION = True if torch.cuda.is_bf16_supported(): VAE_DTYPE = torch.bfloat16 + if is_intel_xpu(): + if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + ENABLE_PYTORCH_ATTENTION = True except: pass From db63aa7e53c459b016cfa4159be004e59af84da9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Sep 2023 12:49:06 -0400 Subject: [PATCH 077/150] Nodes can now control the rounding in the UI. --- custom_nodes/example_node.py.example | 8 +++++++- nodes.py | 4 ++-- web/scripts/widgets.js | 23 ++++++++++++++++------- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index e37808b03..733014f3c 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -54,7 +54,13 @@ class Example: "step": 64, #Slider's step "display": "number" # Cosmetic only: display as "number" or "slider" }), - "float_field": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "display": "number"}), + "float_field": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 10.0, + "step": 0.01, + "round": 0.001, #The value represeting the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. + "display": "number"}), "print_to_screen": (["enable", "disable"],), "string_field": ("STRING", { "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node diff --git a/nodes.py b/nodes.py index 77d180526..3bc08663e 100644 --- a/nodes.py +++ b/nodes.py @@ -1217,7 +1217,7 @@ class KSampler: {"model": ("MODEL",), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ), @@ -1243,7 +1243,7 @@ class KSamplerAdvanced: "add_noise": (["enable", "disable"], ), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), "positive": ("CONDITIONING", ), diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 30caa6a8c..40b3067b7 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -2,17 +2,22 @@ import { api } from "./api.js" function getNumberDefaults(inputData, defaultStep) { let defaultVal = inputData[1]["default"]; - let { min, max, step } = inputData[1]; + let { min, max, step, round} = inputData[1]; if (defaultVal == undefined) defaultVal = 0; if (min == undefined) min = 0; if (max == undefined) max = 2048; if (step == undefined) step = defaultStep; -// precision is the number of decimal places to show. -// by default, display the the smallest number of decimal places such that changes of size step are visible. - let precision = Math.max(-Math.floor(Math.log10(step)),0) -// by default, round the value to those decimal places shown. - let round = Math.round(1000000*Math.pow(0.1,precision))/1000000; + + // precision is the number of decimal places to show. + // by default, display the the smallest number of decimal places such that changes of size step are visible. + let precision = Math.max(-Math.floor(Math.log10(step)),0); + + if (round == undefined || round === true) { + // by default, round the value to those decimal places shown. + round = Math.round(1000000*Math.pow(0.1,precision))/1000000; + } + return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; } @@ -271,7 +276,11 @@ export const ComfyWidgets = { const { val, config } = getNumberDefaults(inputData, 0.5); return { widget: node.addWidget(widgetType, inputName, val, function (v) { - this.value = Math.round(v/config.round)*config.round; + if (config.round) { + this.value = Math.round(v/config.round)*config.round; + } else { + this.value = v; + } }, config) }; }, INT(node, inputName, inputData, app) { From 01094316268cab9ed5cd53b825b359a7becb9d6c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Sep 2023 16:20:03 -0400 Subject: [PATCH 078/150] Lower the minimum resolution of EmptyLatentImage. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 3bc08663e..9ccf179ce 100644 --- a/nodes.py +++ b/nodes.py @@ -889,8 +889,8 @@ class EmptyLatentImage: @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" From b92bf8196e0d3158b3e981d056a2be15ce5ab1cd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Sep 2023 23:04:49 -0400 Subject: [PATCH 079/150] Do lora cast on GPU instead of CPU for higher performance. --- comfy/model_patcher.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a6ee0bae1..85bf5bd2a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -187,13 +187,13 @@ class ModelPatcher: else: weight += alpha * w1.type(weight.dtype).to(weight.device) elif len(v) == 4: #lora/locon - mat1 = v[0].float().to(weight.device) - mat2 = v[1].float().to(weight.device) + mat1 = v[0].to(weight.device).float() + mat2 = v[1].to(weight.device).float() if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = v[3].float().to(weight.device) + mat3 = v[3].to(weight.device).float() final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: @@ -212,18 +212,18 @@ class ModelPatcher: if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(w1_a.float(), w1_b.float()) + w1 = torch.mm(w1_a.to(weight.device).float(), w1_b.to(weight.device).float()) else: - w1 = w1.float().to(weight.device) + w1 = w1.to(weight.device).float() if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(w2_a.float().to(weight.device), w2_b.float().to(weight.device)) + w2 = torch.mm(w2_a.to(weight.device).float(), w2_b.to(weight.device).float()) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2_b.float().to(weight.device), w2_a.float().to(weight.device)) + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2_b.to(weight.device).float(), w2_a.to(weight.device).float()) else: - w2 = w2.float().to(weight.device) + w2 = w2.to(weight.device).float() if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -244,11 +244,11 @@ class ModelPatcher: if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float().to(weight.device), w1b.float().to(weight.device), w1a.float().to(weight.device)) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float().to(weight.device), w2b.float().to(weight.device), w2a.float().to(weight.device)) + m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.to(weight.device).float(), w1b.to(weight.device).float(), w1a.to(weight.device).float()) + m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2b.to(weight.device).float(), w2a.to(weight.device).float()) else: - m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device)) - m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device)) + m1 = torch.mm(w1a.to(weight.device).float(), w1b.to(weight.device).float()) + m2 = torch.mm(w2a.to(weight.device).float(), w2b.to(weight.device).float()) try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) From 26cd8405ddb32216d02b5eed23f6481cb36873c7 Mon Sep 17 00:00:00 2001 From: enzymezoo-code <103286087+enzymezoo-code@users.noreply.github.com> Date: Mon, 18 Sep 2023 22:18:06 -0500 Subject: [PATCH 080/150] Ci quality workflows (#1423) * Add inference tests * Clean up * Rename test graph file * Add readme for tests * Separate server fixture * test file name change * Assert images are generated * Clean up comments * Add __init__.py so tests can run with command line `pytest` * Fix command line args for pytest * Loop all samplers/schedulers in test_inference.py * Ci quality workflows compare (#1) * Add image comparison tests * Comparison tests do not pass with empty metadata * Ensure tests are run in correct order * Save image files with test name * Update tests readme * Reduce step counts in tests to ~halve runtime * Ci quality workflows build (#2) * Add build test github workflow --- .github/workflows/test-build.yml | 31 +++ pytest.ini | 5 + tests/README.md | 29 ++ tests/__init__.py | 0 tests/compare/conftest.py | 41 +++ tests/compare/test_quality.py | 195 ++++++++++++++ tests/conftest.py | 36 +++ tests/inference/__init__.py | 0 .../graphs/default_graph_sdxl1_0.json | 144 ++++++++++ tests/inference/test_inference.py | 247 ++++++++++++++++++ 10 files changed, 728 insertions(+) create mode 100644 .github/workflows/test-build.yml create mode 100644 pytest.ini create mode 100644 tests/README.md create mode 100644 tests/__init__.py create mode 100644 tests/compare/conftest.py create mode 100644 tests/compare/test_quality.py create mode 100644 tests/conftest.py create mode 100644 tests/inference/__init__.py create mode 100644 tests/inference/graphs/default_graph_sdxl1_0.json create mode 100644 tests/inference/test_inference.py diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml new file mode 100644 index 000000000..421dd5ee4 --- /dev/null +++ b/.github/workflows/test-build.yml @@ -0,0 +1,31 @@ +name: Build package + +# +# This workflow is a test of the python package build. +# Install Python dependencies across different Python versions. +# + +on: + push: + paths: + - "requirements.txt" + - ".github/workflows/test-build.yml" + +jobs: + build: + name: Build Test + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..b5a68e0f1 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +markers = + inference: mark as inference test (deselect with '-m "not inference"') +testpaths = tests +addopts = -s \ No newline at end of file diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 000000000..2005fd45b --- /dev/null +++ b/tests/README.md @@ -0,0 +1,29 @@ +# Automated Testing + +## Running tests locally + +Additional requirements for running tests: +``` +pip install pytest +pip install websocket-client==1.6.1 +opencv-python==4.6.0.66 +scikit-image==0.21.0 +``` +Run inference tests: +``` +pytest tests/inference +``` + +## Quality regression test +Compares images in 2 directories to ensure they are the same + +1) Run an inference test to save a directory of "ground truth" images +``` + pytest tests/inference --output_dir tests/inference/baseline +``` +2) Make code edits + +3) Run inference and quality comparison tests +``` +pytest +``` \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/compare/conftest.py b/tests/compare/conftest.py new file mode 100644 index 000000000..dd5078c9e --- /dev/null +++ b/tests/compare/conftest.py @@ -0,0 +1,41 @@ +import os +import pytest + +# Command line arguments for pytest +def pytest_addoption(parser): + parser.addoption('--baseline_dir', action="store", default='tests/inference/baseline', help='Directory for ground-truth images') + parser.addoption('--test_dir', action="store", default='tests/inference/samples', help='Directory for images to test') + parser.addoption('--metrics_file', action="store", default='tests/metrics.md', help='Output file for metrics') + parser.addoption('--img_output_dir', action="store", default='tests/compare/samples', help='Output directory for diff metric images') + +# This initializes args at the beginning of the test session +@pytest.fixture(scope="session", autouse=True) +def args_pytest(pytestconfig): + args = {} + args['baseline_dir'] = pytestconfig.getoption('baseline_dir') + args['test_dir'] = pytestconfig.getoption('test_dir') + args['metrics_file'] = pytestconfig.getoption('metrics_file') + args['img_output_dir'] = pytestconfig.getoption('img_output_dir') + + # Initialize metrics file + with open(args['metrics_file'], 'a') as f: + # if file is empty, write header + if os.stat(args['metrics_file']).st_size == 0: + f.write("| date | run | file | status | value | \n") + f.write("| --- | --- | --- | --- | --- | \n") + + return args + + +def gather_file_basenames(directory: str): + files = [] + for file in os.listdir(directory): + if file.endswith(".png"): + files.append(file) + return files + +# Creates the list of baseline file names to use as a fixture +def pytest_generate_tests(metafunc): + if "baseline_fname" in metafunc.fixturenames: + baseline_fnames = gather_file_basenames(metafunc.config.getoption("baseline_dir")) + metafunc.parametrize("baseline_fname", baseline_fnames) diff --git a/tests/compare/test_quality.py b/tests/compare/test_quality.py new file mode 100644 index 000000000..92a2d5a8b --- /dev/null +++ b/tests/compare/test_quality.py @@ -0,0 +1,195 @@ +import datetime +import numpy as np +import os +from PIL import Image +import pytest +from pytest import fixture +from typing import Tuple, List + +from cv2 import imread, cvtColor, COLOR_BGR2RGB +from skimage.metrics import structural_similarity as ssim + + +""" +This test suite compares images in 2 directories by file name +The directories are specified by the command line arguments --baseline_dir and --test_dir + +""" +# ssim: Structural Similarity Index +# Returns a tuple of (ssim, diff_image) +def ssim_score(img0: np.ndarray, img1: np.ndarray) -> Tuple[float, np.ndarray]: + score, diff = ssim(img0, img1, channel_axis=-1, full=True) + # rescale the difference image to 0-255 range + diff = (diff * 255).astype("uint8") + return score, diff + +# Metrics must return a tuple of (score, diff_image) +METRICS = {"ssim": ssim_score} +METRICS_PASS_THRESHOLD = {"ssim": 0.95} + + +class TestCompareImageMetrics: + @fixture(scope="class") + def test_file_names(self, args_pytest): + test_dir = args_pytest['test_dir'] + fnames = self.gather_file_basenames(test_dir) + yield fnames + del fnames + + @fixture(scope="class", autouse=True) + def teardown(self, args_pytest): + yield + # Runs after all tests are complete + # Aggregate output files into a grid of images + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + img_output_dir = args_pytest['img_output_dir'] + metrics_file = args_pytest['metrics_file'] + + grid_dir = os.path.join(img_output_dir, "grid") + os.makedirs(grid_dir, exist_ok=True) + + for metric_dir in METRICS.keys(): + metric_path = os.path.join(img_output_dir, metric_dir) + for file in os.listdir(metric_path): + if file.endswith(".png"): + score = self.lookup_score_from_fname(file, metrics_file) + image_file_list = [] + image_file_list.append([ + os.path.join(baseline_dir, file), + os.path.join(test_dir, file), + os.path.join(metric_path, file) + ]) + # Create grid + image_list = [[Image.open(file) for file in files] for files in image_file_list] + grid = self.image_grid(image_list) + grid.save(os.path.join(grid_dir, f"{metric_dir}_{score:.3f}_{file}")) + + # Tests run for each baseline file name + @fixture() + def fname(self, baseline_fname): + yield baseline_fname + del baseline_fname + + def test_directories_not_empty(self, args_pytest): + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + assert len(os.listdir(baseline_dir)) != 0, f"Baseline directory {baseline_dir} is empty" + assert len(os.listdir(test_dir)) != 0, f"Test directory {test_dir} is empty" + + def test_dir_has_all_matching_metadata(self, fname, test_file_names, args_pytest): + # Check that all files in baseline_dir have a file in test_dir with matching metadata + baseline_file_path = os.path.join(args_pytest['baseline_dir'], fname) + file_paths = [os.path.join(args_pytest['test_dir'], f) for f in test_file_names] + file_match = self.find_file_match(baseline_file_path, file_paths) + assert file_match is not None, f"Could not find a file in {args_pytest['test_dir']} with matching metadata to {baseline_file_path}" + + # For a baseline image file, finds the corresponding file name in test_dir and + # compares the images using the metrics in METRICS + @pytest.mark.parametrize("metric", METRICS.keys()) + def test_pipeline_compare( + self, + args_pytest, + fname, + test_file_names, + metric, + ): + baseline_dir = args_pytest['baseline_dir'] + test_dir = args_pytest['test_dir'] + metrics_output_file = args_pytest['metrics_file'] + img_output_dir = args_pytest['img_output_dir'] + + baseline_file_path = os.path.join(baseline_dir, fname) + + # Find file match + file_paths = [os.path.join(test_dir, f) for f in test_file_names] + test_file = self.find_file_match(baseline_file_path, file_paths) + + # Run metrics + sample_baseline = self.read_img(baseline_file_path) + sample_secondary = self.read_img(test_file) + + score, metric_img = METRICS[metric](sample_baseline, sample_secondary) + metric_status = score > METRICS_PASS_THRESHOLD[metric] + + # Save metric values + with open(metrics_output_file, 'a') as f: + run_info = os.path.splitext(fname)[0] + metric_status_str = "PASS ✅" if metric_status else "FAIL ❌" + date_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + f.write(f"| {date_str} | {run_info} | {metric} | {metric_status_str} | {score} | \n") + + # Save metric image + metric_img_dir = os.path.join(img_output_dir, metric) + os.makedirs(metric_img_dir, exist_ok=True) + output_filename = f'{fname}' + Image.fromarray(metric_img).save(os.path.join(metric_img_dir, output_filename)) + + assert score > METRICS_PASS_THRESHOLD[metric] + + def read_img(self, filename: str) -> np.ndarray: + cvImg = imread(filename) + cvImg = cvtColor(cvImg, COLOR_BGR2RGB) + return cvImg + + def image_grid(self, img_list: list[list[Image.Image]]): + # imgs is a 2D list of images + # Assumes the input images are a rectangular grid of equal sized images + rows = len(img_list) + cols = len(img_list[0]) + + w, h = img_list[0][0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + + for i, row in enumerate(img_list): + for j, img in enumerate(row): + grid.paste(img, box=(j*w, i*h)) + return grid + + def lookup_score_from_fname(self, + fname: str, + metrics_output_file: str + ) -> float: + fname_basestr = os.path.splitext(fname)[0] + with open(metrics_output_file, 'r') as f: + for line in f: + if fname_basestr in line: + score = float(line.split('|')[5]) + return score + raise ValueError(f"Could not find score for {fname} in {metrics_output_file}") + + def gather_file_basenames(self, directory: str): + files = [] + for file in os.listdir(directory): + if file.endswith(".png"): + files.append(file) + return files + + def read_file_prompt(self, fname:str) -> str: + # Read prompt from image file metadata + img = Image.open(fname) + img.load() + return img.info['prompt'] + + def find_file_match(self, baseline_file: str, file_paths: List[str]): + # Find a file in file_paths with matching metadata to baseline_file + baseline_prompt = self.read_file_prompt(baseline_file) + + # Do not match empty prompts + if baseline_prompt is None or baseline_prompt == "": + return None + + # Find file match + # Reorder test_file_names so that the file with matching name is first + # This is an optimization because matching file names are more likely + # to have matching metadata if they were generated with the same script + basename = os.path.basename(baseline_file) + file_path_basenames = [os.path.basename(f) for f in file_paths] + if basename in file_path_basenames: + match_index = file_path_basenames.index(basename) + file_paths.insert(0, file_paths.pop(match_index)) + + for f in file_paths: + test_file_prompt = self.read_file_prompt(f) + if baseline_prompt == test_file_prompt: + return f \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..1a35880af --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,36 @@ +import os +import pytest + +# Command line arguments for pytest +def pytest_addoption(parser): + parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images') + parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") + parser.addoption("--port", type=int, default=8188, help="Set the listen port.") + +# This initializes args at the beginning of the test session +@pytest.fixture(scope="session", autouse=True) +def args_pytest(pytestconfig): + args = {} + args['output_dir'] = pytestconfig.getoption('output_dir') + args['listen'] = pytestconfig.getoption('listen') + args['port'] = pytestconfig.getoption('port') + + os.makedirs(args['output_dir'], exist_ok=True) + + return args + +def pytest_collection_modifyitems(items): + # Modifies items so tests run in the correct order + + LAST_TESTS = ['test_quality'] + + # Move the last items to the end + last_items = [] + for test_name in LAST_TESTS: + for item in items.copy(): + print(item.module.__name__, item) + if item.module.__name__ == test_name: + last_items.append(item) + items.remove(item) + + items.extend(last_items) diff --git a/tests/inference/__init__.py b/tests/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/graphs/default_graph_sdxl1_0.json b/tests/inference/graphs/default_graph_sdxl1_0.json new file mode 100644 index 000000000..c06c6829c --- /dev/null +++ b/tests/inference/graphs/default_graph_sdxl1_0.json @@ -0,0 +1,144 @@ +{ + "4": { + "inputs": { + "ckpt_name": "sd_xl_base_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "5": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage" + }, + "6": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "4", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "10": { + "inputs": { + "add_noise": "enable", + "noise_seed": 42, + "steps": 20, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 0, + "end_at_step": 32, + "return_with_leftover_noise": "enable", + "model": [ + "4", + 0 + ], + "positive": [ + "6", + 0 + ], + "negative": [ + "15", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "12": { + "inputs": { + "samples": [ + "14", + 0 + ], + "vae": [ + "4", + 2 + ] + }, + "class_type": "VAEDecode" + }, + "13": { + "inputs": { + "filename_prefix": "test_inference", + "images": [ + "12", + 0 + ] + }, + "class_type": "SaveImage" + }, + "14": { + "inputs": { + "add_noise": "disable", + "noise_seed": 42, + "steps": 20, + "cfg": 7.5, + "sampler_name": "euler", + "scheduler": "normal", + "start_at_step": 32, + "end_at_step": 10000, + "return_with_leftover_noise": "disable", + "model": [ + "16", + 0 + ], + "positive": [ + "17", + 0 + ], + "negative": [ + "20", + 0 + ], + "latent_image": [ + "10", + 0 + ] + }, + "class_type": "KSamplerAdvanced" + }, + "15": { + "inputs": { + "conditioning": [ + "6", + 0 + ] + }, + "class_type": "ConditioningZeroOut" + }, + "16": { + "inputs": { + "ckpt_name": "sd_xl_refiner_1.0.safetensors" + }, + "class_type": "CheckpointLoaderSimple" + }, + "17": { + "inputs": { + "text": "a photo of a cat", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + }, + "20": { + "inputs": { + "text": "", + "clip": [ + "16", + 1 + ] + }, + "class_type": "CLIPTextEncode" + } + } \ No newline at end of file diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py new file mode 100644 index 000000000..a96f94550 --- /dev/null +++ b/tests/inference/test_inference.py @@ -0,0 +1,247 @@ +from copy import deepcopy +from io import BytesIO +from urllib import request +import numpy +import os +from PIL import Image +import pytest +from pytest import fixture +import time +import torch +from typing import Union +import json +import subprocess +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import urllib.request +import urllib.parse + +# Currently causes an error when running pytest with built-in pytest args +# TODO: modify cli_args.py to not parse args on import +# We will hard-code sampler and scheduler lists for now +# from comfy.samplers import KSampler + +""" +These tests generate and save images through a range of parameters +""" + +class ComfyGraph: + def __init__(self, + graph: dict, + sampler_nodes: list[str], + ): + self.graph = graph + self.sampler_nodes = sampler_nodes + + def set_prompt(self, prompt, negative_prompt=None): + # Sets the prompt for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + prompt_node = self.graph[node]['inputs']['positive'][0] + self.graph[prompt_node]['inputs']['text'] = prompt + if negative_prompt: + negative_prompt_node = self.graph[node]['inputs']['negative'][0] + self.graph[negative_prompt_node]['inputs']['text'] = negative_prompt + + def set_sampler_name(self, sampler_name:str, ): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['sampler_name'] = sampler_name + + def set_scheduler(self, scheduler:str): + # sets the sampler name for the sampler nodes (eg. base and refiner) + for node in self.sampler_nodes: + self.graph[node]['inputs']['scheduler'] = scheduler + + def set_filename_prefix(self, prefix:str): + # sets the filename prefix for the save nodes + for node in self.graph: + if self.graph[node]['class_type'] == 'SaveImage': + self.graph[node]['inputs']['filename_prefix'] = prefix + + +class ComfyClient: + # From examples/websockets_api_example.py + + def connect(self, + listen:str = '127.0.0.1', + port:Union[str,int] = 8188, + client_id: str = str(uuid.uuid4()) + ): + self.client_id = client_id + self.server_address = f"{listen}:{port}" + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) + self.ws = ws + + def queue_prompt(self, prompt): + p = {"prompt": prompt, "client_id": self.client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + def get_image(self, filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: + return response.read() + + def get_history(self, prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: + return json.loads(response.read()) + + def get_images(self, graph, save=True): + prompt = graph + if not save: + # Replace save nodes with preview nodes + prompt_str = json.dumps(prompt) + prompt_str = prompt_str.replace('SaveImage', 'PreviewImage') + prompt = json.loads(prompt_str) + + prompt_id = self.queue_prompt(prompt)['prompt_id'] + output_images = {} + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['node'] is None and data['prompt_id'] == prompt_id: + break #Execution is done + else: + continue #previews are binary data + + history = self.get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + images_output.append(image_data) + output_images[node_id] = images_output + + return output_images + +# +# Initialize graphs +# +default_graph_file = 'tests/inference/graphs/default_graph_sdxl1_0.json' +with open(default_graph_file, 'r') as file: + default_graph = json.loads(file.read()) +DEFAULT_COMFY_GRAPH = ComfyGraph(graph=default_graph, sampler_nodes=['10','14']) +DEFAULT_COMFY_GRAPH_ID = os.path.splitext(os.path.basename(default_graph_file))[0] + +# +# Loop through these variables +# +comfy_graph_list = [DEFAULT_COMFY_GRAPH] +comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID] +prompt_list = [ + 'a painting of a cat', +] +#TODO use sampler and scheduler list from comfy.samplers.KSampler +# sampler_list = KSampler.SAMPLERS +# scheduler_list = KSampler.SCHEDULERS +# Hard coded sampler and scheduler lists for now +SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] +SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", + "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] +sampler_list = SAMPLERS +scheduler_list = SCHEDULERS +@pytest.mark.inference +@pytest.mark.parametrize("sampler", sampler_list) +@pytest.mark.parametrize("scheduler", scheduler_list) +@pytest.mark.parametrize("prompt", prompt_list) +class TestInference: + # + # Initialize server and client + # + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + # Start server + p = subprocess.Popen([ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + ]) + yield + p.kill() + torch.cuda.empty_cache() + + def start_client(self, listen:str, port:int): + # Start client + comfy_client = ComfyClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + comfy_client.connect(listen=listen, port=port) + except ConnectionRefusedError as e: + print(e) + print(f"({i+1}/{n_tries}) Retrying...") + else: + break + return comfy_client + + # + # Client and graph fixtures with server warmup + # + # Returns a "_client_graph", which is client-graph pair corresponding to an initialized server + # The "graph" is the default graph + @fixture(scope="class", params=comfy_graph_list, ids=comfy_graph_ids, autouse=True) + def _client_graph(self, request, args_pytest, _server) -> (ComfyClient, ComfyGraph): + comfy_graph = request.param + + # Start client + comfy_client = self.start_client(args_pytest["listen"], args_pytest["port"]) + + # Warm up pipeline + comfy_client.get_images(graph=comfy_graph.graph, save=False) + + yield comfy_client, comfy_graph + del comfy_client + del comfy_graph + torch.cuda.empty_cache() + + @fixture + def client(self, _client_graph): + client = _client_graph[0] + yield client + + @fixture + def comfy_graph(self, _client_graph): + # avoid mutating the graph + graph = deepcopy(_client_graph[1]) + yield graph + + def test_comfy( + self, + client, + comfy_graph, + sampler, + scheduler, + prompt, + request + ): + test_info = request.node.name + comfy_graph.set_filename_prefix(test_info) + # Settings for comfy graph + comfy_graph.set_sampler_name(sampler) + comfy_graph.set_scheduler(scheduler) + comfy_graph.set_prompt(prompt) + + # Generate + images = client.get_images(comfy_graph.graph) + + assert len(images) != 0, "No images generated" + # assert all images are not blank + for images_output in images.values(): + for image_data in images_output: + pil_image = Image.open(BytesIO(image_data)) + assert numpy.array(pil_image).any() != 0, "Image is blank" + + From 7c93afd2cd826aea7b49e49f42502b5ac03b647d Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Tue, 19 Sep 2023 05:20:00 +0200 Subject: [PATCH 081/150] Manual float precision, toggle for old behavior (#1541) * Add toggle for float rounding * Add manual precision override --- web/scripts/ui.js | 19 +++++++++++++++++++ web/scripts/widgets.js | 12 +++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index f39939bf3..1e7920167 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -577,6 +577,25 @@ export class ComfyUI { defaultValue: false, }); + this.settings.addSetting({ + id: "Comfy.DisableFloatRounding", + name: "Disable rounding floats (requires page reload).", + type: "boolean", + defaultValue: false, + }); + + this.settings.addSetting({ + id: "Comfy.FloatRoundingPrecision", + name: "Decimal places [0 = auto] (requires page reload).", + type: "slider", + attrs: { + min: 0, + max: 6, + step: 1, + }, + defaultValue: 0, + }); + const fileInput = $el("input", { id: "comfy-file-input", type: "file", diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 40b3067b7..942be8f36 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,6 +1,6 @@ import { api } from "./api.js" -function getNumberDefaults(inputData, defaultStep) { +function getNumberDefaults(inputData, defaultStep, app) { let defaultVal = inputData[1]["default"]; let { min, max, step, round} = inputData[1]; @@ -8,12 +8,14 @@ function getNumberDefaults(inputData, defaultStep) { if (min == undefined) min = 0; if (max == undefined) max = 2048; if (step == undefined) step = defaultStep; - // precision is the number of decimal places to show. // by default, display the the smallest number of decimal places such that changes of size step are visible. let precision = Math.max(-Math.floor(Math.log10(step)),0); + if (app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision") > 0) { + precision = app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision"); + } - if (round == undefined || round === true) { + if (!app.ui.settings.getSettingValue("Comfy.DisableFloatRounding") && (round == undefined || round === true)) { // by default, round the value to those decimal places shown. round = Math.round(1000000*Math.pow(0.1,precision))/1000000; } @@ -273,7 +275,7 @@ export const ComfyWidgets = { "INT:noise_seed": seedWidget, FLOAT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 0.5); + const { val, config } = getNumberDefaults(inputData, 0.5, app); return { widget: node.addWidget(widgetType, inputName, val, function (v) { if (config.round) { @@ -285,7 +287,7 @@ export const ComfyWidgets = { }, INT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 1); + const { val, config } = getNumberDefaults(inputData, 1, app); Object.assign(config, { precision: 0 }); return { widget: node.addWidget( From f32463936d3b8205df7b66dbd9c3f9a2fd69668a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Sep 2023 23:23:25 -0400 Subject: [PATCH 082/150] Unhardcode sampler and scheduler list in test. --- tests/inference/test_inference.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index a96f94550..141cc5c7e 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -16,10 +16,8 @@ import uuid import urllib.request import urllib.parse -# Currently causes an error when running pytest with built-in pytest args -# TODO: modify cli_args.py to not parse args on import -# We will hard-code sampler and scheduler lists for now -# from comfy.samplers import KSampler + +from comfy.samplers import KSampler """ These tests generate and save images through a range of parameters @@ -140,16 +138,10 @@ comfy_graph_ids = [DEFAULT_COMFY_GRAPH_ID] prompt_list = [ 'a painting of a cat', ] -#TODO use sampler and scheduler list from comfy.samplers.KSampler -# sampler_list = KSampler.SAMPLERS -# scheduler_list = KSampler.SCHEDULERS -# Hard coded sampler and scheduler lists for now -SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] -SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", - "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] -sampler_list = SAMPLERS -scheduler_list = SCHEDULERS + +sampler_list = KSampler.SAMPLERS +scheduler_list = KSampler.SCHEDULERS + @pytest.mark.inference @pytest.mark.parametrize("sampler", sampler_list) @pytest.mark.parametrize("scheduler", scheduler_list) From 6d3dee9d16254979592b95399835a54428b3cea6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Sep 2023 23:33:19 -0400 Subject: [PATCH 083/150] Clean up #1541. --- web/scripts/widgets.js | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 942be8f36..2b0239374 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,6 +1,6 @@ import { api } from "./api.js" -function getNumberDefaults(inputData, defaultStep, app) { +function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { let defaultVal = inputData[1]["default"]; let { min, max, step, round} = inputData[1]; @@ -10,17 +10,15 @@ function getNumberDefaults(inputData, defaultStep, app) { if (step == undefined) step = defaultStep; // precision is the number of decimal places to show. // by default, display the the smallest number of decimal places such that changes of size step are visible. - let precision = Math.max(-Math.floor(Math.log10(step)),0); - if (app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision") > 0) { - precision = app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision"); + if (precision == undefined) { + precision = Math.max(-Math.floor(Math.log10(step)),0); } - if (!app.ui.settings.getSettingValue("Comfy.DisableFloatRounding") && (round == undefined || round === true)) { + if (enable_rounding && (round == undefined || round === true)) { // by default, round the value to those decimal places shown. round = Math.round(1000000*Math.pow(0.1,precision))/1000000; } - return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; } @@ -275,7 +273,10 @@ export const ComfyWidgets = { "INT:noise_seed": seedWidget, FLOAT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 0.5, app); + let precision = app.ui.settings.getSettingValue("Comfy.FloatRoundingPrecision"); + let disable_rounding = app.ui.settings.getSettingValue("Comfy.DisableFloatRounding") + if (precision == 0) precision = undefined; + const { val, config } = getNumberDefaults(inputData, 0.5, precision, !disable_rounding); return { widget: node.addWidget(widgetType, inputName, val, function (v) { if (config.round) { @@ -287,7 +288,7 @@ export const ComfyWidgets = { }, INT(node, inputName, inputData, app) { let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 1, app); + const { val, config } = getNumberDefaults(inputData, 1, 0, true); Object.assign(config, { precision: 0 }); return { widget: node.addWidget( From 2b6b17817331a24afc7106bfe9ec3e2f9b03fab1 Mon Sep 17 00:00:00 2001 From: MoonRide303 Date: Tue, 19 Sep 2023 10:40:38 +0200 Subject: [PATCH 084/150] Added support for lanczos scaling --- comfy/utils.py | 11 +++++++++++ comfy_extras/nodes_post_processing.py | 2 +- nodes.py | 4 ++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 3ed32e372..4e08bcb80 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,8 +1,10 @@ import torch +import torchvision import math import struct import comfy.checkpoint_pickle import safetensors.torch +from PIL import Image def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -346,6 +348,13 @@ def bislerp(samples, width, height): result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result +def lanczos(samples, width, height): + images = [torchvision.transforms.functional.to_pil_image(image) for image in samples] + images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] + images = [torchvision.transforms.functional.to_tensor(image) for image in images] + result = torch.stack(images) + return result + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] @@ -364,6 +373,8 @@ def common_upscale(samples, width, height, upscale_method, crop): if upscale_method == "bislerp": return bislerp(s, width, height) + elif upscale_method == "lanczos": + return lanczos(s, width, height) else: return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 51bdb24fa..3f651e594 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -211,7 +211,7 @@ class Sharpen: return (result,) class ImageScaleToTotalPixels: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod diff --git a/nodes.py b/nodes.py index 9ccf179ce..59c50a161 100644 --- a/nodes.py +++ b/nodes.py @@ -1423,7 +1423,7 @@ class LoadImageMask: return True class ImageScale: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @classmethod @@ -1444,7 +1444,7 @@ class ImageScale: return (s,) class ImageScaleBy: - upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] @classmethod def INPUT_TYPES(s): From 83215924081198b7b4cd95a89046a4527951fc68 Mon Sep 17 00:00:00 2001 From: Sean Lynch Date: Tue, 19 Sep 2023 08:18:29 -0400 Subject: [PATCH 085/150] Escape paths when passing them to globs Try to prevent JS search from breaking on pathnames with square brackets. --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index d04060499..b2e16716b 100644 --- a/server.py +++ b/server.py @@ -132,12 +132,12 @@ class PromptServer(): @routes.get("/extensions") async def get_extensions(request): files = glob.glob(os.path.join( - self.web_root, 'extensions/**/*.js'), recursive=True) + glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True) extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)) for name, dir in nodes.EXTENSION_WEB_DIRS.items(): - files = glob.glob(os.path.join(dir, '**/*.js'), recursive=True) + files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True) extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote( name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files))) From 7c9a92f552552cb51c9230d80d05ee42ebd8be90 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 19 Sep 2023 13:12:47 -0400 Subject: [PATCH 086/150] Don't depend on torchvision. --- comfy/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 4e08bcb80..7843b58cc 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,9 +1,9 @@ import torch -import torchvision import math import struct import comfy.checkpoint_pickle import safetensors.torch +import numpy as np from PIL import Image def load_torch_file(ckpt, safe_load=False, device=None): @@ -349,9 +349,9 @@ def bislerp(samples, width, height): return result def lanczos(samples, width, height): - images = [torchvision.transforms.functional.to_pil_image(image) for image in samples] + images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] - images = [torchvision.transforms.functional.to_tensor(image) for image in images] + images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) return result From b92a86d7370b28af6777c3859f7d486191f6379a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 13:24:08 -0400 Subject: [PATCH 087/150] Update litegraph to upstream. --- web/lib/litegraph.core.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 4a21a1b34..8fb5d07a8 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -4928,7 +4928,7 @@ LGraphNode.prototype.executeAction = function(action) this.title = o.title; this._bounding.set(o.bounding); this.color = o.color; - this.font = o.font; + this.font_size = o.font_size; }; LGraphGroup.prototype.serialize = function() { @@ -4942,7 +4942,7 @@ LGraphNode.prototype.executeAction = function(action) Math.round(b[3]) ], color: this.color, - font: this.font + font_size: this.font_size }; }; From 1cdfb3dba4e7af11e2e05dc6a6276ba84eb1adf2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 17:52:41 -0400 Subject: [PATCH 088/150] Only do the cast on the device if the device supports it. --- comfy/model_management.py | 17 ++++++++++++++++ comfy/model_patcher.py | 43 ++++++++++++++++++++++++++------------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d8bc3bfea..1050c13a4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -481,6 +481,23 @@ def get_autocast_device(dev): return dev.type return "cuda" +def cast_to_device(tensor, device, dtype, copy=False): + device_supports_cast = False + if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: + device_supports_cast = True + elif tensor.dtype == torch.bfloat16: + if hasattr(device, 'type') and device.type.startswith("cuda"): + device_supports_cast = True + + if device_supports_cast: + if copy: + if tensor.device == device: + return tensor.to(dtype, copy=copy) + return tensor.to(device, copy=copy).to(dtype) + else: + return tensor.to(device).to(dtype) + else: + return tensor.to(dtype).to(device, copy=copy) def xformers_enabled(): global directml_enabled diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 85bf5bd2a..10551656e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -3,6 +3,7 @@ import copy import inspect import comfy.utils +import comfy.model_management class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None): @@ -154,7 +155,7 @@ class ModelPatcher: self.backup[key] = weight.to(self.offload_device) if device_to is not None: - temp_weight = weight.float().to(device_to, copy=True) + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) @@ -185,15 +186,15 @@ class ModelPatcher: if w1.shape != weight.shape: print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: - weight += alpha * w1.type(weight.dtype).to(weight.device) + weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) elif len(v) == 4: #lora/locon - mat1 = v[0].to(weight.device).float() - mat2 = v[1].to(weight.device).float() + mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) + mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = v[3].to(weight.device).float() + mat3 = comfy.model_management.cast_to_device(v[3], weight.device, torch.float32) final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: @@ -212,18 +213,23 @@ class ModelPatcher: if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(w1_a.to(weight.device).float(), w1_b.to(weight.device).float()) + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1_b, weight.device, torch.float32)) else: - w1 = w1.to(weight.device).float() + w1 = comfy.model_management.cast_to_device(w1, weight.device, torch.float32) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(w2_a.to(weight.device).float(), w2_b.to(weight.device).float()) + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32)) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2_b.to(weight.device).float(), w2_a.to(weight.device).float()) + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2_a, weight.device, torch.float32)) else: - w2 = w2.to(weight.device).float() + w2 = comfy.model_management.cast_to_device(w2, weight.device, torch.float32) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -244,11 +250,20 @@ class ModelPatcher: if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.to(weight.device).float(), w1b.to(weight.device).float(), w1a.to(weight.device).float()) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.to(weight.device).float(), w2b.to(weight.device).float(), w2a.to(weight.device).float()) + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1a, weight.device, torch.float32)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2b, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2a, weight.device, torch.float32)) else: - m1 = torch.mm(w1a.to(weight.device).float(), w1b.to(weight.device).float()) - m2 = torch.mm(w2a.to(weight.device).float(), w2b.to(weight.device).float()) + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w1b, weight.device, torch.float32)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, torch.float32), + comfy.model_management.cast_to_device(w2b, weight.device, torch.float32)) try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) From 1122df1a2018eda31605703e7b3388ad80f209e1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 17:58:54 -0400 Subject: [PATCH 089/150] Increase range of lora strengths. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 59c50a161..18d82ea80 100644 --- a/nodes.py +++ b/nodes.py @@ -543,8 +543,8 @@ class LoraLoader: return {"required": { "model": ("MODEL",), "clip": ("CLIP", ), "lora_name": (folder_paths.get_filename_list("loras"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - "strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL", "CLIP") FUNCTION = "load_lora" From 4d41bd595c1e2bf55f9e3ccee0921b1213c0184a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 21:46:41 -0400 Subject: [PATCH 090/150] Fix loading group titles. --- web/lib/litegraph.core.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 8fb5d07a8..f81c83a8a 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -4928,7 +4928,9 @@ LGraphNode.prototype.executeAction = function(action) this.title = o.title; this._bounding.set(o.bounding); this.color = o.color; - this.font_size = o.font_size; + if (o.font_size) { + this.font_size = o.font_size; + } }; LGraphGroup.prototype.serialize = function() { From 0793eb926933034997cc2383adc414d080643e77 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Sep 2023 23:16:01 -0400 Subject: [PATCH 091/150] Only clear clipboard when copying nodes. --- web/scripts/app.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index f0bb8640c..5efe08c00 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -753,8 +753,9 @@ export class ComfyApp { // Default system copy return; } + // copy nodes and clear clipboard - if (this.canvas.selected_nodes) { + if (e.target.className === "litegraph" && this.canvas.selected_nodes) { this.canvas.copyToClipboard(); e.clipboardData.setData('text', ' '); //clearData doesn't remove images from clipboard e.preventDefault(); From 492db2de8db7e082addf131b40adb4a1b7535821 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 21 Sep 2023 01:14:42 -0400 Subject: [PATCH 092/150] Allow having a different pooled output for each image in a batch. --- comfy/model_base.py | 4 ++-- comfy/samplers.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index ca154dba0..ed2dc83e4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -181,7 +181,7 @@ class SDXLRefiner(BaseModel): out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([aesthetic_score]))) - flat = torch.flatten(torch.cat(out))[None, ] + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SDXL(BaseModel): @@ -206,5 +206,5 @@ class SDXL(BaseModel): out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([target_height]))) out.append(self.embedder(torch.Tensor([target_width]))) - flat = torch.flatten(torch.cat(out))[None, ] + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) diff --git a/comfy/samplers.py b/comfy/samplers.py index 57673a029..e3192ca58 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -7,6 +7,7 @@ from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math from comfy import model_base +import comfy.utils def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) return abs(a*b) // math.gcd(a, b) @@ -538,7 +539,7 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): if adm_out is not None: x[1] = x[1].copy() - x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size).to(device) + x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device) return conds From 422d16c027009cd6165c86179dad937166de5312 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 21 Sep 2023 22:23:01 -0400 Subject: [PATCH 093/150] Add some nodes to add, subtract and multiply latents. --- comfy_extras/nodes_latent.py | 74 ++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 75 insertions(+) create mode 100644 comfy_extras/nodes_latent.py diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py new file mode 100644 index 000000000..b1823d7a1 --- /dev/null +++ b/comfy_extras/nodes_latent.py @@ -0,0 +1,74 @@ +import comfy.utils + +def reshape_latent_to(target_shape, latent): + if latent.shape[1:] != target_shape[1:]: + latent.movedim(1, -1) + latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") + latent.movedim(-1, 1) + return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) + + +class LatentAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + samples_out["samples"] = s1 + s2 + return (samples_out,) + +class LatentSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + samples_out["samples"] = s1 - s2 + return (samples_out,) + +class LatentMuliply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, multiplier): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = s1 * multiplier + return (samples_out,) + +NODE_CLASS_MAPPINGS = { + "LatentAdd": LatentAdd, + "LatentSubtract": LatentSubtract, + "LatentMuliply": LatentMuliply, +} diff --git a/nodes.py b/nodes.py index 18d82ea80..6e0d43747 100644 --- a/nodes.py +++ b/nodes.py @@ -1772,6 +1772,7 @@ def load_custom_nodes(): print() def init_custom_nodes(): + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_latent.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) From 29ccf9f471e3b2ad4f4a08ba9f34698d357f8547 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Sep 2023 01:33:46 -0400 Subject: [PATCH 094/150] Fix typo. --- comfy_extras/nodes_latent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index b1823d7a1..001de39fc 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -48,7 +48,7 @@ class LatentSubtract: samples_out["samples"] = s1 - s2 return (samples_out,) -class LatentMuliply: +class LatentMultiply: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -70,5 +70,5 @@ class LatentMuliply: NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, - "LatentMuliply": LatentMuliply, + "LatentMultiply": LatentMultiply, } From afa2399f79e84919645eb69cd8e1ef1d9f1d6bd1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Sep 2023 20:26:47 -0400 Subject: [PATCH 095/150] Add a way to set output block patches to modify the h and hsp. --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 6 ++++++ comfy/model_patcher.py | 3 +++ 2 files changed, 9 insertions(+) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 3ce3c2e7b..b42637c82 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -608,6 +608,7 @@ class UNetModel(nn.Module): """ transformer_options["original_shape"] = list(x.shape) transformer_options["current_index"] = 0 + transformer_patches = transformer_options.get("patches", {}) assert (y is not None) == ( self.num_classes is not None @@ -644,6 +645,11 @@ class UNetModel(nn.Module): if ctrl is not None: hsp += ctrl + if "output_block_patch" in transformer_patches: + patch = transformer_patches["output_block_patch"] + for p in patch: + h, hsp = p(h, hsp, transformer_options) + h = th.cat([h, hsp], dim=1) del hsp if len(hs) > 0: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 10551656e..ba505221e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -88,6 +88,9 @@ class ModelPatcher: def set_model_attn2_output_patch(self, patch): self.set_model_patch(patch, "attn2_output_patch") + def set_model_output_block_patch(self, patch): + self.set_model_patch(patch, "output_block_patch") + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: From eec449ca8e4b3741032f7fed9372ba52040eb563 Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Fri, 22 Sep 2023 21:11:27 -0700 Subject: [PATCH 096/150] Allow Intel GPUs to LoRA cast on GPU since it supports BF16 natively. --- comfy/model_management.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 1050c13a4..8b8963726 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -488,6 +488,8 @@ def cast_to_device(tensor, device, dtype, copy=False): elif tensor.dtype == torch.bfloat16: if hasattr(device, 'type') and device.type.startswith("cuda"): device_supports_cast = True + elif is_intel_xpu(): + device_supports_cast = True if device_supports_cast: if copy: From fd93c759e278f832b149bc5b0150a8b437c48c77 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 23 Sep 2023 00:56:09 -0400 Subject: [PATCH 097/150] Implement FreeU: Free Lunch in Diffusion U-Net node. _for_testing->FreeU --- comfy_extras/nodes_freelunch.py | 56 +++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 57 insertions(+) create mode 100644 comfy_extras/nodes_freelunch.py diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py new file mode 100644 index 000000000..535eece39 --- /dev/null +++ b/comfy_extras/nodes_freelunch.py @@ -0,0 +1,56 @@ +#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) + +import torch + + +def Fourier_filter(x, threshold, scale): + # FFT + x_freq = torch.fft.fftn(x.float(), dim=(-2, -1)) + x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + + crow, ccol = H // 2, W //2 + mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real + + return x_filtered.to(x.dtype) + + +class FreeU: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}), + "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}), + "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), + "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, b1, b2, s1, s2): + def output_block_patch(h, hsp, transformer_options): + if h.shape[1] == 1280: + h[:,:640] = h[:,:640] * b1 + hsp = Fourier_filter(hsp, threshold=1, scale=s1) + if h.shape[1] == 640: + h[:,:320] = h[:,:320] * b2 + hsp = Fourier_filter(hsp, threshold=1, scale=s2) + return h, hsp + + m = model.clone() + m.set_model_output_block_patch(output_block_patch) + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "FreeU": FreeU, +} diff --git a/nodes.py b/nodes.py index 6e0d43747..115862607 100644 --- a/nodes.py +++ b/nodes.py @@ -1782,4 +1782,5 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_freelunch.py")) load_custom_nodes() From 05e661e5efb64803ff9d27191185159081a05297 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 23 Sep 2023 12:19:08 -0400 Subject: [PATCH 098/150] FreeU now works with the refiner. --- comfy_extras/nodes_freelunch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 535eece39..c3542a7a4 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -37,13 +37,13 @@ class FreeU: CATEGORY = "_for_testing" def patch(self, model, b1, b2, s1, s2): + model_channels = model.model.model_config.unet_config["model_channels"] + scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} def output_block_patch(h, hsp, transformer_options): - if h.shape[1] == 1280: - h[:,:640] = h[:,:640] * b1 - hsp = Fourier_filter(hsp, threshold=1, scale=s1) - if h.shape[1] == 640: - h[:,:320] = h[:,:320] * b2 - hsp = Fourier_filter(hsp, threshold=1, scale=s2) + scale = scale_dict.get(h.shape[1], None) + if scale is not None: + h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0] + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) return h, hsp m = model.clone() From 76cdc809bfe562dc1026784f26ae0b9582016d6b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 23 Sep 2023 18:47:46 -0400 Subject: [PATCH 099/150] Support more controlnet models. --- comfy/controlnet.py | 2 +- comfy/model_detection.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index af0df103e..ea219c7e5 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -354,7 +354,7 @@ def load_controlnet(ckpt_path, model=None): if controlnet_config is None: use_fp16 = comfy.model_management.should_use_fp16() - controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16).unet_config + controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 372d5a2df..787c78575 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,5 +1,5 @@ - -from . import supported_models +import comfy.supported_models +import comfy.supported_models_base def count_blocks(state_dict_keys, prefix_string): count = 0 @@ -109,17 +109,20 @@ def detect_unet_config(state_dict, key_prefix, use_fp16): return unet_config def model_config_from_unet_config(unet_config): - for model_config in supported_models.models: + for model_config in comfy.supported_models.models: if model_config.matches(unet_config): return model_config(unet_config) print("no match", unet_config) return None -def model_config_from_unet(state_dict, unet_key_prefix, use_fp16): +def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False): unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16) - return model_config_from_unet_config(unet_config) - + model_config = model_config_from_unet_config(unet_config) + if model_config is None and use_base_if_no_match: + return comfy.supported_models_base.BASE(unet_config) + else: + return model_config def unet_config_from_diffusers_unet(state_dict, use_fp16): match = {} From 593b7069e7cc3bf6ce8283849c65280369e4414b Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Sun, 24 Sep 2023 12:08:54 -0300 Subject: [PATCH 100/150] Proportional scale latent and image --- nodes.py | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/nodes.py b/nodes.py index 115862607..0882185a4 100644 --- a/nodes.py +++ b/nodes.py @@ -967,8 +967,8 @@ class LatentUpscale: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "crop": (s.crop_methods,)}} RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" @@ -976,8 +976,22 @@ class LatentUpscale: CATEGORY = "latent" def upscale(self, samples, upscale_method, width, height, crop): - s = samples.copy() - s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) + if width == 0 and height == 0: + s = samples + else: + s = samples.copy() + + if width == 0: + height = max(64, height) + width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2])) + elif height == 0: + width = max(64, width) + height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3])) + else: + width = max(64, width) + height = max(64, height) + + s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) class LatentUpscaleBy: @@ -1429,8 +1443,8 @@ class ImageScale: @classmethod def INPUT_TYPES(s): return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "crop": (s.crop_methods,)}} RETURN_TYPES = ("IMAGE",) FUNCTION = "upscale" @@ -1438,9 +1452,18 @@ class ImageScale: CATEGORY = "image/upscaling" def upscale(self, image, upscale_method, width, height, crop): - samples = image.movedim(-1,1) - s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop) - s = s.movedim(1,-1) + if width == 0 and height == 0: + s = image + else: + samples = image.movedim(-1,1) + + if width == 0: + width = max(1, round(samples.shape[3] * height / samples.shape[2])) + elif height == 0: + height = max(1, round(samples.shape[2] * width / samples.shape[3])) + + s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop) + s = s.movedim(1,-1) return (s,) class ImageScaleBy: From 77c124c5a17534e347bdebbc1ace807d61416147 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 24 Sep 2023 13:27:57 -0400 Subject: [PATCH 101/150] Fix typo. --- nodes.py | 2 +- web/scripts/app.js | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 115862607..4739977e4 100644 --- a/nodes.py +++ b/nodes.py @@ -1604,7 +1604,7 @@ NODE_CLASS_MAPPINGS = { "ImageBatch": ImageBatch, "ImagePadForOutpaint": ImagePadForOutpaint, "EmptyImage": EmptyImage, - "ConditioningAverage ": ConditioningAverage , + "ConditioningAverage": ConditioningAverage , "ConditioningCombine": ConditioningCombine, "ConditioningConcat": ConditioningConcat, "ConditioningSetArea": ConditioningSetArea, diff --git a/web/scripts/app.js b/web/scripts/app.js index 5efe08c00..b41c12b86 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1322,6 +1322,7 @@ export class ComfyApp { for (let n of graphData.nodes) { // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader"; + if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix // Find missing node types if (!(n.type in LiteGraph.registered_node_types)) { From f00471cdc8f92c930436cf288f1c12119f638a67 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 24 Sep 2023 18:09:44 -0400 Subject: [PATCH 102/150] Do FreeU fft on CPU if the device doesn't support fft functions. --- comfy_extras/nodes_freelunch.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index c3542a7a4..07a88bd96 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -39,11 +39,22 @@ class FreeU: def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} + on_cpu_devices = {} + def output_block_patch(h, hsp, transformer_options): scale = scale_dict.get(h.shape[1], None) if scale is not None: h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0] - hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + if hsp.device not in on_cpu_devices: + try: + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + except: + print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.") + on_cpu_devices[hsp.device] = True + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + else: + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + return h, hsp m = model.clone() From 42f6d1ebe2b1f53bf491edeac8ca18fd21a12d37 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 25 Sep 2023 01:21:28 -0400 Subject: [PATCH 103/150] Increase maximum batch sizes of empty image nodes. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 4739977e4..fbe1ee1cb 100644 --- a/nodes.py +++ b/nodes.py @@ -891,7 +891,7 @@ class EmptyLatentImage: def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -1503,7 +1503,7 @@ class EmptyImage: def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024}), "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), }} RETURN_TYPES = ("IMAGE",) From 2381d36e6db8e8150e42ff2ede628db5b00ae26f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 25 Sep 2023 01:46:44 -0400 Subject: [PATCH 104/150] 1024 wasn't enough. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index fbe1ee1cb..04d9ae2ca 100644 --- a/nodes.py +++ b/nodes.py @@ -891,7 +891,7 @@ class EmptyLatentImage: def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024})}} + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -1503,7 +1503,7 @@ class EmptyImage: def INPUT_TYPES(s): return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), }} RETURN_TYPES = ("IMAGE",) From 046b4fe0eebffb2e48b1ea9ab5d245a56b2e4c49 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 25 Sep 2023 16:02:21 -0400 Subject: [PATCH 105/150] Support batches of masks in mask composite nodes. --- comfy_extras/nodes_mask.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 43f623a62..b4c658a7a 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -1,6 +1,7 @@ import numpy as np from scipy.ndimage import grey_dilation import torch +import comfy.utils from nodes import MAX_RESOLUTION @@ -8,6 +9,8 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") + source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) + x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) @@ -18,8 +21,8 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou mask = torch.ones_like(source) else: mask = mask.clone() - mask = torch.nn.functional.interpolate(mask[None, None], size=(source.shape[2], source.shape[3]), mode="bilinear") - mask = mask.repeat((source.shape[0], source.shape[1], 1, 1)) + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") + mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) # calculate the bounds of the source that will be overlapping the destination # this prevents the source trying to overwrite latent pixels that are out of bounds @@ -122,7 +125,7 @@ class ImageToMask: def image_to_mask(self, image, channel): channels = ["red", "green", "blue"] - mask = image[0, :, :, channels.index(channel)] + mask = image[:, :, :, channels.index(channel)] return (mask,) class ImageColorToMask: From d2cec6cdbf5361413ddf624c72b0b9b2a7a156ee Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 25 Sep 2023 16:19:13 -0400 Subject: [PATCH 106/150] Make mask functions work with batches of masks and images. --- comfy_extras/nodes_mask.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index b4c658a7a..8f87e4cd8 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -144,8 +144,8 @@ class ImageColorToMask: FUNCTION = "image_to_mask" def image_to_mask(self, image, color): - temp = (torch.clamp(image[0], 0, 1.0) * 255.0).round().to(torch.int) - temp = torch.bitwise_left_shift(temp[:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,1], 8) + temp[:,:,2] + temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) + temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] mask = torch.where(temp == color, 255, 0).float() return (mask,) @@ -167,7 +167,7 @@ class SolidMask: FUNCTION = "solid" def solid(self, value, width, height): - out = torch.full((height, width), value, dtype=torch.float32, device="cpu") + out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") return (out,) class InvertMask: @@ -209,7 +209,8 @@ class CropMask: FUNCTION = "crop" def crop(self, mask, x, y, width, height): - out = mask[y:y + height, x:x + width] + mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) + out = mask[:, y:y + height, x:x + width] return (out,) class MaskComposite: @@ -232,27 +233,28 @@ class MaskComposite: FUNCTION = "combine" def combine(self, destination, source, x, y, operation): - output = destination.clone() + output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() + source = source.reshape((-1, source.shape[-2], source.shape[-1])) left, top = (x, y,) - right, bottom = (min(left + source.shape[1], destination.shape[1]), min(top + source.shape[0], destination.shape[0])) + right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) visible_width, visible_height = (right - left, bottom - top,) source_portion = source[:visible_height, :visible_width] destination_portion = destination[top:bottom, left:right] if operation == "multiply": - output[top:bottom, left:right] = destination_portion * source_portion + output[:, top:bottom, left:right] = destination_portion * source_portion elif operation == "add": - output[top:bottom, left:right] = destination_portion + source_portion + output[:, top:bottom, left:right] = destination_portion + source_portion elif operation == "subtract": - output[top:bottom, left:right] = destination_portion - source_portion + output[:, top:bottom, left:right] = destination_portion - source_portion elif operation == "and": - output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() + output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "or": - output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() + output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "xor": - output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() + output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() output = torch.clamp(output, 0.0, 1.0) @@ -278,7 +280,7 @@ class FeatherMask: FUNCTION = "feather" def feather(self, mask, left, top, right, bottom): - output = mask.clone() + output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() left = min(left, output.shape[1]) right = min(right, output.shape[1]) @@ -287,19 +289,19 @@ class FeatherMask: for x in range(left): feather_rate = (x + 1.0) / left - output[:, x] *= feather_rate + output[:, :, x] *= feather_rate for x in range(right): feather_rate = (x + 1) / right - output[:, -x] *= feather_rate + output[:, :, -x] *= feather_rate for y in range(top): feather_rate = (y + 1) / top - output[y, :] *= feather_rate + output[:, y, :] *= feather_rate for y in range(bottom): feather_rate = (y + 1) / bottom - output[-y, :] *= feather_rate + output[:, -y, :] *= feather_rate return (output,) From e0efa78b710d0bd213e8f22220fd53c9421906d8 Mon Sep 17 00:00:00 2001 From: Michael Poutre Date: Mon, 25 Sep 2023 21:20:51 -0700 Subject: [PATCH 107/150] chore(CI): Update test-build to use updated version of actions --- .github/workflows/test-build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 421dd5ee4..444d6b254 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -20,9 +20,9 @@ jobs: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies From d76d71de3fc5e9618226c53f5a4a1a1a6c14b4fe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 26 Sep 2023 02:45:31 -0400 Subject: [PATCH 108/150] GrowMask can now be used with negative numbers to erode it. --- comfy_extras/nodes_mask.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 8f87e4cd8..aa13cac01 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -1,5 +1,5 @@ import numpy as np -from scipy.ndimage import grey_dilation +import scipy.ndimage import torch import comfy.utils @@ -311,7 +311,7 @@ class GrowMask: return { "required": { "mask": ("MASK",), - "expand": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), "tapered_corners": ("BOOLEAN", {"default": True}), }, } @@ -328,8 +328,11 @@ class GrowMask: [1, 1, 1], [c, 1, c]]) output = mask.numpy().copy() + while expand < 0: + output = scipy.ndimage.grey_erosion(output, footprint=kernel) + expand += 1 while expand > 0: - output = grey_dilation(output, footprint=kernel) + output = scipy.ndimage.grey_dilation(output, footprint=kernel) expand -= 1 output = torch.from_numpy(output) return (output,) From 1d36dfb9fe025b716bc66d920b996381f457393d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 26 Sep 2023 02:53:57 -0400 Subject: [PATCH 109/150] GrowMask now works with mask batches. --- comfy_extras/nodes_mask.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index aa13cac01..af7cb07bf 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -327,15 +327,19 @@ class GrowMask: kernel = np.array([[c, 1, c], [1, 1, 1], [c, 1, c]]) - output = mask.numpy().copy() - while expand < 0: - output = scipy.ndimage.grey_erosion(output, footprint=kernel) - expand += 1 - while expand > 0: - output = scipy.ndimage.grey_dilation(output, footprint=kernel) - expand -= 1 - output = torch.from_numpy(output) - return (output,) + mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) + out = [] + for m in mask: + output = m.numpy() + while expand < 0: + output = scipy.ndimage.grey_erosion(output, footprint=kernel) + expand += 1 + while expand > 0: + output = scipy.ndimage.grey_dilation(output, footprint=kernel) + expand -= 1 + output = torch.from_numpy(output) + out.append(output) + return (torch.cat(out, dim=0),) From 9546a798fba3c9fc9b6aee26cef46674a184727c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 26 Sep 2023 02:56:40 -0400 Subject: [PATCH 110/150] Make LoadImage and LoadImageMask return masks in batch format. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 8a28e127e..4abb0d24d 100644 --- a/nodes.py +++ b/nodes.py @@ -1369,7 +1369,7 @@ class LoadImage: mask = 1. - torch.from_numpy(mask) else: mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (image, mask) + return (image, mask.unsqueeze(0)) @classmethod def IS_CHANGED(s, image): @@ -1416,7 +1416,7 @@ class LoadImageMask: mask = 1. - mask else: mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (mask,) + return (mask.unsqueeze(0),) @classmethod def IS_CHANGED(s, image, channel): From 446caf711c9e9ae4cdced65bf3609095b26fcde0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 26 Sep 2023 13:45:15 -0400 Subject: [PATCH 111/150] Sampling code refactor. --- comfy/ldm/models/diffusion/ddim.py | 2 +- comfy/samplers.py | 261 ++++++++++++++++------------- 2 files changed, 150 insertions(+), 113 deletions(-) diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index befab0075..433d48e30 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -59,7 +59,7 @@ class DDIMSampler(object): @torch.no_grad() def sample_custom(self, ddim_timesteps, - conditioning, + conditioning=None, callback=None, img_callback=None, quantize_x0=False, diff --git a/comfy/samplers.py b/comfy/samplers.py index e3192ca58..9afde9da7 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -544,11 +544,152 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): return conds +class Sampler: + def sample(self): + pass + + def max_denoise(self, model_wrap, sigmas): + return math.isclose(float(model_wrap.sigma_max), float(sigmas[0])) + +class DDIM(Sampler): + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + timesteps = [] + for s in range(sigmas.shape[0]): + timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s])) + noise_mask = None + if denoise_mask is not None: + noise_mask = 1.0 - denoise_mask + + ddim_callback = None + if callback is not None: + total_steps = len(timesteps) - 1 + ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) + + max_denoise = self.max_denoise(model_wrap, sigmas) + + ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device) + ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) + z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise) + samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps, + batch_size=noise.shape[0], + shape=noise.shape[1:], + verbose=False, + eta=0.0, + x_T=z_enc, + x0=latent_image, + img_callback=ddim_callback, + denoise_function=model_wrap.predict_eps_discrete_timestep, + extra_args=extra_args, + mask=noise_mask, + to_zero=sigmas[-1]==0, + end_step=sigmas.shape[0] - 1, + disable_pbar=disable_pbar) + return samples + +class UNIPC(Sampler): + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) + +class UNIPCBH2(Sampler): + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) + +KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", + "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", + "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"] + +def ksampler(sampler_name): + class KSAMPLER(Sampler): + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + extra_args["denoise_mask"] = denoise_mask + model_k = KSamplerX0Inpaint(model_wrap) + model_k.latent_image = latent_image + model_k.noise = noise + + if self.max_denoise(model_wrap, sigmas): + noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) + else: + noise = noise * sigmas[0] + + k_callback = None + total_steps = len(sigmas) - 1 + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + + sigma_min = sigmas[-1] + if sigma_min == 0: + sigma_min = sigmas[-2] + + if latent_image is not None: + noise += latent_image + if sampler_name == "dpm_fast": + samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + elif sampler_name == "dpm_adaptive": + samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) + else: + samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + return samples + return KSAMPLER + + +def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + positive = positive[:] + negative = negative[:] + + resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) + resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) + + model_denoise = CFGNoisePredictor(model) + if model.model_type == model_base.ModelType.V_PREDICTION: + model_wrap = CompVisVDenoiser(model_denoise, quantize=True) + else: + model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True) + + calculate_start_end_timesteps(model_wrap, negative) + calculate_start_end_timesteps(model_wrap, positive) + + #make sure each cond area has an opposite one with the same area + for c in positive: + create_cond_with_same_area_if_none(negative, c) + for c in negative: + create_cond_with_same_area_if_none(positive, c) + + pre_run_control(model_wrap, negative + positive) + + apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) + + if model.is_adm(): + positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") + negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") + + if latent_image is not None: + latent_image = model.process_latent_in(latent_image) + + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} + + cond_concat = None + if hasattr(model, 'concat_keys'): #inpaint + cond_concat = [] + for ck in model.concat_keys: + if denoise_mask is not None: + if ck == "mask": + cond_concat.append(denoise_mask[:,:1]) + elif ck == "masked_image": + cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space + else: + if ck == "mask": + cond_concat.append(torch.ones_like(noise)[:,:1]) + elif ck == "masked_image": + cond_concat.append(blank_inpaint_image_like(noise)) + extra_args["cond_concat"] = cond_concat + + samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) + return model.process_latent_out(samples.to(torch.float32)) + class KSampler: SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] - SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", - "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", - "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "ddim", "uni_pc", "uni_pc_bh2"] + SAMPLERS = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -628,117 +769,13 @@ class KSampler: else: return torch.zeros_like(noise) - positive = positive[:] - negative = negative[:] - - resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) - resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) - - calculate_start_end_timesteps(self.model_wrap, negative) - calculate_start_end_timesteps(self.model_wrap, positive) - - #make sure each cond area has an opposite one with the same area - for c in positive: - create_cond_with_same_area_if_none(negative, c) - for c in negative: - create_cond_with_same_area_if_none(positive, c) - - pre_run_control(self.model_wrap, negative + positive) - - apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) - apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - - if self.model.is_adm(): - positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive") - negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative") - - if latent_image is not None: - latent_image = self.model.process_latent_in(latent_image) - - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed} - - cond_concat = None - if hasattr(self.model, 'concat_keys'): #inpaint - cond_concat = [] - for ck in self.model.concat_keys: - if denoise_mask is not None: - if ck == "mask": - cond_concat.append(denoise_mask[:,:1]) - elif ck == "masked_image": - cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space - else: - if ck == "mask": - cond_concat.append(torch.ones_like(noise)[:,:1]) - elif ck == "masked_image": - cond_concat.append(blank_inpaint_image_like(noise)) - extra_args["cond_concat"] = cond_concat - - if sigmas[0] != self.sigmas[0] or (self.denoise is not None and self.denoise < 1.0): - max_denoise = False - else: - max_denoise = True - - if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) + sampler = UNIPC elif self.sampler == "uni_pc_bh2": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) + sampler = UNIPCBH2 elif self.sampler == "ddim": - timesteps = [] - for s in range(sigmas.shape[0]): - timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s])) - noise_mask = None - if denoise_mask is not None: - noise_mask = 1.0 - denoise_mask - - ddim_callback = None - if callback is not None: - total_steps = len(timesteps) - 1 - ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) - - sampler = DDIMSampler(self.model, device=self.device) - sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) - z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) - samples, _ = sampler.sample_custom(ddim_timesteps=timesteps, - conditioning=positive, - batch_size=noise.shape[0], - shape=noise.shape[1:], - verbose=False, - unconditional_guidance_scale=cfg, - unconditional_conditioning=negative, - eta=0.0, - x_T=z_enc, - x0=latent_image, - img_callback=ddim_callback, - denoise_function=self.model_wrap.predict_eps_discrete_timestep, - extra_args=extra_args, - mask=noise_mask, - to_zero=sigmas[-1]==0, - end_step=sigmas.shape[0] - 1, - disable_pbar=disable_pbar) - + sampler = DDIM else: - extra_args["denoise_mask"] = denoise_mask - self.model_k.latent_image = latent_image - self.model_k.noise = noise + sampler = ksampler(self.sampler) - if max_denoise: - noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) - else: - noise = noise * sigmas[0] - - k_callback = None - total_steps = len(sigmas) - 1 - if callback is not None: - k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) - - if latent_image is not None: - noise += latent_image - if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) - elif self.sampler == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) - else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) - - return self.model.process_latent_out(samples.to(torch.float32)) + return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) From 1d6dd8318463e896abf9f99cf5381438ee64d302 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 26 Sep 2023 16:25:34 -0400 Subject: [PATCH 112/150] Scheduler code refactor. --- comfy/samplers.py | 66 +++++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 9afde9da7..7668d7913 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -549,7 +549,7 @@ class Sampler: pass def max_denoise(self, model_wrap, sigmas): - return math.isclose(float(model_wrap.sigma_max), float(sigmas[0])) + return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05) class DDIM(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): @@ -631,6 +631,13 @@ def ksampler(sampler_name): return samples return KSAMPLER +def wrap_model(model): + model_denoise = CFGNoisePredictor(model) + if model.model_type == model_base.ModelType.V_PREDICTION: + model_wrap = CompVisVDenoiser(model_denoise, quantize=True) + else: + model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True) + return model_wrap def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): positive = positive[:] @@ -639,11 +646,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model resolve_areas_and_cond_masks(positive, noise.shape[2], noise.shape[3], device) resolve_areas_and_cond_masks(negative, noise.shape[2], noise.shape[3], device) - model_denoise = CFGNoisePredictor(model) - if model.model_type == model_base.ModelType.V_PREDICTION: - model_wrap = CompVisVDenoiser(model_denoise, quantize=True) - else: - model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True) + model_wrap = wrap_model(model) calculate_start_end_timesteps(model_wrap, negative) calculate_start_end_timesteps(model_wrap, positive) @@ -687,19 +690,33 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return model.process_latent_out(samples.to(torch.float32)) +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] +SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] + +def calculate_sigmas_scheduler(model, scheduler_name, steps): + model_wrap = wrap_model(model) + if scheduler_name == "karras": + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) + elif scheduler_name == "exponential": + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) + elif scheduler_name == "normal": + sigmas = model_wrap.get_sigmas(steps) + elif scheduler_name == "simple": + sigmas = simple_scheduler(model_wrap, steps) + elif scheduler_name == "ddim_uniform": + sigmas = ddim_scheduler(model_wrap, steps) + elif scheduler_name == "sgm_uniform": + sigmas = sgm_scheduler(model_wrap, steps) + else: + print("error invalid scheduler", self.scheduler) + return sigmas + class KSampler: - SCHEDULERS = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] - SAMPLERS = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] + SCHEDULERS = SCHEDULER_NAMES + SAMPLERS = SAMPLER_NAMES def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model - self.model_denoise = CFGNoisePredictor(self.model) - if self.model.model_type == model_base.ModelType.V_PREDICTION: - self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) - else: - self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True) - - self.model_k = KSamplerX0Inpaint(self.model_wrap) self.device = device if scheduler not in self.SCHEDULERS: scheduler = self.SCHEDULERS[0] @@ -707,8 +724,6 @@ class KSampler: sampler = self.SAMPLERS[0] self.scheduler = scheduler self.sampler = sampler - self.sigma_min=float(self.model_wrap.sigma_min) - self.sigma_max=float(self.model_wrap.sigma_max) self.set_steps(steps, denoise) self.denoise = denoise self.model_options = model_options @@ -721,20 +736,7 @@ class KSampler: steps += 1 discard_penultimate_sigma = True - if self.scheduler == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) - elif self.scheduler == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) - elif self.scheduler == "normal": - sigmas = self.model_wrap.get_sigmas(steps) - elif self.scheduler == "simple": - sigmas = simple_scheduler(self.model_wrap, steps) - elif self.scheduler == "ddim_uniform": - sigmas = ddim_scheduler(self.model_wrap, steps) - elif self.scheduler == "sgm_uniform": - sigmas = sgm_scheduler(self.model_wrap, steps) - else: - print("error invalid scheduler", self.scheduler) + sigmas = calculate_sigmas_scheduler(self.model, self.scheduler, steps) if discard_penultimate_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) @@ -752,10 +754,8 @@ class KSampler: def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): if sigmas is None: sigmas = self.sigmas - sigma_min = self.sigma_min if last_step is not None and last_step < (len(sigmas) - 1): - sigma_min = sigmas[last_step] sigmas = sigmas[:last_step + 1] if force_full_denoise: sigmas[-1] = 0 From fff491b03289ac954eb465b9a57b30f695259c41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 27 Sep 2023 12:04:07 -0400 Subject: [PATCH 113/150] Model patches can now know which batch is positive and negative. --- comfy/ldm/modules/attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 34484b288..fcae6b66a 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -538,6 +538,8 @@ class BasicTransformerBlock(nn.Module): if "block" in transformer_options: block = transformer_options["block"] extra_options["block"] = block + if "cond_or_uncond" in transformer_options: + extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"] if "patches" in transformer_options: transformer_patches = transformer_options["patches"] else: From bf3fc2f1b7f5b5cf684246be84838e6fc19aeb06 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 27 Sep 2023 16:45:22 -0400 Subject: [PATCH 114/150] Refactor sampling related code. --- comfy/sample.py | 24 ++++++++++++++---------- latent_preview.py | 18 ++++++++++++++++++ nodes.py | 17 +---------------- 3 files changed, 33 insertions(+), 26 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index e4730b189..fe9f4118d 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -70,25 +70,29 @@ def cleanup_additional_models(models): if hasattr(m, 'cleanup'): m.cleanup() -def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): - device = comfy.model_management.get_torch_device() +def prepare_sampling(model, noise_shape, positive, negative, noise_mask): + device = model.load_device if noise_mask is not None: - noise_mask = prepare_mask(noise_mask, noise.shape, device) + noise_mask = prepare_mask(noise_mask, noise_shape, device) real_model = None models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory) + comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory) real_model = model.model - noise = noise.to(device) - latent_image = latent_image.to(device) - - positive_copy = broadcast_cond(positive, noise.shape[0], device) - negative_copy = broadcast_cond(negative, noise.shape[0], device) + positive_copy = broadcast_cond(positive, noise_shape[0], device) + negative_copy = broadcast_cond(negative, noise_shape[0], device) + return real_model, positive_copy, negative_copy, noise_mask, models - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): + real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) + + noise = noise.to(model.load_device) + latent_image = latent_image.to(model.load_device) + + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.cpu() diff --git a/latent_preview.py b/latent_preview.py index 87240a582..740e08607 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -5,6 +5,7 @@ import numpy as np from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD import folder_paths +import comfy.utils MAX_PREVIEW_RESOLUTION = 512 @@ -74,4 +75,21 @@ def get_previewer(device, latent_format): previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors) return previewer +def prepare_callback(model, steps, x0_output_dict=None): + preview_format = "JPEG" + if preview_format not in ["JPEG", "PNG"]: + preview_format = "JPEG" + + previewer = get_previewer(model.load_device, model.model.latent_format) + + pbar = comfy.utils.ProgressBar(steps) + def callback(step, x0, x, total_steps): + if x0_output_dict is not None: + x0_output_dict["x0"] = x0 + + preview_bytes = None + if previewer: + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + pbar.update_absolute(step + 1, total_steps, preview_bytes) + return callback diff --git a/nodes.py b/nodes.py index 4abb0d24d..a847db6fb 100644 --- a/nodes.py +++ b/nodes.py @@ -1189,11 +1189,8 @@ class SetLatentNoiseMask: s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) - def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): - device = comfy.model_management.get_torch_device() latent_image = latent["samples"] - if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: @@ -1204,19 +1201,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] - preview_format = "JPEG" - if preview_format not in ["JPEG", "PNG"]: - preview_format = "JPEG" - - previewer = latent_preview.get_previewer(device, model.model.latent_format) - - pbar = comfy.utils.ProgressBar(steps) - def callback(step, x0, x, total_steps): - preview_bytes = None - if previewer: - preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) - pbar.update_absolute(step + 1, total_steps, preview_bytes) - + callback = latent_preview.prepare_callback(model, steps) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed) From 1adcc4c3a2f6f329c1e4e7ac3114f254f9b5f558 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 27 Sep 2023 22:21:18 -0400 Subject: [PATCH 115/150] Add a SamplerCustom Node. This node takes a list of sigmas and a sampler object as input. This lets people easily implement custom schedulers and samplers as nodes. More nodes will be added to it in the future. --- comfy/sample.py | 12 ++++ comfy_extras/nodes_custom_sampler.py | 98 ++++++++++++++++++++++++++++ nodes.py | 4 +- 3 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_custom_sampler.py diff --git a/comfy/sample.py b/comfy/sample.py index fe9f4118d..322272766 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -99,3 +99,15 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative cleanup_additional_models(models) return samples + +def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): + real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask) + noise = noise.to(model.load_device) + latent_image = latent_image.to(model.load_device) + sigmas = sigmas.to(model.load_device) + + samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + samples = samples.cpu() + cleanup_additional_models(models) + return samples + diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py new file mode 100644 index 000000000..062852629 --- /dev/null +++ b/comfy_extras/nodes_custom_sampler.py @@ -0,0 +1,98 @@ +import comfy.samplers +import comfy.sample +from comfy.k_diffusion import sampling as k_diffusion_sampling +import latent_preview + + +class KarrasScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "_for_testing/custom_sampling" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, sigma_max, sigma_min, rho): + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return (sigmas, ) + + +class KSamplerSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sampler_name": (comfy.samplers.KSAMPLER_NAMES, ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "_for_testing/custom_sampling" + + FUNCTION = "get_sampler" + + def get_sampler(self, sampler_name): + sampler = comfy.samplers.ksampler(sampler_name)() + return (sampler, ) + +class SamplerCustom: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "add_noise": (["enable", "disable"], ), + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "sampler": ("SAMPLER", ), + "sigmas": ("SIGMAS", ), + "latent_image": ("LATENT", ), + } + } + + RETURN_TYPES = ("LATENT","LATENT") + RETURN_NAMES = ("output", "denoised_output") + + FUNCTION = "sample" + + CATEGORY = "_for_testing/custom_sampling" + + def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + if add_noise == "disable": + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, noise_seed, batch_inds) + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = False + samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return (out, out_denoised) + +NODE_CLASS_MAPPINGS = { + "SamplerCustom": SamplerCustom, + "KarrasScheduler": KarrasScheduler, + "KSamplerSelect": KSamplerSelect, +} diff --git a/nodes.py b/nodes.py index a847db6fb..1232373be 100644 --- a/nodes.py +++ b/nodes.py @@ -1202,9 +1202,10 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = latent["noise_mask"] callback = latent_preview.prepare_callback(model, steps) + disable_pbar = False samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed) + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) out = latent.copy() out["samples"] = samples return (out, ) @@ -1791,4 +1792,5 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_freelunch.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_custom_sampler.py")) load_custom_nodes() From 1d7dfc07d5e76968c9137c17fca0f7ad77a7b9d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 27 Sep 2023 22:32:42 -0400 Subject: [PATCH 116/150] Make add_noise in SamplerCustom a boolean. --- comfy_extras/nodes_custom_sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 062852629..842a9de4f 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -2,7 +2,7 @@ import comfy.samplers import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling import latent_preview - +import torch class KarrasScheduler: @classmethod @@ -45,7 +45,7 @@ class SamplerCustom: def INPUT_TYPES(s): return {"required": {"model": ("MODEL",), - "add_noise": (["enable", "disable"], ), + "add_noise": ("BOOLEAN", {"default": True}), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "positive": ("CONDITIONING", ), @@ -66,7 +66,7 @@ class SamplerCustom: def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): latent = latent_image latent_image = latent["samples"] - if add_noise == "disable": + if not add_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: batch_inds = latent["batch_index"] if "batch_index" in latent else None From d234ca558a7777b607a4f81aeb9e8703ef020977 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Sep 2023 00:17:03 -0400 Subject: [PATCH 117/150] Add missing samplers to KSamplerSelect. --- comfy/samplers.py | 20 ++++++++++++-------- comfy_extras/nodes_custom_sampler.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 7668d7913..a7c240f40 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -711,6 +711,17 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): print("error invalid scheduler", self.scheduler) return sigmas +def sampler_class(name): + if name == "uni_pc": + sampler = UNIPC + elif name == "uni_pc_bh2": + sampler = UNIPCBH2 + elif name == "ddim": + sampler = DDIM + else: + sampler = ksampler(name) + return sampler + class KSampler: SCHEDULERS = SCHEDULER_NAMES SAMPLERS = SAMPLER_NAMES @@ -769,13 +780,6 @@ class KSampler: else: return torch.zeros_like(noise) - if self.sampler == "uni_pc": - sampler = UNIPC - elif self.sampler == "uni_pc_bh2": - sampler = UNIPCBH2 - elif self.sampler == "ddim": - sampler = DDIM - else: - sampler = ksampler(self.sampler) + sampler = sampler_class(self.sampler) return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 842a9de4f..1c587dbd8 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -28,7 +28,7 @@ class KSamplerSelect: @classmethod def INPUT_TYPES(s): return {"required": - {"sampler_name": (comfy.samplers.KSAMPLER_NAMES, ), + {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ), } } RETURN_TYPES = ("SAMPLER",) @@ -37,7 +37,7 @@ class KSamplerSelect: FUNCTION = "get_sampler" def get_sampler(self, sampler_name): - sampler = comfy.samplers.ksampler(sampler_name)() + sampler = comfy.samplers.sampler_class(sampler_name)() return (sampler, ) class SamplerCustom: From 2bf051fda87cfa94e5c99bbd88fc7f1434e9e1a2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Sep 2023 00:30:45 -0400 Subject: [PATCH 118/150] Add a basic node to generate sigmas from scheduler. --- comfy_extras/nodes_custom_sampler.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 1c587dbd8..aafde8f32 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -4,6 +4,26 @@ from comfy.k_diffusion import sampling as k_diffusion_sampling import latent_preview import torch + +class BasicScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "_for_testing/custom_sampling" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model, scheduler, steps): + sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu() + return (sigmas, ) + + class KarrasScheduler: @classmethod def INPUT_TYPES(s): @@ -95,4 +115,5 @@ NODE_CLASS_MAPPINGS = { "SamplerCustom": SamplerCustom, "KarrasScheduler": KarrasScheduler, "KSamplerSelect": KSamplerSelect, + "BasicScheduler": BasicScheduler, } From 76e0f8fc8fe330b9568fab4b4a8049a62d141165 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Sep 2023 00:40:09 -0400 Subject: [PATCH 119/150] Add function to split sigmas. --- comfy_extras/nodes_custom_sampler.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index aafde8f32..efe03ad24 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -43,6 +43,23 @@ class KarrasScheduler: sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) return (sigmas, ) +class SplitSigmas: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sigmas": ("SIGMAS", ), + "step": ("INT", {"default": 0, "min": 0, "max": 10000}), + } + } + RETURN_TYPES = ("SIGMAS","SIGMAS") + CATEGORY = "_for_testing/custom_sampling" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, sigmas, step): + sigmas1 = sigmas[:step + 1] + sigmas2 = sigmas[step + 1:] + return (sigmas1, sigmas2) class KSamplerSelect: @classmethod @@ -116,4 +133,5 @@ NODE_CLASS_MAPPINGS = { "KarrasScheduler": KarrasScheduler, "KSamplerSelect": KSamplerSelect, "BasicScheduler": BasicScheduler, + "SplitSigmas": SplitSigmas, } From 71713888c4d2af38c2f25f39226933081f5f70d7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Sep 2023 00:54:57 -0400 Subject: [PATCH 120/150] Print missing VAE keys. --- comfy/sd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 9bdb2ad64..2f1b2e964 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -152,7 +152,9 @@ class VAE: sd = comfy.utils.load_torch_file(ckpt_path) if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) - self.first_stage_model.load_state_dict(sd, strict=False) + m, u = self.first_stage_model.load_state_dict(sd, strict=False) + if len(m) > 0: + print("Missing VAE keys", m) if device is None: device = model_management.vae_device() From 26b73728053a786c429356fc02a7c98868d2ba02 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Sep 2023 01:11:22 -0400 Subject: [PATCH 121/150] Fix SplitSigmas. --- comfy_extras/nodes_custom_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index efe03ad24..5e5ef61b5 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -58,7 +58,7 @@ class SplitSigmas: def get_sigmas(self, sigmas, step): sigmas1 = sigmas[:step + 1] - sigmas2 = sigmas[step + 1:] + sigmas2 = sigmas[step:] return (sigmas1, sigmas2) class KSamplerSelect: From 66756de1002c23ec4005504232e3f8e5096c964b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Sep 2023 21:56:23 -0400 Subject: [PATCH 122/150] Add SamplerDPMPP_2M_SDE node. --- comfy/samplers.py | 4 ++-- comfy_extras/nodes_custom_sampler.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index a7c240f40..e43f7a6fe 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -598,7 +598,7 @@ KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral" "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"] -def ksampler(sampler_name): +def ksampler(sampler_name, extra_options={}): class KSAMPLER(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): extra_args["denoise_mask"] = denoise_mask @@ -627,7 +627,7 @@ def ksampler(sampler_name): elif sampler_name == "dpm_adaptive": samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options) return samples return KSAMPLER diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 5e5ef61b5..b667afe4f 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -77,6 +77,30 @@ class KSamplerSelect: sampler = comfy.samplers.sampler_class(sampler_name)() return (sampler, ) +class SamplerDPMPP_2M_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"solver_type": (['midpoint', 'heun'], ), + "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "_for_testing/custom_sampling" + + FUNCTION = "get_sampler" + + def get_sampler(self, solver_type, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_2m_sde" + else: + sampler_name = "dpmpp_2m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})() + return (sampler, ) + + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -132,6 +156,7 @@ NODE_CLASS_MAPPINGS = { "SamplerCustom": SamplerCustom, "KarrasScheduler": KarrasScheduler, "KSamplerSelect": KSamplerSelect, + "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "BasicScheduler": BasicScheduler, "SplitSigmas": SplitSigmas, } From 1c8ae9dbb249ed5326d61d16b4e6b5807c09c0e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 29 Sep 2023 05:01:19 +0300 Subject: [PATCH 123/150] Allow GrowMask node to work with batches (for AnimateDiff) (#1623) * Allow mask batches This allows LatentCompositeMasked -node to work with AnimateDiff. I tried to keep old functionality too, unsure if it's correct, but both single mask and batch of masks seems to work with this change. * Update nodes_mask.py --- comfy_extras/nodes_mask.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index af7cb07bf..cdf762ffd 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -331,15 +331,14 @@ class GrowMask: out = [] for m in mask: output = m.numpy() - while expand < 0: - output = scipy.ndimage.grey_erosion(output, footprint=kernel) - expand += 1 - while expand > 0: - output = scipy.ndimage.grey_dilation(output, footprint=kernel) - expand -= 1 + for _ in range(abs(expand)): + if expand < 0: + output = scipy.ndimage.grey_erosion(output, footprint=kernel) + else: + output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = torch.from_numpy(output) out.append(output) - return (torch.cat(out, dim=0),) + return (torch.stack(out, dim=0),) From 0f17993d0587254fcff06bf689dfe38300ea8834 Mon Sep 17 00:00:00 2001 From: badayvedat Date: Fri, 29 Sep 2023 06:09:59 +0300 Subject: [PATCH 124/150] fix: typo in extra sampler --- comfy/extra_samplers/uni_pc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 7eaf6ff62..7e88bb9fa 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -688,7 +688,7 @@ class UniPC: x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) else: x_t_ = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0 ) if x_t is None: From 213976f8c3ea3f45f0c692dd8aac2fd9fea433e3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 29 Sep 2023 09:05:30 -0400 Subject: [PATCH 125/150] Add ExponentialScheduler and PolyexponentialScheduler nodes. --- comfy_extras/nodes_custom_sampler.py | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index b667afe4f..a1dc97848 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -43,6 +43,43 @@ class KarrasScheduler: sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) return (sigmas, ) +class ExponentialScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "_for_testing/custom_sampling" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, sigma_max, sigma_min): + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) + return (sigmas, ) + +class PolyexponentialScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "_for_testing/custom_sampling" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, sigma_max, sigma_min, rho): + sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return (sigmas, ) + class SplitSigmas: @classmethod def INPUT_TYPES(s): @@ -155,6 +192,8 @@ class SamplerCustom: NODE_CLASS_MAPPINGS = { "SamplerCustom": SamplerCustom, "KarrasScheduler": KarrasScheduler, + "ExponentialScheduler": ExponentialScheduler, + "PolyexponentialScheduler": PolyexponentialScheduler, "KSamplerSelect": KSamplerSelect, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "BasicScheduler": BasicScheduler, From 8ab49dc0a4768f17c5a46627fd5601a484549a5b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 30 Sep 2023 01:31:52 -0400 Subject: [PATCH 126/150] DPMPP_SDE node. --- comfy_extras/nodes_custom_sampler.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index a1dc97848..d2cec7f09 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -138,6 +138,29 @@ class SamplerDPMPP_2M_SDE: return (sampler, ) +class SamplerDPMPP_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "_for_testing/custom_sampling" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise, r, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_sde" + else: + sampler_name = "dpmpp_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})() + return (sampler, ) + class SamplerCustom: @classmethod def INPUT_TYPES(s): @@ -196,6 +219,7 @@ NODE_CLASS_MAPPINGS = { "PolyexponentialScheduler": PolyexponentialScheduler, "KSamplerSelect": KSamplerSelect, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, + "SamplerDPMPP_SDE": SamplerDPMPP_SDE, "BasicScheduler": BasicScheduler, "SplitSigmas": SplitSigmas, } From 2ef459b1d4d627929c84d11e5e0cbe3ded9c9f48 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 1 Oct 2023 03:48:07 -0400 Subject: [PATCH 127/150] Add VPScheduler node --- comfy_extras/nodes_custom_sampler.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index d2cec7f09..42a1fd6ba 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -80,6 +80,25 @@ class PolyexponentialScheduler: sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) return (sigmas, ) +class VPScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), #TODO: fix default values + "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), + "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "_for_testing/custom_sampling" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, beta_d, beta_min, eps_s): + sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) + return (sigmas, ) + class SplitSigmas: @classmethod def INPUT_TYPES(s): @@ -217,6 +236,7 @@ NODE_CLASS_MAPPINGS = { "KarrasScheduler": KarrasScheduler, "ExponentialScheduler": ExponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler, + "VPScheduler": VPScheduler, "KSamplerSelect": KSamplerSelect, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, From ec454c771b8c2007fbf08602a3205bacd96272a6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 2 Oct 2023 17:26:59 -0400 Subject: [PATCH 128/150] Refactor with code from comment of #1588 --- nodes.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/nodes.py b/nodes.py index 1232373be..919aac89e 100644 --- a/nodes.py +++ b/nodes.py @@ -1781,16 +1781,23 @@ def load_custom_nodes(): print() def init_custom_nodes(): - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_latent.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_tomesd.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_clip_sdxl.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_canny.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_freelunch.py")) - load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_custom_sampler.py")) + extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras") + extras_files = [ + "nodes_latent.py", + "nodes_hypernetwork.py", + "nodes_upscale_model.py", + "nodes_post_processing.py", + "nodes_mask.py", + "nodes_rebatch.py", + "nodes_model_merging.py", + "nodes_tomesd.py", + "nodes_clip_sdxl.py", + "nodes_canny.py", + "nodes_freelunch.py", + "nodes_custom_sampler.py" + ] + + for node_file in extras_files: + load_custom_node(os.path.join(extras_dir, node_file)) + load_custom_nodes() From fe1e2dbe9000ad3365a71986c726259c1353d304 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 3 Oct 2023 00:01:49 -0400 Subject: [PATCH 129/150] pytorch nightly is now ROCm 5.7 --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d83b4bdac..97677921a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ComfyUI ======= -A powerful and modular stable diffusion GUI and backend. +The most powerful and modular stable diffusion GUI and backend. ----------- ![ComfyUI Screenshot](comfyui_screenshot.png) @@ -94,8 +94,8 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` -This is the command to install the nightly with ROCm 5.6 that supports the 7000 series and might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.6``` +This is the command to install the nightly with ROCm 5.7 that supports the 7000 series and might have some performance improvements: +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7``` ### NVIDIA From 1f38de1fb3c9e1d8bed81fef7901d5f37561d937 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Tue, 3 Oct 2023 18:30:38 +0900 Subject: [PATCH 130/150] If an error occurs while retrieving object_info, only the node that encountered the error should be handled as an exception, while the information for the other nodes should continue to be processed normally. --- server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index b2e16716b..63f337a87 100644 --- a/server.py +++ b/server.py @@ -413,7 +413,11 @@ class PromptServer(): async def get_object_info(request): out = {} for x in nodes.NODE_CLASS_MAPPINGS: - out[x] = node_info(x) + try: + out[x] = node_info(x) + except Exception as e: + print(f"[ERROR] An error occurred while retrieving information for the '{x}' node.", file=sys.stderr) + traceback.print_exc() return web.json_response(out) @routes.get("/object_info/{node_class}") From 6fc73143934028771466f76818ebef3219bb1793 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 3 Oct 2023 20:19:12 +0100 Subject: [PATCH 131/150] support refreshing primitive combos no longer uses combo list as type name --- web/extensions/core/widgetInputs.js | 128 +++++++++++++++++++--------- web/scripts/app.js | 53 +++++++++++- 2 files changed, 137 insertions(+), 44 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 606605f0a..98d52b02c 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -4,6 +4,11 @@ import { app } from "../../scripts/app.js"; const CONVERTED_TYPE = "converted-widget"; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; +function getConfig(widgetName) { + const { nodeData } = this.constructor; + return nodeData?.input?.required[widgetName] ?? nodeData?.input?.optional?.[widgetName]; +} + function isConvertableWidget(widget, config) { return (VALID_TYPES.includes(widget.type) || VALID_TYPES.includes(config[0])) && !widget.options?.forceInput; } @@ -55,12 +60,12 @@ function showWidget(widget) { function convertToInput(node, widget, config) { hideWidget(node, widget); - const { linkType } = getWidgetType(config); + const { linkType } = getWidgetType(config, `${node.comfyClass}|${widget.name}`); // Add input and store widget config for creating on primitive node const sz = node.size; node.addInput(widget.name, linkType, { - widget: { name: widget.name, config }, + widget: { name: widget.name, getConfig: () => config }, }); for (const widget of node.widgets) { @@ -84,13 +89,13 @@ function convertToWidget(node, widget) { node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]); } -function getWidgetType(config) { +function getWidgetType(config, comboType) { // Special handling for COMBO so we restrict links based on the entries let type = config[0]; let linkType = type; if (type instanceof Array) { type = "COMBO"; - linkType = linkType.join(","); + linkType = comboType; } return { type, linkType }; } @@ -116,7 +121,7 @@ app.registerExtension({ callback: () => convertToWidget(this, w), }); } else { - const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}]; + const config = getConfig.call(this, w.name) ?? [w.type, w.options || {}]; if (isConvertableWidget(w, config)) { toInput.push({ content: `Convert ${w.name} to input`, @@ -137,34 +142,56 @@ app.registerExtension({ return r; }; - const origOnNodeCreated = nodeType.prototype.onNodeCreated + nodeType.prototype.onGraphConfigured = function () { + if (!this.inputs) return; + + for (const input of this.inputs) { + if (input.widget) { + // Cleanup old widget config + delete input.widget.config; + + if (!input.widget.getConfig) { + input.widget.getConfig = getConfig.bind(this, input.widget.name); + } + + const config = input.widget.getConfig(); + if (config[1]?.forceInput) continue; + + const w = this.widgets.find((w) => w.name === input.widget.name); + if (w) { + hideWidget(this, w); + } else { + convertToWidget(this, input); + } + } + } + }; + + const origOnNodeCreated = nodeType.prototype.onNodeCreated; nodeType.prototype.onNodeCreated = function () { const r = origOnNodeCreated ? origOnNodeCreated.apply(this) : undefined; - if (this.widgets) { + + // When node is created, convert any force/default inputs + if (!app.configuringGraph && this.widgets) { for (const w of this.widgets) { if (w?.options?.forceInput || w?.options?.defaultInput) { - const config = nodeData?.input?.required[w.name] || nodeData?.input?.optional?.[w.name] || [w.type, w.options || {}]; + const config = getConfig.call(this, w.name) ?? [w.type, w.options || {}]; convertToInput(this, w, config); } } } - return r; - } - // On initial configure of nodes hide all converted widgets + return r; + }; + const origOnConfigure = nodeType.prototype.onConfigure; nodeType.prototype.onConfigure = function () { const r = origOnConfigure ? origOnConfigure.apply(this, arguments) : undefined; - - if (this.inputs) { + if (!app.configuringGraph && this.inputs) { + // On copy + paste of nodes, ensure that widget configs are set up for (const input of this.inputs) { - if (input.widget && !input.widget.config[1]?.forceInput) { - const w = this.widgets.find((w) => w.name === input.widget.name); - if (w) { - hideWidget(this, w); - } else { - convertToWidget(this, input) - } + if (input.widget && !input.widget.getConfig) { + input.widget.getConfig = getConfig.bind(this, input.widget.name); } } } @@ -190,7 +217,7 @@ app.registerExtension({ const input = this.inputs[slot]; if (!input.widget || !input[ignoreDblClick]) { // Not a widget input or already handled input - if (!(input.type in ComfyWidgets) && !(input.widget.config?.[0] instanceof Array)) { + if (!(input.type in ComfyWidgets) && !(input.widget.getConfig?.()?.[0] instanceof Array)) { return r; //also Not a ComfyWidgets input or combo (do nothing) } } @@ -262,17 +289,38 @@ app.registerExtension({ } } + refreshComboInNode() { + const widget = this.widgets?.[0]; + if (widget?.type === "combo") { + widget.options.values = this.outputs[0].widget.getConfig()[0]; + + if (!widget.options.values.includes(widget.value)) { + widget.value = widget.options.values[0]; + widget.callback(widget.value); + } + } + } + + onAfterGraphConfigured() { + if (this.outputs[0].links?.length && !this.widgets?.length) { + this.#onFirstConnection(); + + // Populate widget values from config data + for (let i = 0; i < this.widgets_values.length; i++) { + this.widgets[i].value = this.widgets_values[i]; + } + } + } + onConnectionsChange(_, index, connected) { + if (app.configuringGraph) { + // Dont run while the graph is still setting up + return; + } + if (connected) { - if (this.outputs[0].links?.length) { - if (!this.widgets?.length) { - this.#onFirstConnection(); - } - if (!this.widgets?.length && this.outputs[0].widget) { - // On first load it often cant recreate the widget as the other node doesnt exist yet - // Manually recreate it from the output info - this.#createWidget(this.outputs[0].widget.config); - } + if (this.outputs[0].links?.length && !this.widgets?.length) { + this.#onFirstConnection(); } } else if (!this.outputs[0].links?.length) { this.#onLastDisconnect(); @@ -304,23 +352,21 @@ app.registerExtension({ const input = theirNode.inputs[link.target_slot]; if (!input) return; - - var _widget; + let widget; if (!input.widget) { if (!(input.type in ComfyWidgets)) return; - _widget = { "name": input.name, "config": [input.type, {}] }//fake widget + widget = { name: input.name, getConfig: () => [input.type, {}] }; //fake widget } else { - _widget = input.widget; + widget = input.widget; } - const widget = _widget; - const { type, linkType } = getWidgetType(widget.config); + const { type, linkType } = getWidgetType(widget.getConfig(), `${theirNode.comfyClass}|${widget.name}`); // Update our output to restrict to the widget type this.outputs[0].type = linkType; this.outputs[0].name = type; this.outputs[0].widget = widget; - this.#createWidget(widget.config, theirNode, widget.name); + this.#createWidget(widget.getConfig(), theirNode, widget.name); } #createWidget(inputData, node, widgetName) { @@ -334,7 +380,7 @@ app.registerExtension({ if (type in ComfyWidgets) { widget = (ComfyWidgets[type](this, "value", inputData, app) || {}).widget; } else { - widget = this.addWidget(type, "value", null, () => { }, {}); + widget = this.addWidget(type, "value", null, () => {}, {}); } if (node?.widgets && widget) { @@ -376,8 +422,8 @@ app.registerExtension({ #isValidConnection(input) { // Only allow connections where the configs match - const config1 = this.outputs[0].widget.config; - const config2 = input.widget.config; + const config1 = this.outputs[0].widget.getConfig(); + const config2 = input.widget.getConfig(); if (config1[0] instanceof Array) { // These checks shouldnt actually be necessary as the types should match @@ -395,7 +441,7 @@ app.registerExtension({ } for (const k in config1[1]) { - if (k !== "default" && k !== 'forceInput') { + if (k !== "default" && k !== "forceInput") { if (config1[1][k] !== config2[1][k]) { return false; } diff --git a/web/scripts/app.js b/web/scripts/app.js index b41c12b86..3c29a684a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1114,6 +1114,40 @@ export class ComfyApp { }); } + #addConfigureHandler() { + const app = this; + const configure = LGraph.prototype.configure; + // Flag that the graph is configuring to prevent nodes from running checks while its still loading + LGraph.prototype.configure = function () { + app.configuringGraph = true; + try { + return configure.apply(this, arguments); + } finally { + app.configuringGraph = false; + } + }; + } + + #addAfterConfigureHandler() { + const app = this; + const onConfigure = app.graph.onConfigure; + app.graph.onConfigure = function () { + // Fire callbacks before the onConfigure, this is used by widget inputs to setup the config + for (const node of app.graph._nodes) { + node.onGraphConfigured?.(); + } + + const r = onConfigure?.apply(this, arguments); + + // Fire after onConfigure, used by primitves to generate widget using input nodes config + for (const node of app.graph._nodes) { + node.onAfterGraphConfigured?.(); + } + + return r; + }; + } + /** * Loads all extensions from the API into the window in parallel */ @@ -1147,8 +1181,12 @@ export class ComfyApp { this.#addProcessMouseHandler(); this.#addProcessKeyHandler(); + this.#addConfigureHandler(); this.graph = new LGraph(); + + this.#addAfterConfigureHandler(); + const canvas = (this.canvas = new LGraphCanvas(canvasEl, this.graph)); this.ctx = canvasEl.getContext("2d"); @@ -1285,6 +1323,7 @@ export class ComfyApp { { title: nodeData.display_name || nodeData.name, comfyClass: nodeData.name, + nodeData } ); node.prototype.comfyClass = nodeData.name; @@ -1670,13 +1709,21 @@ export class ComfyApp { async refreshComboInNodes() { const defs = await api.getNodeDefs(); + for(const nodeId in LiteGraph.registered_node_types) { + const node = LiteGraph.registered_node_types[nodeId]; + const nodeDef = defs[nodeId]; + if(!nodeDef) continue; + + node.nodeData = nodeDef; + } + for(let nodeNum in this.graph._nodes) { const node = this.graph._nodes[nodeNum]; - const def = defs[node.type]; - // HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes, - // and additional work is needed to consider the primitive logic in the refresh logic. + // Allow primitive nodes to handle refresh + node.refreshComboInNode?.(defs); + if(!def) continue; From 9bfec2bdbf0b0d778087a9b32f79e57e2d15b913 Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:40:59 +0200 Subject: [PATCH 132/150] Fix quality loss due to low precision --- comfy/sd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 2f1b2e964..f186273ea 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -183,7 +183,7 @@ class VAE: steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode(2. * a.to(self.vae_dtype).to(self.device) - 1.).sample().float() + encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).sample().float() samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) @@ -202,7 +202,7 @@ class VAE: pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples) + 1.0) / 2.0, min=0.0, max=1.0).cpu().float() + pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) From d06cd2805d86d7a9ed7485b6a0c7e113cff27d8e Mon Sep 17 00:00:00 2001 From: MoonRide303 Date: Fri, 22 Sep 2023 23:03:22 +0200 Subject: [PATCH 133/150] Added support for Porter-Duff image compositing --- comfy_extras/nodes_compositing.py | 239 ++++++++++++++++++++++++++++++ nodes.py | 28 ++++ 2 files changed, 267 insertions(+) create mode 100644 comfy_extras/nodes_compositing.py diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py new file mode 100644 index 000000000..c4c58b64e --- /dev/null +++ b/comfy_extras/nodes_compositing.py @@ -0,0 +1,239 @@ +import numpy as np +import torch +import comfy.utils +from enum import Enum + + +class PorterDuffMode(Enum): + ADD = 0 + CLEAR = 1 + DARKEN = 2 + DST = 3 + DST_ATOP = 4 + DST_IN = 5 + DST_OUT = 6 + DST_OVER = 7 + LIGHTEN = 8 + MULTIPLY = 9 + OVERLAY = 10 + SCREEN = 11 + SRC = 12 + SRC_ATOP = 13 + SRC_IN = 14 + SRC_OUT = 15 + SRC_OVER = 16 + XOR = 17 + + +def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): + if mode == PorterDuffMode.ADD: + out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1) + out_image = torch.clamp(src_image + dst_image, 0, 1) + elif mode == PorterDuffMode.CLEAR: + out_alpha = torch.zeros_like(dst_alpha) + out_image = torch.zeros_like(dst_image) + elif mode == PorterDuffMode.DARKEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image) + elif mode == PorterDuffMode.DST: + out_alpha = dst_alpha + out_image = dst_image + elif mode == PorterDuffMode.DST_ATOP: + out_alpha = src_alpha + out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.DST_IN: + out_alpha = src_alpha * dst_alpha + out_image = dst_image * src_alpha + elif mode == PorterDuffMode.DST_OUT: + out_alpha = (1 - src_alpha) * dst_alpha + out_image = (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.DST_OVER: + out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha + out_image = dst_image + (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.LIGHTEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image) + elif mode == PorterDuffMode.MULTIPLY: + out_alpha = src_alpha * dst_alpha + out_image = src_image * dst_image + elif mode == PorterDuffMode.OVERLAY: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image, + src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image)) + elif mode == PorterDuffMode.SCREEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = src_image + dst_image - src_image * dst_image + elif mode == PorterDuffMode.SRC: + out_alpha = src_alpha + out_image = src_image + elif mode == PorterDuffMode.SRC_ATOP: + out_alpha = dst_alpha + out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.SRC_IN: + out_alpha = src_alpha * dst_alpha + out_image = src_image * dst_alpha + elif mode == PorterDuffMode.SRC_OUT: + out_alpha = (1 - dst_alpha) * src_alpha + out_image = (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.SRC_OVER: + out_alpha = src_alpha + (1 - src_alpha) * dst_alpha + out_image = src_image + (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.XOR: + out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + else: + out_alpha = None + out_image = None + return out_image, out_alpha + + +class PorterDuffImageComposite: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source": ("IMAGE",), + "source_alpha": ("ALPHA",), + "destination": ("IMAGE",), + "destination_alpha": ("ALPHA",), + "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), + }, + } + + RETURN_TYPES = ("IMAGE", "ALPHA") + FUNCTION = "composite" + CATEGORY = "compositing" + + def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): + batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) + out_images = [] + out_alphas = [] + + for i in range(batch_size): + src_image = source[i] + dst_image = destination[i] + + src_alpha = source_alpha[i].unsqueeze(2) + dst_alpha = destination_alpha[i].unsqueeze(2) + + if dst_alpha.shape != dst_image.shape: + upscale_input = dst_alpha[None,:,:,:].permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') + dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) + if src_image.shape != dst_image.shape: + upscale_input = src_image[None,:,:,:].permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') + src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0) + if src_alpha.shape != dst_alpha.shape: + upscale_input = src_alpha[None,:,:,:].permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center') + src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) + + out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode]) + + out_images.append(out_image) + out_alphas.append(out_alpha.squeeze(2)) + + result = (torch.stack(out_images), torch.stack(out_alphas)) + return result + + +class SplitImageWithAlpha: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + } + } + + CATEGORY = "compositing" + RETURN_TYPES = ("IMAGE", "ALPHA") + FUNCTION = "split_image_with_alpha" + + def split_image_with_alpha(self, image: torch.Tensor): + out_images = [i[:,:,:3] for i in image] + out_alphas = [i[:,:,3] for i in image] + result = (torch.stack(out_images), torch.stack(out_alphas)) + return result + + +class JoinImageWithAlpha: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "alpha": ("ALPHA",), + } + } + + CATEGORY = "compositing" + RETURN_TYPES = ("IMAGE",) + FUNCTION = "join_image_with_alpha" + + def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor): + batch_size = min(len(image), len(alpha)) + out_images = [] + + for i in range(batch_size): + out_images.append(torch.cat((image[i], alpha[i].unsqueeze(2)), dim=2)) + + result = (torch.stack(out_images),) + return result + + +class ConvertAlphaToImage: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "alpha": ("ALPHA",), + } + } + + CATEGORY = "compositing" + RETURN_TYPES = ("IMAGE",) + FUNCTION = "alpha_to_image" + + def alpha_to_image(self, alpha): + result = alpha.reshape((-1, 1, alpha.shape[-2], alpha.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + return (result,) + + +class ConvertImageToAlpha: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "channel": (["red", "green", "blue", "alpha"],), + } + } + + CATEGORY = "compositing" + RETURN_TYPES = ("ALPHA",) + FUNCTION = "image_to_alpha" + + def image_to_alpha(self, image, channel): + channels = ["red", "green", "blue", "alpha"] + alpha = image[0, :, :, channels.index(channel)] + return (alpha,) + + +NODE_CLASS_MAPPINGS = { + "PorterDuffImageComposite": PorterDuffImageComposite, + "SplitImageWithAlpha": SplitImageWithAlpha, + "JoinImageWithAlpha": JoinImageWithAlpha, + "ConvertAlphaToImage": ConvertAlphaToImage, + "ConvertImageToAlpha": ConvertImageToAlpha, +} + + +NODE_DISPLAY_NAME_MAPPINGS = { + "PorterDuffImageComposite": "Porter-Duff Image Composite", + "SplitImageWithAlpha": "Split Image with Alpha", + "JoinImageWithAlpha": "Join Image with Alpha", + "ConvertAlphaToImage": "Convert Alpha to Image", + "ConvertImageToAlpha": "Convert Image to Alpha", +} diff --git a/nodes.py b/nodes.py index 919aac89e..8be332f91 100644 --- a/nodes.py +++ b/nodes.py @@ -1372,6 +1372,31 @@ class LoadImage: return True +class LoadImageWithAlpha(LoadImage): + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + return {"required": + {"image": (sorted(files), {"image_upload": True})}, + } + + CATEGORY = "compositing" + + RETURN_TYPES = ("IMAGE", "ALPHA") + + FUNCTION = "load_image" + def load_image(self, image): + image_path = folder_paths.get_annotated_filepath(image) + i = Image.open(image_path) + i = ImageOps.exif_transpose(i) + image = i.convert("RGBA") + alpha = np.array(image.getchannel("A")).astype(np.float32) / 255.0 + alpha = torch.from_numpy(alpha)[None,] + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + return (image, alpha) + class LoadImageMask: _color_channels = ["alpha", "red", "green", "blue"] @classmethod @@ -1606,6 +1631,7 @@ NODE_CLASS_MAPPINGS = { "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, + "LoadImageWithAlpha": LoadImageWithAlpha, "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, "ImageScaleBy": ImageScaleBy, @@ -1702,6 +1728,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveImage": "Save Image", "PreviewImage": "Preview Image", "LoadImage": "Load Image", + "LoadImageWithAlpha": "Load Image with Alpha", "LoadImageMask": "Load Image (as Mask)", "ImageScale": "Upscale Image", "ImageScaleBy": "Upscale Image By", @@ -1788,6 +1815,7 @@ def init_custom_nodes(): "nodes_upscale_model.py", "nodes_post_processing.py", "nodes_mask.py", + "nodes_compositing.py", "nodes_rebatch.py", "nodes_model_merging.py", "nodes_tomesd.py", From ece69bf28c0d5872bdec1cc9e66db50f09eaa74b Mon Sep 17 00:00:00 2001 From: MoonRide303 Date: Sat, 23 Sep 2023 08:34:54 +0200 Subject: [PATCH 134/150] Change channel type to MASK (reduced redundancy, increased usability) --- comfy_extras/nodes_compositing.py | 52 +++---------------------------- comfy_extras/nodes_mask.py | 4 +-- nodes.py | 2 +- 3 files changed, 8 insertions(+), 50 deletions(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index c4c58b64e..6899e4a86 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -93,14 +93,14 @@ class PorterDuffImageComposite: return { "required": { "source": ("IMAGE",), - "source_alpha": ("ALPHA",), + "source_alpha": ("MASK",), "destination": ("IMAGE",), - "destination_alpha": ("ALPHA",), + "destination_alpha": ("MASK",), "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), }, } - RETURN_TYPES = ("IMAGE", "ALPHA") + RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "composite" CATEGORY = "compositing" @@ -148,7 +148,7 @@ class SplitImageWithAlpha: } CATEGORY = "compositing" - RETURN_TYPES = ("IMAGE", "ALPHA") + RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "split_image_with_alpha" def split_image_with_alpha(self, image: torch.Tensor): @@ -164,7 +164,7 @@ class JoinImageWithAlpha: return { "required": { "image": ("IMAGE",), - "alpha": ("ALPHA",), + "alpha": ("MASK",), } } @@ -183,50 +183,10 @@ class JoinImageWithAlpha: return result -class ConvertAlphaToImage: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "alpha": ("ALPHA",), - } - } - - CATEGORY = "compositing" - RETURN_TYPES = ("IMAGE",) - FUNCTION = "alpha_to_image" - - def alpha_to_image(self, alpha): - result = alpha.reshape((-1, 1, alpha.shape[-2], alpha.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return (result,) - - -class ConvertImageToAlpha: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "channel": (["red", "green", "blue", "alpha"],), - } - } - - CATEGORY = "compositing" - RETURN_TYPES = ("ALPHA",) - FUNCTION = "image_to_alpha" - - def image_to_alpha(self, image, channel): - channels = ["red", "green", "blue", "alpha"] - alpha = image[0, :, :, channels.index(channel)] - return (alpha,) - - NODE_CLASS_MAPPINGS = { "PorterDuffImageComposite": PorterDuffImageComposite, "SplitImageWithAlpha": SplitImageWithAlpha, "JoinImageWithAlpha": JoinImageWithAlpha, - "ConvertAlphaToImage": ConvertAlphaToImage, - "ConvertImageToAlpha": ConvertImageToAlpha, } @@ -234,6 +194,4 @@ NODE_DISPLAY_NAME_MAPPINGS = { "PorterDuffImageComposite": "Porter-Duff Image Composite", "SplitImageWithAlpha": "Split Image with Alpha", "JoinImageWithAlpha": "Join Image with Alpha", - "ConvertAlphaToImage": "Convert Alpha to Image", - "ConvertImageToAlpha": "Convert Image to Alpha", } diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index cdf762ffd..9b0b289c1 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -114,7 +114,7 @@ class ImageToMask: return { "required": { "image": ("IMAGE",), - "channel": (["red", "green", "blue"],), + "channel": (["red", "green", "blue", "alpha"],), } } @@ -124,7 +124,7 @@ class ImageToMask: FUNCTION = "image_to_mask" def image_to_mask(self, image, channel): - channels = ["red", "green", "blue"] + channels = ["red", "green", "blue", "alpha"] mask = image[:, :, :, channels.index(channel)] return (mask,) diff --git a/nodes.py b/nodes.py index 8be332f91..9f8e58d0f 100644 --- a/nodes.py +++ b/nodes.py @@ -1383,7 +1383,7 @@ class LoadImageWithAlpha(LoadImage): CATEGORY = "compositing" - RETURN_TYPES = ("IMAGE", "ALPHA") + RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): From 585fb0475bbaf919bd340c72d339752bbb93ef55 Mon Sep 17 00:00:00 2001 From: MoonRide303 Date: Sat, 23 Sep 2023 13:19:42 +0200 Subject: [PATCH 135/150] Adding default alpha when splitting RGB images --- comfy_extras/nodes_compositing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 6899e4a86..b0ae2dfa0 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -153,7 +153,7 @@ class SplitImageWithAlpha: def split_image_with_alpha(self, image: torch.Tensor): out_images = [i[:,:,:3] for i in image] - out_alphas = [i[:,:,3] for i in image] + out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] result = (torch.stack(out_images), torch.stack(out_alphas)) return result From 214ca7197ef753bce3b40f642c6775d919568c2f Mon Sep 17 00:00:00 2001 From: MoonRide303 Date: Sun, 24 Sep 2023 00:12:55 +0200 Subject: [PATCH 136/150] Corrected joining images with alpha (for RGBA input), and checking scaling conditions --- comfy_extras/nodes_compositing.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index b0ae2dfa0..f39daa009 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -113,19 +113,21 @@ class PorterDuffImageComposite: src_image = source[i] dst_image = destination[i] + assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels + src_alpha = source_alpha[i].unsqueeze(2) dst_alpha = destination_alpha[i].unsqueeze(2) - if dst_alpha.shape != dst_image.shape: - upscale_input = dst_alpha[None,:,:,:].permute(0, 3, 1, 2) + if dst_alpha.shape[:2] != dst_image.shape[:2]: + upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2) upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) if src_image.shape != dst_image.shape: - upscale_input = src_image[None,:,:,:].permute(0, 3, 1, 2) + upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2) upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0) if src_alpha.shape != dst_alpha.shape: - upscale_input = src_alpha[None,:,:,:].permute(0, 3, 1, 2) + upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2) upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center') src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) @@ -177,7 +179,7 @@ class JoinImageWithAlpha: out_images = [] for i in range(batch_size): - out_images.append(torch.cat((image[i], alpha[i].unsqueeze(2)), dim=2)) + out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) result = (torch.stack(out_images),) return result From 9212bea87c47af5a1d9b51d59a2cf17e9a00e73f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 4 Oct 2023 14:40:17 -0400 Subject: [PATCH 137/150] Change a few things in #1578. --- comfy_extras/nodes_compositing.py | 6 +++--- nodes.py | 27 --------------------------- 2 files changed, 3 insertions(+), 30 deletions(-) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index f39daa009..f8901eca1 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -102,7 +102,7 @@ class PorterDuffImageComposite: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "composite" - CATEGORY = "compositing" + CATEGORY = "mask/compositing" def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) @@ -149,7 +149,7 @@ class SplitImageWithAlpha: } } - CATEGORY = "compositing" + CATEGORY = "mask/compositing" RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "split_image_with_alpha" @@ -170,7 +170,7 @@ class JoinImageWithAlpha: } } - CATEGORY = "compositing" + CATEGORY = "mask/compositing" RETURN_TYPES = ("IMAGE",) FUNCTION = "join_image_with_alpha" diff --git a/nodes.py b/nodes.py index 9f8e58d0f..16bf07cca 100644 --- a/nodes.py +++ b/nodes.py @@ -1372,31 +1372,6 @@ class LoadImage: return True -class LoadImageWithAlpha(LoadImage): - @classmethod - def INPUT_TYPES(s): - input_dir = folder_paths.get_input_directory() - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] - return {"required": - {"image": (sorted(files), {"image_upload": True})}, - } - - CATEGORY = "compositing" - - RETURN_TYPES = ("IMAGE", "MASK") - - FUNCTION = "load_image" - def load_image(self, image): - image_path = folder_paths.get_annotated_filepath(image) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) - image = i.convert("RGBA") - alpha = np.array(image.getchannel("A")).astype(np.float32) / 255.0 - alpha = torch.from_numpy(alpha)[None,] - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - return (image, alpha) - class LoadImageMask: _color_channels = ["alpha", "red", "green", "blue"] @classmethod @@ -1631,7 +1606,6 @@ NODE_CLASS_MAPPINGS = { "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, - "LoadImageWithAlpha": LoadImageWithAlpha, "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, "ImageScaleBy": ImageScaleBy, @@ -1728,7 +1702,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveImage": "Save Image", "PreviewImage": "Preview Image", "LoadImage": "Load Image", - "LoadImageWithAlpha": "Load Image with Alpha", "LoadImageMask": "Load Image (as Mask)", "ImageScale": "Upscale Image", "ImageScaleBy": "Upscale Image By", From 0b9246d9fad06834e8418904eb189d57f65c8eb7 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 4 Oct 2023 20:48:55 +0100 Subject: [PATCH 138/150] allow connecting numbers merging config --- web/extensions/core/widgetInputs.js | 232 +++++++++++++++++++++++----- 1 file changed, 191 insertions(+), 41 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 98d52b02c..ccf437ed4 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -3,6 +3,7 @@ import { app } from "../../scripts/app.js"; const CONVERTED_TYPE = "converted-widget"; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; +const CONFIG = Symbol(); function getConfig(widgetName) { const { nodeData } = this.constructor; @@ -154,9 +155,6 @@ app.registerExtension({ input.widget.getConfig = getConfig.bind(this, input.widget.name); } - const config = input.widget.getConfig(); - if (config[1]?.forceInput) continue; - const w = this.widgets.find((w) => w.name === input.widget.name); if (w) { hideWidget(this, w); @@ -306,9 +304,17 @@ app.registerExtension({ this.#onFirstConnection(); // Populate widget values from config data - for (let i = 0; i < this.widgets_values.length; i++) { - this.widgets[i].value = this.widgets_values[i]; + if (this.widgets) { + for (let i = 0; i < this.widgets_values.length; i++) { + const w = this.widgets[i]; + if (w) { + w.value = this.widgets_values[i]; + } + } } + + // Merge values if required + this.#mergeWidgetConfig(); } } @@ -318,12 +324,18 @@ app.registerExtension({ return; } + const links = this.outputs[0].links; if (connected) { - if (this.outputs[0].links?.length && !this.widgets?.length) { + if (links?.length && !this.widgets?.length) { this.#onFirstConnection(); } - } else if (!this.outputs[0].links?.length) { - this.#onLastDisconnect(); + } else { + // We may have removed a link that caused the constraints to change + this.#mergeWidgetConfig(); + + if (!links?.length) { + this.#onLastDisconnect(); + } } } @@ -340,7 +352,7 @@ app.registerExtension({ } } - #onFirstConnection() { + #onFirstConnection(recreating) { // First connection can fire before the graph is ready on initial load so random things can be missing const linkId = this.outputs[0].links[0]; const link = this.graph.links[linkId]; @@ -366,10 +378,10 @@ app.registerExtension({ this.outputs[0].name = type; this.outputs[0].widget = widget; - this.#createWidget(widget.getConfig(), theirNode, widget.name); + this.#createWidget(widget[CONFIG] ?? widget.getConfig(), theirNode, widget.name, recreating); } - #createWidget(inputData, node, widgetName) { + #createWidget(inputData, node, widgetName, recreating) { let type = inputData[0]; if (type instanceof Array) { @@ -404,25 +416,70 @@ app.registerExtension({ return r; }; - // Grow our node if required - const sz = this.computeSize(); - if (this.size[0] < sz[0]) { - this.size[0] = sz[0]; - } - if (this.size[1] < sz[1]) { - this.size[1] = sz[1]; - } - - requestAnimationFrame(() => { - if (this.onResize) { - this.onResize(this.size); + if (!recreating) { + // Grow our node if required + const sz = this.computeSize(); + if (this.size[0] < sz[0]) { + this.size[0] = sz[0]; } - }); + if (this.size[1] < sz[1]) { + this.size[1] = sz[1]; + } + + requestAnimationFrame(() => { + if (this.onResize) { + this.onResize(this.size); + } + }); + } } - #isValidConnection(input) { + #recreateWidget() { + const values = this.widgets.map((w) => w.value); + this.#removeWidgets(); + this.#onFirstConnection(true); + for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; + } + + #mergeWidgetConfig() { + // Merge widget configs if the node has multiple outputs + const output = this.outputs[0]; + const links = output.links; + + const hasConfig = !!output.widget[CONFIG]; + if (hasConfig) { + delete output.widget[CONFIG]; + } + + if (links?.length < 2 && hasConfig) { + // Copy the widget options from the source + if (links.length) { + this.#recreateWidget(); + } + + return; + } + + const config1 = output.widget.getConfig(); + const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; + if (!isNumber) return; + + for (const linkId of links) { + const link = app.graph.links[linkId]; + if (!link) continue; // Can be null when removing a node + + const theirNode = app.graph.getNodeById(link.target_id); + const theirInput = theirNode.inputs[link.target_slot]; + + // Call is valid connection so it can merge the configs when validating + this.#isValidConnection(theirInput, hasConfig); + } + } + + #isValidConnection(input, forceUpdate) { // Only allow connections where the configs match - const config1 = this.outputs[0].widget.getConfig(); + const output = this.outputs[0]; + const config1 = output.widget[CONFIG] ?? output.widget.getConfig(); const config2 = input.widget.getConfig(); if (config1[0] instanceof Array) { @@ -430,34 +487,117 @@ app.registerExtension({ // but double checking doesn't hurt // New input isnt a combo - if (!(config2[0] instanceof Array)) return false; + if (!(config2[0] instanceof Array)) { + console.log(`connection rejected: tried to connect combo to ${config2[0]}`); + return false; + } // New imput combo has a different size - if (config1[0].length !== config2[0].length) return false; + if (config1[0].length !== config2[0].length) { + console.log(`connection rejected: combo lists dont match`); + return false; + } // New input combo has different elements - if (config1[0].find((v, i) => config2[0][i] !== v)) return false; + if (config1[0].find((v, i) => config2[0][i] !== v)) { + console.log(`connection rejected: combo lists dont match`); + return false; + } } else if (config1[0] !== config2[0]) { - // Configs dont match + // Types dont match + console.log(`connection rejected: types dont match`, config1[0], config2[0]); return false; } - for (const k in config1[1]) { - if (k !== "default" && k !== "forceInput") { - if (config1[1][k] !== config2[1][k]) { - return false; + const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]); + + let customConfig; + const getCustomConfig = () => { + if (!customConfig) { + if (typeof structuredClone === "undefined") { + customConfig = JSON.parse(JSON.stringify(config1[1] ?? {})); + } else { + customConfig = structuredClone(config1[1] ?? {}); } } + return customConfig; + }; + + const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; + for (const k of keys.values()) { + if (k !== "default" && k !== "forceInput" && k !== "defaultInput") { + let v1 = config1[1][k]; + let v2 = config2[1][k]; + + if (v1 === v2 || (!v1 && !v2)) continue; + + if (isNumber) { + if (k === "min") { + const theirMax = config2[1]["max"]; + if (theirMax != null && v1 > theirMax) { + console.log("Invalid connection, min > max"); + return false; + } + getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2); + continue; + } else if (k === "max") { + const theirMin = config2[1]["min"]; + if (theirMin != null && v1 < theirMin) { + console.log("Invalid connection, max < min"); + return false; + } + getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2); + continue; + } else if (k === "step") { + let step; + if (v1 == null) { + step = v2; + } else if (v2 == null) { + step = v1; + } else { + if (v1 < v2) { + const a = v2; + v2 = v1; + v1 = a; + } + if (v1 % v2) { + console.log("Steps not divisible", "current:", v1, "new:", v2); + return false; + } + + step = v1; + } + + getCustomConfig()[k] = step; + continue; + } + } + + console.log(`connection rejected: config ${k} values dont match`, v1, v2); + return false; + } + } + + if (customConfig || forceUpdate) { + if (customConfig) { + output.widget[CONFIG] = [config1[0], customConfig]; + } + + this.#recreateWidget(); + + const widget = this.widgets[0]; + // When deleting a node this can be null + if (widget) { + const min = widget.options.min; + const max = widget.options.max; + if (min != null && widget.value < min) widget.value = min; + if (max != null && widget.value > max) widget.value = max; + widget.callback(widget.value); + } } return true; } - #onLastDisconnect() { - // We cant remove + re-add the output here as if you drag a link over the same link - // it removes, then re-adds, causing it to break - this.outputs[0].type = "*"; - this.outputs[0].name = "connect to widget input"; - delete this.outputs[0].widget; - + #removeWidgets() { if (this.widgets) { // Allow widgets to cleanup for (const w of this.widgets) { @@ -468,6 +608,16 @@ app.registerExtension({ this.widgets.length = 0; } } + + #onLastDisconnect() { + // We cant remove + re-add the output here as if you drag a link over the same link + // it removes, then re-adds, causing it to break + this.outputs[0].type = "*"; + this.outputs[0].name = "connect to widget input"; + delete this.outputs[0].widget; + + this.#removeWidgets(); + } } LiteGraph.registerNodeType( From 0e763e880f5e838e7a1e3914444cae6790c48627 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 4 Oct 2023 15:54:34 -0400 Subject: [PATCH 139/150] JoinImageWithAlpha now works with any mask shape. --- comfy_extras/nodes_compositing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index f8901eca1..68bfce111 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -3,6 +3,8 @@ import torch import comfy.utils from enum import Enum +def resize_mask(mask, shape): + return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) class PorterDuffMode(Enum): ADD = 0 @@ -178,6 +180,7 @@ class JoinImageWithAlpha: batch_size = min(len(image), len(alpha)) out_images = [] + alpha = resize_mask(alpha, image.shape[1:]) for i in range(batch_size): out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) From 48242be50866f5d6d22d120743d5d39cd6a0c178 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 5 Oct 2023 08:25:15 -0400 Subject: [PATCH 140/150] Update readme for pytorch 2.1 --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 97677921a..559e99ffa 100644 --- a/README.md +++ b/README.md @@ -92,16 +92,16 @@ Put your VAE in: models/vae ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.4.2``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6``` -This is the command to install the nightly with ROCm 5.7 that supports the 7000 series and might have some performance improvements: +This is the command to install the nightly with ROCm 5.7 that might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7``` ### NVIDIA -Nvidia users should install torch and xformers using this command: +Nvidia users should install pytorch using this command: -```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 xformers``` +```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121``` #### Troubleshooting From 80932ddf406c7da0ab97855801c468cfafa50386 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 5 Oct 2023 17:13:13 +0100 Subject: [PATCH 141/150] updated messages --- web/extensions/core/widgetInputs.js | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index ccf437ed4..271b02db3 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -533,7 +533,7 @@ app.registerExtension({ if (k === "min") { const theirMax = config2[1]["max"]; if (theirMax != null && v1 > theirMax) { - console.log("Invalid connection, min > max"); + console.log("connection rejected: min > max", v1, theirMax); return false; } getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2); @@ -541,7 +541,7 @@ app.registerExtension({ } else if (k === "max") { const theirMin = config2[1]["min"]; if (theirMin != null && v1 < theirMin) { - console.log("Invalid connection, max < min"); + console.log("connection rejected: max < min", v1, theirMin); return false; } getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2); @@ -549,17 +549,20 @@ app.registerExtension({ } else if (k === "step") { let step; if (v1 == null) { + // No current step step = v2; } else if (v2 == null) { + // No new step step = v1; } else { if (v1 < v2) { + // Ensure v1 is larger for the mod const a = v2; v2 = v1; v1 = a; } if (v1 % v2) { - console.log("Steps not divisible", "current:", v1, "new:", v2); + console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2); return false; } From b9b178b8394122651118c7453518320604a3f1f1 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 5 Oct 2023 19:16:39 +0100 Subject: [PATCH 142/150] More cleanup of old type data Fix connecting combos of same type from different types of node --- web/extensions/core/widgetInputs.js | 34 +++++++++++++++++------------ 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 271b02db3..c734ffe27 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -61,11 +61,11 @@ function showWidget(widget) { function convertToInput(node, widget, config) { hideWidget(node, widget); - const { linkType } = getWidgetType(config, `${node.comfyClass}|${widget.name}`); + const { type } = getWidgetType(config); // Add input and store widget config for creating on primitive node const sz = node.size; - node.addInput(widget.name, linkType, { + node.addInput(widget.name, type, { widget: { name: widget.name, getConfig: () => config }, }); @@ -90,15 +90,13 @@ function convertToWidget(node, widget) { node.setSize([Math.max(sz[0], node.size[0]), Math.max(sz[1], node.size[1])]); } -function getWidgetType(config, comboType) { +function getWidgetType(config) { // Special handling for COMBO so we restrict links based on the entries let type = config[0]; - let linkType = type; if (type instanceof Array) { type = "COMBO"; - linkType = comboType; } - return { type, linkType }; + return { type }; } app.registerExtension({ @@ -148,13 +146,24 @@ app.registerExtension({ for (const input of this.inputs) { if (input.widget) { - // Cleanup old widget config - delete input.widget.config; - if (!input.widget.getConfig) { input.widget.getConfig = getConfig.bind(this, input.widget.name); } + // Cleanup old widget config + if (input.widget.config) { + if (input.widget.config[0] instanceof Array) { + // If we are an old converted combo then replace the input type and the stored link data + input.type = "COMBO"; + + const link = app.graph.links[input.link]; + if (link) { + link.type = input.type; + } + } + delete input.widget.config; + } + const w = this.widgets.find((w) => w.name === input.widget.name); if (w) { hideWidget(this, w); @@ -372,9 +381,9 @@ app.registerExtension({ widget = input.widget; } - const { type, linkType } = getWidgetType(widget.getConfig(), `${theirNode.comfyClass}|${widget.name}`); + const { type } = getWidgetType(widget.getConfig()); // Update our output to restrict to the widget type - this.outputs[0].type = linkType; + this.outputs[0].type = type; this.outputs[0].name = type; this.outputs[0].widget = widget; @@ -483,9 +492,6 @@ app.registerExtension({ const config2 = input.widget.getConfig(); if (config1[0] instanceof Array) { - // These checks shouldnt actually be necessary as the types should match - // but double checking doesn't hurt - // New input isnt a combo if (!(config2[0] instanceof Array)) { console.log(`connection rejected: tried to connect combo to ${config2[0]}`); From 6f464f801f718bf9c1274aa49252deb0b52fbd51 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 6 Oct 2023 03:32:00 -0400 Subject: [PATCH 143/150] Update nightly workflow to python 3.11.6 --- .github/workflows/windows_release_nightly_pytorch.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 319942e7c..b793f7fe2 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -20,12 +20,12 @@ jobs: persist-credentials: false - uses: actions/setup-python@v4 with: - python-version: '3.11.3' + python-version: '3.11.6' - shell: bash run: | cd .. cp -r ComfyUI ComfyUI_copy - curl https://www.python.org/ftp/python/3.11.3/python-3.11.3-embed-amd64.zip -o python_embeded.zip + curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip unzip python_embeded.zip -d python_embeded cd python_embeded echo 'import site' >> ./python311._pth From 34b36e3207522aa1a3e48a17e628c0aae3c4c5c9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 6 Oct 2023 10:26:51 -0400 Subject: [PATCH 144/150] More configurable workflows to package windows release. --- .../windows_release_dependencies.yml | 53 ++++++++++ .github/workflows/windows_release_package.yml | 96 +++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 .github/workflows/windows_release_dependencies.yml create mode 100644 .github/workflows/windows_release_package.yml diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml new file mode 100644 index 000000000..590495c65 --- /dev/null +++ b/.github/workflows/windows_release_dependencies.yml @@ -0,0 +1,53 @@ +name: "Windows Release dependencies" + +on: + workflow_dispatch: + inputs: + xformers: + description: 'xformers version' + required: true + type: string + default: "" + cu: + description: 'cuda version' + required: true + type: string + default: "121" + + python_minor: + description: 'python minor version' + required: true + type: string + default: "11" + + python_patch: + description: 'python patch version' + required: true + type: string + default: "6" +# push: +# branches: +# - master + +jobs: + build_dependencies: + runs-on: windows-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }} + + - shell: bash + run: | + python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip install --no-cache-dir ./temp_wheel_dir/* + echo installed basic + ls -lah temp_wheel_dir + mv temp_wheel_dir cu${{ inputs.cu }}_python_deps + tar cf cu${{ inputs.cu }}_python_deps.tar cu${{ inputs.cu }}_python_deps + + - uses: actions/cache/save@v3 + with: + path: cu${{ inputs.cu }}_python_deps.tar + key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }} diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml new file mode 100644 index 000000000..bc26db282 --- /dev/null +++ b/.github/workflows/windows_release_package.yml @@ -0,0 +1,96 @@ +name: "Windows Release packaging" + +on: + workflow_dispatch: + cu: + description: 'cuda version' + required: true + type: string + default: "121" + + python_minor: + description: 'python minor version' + required: true + type: string + default: "11" + + python_patch: + description: 'python patch version' + required: true + type: string + default: "6" +# push: +# branches: +# - master + +jobs: + package_comfyui: + permissions: + contents: "write" + packages: "write" + pull-requests: "read" + runs-on: windows-latest + steps: + - uses: actions/cache/restore@v3 + id: cache + with: + path: cu${{ inputs.cu }}_python_deps.tar + key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }} + - shell: bash + run: | + mv cu${{ inputs.cu }}_python_deps.tar ../ + cd .. + tar xf cu${{ inputs.cu }}_python_deps.tar + pwd + ls + + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + persist-credentials: false + - shell: bash + run: | + cd .. + cp -r ComfyUI ComfyUI_copy + curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip + unzip python_embeded.zip -d python_embeded + cd python_embeded + echo 'import site' >> ./python3${{ inputs.python_minor }}._pth + curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py + ./python.exe get-pip.py + ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + cd .. + + git clone https://github.com/comfyanonymous/taesd + cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ + + mkdir ComfyUI_windows_portable + mv python_embeded ComfyUI_windows_portable + mv ComfyUI_copy ComfyUI_windows_portable/ComfyUI + + cd ComfyUI_windows_portable + + mkdir update + cp -r ComfyUI/.ci/update_windows/* ./update/ + cp -r ComfyUI/.ci/update_windows_cu${{ inputs.cu }}/* ./update/ + cp -r ComfyUI/.ci/windows_base_files/* ./ + + cd .. + + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z + + cd ComfyUI_windows_portable + python_embeded/python.exe -s ComfyUI/main.py --quick-test-for-ci --cpu + + ls + + - name: Upload binaries to release + uses: svenstaro/upload-release-action@v2 + with: + repo_token: ${{ secrets.GITHUB_TOKEN }} + file: new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z + tag: "latest" + overwrite: true + From 640d5080e53cc687384fdfa807ca0c29a16e6687 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 6 Oct 2023 10:29:52 -0400 Subject: [PATCH 145/150] Make xformers optional in packaging. --- .github/workflows/windows_release_dependencies.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index 590495c65..104639a05 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -5,7 +5,7 @@ on: inputs: xformers: description: 'xformers version' - required: true + required: false type: string default: "" cu: From 1497528de8fbacd400921a1c0a307356aea94abf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 6 Oct 2023 10:43:12 -0400 Subject: [PATCH 146/150] Fix workflow. --- .github/workflows/windows_release_package.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index bc26db282..a4f36a706 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -2,6 +2,7 @@ name: "Windows Release packaging" on: workflow_dispatch: + inputs: cu: description: 'cuda version' required: true From d761eaa4864e21d9302c6e58eb36daa20cecee6a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 6 Oct 2023 17:47:46 +0100 Subject: [PATCH 147/150] if the output type is an array, use combo --- web/scripts/app.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 3c29a684a..5b9e76580 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1306,7 +1306,8 @@ export class ComfyApp { } for (const o in nodeData["output"]) { - const output = nodeData["output"][o]; + let output = nodeData["output"][o]; + if(output instanceof Array) output = "COMBO"; const outputName = nodeData["output_name"][o] || output; const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; this.addOutput(outputName, output, { shape: outputShape }); From 0134d7ab49702b71af37451c647fedb8814704ac Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 6 Oct 2023 12:49:40 -0400 Subject: [PATCH 148/150] Generate update script with right settings. --- .../workflows/windows_release_dependencies.yml | 16 +++++++++++++++- .github/workflows/windows_release_package.yml | 7 +++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index 104639a05..f2ac94074 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -40,6 +40,18 @@ jobs: - shell: bash run: | + echo "@echo off + ..\python_embeded\python.exe .\update.py ..\ComfyUI\ + echo + echo This will try to update pytorch and all python dependencies, if you get an error wait for pytorch/xformers to fix their stuff + echo You should not be running this anyways unless you really have to + echo + echo If you just want to update normally, close this and run update_comfyui.bat instead. + echo + pause + ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 + pause" > update_comfyui_and_python_dependencies.bat + python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip install --no-cache-dir ./temp_wheel_dir/* echo installed basic @@ -49,5 +61,7 @@ jobs: - uses: actions/cache/save@v3 with: - path: cu${{ inputs.cu }}_python_deps.tar + path: | + cu${{ inputs.cu }}_python_deps.tar + update_comfyui_and_python_dependencies.bat key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }} diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index a4f36a706..87d37c24d 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -35,11 +35,14 @@ jobs: - uses: actions/cache/restore@v3 id: cache with: - path: cu${{ inputs.cu }}_python_deps.tar + path: | + cu${{ inputs.cu }}_python_deps.tar + update_comfyui_and_python_dependencies.bat key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }} - shell: bash run: | mv cu${{ inputs.cu }}_python_deps.tar ../ + mv update_comfyui_and_python_dependencies.bat ../ cd .. tar xf cu${{ inputs.cu }}_python_deps.tar pwd @@ -74,8 +77,8 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ - cp -r ComfyUI/.ci/update_windows_cu${{ inputs.cu }}/* ./update/ cp -r ComfyUI/.ci/windows_base_files/* ./ + cp ../update_comfyui_and_python_dependencies.bat ./update/ cd .. From 72188dffc3d331be41e366c4f0fa6883645f669a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 6 Oct 2023 13:48:18 -0400 Subject: [PATCH 149/150] load_checkpoint_guess_config can now optionally output the model. --- comfy/sd.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index f186273ea..cfd6fb3cb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -394,13 +394,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): sd = comfy.utils.load_torch_file(ckpt_path) sd_keys = sd.keys() clip = None clipvision = None vae = None model = None + model_patcher = None clip_target = None parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") @@ -421,10 +422,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if fp16: dtype = torch.float16 - inital_load_device = model_management.unet_inital_load_device(parameters, dtype) - offload_device = model_management.unet_offload_device() - model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) - model.load_model_weights(sd, "model.diffusion_model.") + if output_model: + inital_load_device = model_management.unet_inital_load_device(parameters, dtype) + offload_device = model_management.unet_offload_device() + model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device) + model.load_model_weights(sd, "model.diffusion_model.") if output_vae: vae = VAE() @@ -444,10 +446,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if len(left_over) > 0: print("left over keys:", left_over) - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) - if inital_load_device != torch.device("cpu"): - print("loaded straight to GPU") - model_management.load_model_gpu(model_patcher) + if output_model: + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + if inital_load_device != torch.device("cpu"): + print("loaded straight to GPU") + model_management.load_model_gpu(model_patcher) return (model_patcher, clip, vae, clipvision) From ae3e4e9ad821c12b955f6b2343e6255e6d71eaf7 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 6 Oct 2023 21:48:30 +0100 Subject: [PATCH 150/150] access getConfig via a symbol so structuredClone works (#1677) --- web/extensions/core/widgetInputs.js | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index c734ffe27..3c9da458d 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -4,6 +4,7 @@ import { app } from "../../scripts/app.js"; const CONVERTED_TYPE = "converted-widget"; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; const CONFIG = Symbol(); +const GET_CONFIG = Symbol(); function getConfig(widgetName) { const { nodeData } = this.constructor; @@ -66,7 +67,7 @@ function convertToInput(node, widget, config) { // Add input and store widget config for creating on primitive node const sz = node.size; node.addInput(widget.name, type, { - widget: { name: widget.name, getConfig: () => config }, + widget: { name: widget.name, [GET_CONFIG]: () => config }, }); for (const widget of node.widgets) { @@ -146,8 +147,8 @@ app.registerExtension({ for (const input of this.inputs) { if (input.widget) { - if (!input.widget.getConfig) { - input.widget.getConfig = getConfig.bind(this, input.widget.name); + if (!input.widget[GET_CONFIG]) { + input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name); } // Cleanup old widget config @@ -197,8 +198,8 @@ app.registerExtension({ if (!app.configuringGraph && this.inputs) { // On copy + paste of nodes, ensure that widget configs are set up for (const input of this.inputs) { - if (input.widget && !input.widget.getConfig) { - input.widget.getConfig = getConfig.bind(this, input.widget.name); + if (input.widget && !input.widget[GET_CONFIG]) { + input.widget[GET_CONFIG] = () => getConfig.call(this, input.widget.name); } } } @@ -224,7 +225,7 @@ app.registerExtension({ const input = this.inputs[slot]; if (!input.widget || !input[ignoreDblClick]) { // Not a widget input or already handled input - if (!(input.type in ComfyWidgets) && !(input.widget.getConfig?.()?.[0] instanceof Array)) { + if (!(input.type in ComfyWidgets) && !(input.widget[GET_CONFIG]?.()?.[0] instanceof Array)) { return r; //also Not a ComfyWidgets input or combo (do nothing) } } @@ -299,7 +300,7 @@ app.registerExtension({ refreshComboInNode() { const widget = this.widgets?.[0]; if (widget?.type === "combo") { - widget.options.values = this.outputs[0].widget.getConfig()[0]; + widget.options.values = this.outputs[0].widget[GET_CONFIG]()[0]; if (!widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; @@ -376,18 +377,18 @@ app.registerExtension({ let widget; if (!input.widget) { if (!(input.type in ComfyWidgets)) return; - widget = { name: input.name, getConfig: () => [input.type, {}] }; //fake widget + widget = { name: input.name, [GET_CONFIG]: () => [input.type, {}] }; //fake widget } else { widget = input.widget; } - const { type } = getWidgetType(widget.getConfig()); + const { type } = getWidgetType(widget[GET_CONFIG]()); // Update our output to restrict to the widget type this.outputs[0].type = type; this.outputs[0].name = type; this.outputs[0].widget = widget; - this.#createWidget(widget[CONFIG] ?? widget.getConfig(), theirNode, widget.name, recreating); + this.#createWidget(widget[CONFIG] ?? widget[GET_CONFIG](), theirNode, widget.name, recreating); } #createWidget(inputData, node, widgetName, recreating) { @@ -469,7 +470,7 @@ app.registerExtension({ return; } - const config1 = output.widget.getConfig(); + const config1 = output.widget[GET_CONFIG](); const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; if (!isNumber) return; @@ -488,8 +489,8 @@ app.registerExtension({ #isValidConnection(input, forceUpdate) { // Only allow connections where the configs match const output = this.outputs[0]; - const config1 = output.widget[CONFIG] ?? output.widget.getConfig(); - const config2 = input.widget.getConfig(); + const config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG](); + const config2 = input.widget[GET_CONFIG](); if (config1[0] instanceof Array) { // New input isnt a combo