From 7e51bbd07f809555cc50c4fdae3ef84720e5c86f Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 4 May 2023 19:42:07 +0100 Subject: [PATCH 01/85] automatic calculation of image pos from widgets --- web/scripts/app.js | 39 ++++++++++++++++++++++++++++++--------- web/scripts/widgets.js | 9 +-------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index ada1708dc..f0c0f9de4 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -263,6 +263,34 @@ export class ComfyApp { */ #addDrawBackgroundHandler(node) { const app = this; + + function getImageTop(node) { + let shiftY; + if (node.imageOffset != null) { + shiftY = node.imageOffset; + } else { + if (node.widgets?.length) { + const w = node.widgets[node.widgets.length - 1]; + shiftY = w.last_y; + if (w.computeSize) { + shiftY += w.computeSize()[1] + 4; + } else { + shiftY += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } else { + shiftY = node.computeSize()[1]; + } + } + return shiftY; + } + + node.prototype.setSizeForImage = function () { + const minHeight = getImageTop(this) + 220; + if (this.size[1] < minHeight) { + this.setSize([this.size[0], minHeight]); + } + }; + node.prototype.onDrawBackground = function (ctx) { if (!this.flags.collapsed) { const output = app.nodeOutputs[this.id + ""]; @@ -283,9 +311,7 @@ export class ComfyApp { ).then((imgs) => { if (this.images === output.images) { this.imgs = imgs.filter(Boolean); - if (this.size[1] < 100) { - this.size[1] = 250; - } + this.setSizeForImage?.(); app.graph.setDirtyCanvas(true); } }); @@ -310,12 +336,7 @@ export class ComfyApp { this.imageIndex = imageIndex = 0; } - let shiftY; - if (this.imageOffset != null) { - shiftY = this.imageOffset; - } else { - shiftY = this.computeSize()[1]; - } + const shiftY = getImageTop(this); let dw = this.size[0]; let dh = this.size[1]; diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index c0e73ffa1..cd471bc93 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -261,20 +261,13 @@ export const ComfyWidgets = { let uploadWidget; function showImage(name) { - // Position the image somewhere sensible - if (!node.imageOffset) { - node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75; - } - const img = new Image(); img.onload = () => { node.imgs = [img]; app.graph.setDirtyCanvas(true); }; img.src = `/view?filename=${name}&type=input`; - if ((node.size[1] - node.imageOffset) < 100) { - node.size[1] = 250 + node.imageOffset; - } + node.setSizeForImage?.(); } // Add our own callback to the combo widget to render an image when it changes From bae4fb4a9dc944c10cca922dc4442eef57bbf583 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 May 2023 18:07:41 -0400 Subject: [PATCH 02/85] Fix imports. --- comfy/cldm/cldm.py | 10 +++---- comfy/gligen.py | 2 +- comfy/ldm/models/autoencoder.py | 8 +++--- comfy/ldm/models/diffusion/ddim.py | 2 +- comfy/ldm/models/diffusion/ddpm.py | 12 ++++----- comfy/ldm/modules/attention.py | 4 +-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- .../modules/diffusionmodules/openaimodel.py | 6 ++--- .../ldm/modules/diffusionmodules/upscaling.py | 4 +-- comfy/ldm/modules/diffusionmodules/util.py | 2 +- .../ldm/modules/encoders/noise_aug_modules.py | 4 +-- comfy/model_management.py | 2 +- comfy/sd.py | 26 +++++++++---------- 13 files changed, 42 insertions(+), 42 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index c60abf80b..cb660ee77 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -5,17 +5,17 @@ import torch import torch as th import torch.nn as nn -from ldm.modules.diffusionmodules.util import ( +from ..ldm.modules.diffusionmodules.util import ( conv_nd, linear, zero_module, timestep_embedding, ) -from ldm.modules.attention import SpatialTransformer -from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock -from ldm.models.diffusion.ddpm import LatentDiffusion -from ldm.util import log_txt_as_img, exists, instantiate_from_config +from ..ldm.modules.attention import SpatialTransformer +from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock +from ..ldm.models.diffusion.ddpm import LatentDiffusion +from ..ldm.util import log_txt_as_img, exists, instantiate_from_config class ControlledUnetModel(UNetModel): diff --git a/comfy/gligen.py b/comfy/gligen.py index 8770383e5..45b674503 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -1,6 +1,6 @@ import torch from torch import nn, einsum -from ldm.modules.attention import CrossAttention +from .ldm.modules.attention import CrossAttention from inspect import isfunction diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index bd698621c..1fb7ed879 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -3,11 +3,11 @@ import torch import torch.nn.functional as F from contextlib import contextmanager -from ldm.modules.diffusionmodules.model import Encoder, Decoder -from ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder +from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution -from ldm.util import instantiate_from_config -from ldm.modules.ema import LitEma +from comfy.ldm.util import instantiate_from_config +from comfy.ldm.modules.ema import LitEma # class AutoencoderKL(pl.LightningModule): class AutoencoderKL(torch.nn.Module): diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index deab76f21..c279f2c18 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -4,7 +4,7 @@ import torch import numpy as np from tqdm import tqdm -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor +from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor class DDIMSampler(object): diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index d3f0eb2b2..0f484a7f1 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -19,12 +19,12 @@ from tqdm import tqdm from torchvision.utils import make_grid # from pytorch_lightning.utilities.distributed import rank_zero_only -from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from ldm.models.diffusion.ddim import DDIMSampler +from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from comfy.ldm.modules.ema import LitEma +from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ..autoencoder import IdentityFirstStage, AutoencoderKL +from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from .ddim import DDIMSampler __conditioning_keys__ = {'concat': 'c_concat', diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ce7180d91..5eabecd65 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -6,7 +6,7 @@ from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any -from ldm.modules.diffusionmodules.util import checkpoint +from .diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management @@ -21,7 +21,7 @@ if model_management.xformers_enabled(): import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") -from cli_args import args +from comfy.cli_args import args def exists(val): return val is not None diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 1599d386e..5e4d2b60f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -6,7 +6,7 @@ import numpy as np from einops import rearrange from typing import Optional, Any -from ldm.modules.attention import MemoryEfficientCrossAttention +from ..attention import MemoryEfficientCrossAttention from comfy import model_management if model_management.xformers_enabled_vae(): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 25309dbd7..4352b756d 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -6,7 +6,7 @@ import torch as th import torch.nn as nn import torch.nn.functional as F -from ldm.modules.diffusionmodules.util import ( +from .util import ( checkpoint, conv_nd, linear, @@ -15,8 +15,8 @@ from ldm.modules.diffusionmodules.util import ( normalization, timestep_embedding, ) -from ldm.modules.attention import SpatialTransformer -from ldm.util import exists +from ..attention import SpatialTransformer +from comfy.ldm.util import exists # dummy replace diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index 038166620..709a7f52e 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -3,8 +3,8 @@ import torch.nn as nn import numpy as np from functools import partial -from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule -from ldm.util import default +from .util import extract_into_tensor, make_beta_schedule +from comfy.ldm.util import default class AbstractLowScaleModel(nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index daf35da7b..82ea3f0a6 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -15,7 +15,7 @@ import torch.nn as nn import numpy as np from einops import repeat -from ldm.util import instantiate_from_config +from comfy.ldm.util import instantiate_from_config def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index f99e7920a..b59bf204b 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -1,5 +1,5 @@ -from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation -from ldm.modules.diffusionmodules.openaimodel import Timestep +from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +from ..diffusionmodules.openaimodel import Timestep import torch class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): diff --git a/comfy/model_management.py b/comfy/model_management.py index db5d368e1..e89f80d69 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,6 @@ import psutil from enum import Enum -from cli_args import args +from .cli_args import args class VRAMState(Enum): CPU = 0 diff --git a/comfy/sd.py b/comfy/sd.py index 174ed35e5..3543bdb77 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -2,8 +2,8 @@ import torch import contextlib import copy -import sd1_clip -import sd2_clip +from . import sd1_clip +from . import sd2_clip from comfy import model_management from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL @@ -446,10 +446,10 @@ class CLIP: else: params = {} - if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder": + if self.target_clip.endswith("FrozenOpenCLIPEmbedder"): clip = sd2_clip.SD2ClipModel tokenizer = sd2_clip.SD2Tokenizer - elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder": + elif self.target_clip.endswith("FrozenCLIPEmbedder"): clip = sd1_clip.SD1ClipModel tokenizer = sd1_clip.SD1Tokenizer @@ -896,9 +896,9 @@ def load_clip(ckpt_path, embedding_directory=None): clip_data = utils.load_torch_file(ckpt_path) config = {} if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: - config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' + config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' else: - config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' + config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder' clip = CLIP(config=config, embedding_directory=embedding_directory) clip.load_from_state_dict(clip_data) return clip @@ -974,9 +974,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clip: clip_config = {} if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys: - clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' + clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' else: - clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' + clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder' clip = CLIP(config=clip_config, embedding_directory=embedding_directory) w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] @@ -997,7 +997,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0] noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2" params["noise_schedule_config"] = noise_schedule_config - noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" + noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" if size == 1280: #h params["timestep_dim"] = 1024 elif size == 1024: #l @@ -1049,19 +1049,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] - sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} - model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} + sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} + model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} if noise_aug_config is not None: #SD2.x unclip model sd_config["noise_aug_config"] = noise_aug_config sd_config["image_size"] = 96 sd_config["embedding_dropout"] = 0.25 sd_config["conditioning_key"] = 'crossattn-adm' - model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" + model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" elif unet_config["in_channels"] > 4: #inpainting model sd_config["conditioning_key"] = "hybrid" sd_config["finetune_keys"] = None - model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion" else: sd_config["conditioning_key"] = "crossattn" From 1a31020081b22cb55e573f65a11bd4c2c96f17f1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 00:16:57 -0400 Subject: [PATCH 03/85] Support softsign hypernetwork. --- comfy_extras/nodes_hypernetwork.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index 0c7250e43..c19b5e4c7 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -18,6 +18,7 @@ def load_hypernetwork_patch(path, strength): "swish": torch.nn.Hardswish, "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, + "softsign": torch.nn.Softsign, } if activation_func not in valid_activation: From 6ee11d7bc00bdbc109e3b84231aa74fc1799d543 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 00:19:35 -0400 Subject: [PATCH 04/85] Fix import. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index e89f80d69..3aea7ea8e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,6 @@ import psutil from enum import Enum -from .cli_args import args +from comfy.cli_args import args class VRAMState(Enum): CPU = 0 From af9cc1fb6a88e604700d3f57638ab23b9f607e9e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 01:28:48 -0400 Subject: [PATCH 05/85] Search recursively in subfolders for embeddings. --- comfy/sd1_clip.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 7f1217c3d..b1a392736 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -191,11 +191,20 @@ def safe_load_embed_zip(embed_path): del embed return out +def expand_directory_list(directories): + dirs = set() + for x in directories: + dirs.add(x) + for root, subdir, file in os.walk(x, followlinks=True): + dirs.add(root) + return list(dirs) def load_embed(embedding_name, embedding_directory): if isinstance(embedding_directory, str): embedding_directory = [embedding_directory] + embedding_directory = expand_directory_list(embedding_directory) + valid_file = None for embed_dir in embedding_directory: embed_path = os.path.join(embed_dir, embedding_name) From f31e31ee0a3d7da01f2b1f3b68047445c16e494a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 5 May 2023 10:12:06 +0100 Subject: [PATCH 06/85] Fix box shape Match card to litegraph selection --- web/scripts/app.js | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index ada1708dc..68eeb6329 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -703,7 +703,7 @@ export class ComfyApp { ctx.globalAlpha = 0.8; ctx.beginPath(); if (shape == LiteGraph.BOX_SHAPE) - ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT); + ctx.rect(-6, -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT); else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed)) ctx.roundRect( -6, @@ -715,12 +715,11 @@ export class ComfyApp { else if (shape == LiteGraph.CARD_SHAPE) ctx.roundRect( -6, - -6 + LiteGraph.NODE_TITLE_HEIGHT, + -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, - this.round_radius * 2, - 2 - ); + [this.round_radius * 2,2,this.round_radius * 2,2] + ); else if (shape == LiteGraph.CIRCLE_SHAPE) ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2); ctx.strokeStyle = color; From de4623a8a4b8282f2d29d5a3ecbcb9840c3dc7ac Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 5 May 2023 10:34:09 +0100 Subject: [PATCH 07/85] actually fix card --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 68eeb6329..98c0e0799 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -718,7 +718,7 @@ export class ComfyApp { -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, - [this.round_radius * 2,2,this.round_radius * 2,2] + [this.round_radius * 2, this.round_radius * 2, 2, 2] ); else if (shape == LiteGraph.CIRCLE_SHAPE) ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2); From cb1551b819ecaa7d9044c13d0c8e8cfa4ff72830 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 18:01:21 -0400 Subject: [PATCH 08/85] Lowvram mode for gligen and fix some lowvram issues. --- comfy/gligen.py | 27 +++++++++++++++---- comfy/ldm/modules/attention.py | 3 --- .../modules/diffusionmodules/openaimodel.py | 19 ++++++++++--- comfy/model_management.py | 3 +++ 4 files changed, 41 insertions(+), 11 deletions(-) diff --git a/comfy/gligen.py b/comfy/gligen.py index 45b674503..8c7cb432e 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -242,14 +242,28 @@ class Gligen(nn.Module): self.position_net = position_net self.key_dim = key_dim self.max_objs = 30 + self.lowvram = False def _set_position(self, boxes, masks, positive_embeddings): + if self.lowvram == True: + self.position_net.to(boxes.device) + objs = self.position_net(boxes, masks, positive_embeddings) - def func(key, x): - module = self.module_list[key] - return module(x, objs) - return func + if self.lowvram == True: + self.position_net.cpu() + def func_lowvram(key, x): + module = self.module_list[key] + module.to(x.device) + r = module(x, objs) + module.cpu() + return r + return func_lowvram + else: + def func(key, x): + module = self.module_list[key] + return module(x, objs) + return func def set_position(self, latent_image_shape, position_params, device): batch, c, h, w = latent_image_shape @@ -294,8 +308,11 @@ class Gligen(nn.Module): masks.to(device), conds.to(device)) + def set_lowvram(self, value=True): + self.lowvram = value + def cleanup(self): - pass + self.lowvram = False def get_models(self): return [self] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 5eabecd65..573f4e1c6 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -572,9 +572,6 @@ class BasicTransformerBlock(nn.Module): x += n x = self.ff(self.norm3(x)) + x - - if current_index is not None: - transformer_options["current_index"] += 1 return x diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4352b756d..5aef23f33 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -88,6 +88,19 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): x = layer(x) return x +#This is needed because accelerate makes a copy of transformer_options which breaks "current_index" +def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): + for layer in ts: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context, transformer_options) + transformer_options["current_index"] += 1 + elif isinstance(layer, Upsample): + x = layer(x, output_shape=output_shape) + else: + x = layer(x) + return x class Upsample(nn.Module): """ @@ -805,13 +818,13 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): - h = module(h, emb, context, transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options) if control is not None and 'input' in control and len(control['input']) > 0: ctrl = control['input'].pop() if ctrl is not None: h += ctrl hs.append(h) - h = self.middle_block(h, emb, context, transformer_options) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) if control is not None and 'middle' in control and len(control['middle']) > 0: h += control['middle'].pop() @@ -828,7 +841,7 @@ class UNetModel(nn.Module): output_shape = hs[-1].shape else: output_shape = None - h = module(h, emb, context, transformer_options, output_shape) + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3aea7ea8e..7070912df 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -201,6 +201,9 @@ def load_controlnet_gpu(control_models): return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: + for m in control_models: + if hasattr(m, 'set_lowvram'): + m.set_lowvram(True) #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return From 7a9268185cb6456890f6fe61bcc380b5cb21f614 Mon Sep 17 00:00:00 2001 From: WAS Date: Fri, 5 May 2023 18:06:54 -0700 Subject: [PATCH 09/85] Update README.md Add quick search explanation --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 3b3824714..bfa8904df 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git | Q | Toggle visibility of the queue | | H | Toggle visibility of history | | R | Refresh graph | +| Double-Click LMB | Open node quick search palette | Ctrl can also be replaced with Cmd instead for MacOS users From 8e03c789a25470a88aa05bcc73b1fe226334926b Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 6 May 2023 16:59:40 -0400 Subject: [PATCH 10/85] auto-launch cli arg --- comfy/cli_args.py | 4 ++++ main.py | 13 +++---------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 764427165..cc4709f70 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -7,6 +7,7 @@ parser.add_argument("--port", type=int, default=8188, help="Set the listen port. parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") @@ -30,3 +31,6 @@ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") args = parser.parse_args() + +if args.windows_standalone_build: + args.auto_launch = True diff --git a/main.py b/main.py index f369b82f3..eb97a2fb8 100644 --- a/main.py +++ b/main.py @@ -91,23 +91,16 @@ if __name__ == "__main__": threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() - address = args.listen - - dont_print = args.dont_print_server - - if args.output_directory: output_dir = os.path.abspath(args.output_directory) print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) - port = args.port - if args.quick_test_for_ci: exit(0) call_on_start = None - if args.windows_standalone_build: + if args.auto_launch: def startup_server(address, port): import webbrowser webbrowser.open("http://{}:{}".format(address, port)) @@ -115,10 +108,10 @@ if __name__ == "__main__": if os.name == "nt": try: - loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start)) + loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) except KeyboardInterrupt: pass else: - loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start)) + loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) cleanup_temp() From 678f933d382641933920e84414fe36f89d1da5a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 May 2023 19:00:49 -0400 Subject: [PATCH 11/85] maximum_batch_area for xformers. Remove useless code. --- comfy/model_management.py | 7 ++++++- nodes.py | 4 +--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7070912df..b0640d674 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -312,7 +312,12 @@ def maximum_batch_area(): return 0 memory_free = get_free_memory() / (1024 * 1024) - area = ((memory_free - 1024) * 0.9) / (0.6) + if xformers_enabled(): + #TODO: this needs to be tweaked + area = 50 * memory_free + else: + #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future + area = ((memory_free - 1024) * 0.9) / (0.6) return int(max(area, 0)) def cpu_mode(): diff --git a/nodes.py b/nodes.py index c2bc36855..ca0769ba7 100644 --- a/nodes.py +++ b/nodes.py @@ -105,15 +105,13 @@ class ConditioningSetArea: CATEGORY = "conditioning" - def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, width, height, x, y, strength): c = [] for t in conditioning: n = [t[0], t[1].copy()] n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) n[1]['strength'] = strength n[1]['set_area_to_bounds'] = False - n[1]['min_sigma'] = min_sigma - n[1]['max_sigma'] = max_sigma c.append(n) return (c, ) From 6fc4917634d457c07eb8b676da4fa88e0ef4704b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 May 2023 19:58:54 -0400 Subject: [PATCH 12/85] Make maximum_batch_area take into account python2.0 attention function. More conservative xformers maximum_batch_area. --- comfy/model_management.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b0640d674..39df8d9a7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -275,8 +275,17 @@ def xformers_enabled_vae(): return XFORMERS_ENABLED_VAE def pytorch_attention_enabled(): + global ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION +def pytorch_attention_flash_attention(): + global ENABLE_PYTORCH_ATTENTION + if ENABLE_PYTORCH_ATTENTION: + #TODO: more reliable way of checking for flash attention? + if torch.version.cuda: #pytorch flash attention only works on Nvidia + return True + return False + def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled @@ -312,9 +321,9 @@ def maximum_batch_area(): return 0 memory_free = get_free_memory() / (1024 * 1024) - if xformers_enabled(): + if xformers_enabled() or pytorch_attention_flash_attention(): #TODO: this needs to be tweaked - area = 50 * memory_free + area = 20 * memory_free else: #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future area = ((memory_free - 1024) * 0.9) / (0.6) From ae08fdb9990956f671d658aaf72a1eaf982b5b33 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 9 May 2023 03:37:36 +0900 Subject: [PATCH 13/85] Clipspace Menu and MaskEditor application. (#548) * Add clipspace feature. * feat: copy content to clipspace * feat: paste content from clipspace Extend validation to allow for validating annotated_path in addition to other parameters. Add support for annotated_filepath in folder_paths function. Generalize the '/upload/image' API to allow for uploading images to the 'input', 'temp', or 'output' directories. * rename contentClipboard -> clipspace * Do deep copy for imgs on copy to clipspace. * mask painting on clipspace * add original_imgs into clipspace * Preserve the original image when 'imgs' are modified * robust patch & refactoring folder_paths about annotated_filepath * wip * Only show the Paste menu if the ComfyApp.clipspace is not empty * clipspace feature added maskeditor feature added * instant refresh on paste force triggering 'changed' on paste action * enhance mask painting smooth drawing add brush_size +/- button * robust patch use mouseup event * robust patch again... * subfolder fix on paste logic attach subfolder if subfolder isn't empty * event listener patch add ], [ key event for brush size remove listener on close * Fix button positioning issue related to window height. Change brush size from button to slider. * clean commit * clean code * various bug fixes * paste action - prevent opening upload popup - ensure rendering after widget_value update * view api update - support annotated_filepath * maskeditor layout - prevent covering button by hidden div * remove dbg message * Add cursor functionality to display brush size * refactor: Replace brush preview feature with missionfloyd implementation * missionfloyd implementation * hiding brush preview off the canvas * change brush size on wheel event * keyup -> keydown event * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * Add support for channel-specific image data retrieval in /view API to fix alpha mask loading issue When loading an image with an alpha mask in JavaScript canvas, there is an issue where the alpha and RGB channels are premultiplied. To avoid reliance on JavaScript canvas, I added support for channel-specific image data retrieval in the "/view" API. This allows us to retrieve data for each channel separately and fix the alpha mask loading issue. The changes have been committed to the repository. * Enable brush preview for key and slider events * optimize * preview fix * robust patch * fix copy (clipspace) action imgs[0] copy -> whole imgs copy * support batch images on clipspace, maskeditor * copy/paste bug fixes for batch images enhance selector preview on clipspace menu add img_paste_mode option into clipspace menu * crash fix * print message if clipspace content cannot editable * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * make default img_paste_mode to 'selected' refactor space -> tab * save clipspace files to input/clipspace instead of temp * show "clipspace/filename.png" instead of 'filename.png [clipspace]' in LoadImage/LoadImageMask * refresh fix related to FILE_COMBO * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * adjust margin based on missionfloyd impelements * mouse event -> pointer event * pen, touch, mouse drawing patched and tested * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * add comment about touch event. --------- Co-authored-by: Lt.Dr.Data Co-authored-by: missionfloyd --- folder_paths.py | 9 + nodes.py | 8 +- server.py | 122 ++++++- web/extensions/core/clipspace.js | 166 +++++++++ web/extensions/core/maskeditor.js | 589 ++++++++++++++++++++++++++++++ web/scripts/app.js | 114 ++++-- web/scripts/ui.js | 1 + web/scripts/widgets.js | 14 + 8 files changed, 976 insertions(+), 47 deletions(-) create mode 100644 web/extensions/core/clipspace.js create mode 100644 web/extensions/core/maskeditor.js diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..0acd22674 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -57,6 +57,10 @@ def get_input_directory(): global input_directory return input_directory +def get_clipspace_directory(): + global input_directory + return input_directory+"/clipspace" + #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name): @@ -66,6 +70,8 @@ def get_directory_by_type(type_name): return get_temp_directory() if type_name == "input": return get_input_directory() + if type_name == "clipspace": + return get_clipspace_directory() return None @@ -81,6 +87,9 @@ def annotated_filepath(name): elif name.endswith("[temp]"): base_dir = get_temp_directory() name = name[:-7] + elif name.endswith("[clipspace]"): + base_dir = get_clipspace_directory() + name = name[:-12] else: return name, None diff --git a/nodes.py b/nodes.py index ca0769ba7..1d9a5c872 100644 --- a/nodes.py +++ b/nodes.py @@ -973,8 +973,9 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), )}, + {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, )}, } CATEGORY = "image" @@ -1014,9 +1015,10 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), ), - "channel": (s._color_channels, ),} + {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, ), + "channel": (s._color_channels, ), } } CATEGORY = "mask" diff --git a/server.py b/server.py index 1c5c17916..48644d83a 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,9 @@ import execution import uuid import json import glob +from PIL import Image +from io import BytesIO + try: import aiohttp from aiohttp import web @@ -110,19 +113,26 @@ class PromptServer(): files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) + def get_dir_by_type(dir_type): + if dir_type is None: + type_dir = folder_paths.get_input_directory() + elif dir_type == "input": + type_dir = folder_paths.get_input_directory() + elif dir_type == "clipspace": + type_dir = folder_paths.get_clipspace_directory() + elif dir_type == "temp": + type_dir = folder_paths.get_temp_directory() + elif dir_type == "output": + type_dir = folder_paths.get_output_directory() + + return type_dir + @routes.post("/upload/image") async def upload_image(request): post = await request.post() image = post.get("image") - if post.get("type") is None: - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "input": - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "temp": - upload_dir = folder_paths.get_temp_directory() - elif post.get("type") == "output": - upload_dir = folder_paths.get_output_directory() + upload_dir = get_dir_by_type(post.get("type")) if not os.path.exists(upload_dir): os.makedirs(upload_dir) @@ -147,12 +157,62 @@ class PromptServer(): else: return web.Response(status=400) + @routes.post("/upload/mask") + async def upload_mask(request): + post = await request.post() + image = post.get("image") + original_image = post.get("original_image") + + upload_dir = get_dir_by_type(post.get("type")) + + if not os.path.exists(upload_dir): + os.makedirs(upload_dir) + + if image and image.file: + filename = image.filename + if not filename: + return web.Response(status=400) + + split = os.path.splitext(filename) + i = 1 + while os.path.exists(os.path.join(upload_dir, filename)): + filename = f"{split[0]} ({i}){split[1]}" + i += 1 + + filepath = os.path.join(upload_dir, filename) + + original_pil = Image.open(original_image.file).convert('RGBA') + mask_pil = Image.open(image.file).convert('RGBA') + + # alpha copy + new_alpha = mask_pil.getchannel('A') + original_pil.putalpha(new_alpha) + + original_pil.save(filepath) + + return web.json_response({"name": filename}) + else: + return web.Response(status=400) + @routes.get("/view") async def view_image(request): if "filename" in request.rel_url.query: - type = request.rel_url.query.get("type", "output") - output_dir = folder_paths.get_directory_by_type(type) + filename = request.rel_url.query["filename"] + filename,output_dir = folder_paths.annotated_filepath(filename) + + if request.rel_url.query.get("type", "input") and filename.startswith("clipspace/"): + output_dir = folder_paths.get_clipspace_directory() + filename = filename[10:] + + # validation for security: prevent accessing arbitrary path + if filename[0] == '/' or '..' in filename: + return web.Response(status=400) + + if output_dir is None: + type = request.rel_url.query.get("type", "output") + output_dir = folder_paths.get_directory_by_type(type) + if output_dir is None: return web.Response(status=400) @@ -162,13 +222,49 @@ class PromptServer(): return web.Response(status=403) output_dir = full_output_dir - filename = request.rel_url.query["filename"] filename = os.path.basename(filename) file = os.path.join(output_dir, filename) if os.path.isfile(file): - return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) - + if 'channel' not in request.rel_url.query: + channel = 'rgba' + else: + channel = request.rel_url.query["channel"] + + if channel == 'rgb': + with Image.open(file) as img: + if img.mode == "RGBA": + r, g, b, a = img.split() + new_img = Image.merge('RGB', (r, g, b)) + else: + new_img = img.convert("RGB") + + buffer = BytesIO() + new_img.save(buffer, format='PNG') + buffer.seek(0) + + return web.Response(body=buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + + elif channel == 'a': + with Image.open(file) as img: + if img.mode == "RGBA": + _, _, _, a = img.split() + else: + a = Image.new('L', img.size, 255) + + # alpha img + alpha_img = Image.new('RGBA', img.size) + alpha_img.putalpha(a) + alpha_buffer = BytesIO() + alpha_img.save(alpha_buffer, format='PNG') + alpha_buffer.seek(0) + + return web.Response(body=alpha_buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + else: + return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) + return web.Response(status=404) @routes.get("/prompt") diff --git a/web/extensions/core/clipspace.js b/web/extensions/core/clipspace.js new file mode 100644 index 000000000..adb5877ea --- /dev/null +++ b/web/extensions/core/clipspace.js @@ -0,0 +1,166 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; + +export class ClipspaceDialog extends ComfyDialog { + static items = []; + static instance = null; + + static registerButton(name, contextPredicate, callback) { + const item = + $el("button", { + type: "button", + textContent: name, + contextPredicate: contextPredicate, + onclick: callback + }) + + ClipspaceDialog.items.push(item); + } + + static invalidatePreview() { + if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) { + const img_preview = document.getElementById("clipspace_preview"); + if(img_preview) { + img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + img_preview.style.maxHeight = "100%"; + img_preview.style.maxWidth = "100%"; + } + } + } + + static invalidate() { + if(ClipspaceDialog.instance) { + const self = ClipspaceDialog.instance; + // allow reconstruct controls when copying from non-image to image content. + const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]); + + if(self.element) { + // update + self.element.removeChild(self.element.firstChild); + self.element.appendChild(children); + } + else { + // new + self.element = $el("div.comfy-modal", { parent: document.body }, [children,]); + } + + if(self.element.children[0].children.length <= 1) { + self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."])); + } + + ClipspaceDialog.invalidatePreview(); + } + } + + constructor() { + super(); + } + + createButtons(self) { + const buttons = []; + + for(let idx in ClipspaceDialog.items) { + const item = ClipspaceDialog.items[idx]; + if(!item.contextPredicate || item.contextPredicate()) + buttons.push(ClipspaceDialog.items[idx]); + } + + buttons.push( + $el("button", { + type: "button", + textContent: "Close", + onclick: () => { this.close(); } + }) + ); + + return buttons; + } + + createImgSettings() { + if(ComfyApp.clipspace.imgs) { + const combo_items = []; + const imgs = ComfyApp.clipspace.imgs; + + for(let i=0; i < imgs.length; i++) { + combo_items.push($el("option", {value:i}, [`${i}`])); + } + + const combo1 = $el("select", + {id:"clipspace_img_selector", onchange:(event) => { + ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex; + ClipspaceDialog.invalidatePreview(); + } }, combo_items); + + const row1 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]), + $el("td", {}, [combo1]) + ]); + + + const combo2 = $el("select", + {id:"clipspace_img_paste_mode", onchange:(event) => { + ComfyApp.clipspace['img_paste_mode'] = event.target.value; + } }, + [ + $el("option", {value:'selected'}, 'selected'), + $el("option", {value:'all'}, 'all') + ]); + combo2.value = ComfyApp.clipspace['img_paste_mode']; + + const row2 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]), + $el("td", {}, [combo2]) + ]); + + const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'}, + [ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]); + + const row3 = + $el("tr", {}, [td]); + + return $el("table", {}, [row1, row2, row3]); + } + else { + return []; + } + } + + createImgPreview() { + if(ComfyApp.clipspace.imgs) { + return $el("img",{id:"clipspace_preview", ondragstart:() => false}); + } + else + return []; + } + + show() { + const img_preview = document.getElementById("clipspace_preview"); + ClipspaceDialog.invalidate(); + + this.element.style.display = "block"; + } +} + +app.registerExtension({ + name: "Comfy.Clipspace", + init(app) { + app.openClipspace = + function () { + if(!ClipspaceDialog.instance) { + ClipspaceDialog.instance = new ClipspaceDialog(app); + ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate; + } + + if(ComfyApp.clipspace) { + ClipspaceDialog.instance.show(); + } + else + app.ui.dialog.show("Clipspace is Empty!"); + }; + } +}); \ No newline at end of file diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js new file mode 100644 index 000000000..c55f841b6 --- /dev/null +++ b/web/extensions/core/maskeditor.js @@ -0,0 +1,589 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; +import { ClipspaceDialog } from "/extensions/core/clipspace.js"; + +// Helper function to convert a data URL to a Blob object +function dataURLToBlob(dataURL) { + const parts = dataURL.split(';base64,'); + const contentType = parts[0].split(':')[1]; + const byteString = atob(parts[1]); + const arrayBuffer = new ArrayBuffer(byteString.length); + const uint8Array = new Uint8Array(arrayBuffer); + for (let i = 0; i < byteString.length; i++) { + uint8Array[i] = byteString.charCodeAt(i); + } + return new Blob([arrayBuffer], { type: contentType }); +} + +function loadedImageToBlob(image) { + const canvas = document.createElement('canvas'); + + canvas.width = image.width; + canvas.height = image.height; + + const ctx = canvas.getContext('2d'); + + ctx.drawImage(image, 0, 0); + + const dataURL = canvas.toDataURL('image/png', 1); + const blob = dataURLToBlob(dataURL); + + return blob; +} + +async function uploadMask(filepath, formData) { + await fetch('/upload/mask', { + method: 'POST', + body: formData + }).then(response => {}).catch(error => { + console.error('Error:', error); + }); + + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = `view?filename=${filepath.filename}&type=${filepath.type}`; + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; + + ClipspaceDialog.invalidatePreview(); +} + +function prepareRGB(image, backupCanvas, backupCtx) { + // paste mask data into alpha channel + backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height); + const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); +} + +class MaskEditorDialog extends ComfyDialog { + static instance = null; + constructor() { + super(); + this.element = $el("div.comfy-modal", { parent: document.body }, + [ $el("div.comfy-modal-content", + [...this.createButtons()]), + ]); + MaskEditorDialog.instance = this; + } + + createButtons() { + return []; + } + + clearMask(self) { + } + + createButton(name, callback) { + var button = document.createElement("button"); + button.innerText = name; + button.addEventListener("click", callback); + return button; + } + createLeftButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "left"; + button.style.marginRight = "4px"; + return button; + } + createRightButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "right"; + button.style.marginLeft = "4px"; + return button; + } + createLeftSlider(self, name, callback) { + const divElement = document.createElement('div'); + divElement.id = "maskeditor-slider"; + divElement.style.cssFloat = "left"; + divElement.style.fontFamily = "sans-serif"; + divElement.style.marginRight = "4px"; + divElement.style.color = "var(--input-text)"; + divElement.style.backgroundColor = "var(--comfy-input-bg)"; + divElement.style.borderRadius = "8px"; + divElement.style.borderColor = "var(--border-color)"; + divElement.style.borderStyle = "solid"; + divElement.style.fontSize = "15px"; + divElement.style.height = "21px"; + divElement.style.padding = "1px 6px"; + divElement.style.display = "flex"; + divElement.style.position = "relative"; + divElement.style.top = "2px"; + self.brush_slider_input = document.createElement('input'); + self.brush_slider_input.setAttribute('type', 'range'); + self.brush_slider_input.setAttribute('min', '1'); + self.brush_slider_input.setAttribute('max', '100'); + self.brush_slider_input.setAttribute('value', '10'); + const labelElement = document.createElement("label"); + labelElement.textContent = name; + + divElement.appendChild(labelElement); + divElement.appendChild(self.brush_slider_input); + + self.brush_slider_input.addEventListener("change", callback); + + return divElement; + } + + setlayout(imgCanvas, maskCanvas) { + const self = this; + + // If it is specified as relative, using it only as a hidden placeholder for padding is recommended + // to prevent anomalies where it exceeds a certain size and goes outside of the window. + var placeholder = document.createElement("div"); + placeholder.style.position = "relative"; + placeholder.style.height = "50px"; + + var bottom_panel = document.createElement("div"); + bottom_panel.style.position = "absolute"; + bottom_panel.style.bottom = "0px"; + bottom_panel.style.left = "20px"; + bottom_panel.style.right = "20px"; + bottom_panel.style.height = "50px"; + + var brush = document.createElement("div"); + brush.id = "brush"; + brush.style.backgroundColor = "transparent"; + brush.style.outline = "1px dashed black"; + brush.style.boxShadow = "0 0 0 1px white"; + brush.style.borderRadius = "50%"; + brush.style.MozBorderRadius = "50%"; + brush.style.WebkitBorderRadius = "50%"; + brush.style.position = "absolute"; + brush.style.zIndex = 100; + brush.style.pointerEvents = "none"; + this.brush = brush; + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + document.body.appendChild(brush); + + var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { + self.brush_size = event.target.value; + self.updateBrushPreview(self, null, null); + }); + var clearButton = this.createLeftButton("Clear", + () => { + self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); + self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height); + }); + var cancelButton = this.createRightButton("Cancel", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.close(); + }); + var saveButton = this.createRightButton("Save", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.save(); + }); + + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + + bottom_panel.appendChild(clearButton); + bottom_panel.appendChild(saveButton); + bottom_panel.appendChild(cancelButton); + bottom_panel.appendChild(brush_size_slider); + + this.element.style.display = "block"; + imgCanvas.style.position = "relative"; + imgCanvas.style.top = "200"; + imgCanvas.style.left = "0"; + + maskCanvas.style.position = "absolute"; + } + + show() { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); + + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; + + this.setlayout(imgCanvas, maskCanvas); + + // prepare content + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); + + this.setImages(imgCanvas, backupCanvas); + this.setEventHandler(maskCanvas); + } + + setImages(imgCanvas, backupCanvas) { + const imgCtx = imgCanvas.getContext('2d'); + const backupCtx = backupCanvas.getContext('2d'); + const maskCtx = this.maskCtx; + const maskCanvas = this.maskCanvas; + + // image load + const orig_image = new Image(); + window.addEventListener("resize", () => { + // repositioning + imgCanvas.width = window.innerWidth - 250; + imgCanvas.height = window.innerHeight - 200; + + // redraw image + let drawWidth = orig_image.width; + let drawHeight = orig_image.height; + if (orig_image.width > imgCanvas.width) { + drawWidth = imgCanvas.width; + drawHeight = (drawWidth / orig_image.width) * orig_image.height; + } + + if (drawHeight > imgCanvas.height) { + drawHeight = imgCanvas.height; + drawWidth = (drawHeight / orig_image.height) * orig_image.width; + } + + imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); + + // update mask + backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); + maskCanvas.width = drawWidth; + maskCanvas.height = drawHeight; + maskCanvas.style.top = imgCanvas.offsetTop + "px"; + maskCanvas.style.left = imgCanvas.offsetLeft + "px"; + maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); + }); + + const filepath = ComfyApp.clipspace.images; + + const touched_image = new Image(); + + touched_image.onload = function() { + backupCanvas.width = touched_image.width; + backupCanvas.height = touched_image.height; + + prepareRGB(touched_image, backupCanvas, backupCtx); + }; + + const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) + alpha_url.searchParams.delete('channel'); + alpha_url.searchParams.set('channel', 'a'); + touched_image.src = alpha_url; + + // original image load + orig_image.onload = function() { + window.dispatchEvent(new Event('resize')); + }; + + const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); + rgb_url.searchParams.delete('channel'); + rgb_url.searchParams.set('channel', 'rgb'); + orig_image.src = rgb_url; + this.image = orig_image; + }g + + + setEventHandler(maskCanvas) { + maskCanvas.addEventListener("contextmenu", (event) => { + event.preventDefault(); + }); + + const self = this; + maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); + maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); + document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); + maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); + maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); + document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); + } + + brush_size = 10; + drawing_mode = false; + lastx = -1; + lasty = -1; + lasttime = 0; + + static handleKeyDown(event) { + const self = MaskEditorDialog.instance; + if (event.key === ']') { + self.brush_size = Math.min(self.brush_size+2, 100); + } else if (event.key === '[') { + self.brush_size = Math.max(self.brush_size-2, 1); + } + + self.updateBrushPreview(self); + } + + static handlePointerUp(event) { + event.preventDefault(); + MaskEditorDialog.instance.drawing_mode = false; + } + + updateBrushPreview(self) { + const brush = self.brush; + + var centerX = self.cursorX; + var centerY = self.cursorY; + + brush.style.width = self.brush_size * 2 + "px"; + brush.style.height = self.brush_size * 2 + "px"; + brush.style.left = (centerX - self.brush_size) + "px"; + brush.style.top = (centerY - self.brush_size) + "px"; + } + + handleWheelEvent(self, event) { + if(event.deltaY < 0) + self.brush_size = Math.min(self.brush_size+2, 100); + else + self.brush_size = Math.max(self.brush_size-2, 1); + + self.brush_slider_input.value = self.brush_size; + + self.updateBrushPreview(self); + } + + draw_move(self, event) { + event.preventDefault(); + + this.cursorX = event.pageX; + this.cursorY = event.pageY; + + self.updateBrushPreview(self); + + if (event instanceof TouchEvent || event.buttons == 1) { + var diff = performance.now() - self.lasttime; + + const maskRect = self.maskCanvas.getBoundingClientRect(); + + var x = event.offsetX; + var y = event.offsetY + + if(event.offsetX == null) { + x = event.targetTouches[0].clientX - maskRect.left; + } + + if(event.offsetY == null) { + y = event.targetTouches[0].clientY - maskRect.top; + } + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(event instanceof TouchEvent && diff < 20){ + // The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents. + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !this.drawing_mode) + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(event instanceof TouchEvent && diff < 20){ + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + } + + handlePointerDown(self, event) { + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + + if ([0, 2, 5].includes(event.button)) { + self.drawing_mode = true; + + event.preventDefault(); + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + self.maskCtx.beginPath(); + if (event.button == 0) { + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + } else { + self.maskCtx.globalCompositeOperation = "destination-out"; + } + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + self.lasttime = performance.now(); + } + } + + save() { + const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); + + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + backupCtx.drawImage(this.maskCanvas, + 0, 0, this.maskCanvas.width, this.maskCanvas.height, + 0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // paste mask data into alpha channel + const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); + + const formData = new FormData(); + const filename = "clipspace-mask-" + performance.now() + ".png"; + + const item = + { + "filename": filename, + "subfolder": "", + "type": "clipspace", + }; + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[0] = item; + + if(ComfyApp.clipspace.widgets) { + const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); + + if(index >= 0) + ComfyApp.clipspace.widgets[index].value = item; + } + + const dataURL = this.backupCanvas.toDataURL(); + const blob = dataURLToBlob(dataURL); + + const original_blob = loadedImageToBlob(this.image); + + formData.append('image', blob, filename); + formData.append('original_image', original_blob); + formData.append('type', "clipspace"); + + uploadMask(item, formData); + this.close(); + } +} + +app.registerExtension({ + name: "Comfy.MaskEditor", + init(app) { + const callback = + function () { + let dlg = new MaskEditorDialog(app); + dlg.show(); + }; + + const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 + ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); + } +}); \ No newline at end of file diff --git a/web/scripts/app.js b/web/scripts/app.js index 245605484..f4f7272db 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -25,6 +25,7 @@ export class ComfyApp { * @type {serialized node object} */ static clipspace = null; + static clipspace_invalidate_handler = null; constructor() { this.ui = new ComfyUI(this); @@ -143,22 +144,34 @@ export class ComfyApp { callback: (obj) => { var widgets = null; if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); + widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); } - let img = new Image(); var imgs = undefined; + var orig_imgs = undefined; if(this.imgs != undefined) { - img.src = this.imgs[0].src; - imgs = [img]; + imgs = []; + orig_imgs = []; + + for (let i = 0; i < this.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = this.imgs[i].src; + orig_imgs[i] = imgs[i]; + } } ComfyApp.clipspace = { 'widgets': widgets, 'imgs': imgs, - 'original_imgs': imgs, - 'images': this.images + 'original_imgs': orig_imgs, + 'images': this.images, + 'selectedIndex': 0, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action }; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } } }); @@ -167,48 +180,82 @@ export class ComfyApp { { content: "Paste (Clipspace)", callback: () => { - if(ComfyApp.clipspace != null) { - if(ComfyApp.clipspace.widgets != null && this.widgets != null) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop) { - prop.callback(value); - } - }); - } - + if(ComfyApp.clipspace) { // image paste - if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) { + if(ComfyApp.clipspace.imgs && this.imgs) { var filename = ""; if(this.images && ComfyApp.clipspace.images) { - this.images = ComfyApp.clipspace.images; + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + + } + else + app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images; } - if(ComfyApp.clipspace.images != undefined) { - const clip_image = ComfyApp.clipspace.images[0]; + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + this.imgs = [img]; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); if(index_in_clip >= 0) { - filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`; + const item = ComfyApp.clipspace.widgets[index_in_clip].value; + if(item.type) + filename = `${item.filename} [${item.type}]`; + else + filename = item.filename; } } - const index = this.widgets.findIndex(obj => obj.name === 'image'); - if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) { - this.imgs = ComfyApp.clipspace.imgs; + // for Load Image node. + if(this.widgets) { + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0 && filename != "") { + const postfix = ' [clipspace]'; + if(filename.endsWith(postfix) && this.widgets[index].options.base_dir == 'input') { + filename = "clipspace/" + filename.slice(0, filename.indexOf(postfix)); + } - this.widgets[index].value = filename; - if(this.widgets_values != undefined) { - this.widgets_values[index] = filename; + this.widgets[index].value = filename; + if(this.widgets_values != undefined) { + this.widgets_values[index] = filename; + } } } } - this.trigger('changed'); + + // ensure render after update widget_value + if(ComfyApp.clipspace.widgets && this.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.callback(value); + } + }); + } } + + app.graph.setDirtyCanvas(true); } } ); @@ -1275,12 +1322,17 @@ export class ComfyApp { for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] - if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { - widget.options.values = def["input"]["required"][widget.name][0]; + if(def["input"]["required"][widget.name][0] == "FILE_COMBO") { + console.log(widget.options.values = def["input"]["required"][widget.name][1].files); + widget.options.values = def["input"]["required"][widget.name][1].files; + } + else + widget.options.values = def["input"]["required"][widget.name][0]; if(!widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; + widget.callback(widget.value); } } } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 5accc9d86..77517aec1 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -581,6 +581,7 @@ export class ComfyUI { }), $el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }), $el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), + $el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }), $el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => { if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index cd471bc93..4a72246db 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -256,6 +256,20 @@ export const ComfyWidgets = { } return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; }, + FILE_COMBO(node, inputName, inputData) { + const base_dir = inputData[1].base_dir; + let defaultValue = inputData[1].files[0]; + + const files = [] + for(let i in inputData[1].files) { + files[i] = inputData[1].files[i]; + const postfix = ' [clipspace]'; + if(base_dir == 'input' && files[i].endsWith(postfix)) + files[i] = "clipspace/" + files[i].slice(0, files[i].indexOf(postfix)); + } + + return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { base_dir:base_dir, values: files }) }; + }, IMAGEUPLOAD(node, inputName, inputData, app) { const imageWidget = node.widgets.find((w) => w.name === "image"); let uploadWidget; From 850daf0416367ba39d10195540f5b735952f0ee7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 14:13:06 -0400 Subject: [PATCH 14/85] Masked editor changes. Add a way to upload to subfolders. Clean up code. Fix some issues. --- folder_paths.py | 9 ---- nodes.py | 8 ++-- server.py | 74 ++++++++++++------------------- web/extensions/core/maskeditor.js | 9 ++-- web/scripts/app.js | 66 ++++++++------------------- web/scripts/widgets.js | 52 +++++++++++++++------- 6 files changed, 93 insertions(+), 125 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 0acd22674..e5b89492c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -57,10 +57,6 @@ def get_input_directory(): global input_directory return input_directory -def get_clipspace_directory(): - global input_directory - return input_directory+"/clipspace" - #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name): @@ -70,8 +66,6 @@ def get_directory_by_type(type_name): return get_temp_directory() if type_name == "input": return get_input_directory() - if type_name == "clipspace": - return get_clipspace_directory() return None @@ -87,9 +81,6 @@ def annotated_filepath(name): elif name.endswith("[temp]"): base_dir = get_temp_directory() name = name[:-7] - elif name.endswith("[clipspace]"): - base_dir = get_clipspace_directory() - name = name[:-12] else: return name, None diff --git a/nodes.py b/nodes.py index 1d9a5c872..699e60ae8 100644 --- a/nodes.py +++ b/nodes.py @@ -973,9 +973,9 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, )}, + {"image": (sorted(files), )}, } CATEGORY = "image" @@ -1015,9 +1015,9 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, ), + {"image": (sorted(files), ), "channel": (s._color_channels, ), } } diff --git a/server.py b/server.py index 48644d83a..3d02b2f7a 100644 --- a/server.py +++ b/server.py @@ -118,8 +118,6 @@ class PromptServer(): type_dir = folder_paths.get_input_directory() elif dir_type == "input": type_dir = folder_paths.get_input_directory() - elif dir_type == "clipspace": - type_dir = folder_paths.get_clipspace_directory() elif dir_type == "temp": type_dir = folder_paths.get_temp_directory() elif dir_type == "output": @@ -127,73 +125,63 @@ class PromptServer(): return type_dir - @routes.post("/upload/image") - async def upload_image(request): - post = await request.post() + def image_upload(post, image_save_function=None): image = post.get("image") - upload_dir = get_dir_by_type(post.get("type")) - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) + image_upload_type = post.get("type") + upload_dir = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename if not filename: return web.Response(status=400) + subfolder = post.get("subfolder", "") + full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder)) + + if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir: + return web.Response(status=400) + + if not os.path.exists(full_output_folder): + os.makedirs(full_output_folder) + split = os.path.splitext(filename) + filepath = os.path.join(full_output_folder, filename) + i = 1 - while os.path.exists(os.path.join(upload_dir, filename)): + while os.path.exists(filepath): filename = f"{split[0]} ({i}){split[1]}" i += 1 - filepath = os.path.join(upload_dir, filename) + if image_save_function is not None: + image_save_function(image, post, filepath) + else: + with open(filepath, "wb") as f: + f.write(image.file.read()) - with open(filepath, "wb") as f: - f.write(image.file.read()) - - return web.json_response({"name" : filename}) + return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) else: return web.Response(status=400) + @routes.post("/upload/image") + async def upload_image(request): + post = await request.post() + return image_upload(post) + @routes.post("/upload/mask") async def upload_mask(request): post = await request.post() - image = post.get("image") - original_image = post.get("original_image") - upload_dir = get_dir_by_type(post.get("type")) - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) - - if image and image.file: - filename = image.filename - if not filename: - return web.Response(status=400) - - split = os.path.splitext(filename) - i = 1 - while os.path.exists(os.path.join(upload_dir, filename)): - filename = f"{split[0]} ({i}){split[1]}" - i += 1 - - filepath = os.path.join(upload_dir, filename) - - original_pil = Image.open(original_image.file).convert('RGBA') + def image_save_function(image, post, filepath): + original_pil = Image.open(post.get("original_image").file).convert('RGBA') mask_pil = Image.open(image.file).convert('RGBA') # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath) - return web.json_response({"name": filename}) - else: - return web.Response(status=400) - + return image_upload(post, image_save_function) @routes.get("/view") async def view_image(request): @@ -201,10 +189,6 @@ class PromptServer(): filename = request.rel_url.query["filename"] filename,output_dir = folder_paths.annotated_filepath(filename) - if request.rel_url.query.get("type", "input") and filename.startswith("clipspace/"): - output_dir = folder_paths.get_clipspace_directory() - filename = filename[10:] - # validation for security: prevent accessing arbitrary path if filename[0] == '/' or '..' in filename: return web.Response(status=400) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index c55f841b6..0ffa50c69 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -41,7 +41,7 @@ async function uploadMask(filepath, formData) { }); ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); - ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = `view?filename=${filepath.filename}&type=${filepath.type}`; + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString(); if(ComfyApp.clipspace.images) ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; @@ -546,8 +546,8 @@ class MaskEditorDialog extends ComfyDialog { const item = { "filename": filename, - "subfolder": "", - "type": "clipspace", + "subfolder": "clipspace", + "type": "input", }; if(ComfyApp.clipspace.images) @@ -567,7 +567,8 @@ class MaskEditorDialog extends ComfyDialog { formData.append('image', blob, filename); formData.append('original_image', original_blob); - formData.append('type', "clipspace"); + formData.append('type', "input"); + formData.append('subfolder', "clipspace"); uploadMask(item, formData); this.close(); diff --git a/web/scripts/app.js b/web/scripts/app.js index f4f7272db..c6c29e45b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -183,7 +183,6 @@ export class ComfyApp { if(ComfyApp.clipspace) { // image paste if(ComfyApp.clipspace.imgs && this.imgs) { - var filename = ""; if(this.images && ComfyApp.clipspace.images) { if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; @@ -209,49 +208,25 @@ export class ComfyApp { } } } - - if(ComfyApp.clipspace.images) { - const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]; - if(clip_image.subfolder != '') - filename = `${clip_image.subfolder}/`; - filename += `${clip_image.filename} [${clip_image.type}]`; - } - else if(ComfyApp.clipspace.widgets) { - const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); - if(index_in_clip >= 0) { - const item = ComfyApp.clipspace.widgets[index_in_clip].value; - if(item.type) - filename = `${item.filename} [${item.type}]`; - else - filename = item.filename; - } - } - - // for Load Image node. - if(this.widgets) { - const index = this.widgets.findIndex(obj => obj.name === 'image'); - if(index >= 0 && filename != "") { - const postfix = ' [clipspace]'; - if(filename.endsWith(postfix) && this.widgets[index].options.base_dir == 'input') { - filename = "clipspace/" + filename.slice(0, filename.indexOf(postfix)); - } - - this.widgets[index].value = filename; - if(this.widgets_values != undefined) { - this.widgets_values[index] = filename; - } - } - } } - // ensure render after update widget_value - if(ComfyApp.clipspace.widgets && this.widgets) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { - prop.callback(value); - } - }); + if(this.widgets) { + if(ComfyApp.clipspace.images) { + const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]; + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0) { + this.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } } } @@ -1323,12 +1298,7 @@ export class ComfyApp { for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { - if(def["input"]["required"][widget.name][0] == "FILE_COMBO") { - console.log(widget.options.values = def["input"]["required"][widget.name][1].files); - widget.options.values = def["input"]["required"][widget.name][1].files; - } - else - widget.options.values = def["input"]["required"][widget.name][0]; + widget.options.values = def["input"]["required"][widget.name][0]; if(!widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 4a72246db..65edc0392 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -256,20 +256,6 @@ export const ComfyWidgets = { } return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; }, - FILE_COMBO(node, inputName, inputData) { - const base_dir = inputData[1].base_dir; - let defaultValue = inputData[1].files[0]; - - const files = [] - for(let i in inputData[1].files) { - files[i] = inputData[1].files[i]; - const postfix = ' [clipspace]'; - if(base_dir == 'input' && files[i].endsWith(postfix)) - files[i] = "clipspace/" + files[i].slice(0, files[i].indexOf(postfix)); - } - - return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { base_dir:base_dir, values: files }) }; - }, IMAGEUPLOAD(node, inputName, inputData, app) { const imageWidget = node.widgets.find((w) => w.name === "image"); let uploadWidget; @@ -280,10 +266,46 @@ export const ComfyWidgets = { node.imgs = [img]; app.graph.setDirtyCanvas(true); }; - img.src = `/view?filename=${name}&type=input`; + let folder_separator = name.lastIndexOf("/"); + let subfolder = ""; + if (folder_separator > -1) { + subfolder = name.substring(0, folder_separator); + name = name.substring(folder_separator + 1); + } + img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`; node.setSizeForImage?.(); } + var default_value = imageWidget.value; + Object.defineProperty(imageWidget, "value", { + set : function(value) { + this._real_value = value; + }, + + get : function() { + let value = ""; + if (this._real_value) { + value = this._real_value; + } else { + return default_value; + } + + if (value.filename) { + let real_value = value; + value = ""; + if (real_value.subfolder) { + value = real_value.subfolder + "/"; + } + + value += real_value.filename; + + if(real_value.type && real_value.type !== "input") + value += ` [${real_value.type}]`; + } + return value; + } + }); + // Add our own callback to the combo widget to render an image when it changes const cb = node.callback; imageWidget.callback = function () { From a7ebd5aa1278a63f2f14852dce59b43834f6b9d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 15:52:33 -0400 Subject: [PATCH 15/85] Fix masked editor issue with firefox on windows. --- web/extensions/core/maskeditor.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 0ffa50c69..552059e86 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -368,7 +368,7 @@ class MaskEditorDialog extends ComfyDialog { self.updateBrushPreview(self); - if (event instanceof TouchEvent || event.buttons == 1) { + if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) { var diff = performance.now() - self.lasttime; const maskRect = self.maskCanvas.getBoundingClientRect(); @@ -389,7 +389,7 @@ class MaskEditorDialog extends ComfyDialog { brush_size *= event.pressure; this.last_pressure = event.pressure; } - else if(event instanceof TouchEvent && diff < 20){ + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ // The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents. brush_size *= this.last_pressure; } @@ -442,7 +442,7 @@ class MaskEditorDialog extends ComfyDialog { brush_size *= event.pressure; this.last_pressure = event.pressure; } - else if(event instanceof TouchEvent && diff < 20){ + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ brush_size *= this.last_pressure; } else { From a8705dbfe20ba86eaac5a669c61453775c796441 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 17:05:28 -0400 Subject: [PATCH 16/85] Speed up the mask save and fix refresh replacing copied image. --- server.py | 2 +- web/scripts/app.js | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 3d02b2f7a..c1226f304 100644 --- a/server.py +++ b/server.py @@ -179,7 +179,7 @@ class PromptServer(): # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath) + original_pil.save(filepath, compress_level=4) return image_upload(post, image_save_function) diff --git a/web/scripts/app.js b/web/scripts/app.js index c6c29e45b..2da1b5581 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1300,7 +1300,7 @@ export class ComfyApp { if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { widget.options.values = def["input"]["required"][widget.name][0]; - if(!widget.options.values.includes(widget.value)) { + if(widget.name != 'image' && !widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; widget.callback(widget.value); } From c6e34963e412e1960f73ad357d10c2b7bd1464e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 18:15:19 -0400 Subject: [PATCH 17/85] Make t2i adapter work with any latent resolution. --- comfy/t2i_adapter/adapter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 0221fff83..87e3d859e 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -56,7 +56,12 @@ class Downsample(nn.Module): def forward(self, x): assert x.shape[1] == self.channels - return self.op(x) + if not self.use_conv: + padding = [x.shape[2] % 2, x.shape[3] % 2] + self.op.padding = padding + + x = self.op(x) + return x class ResnetBlock(nn.Module): From d43e45ce624b82dadbe98646329d2b0fbc17edcf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 10:29:58 -0400 Subject: [PATCH 18/85] Remove print. --- nodes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nodes.py b/nodes.py index 699e60ae8..760db24e1 100644 --- a/nodes.py +++ b/nodes.py @@ -443,7 +443,6 @@ class ControlNetApply: def apply_controlnet(self, conditioning, control_net, image, strength): c = [] control_hint = image.movedim(-1,1) - print(control_hint.shape) for t in conditioning: n = [t[0], t[1].copy()] c_net = control_net.copy().set_cond_hint(control_hint, strength) From 314e526c5ce428a3717207c5c36a42a5c895b6a5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 12:18:18 -0400 Subject: [PATCH 19/85] Not needed anymore because sampling works with any latent size. --- comfy/samplers.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index dcf93cca2..6417f2ed4 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -362,19 +362,8 @@ def resolve_cond_masks(conditions, h, w, device): else: box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) - # Make sure the height and width are divisible by 8 - if X % 8 != 0: - newx = X // 8 * 8 - W = W + (X - newx) - X = newx - if Y % 8 != 0: - newy = Y // 8 * 8 - H = H + (Y - newy) - Y = newy - if H % 8 != 0: - H = H + (8 - (H % 8)) - if W % 8 != 0: - W = W + (8 - (W % 8)) + H = max(8, H) + W = max(8, W) area = (int(H), int(W), int(Y), int(X)) modified['area'] = area From 02ca1c67f87e46e926aba325e73b2845d5244874 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 23:51:52 -0400 Subject: [PATCH 20/85] Don't print traceback when processing interrupted. --- execution.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index c19c10bc6..edf884611 100644 --- a/execution.py +++ b/execution.py @@ -194,7 +194,10 @@ class PromptExecutor: if valid: recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: - print(traceback.format_exc()) + if isinstance(e, comfy.model_management.InterruptProcessingException): + print("Processing interrupted") + else: + print(traceback.format_exc()) to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): From d6dee8af1df5e7dc80463b9e45bdce76767e4119 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 00:29:31 -0400 Subject: [PATCH 21/85] Only validate each input once. --- execution.py | 40 ++++++++++++++++++---------------------- main.py | 2 +- server.py | 2 +- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/execution.py b/execution.py index edf884611..3953fde3a 100644 --- a/execution.py +++ b/execution.py @@ -147,7 +147,7 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}): + def execute(self, prompt, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -172,27 +172,15 @@ class PromptExecutor: executed = set() try: to_execute = [] - for x in prompt: - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - to_execute += [(0, x)] + for x in list(execute_outputs): + to_execute += [(0, x)] while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - if class_.OUTPUT_NODE == True: - valid = False - try: - m = validate_inputs(prompt, x) - valid = m[0] - except: - valid = False - if valid: - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") @@ -219,8 +207,11 @@ class PromptExecutor: comfy.model_management.soft_empty_cache() -def validate_inputs(prompt, item): +def validate_inputs(prompt, item, validated): unique_id = item + if unique_id in validated: + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -241,8 +232,9 @@ def validate_inputs(prompt, item): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) - r = validate_inputs(prompt, o_id) + r = validate_inputs(prompt, o_id, validated) if r[0] == False: + validated[o_id] = r return r else: if type_input == "INT": @@ -270,7 +262,10 @@ def validate_inputs(prompt, item): if isinstance(type_input, list): if val not in type_input: return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) - return (True, "") + + ret = (True, "") + validated[unique_id] = ret + return ret def validate_prompt(prompt): outputs = set() @@ -284,11 +279,12 @@ def validate_prompt(prompt): good_outputs = set() errors = [] + validated = {} for o in outputs: valid = False reason = "" try: - m = validate_inputs(prompt, o) + m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] except Exception as e: @@ -297,7 +293,7 @@ def validate_prompt(prompt): reason = "Parsing error" if valid == True: - good_outputs.add(x) + good_outputs.add(o) else: print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") @@ -307,7 +303,7 @@ def validate_prompt(prompt): errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) - return (True, "") + return (True, "", list(good_outputs)) class PromptQueue: diff --git a/main.py b/main.py index eb97a2fb8..d385df70a 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-2], item[-1]) + e.execute(item[-3], item[-2], item[-1]) q.task_done(item_id, e.outputs) async def run(server, address='', port=8188, verbose=True, call_on_start=None): diff --git a/server.py b/server.py index c1226f304..b6ac7d483 100644 --- a/server.py +++ b/server.py @@ -312,7 +312,7 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data)) + self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2])) else: resp_code = 400 out_string = valid[1] From 8e3d1cbf3b8488b319675f952e1a868aa78f1161 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 01:45:27 -0400 Subject: [PATCH 22/85] Fix bug when uploading image with the same name. --- server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server.py b/server.py index b6ac7d483..911f6a614 100644 --- a/server.py +++ b/server.py @@ -151,6 +151,7 @@ class PromptServer(): i = 1 while os.path.exists(filepath): filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) i += 1 if image_save_function is not None: From 51583164ef08d2173eb93eefa36bc50429cfe7c6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 10:03:30 -0400 Subject: [PATCH 23/85] Make MaskToImage support masks with a batch size. --- comfy_extras/nodes_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 131cd6a9f..9916f3b21 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -72,7 +72,7 @@ class MaskToImage: FUNCTION = "mask_to_image" def mask_to_image(self, mask): - result = mask[None, :, :, None].expand(-1, -1, -1, 3) + result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) return (result,) class ImageToMask: From f7c0f75d1fb1c6e3657f69247eace796882c62da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 13:58:19 -0400 Subject: [PATCH 24/85] Auto batching improvements. Try batching when cond sizes don't match with smart padding. --- comfy/samplers.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6417f2ed4..aa44fa82d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,10 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +import math + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers #Returns predicted noise @@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if c1.keys() != c2.keys(): return False if 'c_crossattn' in c1: - if c1['c_crossattn'].shape != c2['c_crossattn'].shape: - return False + s1 = c1['c_crossattn'].shape + s2 = c2['c_crossattn'].shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False if 'c_concat' in c1: if c1['c_concat'].shape != c2['c_concat'].shape: return False @@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c_crossattn = [] c_concat = [] c_adm = [] + crossattn_max_len = 0 for x in c_list: if 'c_crossattn' in x: - c_crossattn.append(x['c_crossattn']) + c = x['c_crossattn'] + if crossattn_max_len == 0: + crossattn_max_len = c.shape[1] + else: + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + c_crossattn.append(c) if 'c_concat' in x: c_concat.append(x['c_concat']) if 'c_adm' in x: c_adm.append(x['c_adm']) out = {} - if len(c_crossattn) > 0: - out['c_crossattn'] = [torch.cat(c_crossattn)] + c_crossattn_out = [] + for c in c_crossattn: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + c_crossattn_out.append(c) + + if len(c_crossattn_out) > 0: + out['c_crossattn'] = [torch.cat(c_crossattn_out)] if len(c_concat) > 0: out['c_concat'] = [torch.cat(c_concat)] if len(c_adm) > 0: From 602095f614276dd52fad718c223e0be17d12b11e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 15:49:49 -0400 Subject: [PATCH 25/85] Send execution_error message on websocket on execution exception. --- execution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 3953fde3a..7ee038975 100644 --- a/execution.py +++ b/execution.py @@ -185,7 +185,11 @@ class PromptExecutor: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") else: - print(traceback.format_exc()) + message = str(traceback.format_exc()) + print(message) + if self.server.client_id is not None: + self.server.send_sync("execution_error", { "message": message }, self.server.client_id) + to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): From 3a7c3acc72435f312a8f050d8ad3a1c902d9cff4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 15:59:24 -0400 Subject: [PATCH 26/85] Send websocket message with list of cached nodes right before execution. --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index 7ee038975..7d18d3b65 100644 --- a/execution.py +++ b/execution.py @@ -169,6 +169,8 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) + if self.server.client_id is not None: + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) }, self.server.client_id) executed = set() try: to_execute = [] From 974958ff81d9af92b01490bcc99dfc93f8bb5d30 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 16:41:43 -0400 Subject: [PATCH 27/85] Make the prompt_id a uuid and return it when queueing the prompt. --- server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 911f6a614..6965ff3c1 100644 --- a/server.py +++ b/server.py @@ -81,7 +81,7 @@ class PromptServer(): # Reusing existing session, remove old self.sockets.pop(sid, None) else: - sid = uuid.uuid4().hex + sid = uuid.uuid4().hex self.sockets[sid] = ws @@ -313,7 +313,9 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2])) + prompt_id = str(uuid.uuid4()) + self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) + return web.json_response({"prompt_id": prompt_id}) else: resp_code = 400 out_string = valid[1] From dfc74c19d944b4a4503e22297592fa3a537d3092 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 May 2023 01:22:40 -0400 Subject: [PATCH 28/85] Add the prompt_id to some websocket messages. --- execution.py | 8 ++++---- main.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/execution.py b/execution.py index 7d18d3b65..0ac4d462c 100644 --- a/execution.py +++ b/execution.py @@ -147,7 +147,7 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}, execute_outputs=[]): + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -170,7 +170,7 @@ class PromptExecutor: current_outputs = set(self.outputs.keys()) if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) }, self.server.client_id) + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() try: to_execute = [] @@ -190,7 +190,7 @@ class PromptExecutor: message = str(traceback.format_exc()) print(message) if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message }, self.server.client_id) + self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) to_delete = [] for o in self.outputs: @@ -207,7 +207,7 @@ class PromptExecutor: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None }, self.server.client_id) + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) gc.collect() comfy.model_management.soft_empty_cache() diff --git a/main.py b/main.py index d385df70a..00cbf3c4a 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-3], item[-2], item[-1]) + e.execute(item[2], item[1], item[3], item[4]) q.task_done(item_id, e.outputs) async def run(server, address='', port=8188, verbose=True, call_on_start=None): From 8ea165dd1ef877f58f3710f31ce43f27e0f739ab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 May 2023 14:15:13 -0400 Subject: [PATCH 29/85] Add a way to overwrite images when uploading. --- server.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server.py b/server.py index 6965ff3c1..a2bb26ad9 100644 --- a/server.py +++ b/server.py @@ -127,6 +127,7 @@ class PromptServer(): def image_upload(post, image_save_function=None): image = post.get("image") + overwrite = post.get("overwrite") image_upload_type = post.get("type") upload_dir = get_dir_by_type(image_upload_type) @@ -148,11 +149,14 @@ class PromptServer(): split = os.path.splitext(filename) filepath = os.path.join(full_output_folder, filename) - i = 1 - while os.path.exists(filepath): - filename = f"{split[0]} ({i}){split[1]}" - filepath = os.path.join(full_output_folder, filename) - i += 1 + if overwrite is not None and (overwrite == "true" or overwrite == "1"): + pass + else: + i = 1 + while os.path.exists(filepath): + filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) + i += 1 if image_save_function is not None: image_save_function(image, post, filepath) From 8a4ff5e34cc53252a9ff23e796904100d75bea55 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 12 May 2023 20:58:29 +0100 Subject: [PATCH 30/85] allow static files to be symlinks --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index a2bb26ad9..ef858a98a 100644 --- a/server.py +++ b/server.py @@ -362,7 +362,7 @@ class PromptServer(): def add_routes(self): self.app.add_routes(self.routes) self.app.add_routes([ - web.static('/', self.web_root), + web.static('/', self.web_root, follow_symlinks=True), ]) def get_queue_info(self): From d9e088ddfd97663abbb933c77f79d2a6c6127851 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 12 May 2023 23:49:09 +0200 Subject: [PATCH 31/85] minor changes for tiled sampler --- comfy/ldm/modules/tomesd.py | 2 +- comfy/sd.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 6a13b80c9..bb971e88f 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, """ B, N, _ = metric.shape - if r <= 0: + if r <= 0 or w == 1 or h == 1: return do_nothing, do_nothing gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather diff --git a/comfy/sd.py b/comfy/sd.py index 3543bdb77..0200f7742 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -581,10 +581,7 @@ class VAE: samples = samples.cpu() return samples -def resize_image_to(tensor, target_latent_tensor, batched_number): - tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center") - target_batch_size = target_latent_tensor.shape[0] - +def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] print(current_batch_size, target_batch_size) if current_batch_size == 1: @@ -623,7 +620,9 @@ class ControlNet: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) if self.control_model.dtype == torch.float16: precision_scope = torch.autocast @@ -794,10 +793,14 @@ class T2IAdapter: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint + self.control_input = None self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + if self.control_input is None: self.t2i_model.to(self.device) self.control_input = self.t2i_model(self.cond_hint) self.t2i_model.cpu() From 19c014f4292863444a3d677d504ad58623395a58 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 12 May 2023 23:57:40 +0200 Subject: [PATCH 32/85] comment out annoying print statement --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 0200f7742..c6be900ad 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -583,7 +583,7 @@ class VAE: def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - print(current_batch_size, target_batch_size) + #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor From c5c0ea666f8456b5a788092bad88528bbf34f559 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 12 May 2023 20:34:48 -0400 Subject: [PATCH 33/85] noise_mask in latent should be in a single format. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 760db24e1..c2201dafc 100644 --- a/nodes.py +++ b/nodes.py @@ -795,7 +795,7 @@ class SetLatentNoiseMask: def set_mask(self, samples, mask): s = samples.copy() - s["noise_mask"] = mask + s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): From 997dd1b1312a00cbedeafaf916e49f294a73a431 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 02:07:49 -0400 Subject: [PATCH 34/85] Fix queue delete. --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index a2bb26ad9..8435d091b 100644 --- a/server.py +++ b/server.py @@ -336,9 +336,9 @@ class PromptServer(): if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: - delete_func = lambda a: a[1] == int(id_to_delete) + delete_func = lambda a: a[1] == id_to_delete self.prompt_queue.delete_queue_item(delete_func) - + return web.Response(status=200) @routes.post("/interrupt") From 1201d2eae5820bb8124beb22b712d743415fd47d Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 13 May 2023 17:15:45 +0200 Subject: [PATCH 35/85] Make nodes map over input lists (#579) * allow nodes to map over lists * make work with IS_CHANGED and VALIDATE_INPUTS * give list outputs distinct socket shape * add rebatch node * add batch index logic * add repeat latent batch * deal with noise mask edge cases in latentfrombatch --- comfy/sample.py | 17 ++++-- comfy_extras/nodes_rebatch.py | 108 ++++++++++++++++++++++++++++++++++ execution.py | 90 +++++++++++++++++++++++----- nodes.py | 57 +++++++++++++++--- server.py | 1 + web/scripts/app.js | 3 +- 6 files changed, 250 insertions(+), 26 deletions(-) create mode 100644 comfy_extras/nodes_rebatch.py diff --git a/comfy/sample.py b/comfy/sample.py index bd38585ac..284efca61 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,17 +2,26 @@ import torch import comfy.model_management import comfy.samplers import math +import numpy as np -def prepare_noise(latent_image, seed, skip=0): +def prepare_noise(latent_image, seed, noise_inds=None): """ creates random noise given a latent image and a seed. optional arg skip can be used to skip and discard x number of noise generations for a given seed """ generator = torch.manual_seed(seed) - for _ in range(skip): + if noise_inds is None: + return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + noises = [] + for i in range(unique_inds[-1]+1): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - return noise + if i in unique_inds: + noises.append(noise) + noises = [noises[i] for i in inverse] + noises = torch.cat(noises, axis=0) + return noises def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py new file mode 100644 index 000000000..0a9daf272 --- /dev/null +++ b/comfy_extras/nodes_rebatch.py @@ -0,0 +1,108 @@ +import torch + +class LatentRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "latents": ("LATENT",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "latent/batch" + + @staticmethod + def get_batch(latents, list_ind, offset): + '''prepare a batch out of the list of latents''' + samples = latents[list_ind]['samples'] + shape = samples.shape + mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: + torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") + if mask.shape[0] < samples.shape[0]: + mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] + if 'batch_index' in latents[list_ind]: + batch_inds = latents[list_ind]['batch_index'] + else: + batch_inds = [x+offset for x in range(shape[0])] + return samples, mask, batch_inds + + @staticmethod + def get_slices(indexable, num, batch_size): + '''divides an indexable object into num slices of length batch_size, and a remainder''' + slices = [] + for i in range(num): + slices.append(indexable[i*batch_size:(i+1)*batch_size]) + if num * batch_size < len(indexable): + return slices, indexable[num * batch_size:] + else: + return slices, None + + @staticmethod + def slice_batch(batch, num, batch_size): + result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] + return list(zip(*result)) + + @staticmethod + def cat_batch(batch1, batch2): + if batch1[0] is None: + return batch2 + result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] + return result + + def rebatch(self, latents, batch_size): + batch_size = batch_size[0] + + output_list = [] + current_batch = (None, None, None) + processed = 0 + + for i in range(len(latents)): + # fetch new entry of list + #samples, masks, indices = self.get_batch(latents, i) + next_batch = self.get_batch(latents, i, processed) + processed += len(next_batch[2]) + # set to current if current is None + if current_batch[0] is None: + current_batch = next_batch + # add previous to list if dimensions do not match + elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + current_batch = next_batch + # cat if everything checks out + else: + current_batch = self.cat_batch(current_batch, next_batch) + + # add to list if dimensions gone above target batch size + if current_batch[0].shape[0] > batch_size: + num = current_batch[0].shape[0] // batch_size + sliced, remainder = self.slice_batch(current_batch, num, batch_size) + + for i in range(num): + output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) + + current_batch = remainder + + #add remainder + if current_batch[0] is not None: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + + #get rid of empty masks + for s in output_list: + if s['noise_mask'].mean() == 1.0: + del s['noise_mask'] + + return (output_list,) + +NODE_CLASS_MAPPINGS = { + "RebatchLatents": LatentRebatch, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RebatchLatents": "Rebatch Latents", +} \ No newline at end of file diff --git a/execution.py b/execution.py index 0ac4d462c..cf2e5ea71 100644 --- a/execution.py +++ b/execution.py @@ -26,20 +26,81 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = input_data + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = prompt + input_data_all[x] = [prompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: - input_data_all[x] = extra_data['extra_pnginfo'] + input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": - input_data_all[x] = unique_id + input_data_all[x] = [unique_id] return input_data_all +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): + # check if node wants the lists + intput_is_list = False + if hasattr(obj, "INPUT_IS_LIST"): + intput_is_list = obj.INPUT_IS_LIST + + max_len_input = max([len(x) for x in input_data_all.values()]) + + # get a slice of inputs, repeat last input when list isn't long enough + def slice_dict(d, i): + d_new = dict() + for k,v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + + results = [] + if intput_is_list: + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**input_data_all)) + else: + for i in range(max_len_input): + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + return results + +def get_output_data(obj, input_data_all): + + results = [] + uis = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) + else: + results.append(r) + + output = [] + if len(results) > 0: + # check which outputs need concatenating + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + + ui = dict() + if len(uis) > 0: + ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} + return output, ui + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -63,13 +124,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute server.send_sync("executing", { "node": unique_id }, server.client_id) obj = class_def() - nodes.before_node_execution() - outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) - if "ui" in outputs[unique_id]: + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) - if "result" in outputs[unique_id]: - outputs[unique_id] = outputs[unique_id]["result"] + server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id) executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -105,7 +164,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: - is_changed = class_def.IS_CHANGED(**input_data_all) + #is_changed = class_def.IS_CHANGED(**input_data_all) + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: to_delete = True @@ -261,9 +321,11 @@ def validate_inputs(prompt, item, validated): if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) - ret = obj_class.VALIDATE_INPUTS(**input_data_all) - if ret != True: - return (False, "{}, {}".format(class_type, ret)) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") + for r in ret: + if r != True: + return (False, "{}, {}".format(class_type, r)) else: if isinstance(type_input, list): if val not in type_input: diff --git a/nodes.py b/nodes.py index c2201dafc..509dc0697 100644 --- a/nodes.py +++ b/nodes.py @@ -629,18 +629,57 @@ class LatentFromBatch: def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), }} RETURN_TYPES = ("LATENT",) - FUNCTION = "rotate" + FUNCTION = "frombatch" - CATEGORY = "latent" + CATEGORY = "latent/batch" - def rotate(self, samples, batch_index): + def frombatch(self, samples, batch_index, length): s = samples.copy() s_in = samples["samples"] batch_index = min(s_in.shape[0] - 1, batch_index) - s["samples"] = s_in[batch_index:batch_index + 1].clone() - s["batch_index"] = batch_index + length = min(s_in.shape[0] - batch_index, length) + s["samples"] = s_in[batch_index:batch_index + length].clone() + if "noise_mask" in samples: + masks = samples["noise_mask"] + if masks.shape[0] == 1: + s["noise_mask"] = masks.clone() + else: + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = masks[batch_index:batch_index + length].clone() + if "batch_index" not in s: + s["batch_index"] = [x for x in range(batch_index, batch_index+length)] + else: + s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] + return (s,) + +class RepeatLatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "repeat" + + CATEGORY = "latent/batch" + + def repeat(self, samples, amount): + s = samples.copy() + s_in = samples["samples"] + + s["samples"] = s_in.repeat((amount, 1,1,1)) + if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: + masks = samples["noise_mask"] + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + if "batch_index" in s: + offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 + s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] return (s,) class LatentUpscale: @@ -805,8 +844,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - skip = latent["batch_index"] if "batch_index" in latent else 0 - noise = comfy.sample.prepare_noise(latent_image, seed, skip) + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) noise_mask = None if "noise_mask" in latent: @@ -1170,6 +1209,7 @@ NODE_CLASS_MAPPINGS = { "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "LatentFromBatch": LatentFromBatch, + "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, @@ -1244,6 +1284,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", + "LatentFromBatch" : "Latent From Batch", + "RepeatLatentBatch": "Repeat Latent Batch", # Image "SaveImage": "Save Image", "PreviewImage": "Preview Image", @@ -1299,3 +1341,4 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) diff --git a/server.py b/server.py index 8435d091b..cb66cc618 100644 --- a/server.py +++ b/server.py @@ -268,6 +268,7 @@ class PromptServer(): info = {} info['input'] = obj_class.INPUT_TYPES() info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = x info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x diff --git a/web/scripts/app.js b/web/scripts/app.js index 2da1b5581..1a4a18b94 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -976,7 +976,8 @@ export class ComfyApp { for (const o in nodeData["output"]) { const output = nodeData["output"][o]; const outputName = nodeData["output_name"][o] || output; - this.addOutput(outputName, output); + const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; + this.addOutput(outputName, output, { shape: outputShape }); } const s = this.computeSize(); From 44f9f9baf170ddf27891b240002300d8aa09fb2a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 11:17:16 -0400 Subject: [PATCH 36/85] Add the prompt id to some websocket messages. --- execution.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/execution.py b/execution.py index cf2e5ea71..b9548229c 100644 --- a/execution.py +++ b/execution.py @@ -101,7 +101,7 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -116,19 +116,19 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) obj = class_def() output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id) + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -215,6 +215,9 @@ class PromptExecutor: else: self.server.client_id = None + if self.server.client_id is not None: + self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + with torch.inference_mode(): #delete cached outputs if nodes don't exist for them to_delete = [] @@ -242,7 +245,7 @@ class PromptExecutor: to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") From cb4b8223981ec9e090ebf44205f5ce16d72f01cb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 11:54:45 -0400 Subject: [PATCH 37/85] Print custom nodes that take too much time to import. --- nodes.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nodes.py b/nodes.py index 509dc0697..bc7968308 100644 --- a/nodes.py +++ b/nodes.py @@ -6,6 +6,7 @@ import json import hashlib import traceback import math +import time from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -1325,6 +1326,7 @@ def load_custom_node(module_path): def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") + node_import_times = [] for custom_node_path in node_paths: possible_modules = os.listdir(custom_node_path) if "__pycache__" in possible_modules: @@ -1333,7 +1335,16 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + time_before = time.time() load_custom_node(module_path) + node_import_times.append((time.time() - time_before, module_path)) + + slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) + if len(slow_nodes) > 0: + print("\nDetected some custom nodes that were slow to import, if this is one of yours please improve it if you can:") + for n in sorted(slow_nodes): + print("{:6.1f} seconds to import:".format(n[0]), n[1]) + print() def init_custom_nodes(): load_custom_nodes() From cf439709b6b3ffae5ad15a9f7e59fedc214d5f1c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 12:50:21 -0400 Subject: [PATCH 38/85] Load nodes in comfy_extras before custom nodes. Change the slow import message. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index bc7968308..956b739d9 100644 --- a/nodes.py +++ b/nodes.py @@ -1341,15 +1341,15 @@ def load_custom_nodes(): slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) if len(slow_nodes) > 0: - print("\nDetected some custom nodes that were slow to import, if this is one of yours please improve it if you can:") + print("\nDetected some custom nodes that were slow to import:") for n in sorted(slow_nodes): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() def init_custom_nodes(): - load_custom_nodes() load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_nodes() From 92bf1cb61efcab45961d1119cb7ec7a076caf24e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:05:52 -0400 Subject: [PATCH 39/85] Change message. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 956b739d9..28215127c 100644 --- a/nodes.py +++ b/nodes.py @@ -1341,7 +1341,7 @@ def load_custom_nodes(): slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) if len(slow_nodes) > 0: - print("\nDetected some custom nodes that were slow to import:") + print("\nImport times for custom nodes:") for n in sorted(slow_nodes): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() From 2ac744f6628d107b3534177eeca5ef06f6668609 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:15:31 -0400 Subject: [PATCH 40/85] Print all custom node import times. --- nodes.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 28215127c..f3b7da1a9 100644 --- a/nodes.py +++ b/nodes.py @@ -1339,10 +1339,9 @@ def load_custom_nodes(): load_custom_node(module_path) node_import_times.append((time.time() - time_before, module_path)) - slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) - if len(slow_nodes) > 0: + if len(node_import_times) > 0: print("\nImport times for custom nodes:") - for n in sorted(slow_nodes): + for n in sorted(node_import_times): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() From db4d3a8494a4a7dbb6f911ae126a92abec6bf91b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:23:42 -0400 Subject: [PATCH 41/85] Print if custom nodes imported successfully or not. --- nodes.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index f3b7da1a9..63d9adc3d 100644 --- a/nodes.py +++ b/nodes.py @@ -1318,11 +1318,14 @@ def load_custom_node(module_path): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) + return True else: print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + return False except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom nodes:", e) + return False def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") @@ -1336,13 +1339,17 @@ def load_custom_nodes(): module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue time_before = time.time() - load_custom_node(module_path) - node_import_times.append((time.time() - time_before, module_path)) + success = load_custom_node(module_path) + node_import_times.append((time.time() - time_before, module_path, success)) if len(node_import_times) > 0: print("\nImport times for custom nodes:") for n in sorted(node_import_times): - print("{:6.1f} seconds to import:".format(n[0]), n[1]) + if n[2]: + import_message = "" + else: + import_message = " (IMPORT FAILED)" + print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) print() def init_custom_nodes(): From b0505eb7ab8af1986dabd97c23fae83a0539303d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 15:31:22 -0400 Subject: [PATCH 42/85] Return right type when none specified in upload route. Switch time.time to time.perf_counter for custom node import times. --- nodes.py | 4 ++-- server.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/nodes.py b/nodes.py index 63d9adc3d..c4aff1012 100644 --- a/nodes.py +++ b/nodes.py @@ -1338,9 +1338,9 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - time_before = time.time() + time_before = time.perf_counter() success = load_custom_node(module_path) - node_import_times.append((time.time() - time_before, module_path, success)) + node_import_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_import_times) > 0: print("\nImport times for custom nodes:") diff --git a/server.py b/server.py index d1079dd83..ba4dcba03 100644 --- a/server.py +++ b/server.py @@ -115,22 +115,23 @@ class PromptServer(): def get_dir_by_type(dir_type): if dir_type is None: - type_dir = folder_paths.get_input_directory() - elif dir_type == "input": + dir_type = "input" + + if dir_type == "input": type_dir = folder_paths.get_input_directory() elif dir_type == "temp": type_dir = folder_paths.get_temp_directory() elif dir_type == "output": type_dir = folder_paths.get_output_directory() - return type_dir + return type_dir, dir_type def image_upload(post, image_save_function=None): image = post.get("image") overwrite = post.get("overwrite") image_upload_type = post.get("type") - upload_dir = get_dir_by_type(image_upload_type) + upload_dir, image_upload_type = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename From 3a1f47764d76bb9878b55e82657044b3faceda9c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 17:11:27 -0400 Subject: [PATCH 43/85] Print the torch device that is used on startup. --- comfy/model_management.py | 42 ++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 39df8d9a7..c15323219 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -127,6 +127,32 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.cuda.current_device() + +def get_torch_device_name(device): + if hasattr(device, 'type'): + return "{}".format(device.type) + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + +try: + print("Using device:", get_torch_device_name(get_torch_device())) +except: + print("Could not pick default device.") + current_loaded_model = None current_gpu_controlnets = [] @@ -233,22 +259,6 @@ def unload_if_low_vram(model): return model.cpu() return model -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type From e7b9d2c02cffd59fecca4ee617137ea38641078a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 01:30:58 -0400 Subject: [PATCH 44/85] /prompt endpoint error is now in json format. --- server.py | 7 +++---- web/scripts/api.js | 2 +- web/scripts/app.js | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/server.py b/server.py index ba4dcba03..f52117f10 100644 --- a/server.py +++ b/server.py @@ -323,12 +323,11 @@ class PromptServer(): self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) return web.json_response({"prompt_id": prompt_id}) else: - resp_code = 400 - out_string = valid[1] print("invalid prompt:", valid[1]) + return web.json_response({"error": valid[1]}, status=400) + else: + return web.json_response({"error": "no prompt"}, status=400) - return web.Response(body=out_string, status=resp_code) - @routes.post("/queue") async def post_queue(request): json_data = await request.json() diff --git a/web/scripts/api.js b/web/scripts/api.js index d29faa5ba..4f061c358 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -163,7 +163,7 @@ class ComfyApi extends EventTarget { if (res.status !== 200) { throw { - response: await res.text(), + response: await res.json(), }; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index 1a4a18b94..00d3c9746 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1222,7 +1222,7 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response || error.toString()); + this.ui.dialog.show(error.response.error || error.toString()); break; } From 9bf67c4c5a5c8b8d1efc2d4ce7e7ab1eccce1fa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 01:34:25 -0400 Subject: [PATCH 45/85] Print prompt execution time. --- execution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/execution.py b/execution.py index b9548229c..dd88029bc 100644 --- a/execution.py +++ b/execution.py @@ -6,6 +6,7 @@ import threading import heapq import traceback import gc +import time import torch import nodes @@ -215,6 +216,7 @@ class PromptExecutor: else: self.server.client_id = None + execution_start_time = time.perf_counter() if self.server.client_id is not None: self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) @@ -272,6 +274,7 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() comfy.model_management.soft_empty_cache() From d926f65f56217e7828ad27ec5b646c74398593c4 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sun, 14 May 2023 23:21:22 +0900 Subject: [PATCH 46/85] Feature/maskeditor context menu (#649) * add "Open in MaskEditor" to context menu * change save button name to 'Save to node' if open in node. clear clipspace_return_node after auto paste * * leak patch: prevent infinite duplication of MaskEditorDialog instance on every dialog open * prevent conflict of multiple opening of MaskEditorDialog * name of save button fix * patch: brushPreview hiding by dialog * consider close by 'esc' key on maskeditor. * bugfix about last patch * patch: invalid close detection * 'enter' key as save action * * batch support enhance - pick index based on imageIndex on copy action * paste fix on batch image node * typo --------- Co-authored-by: Lt.Dr.Data --- web/extensions/core/maskeditor.js | 120 ++++++++++++---- web/scripts/app.js | 226 +++++++++++++++++------------- 2 files changed, 221 insertions(+), 125 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 552059e86..4b0c12747 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -72,40 +72,50 @@ function prepareRGB(image, backupCanvas, backupCtx) { class MaskEditorDialog extends ComfyDialog { static instance = null; + + static getInstance() { + if(!MaskEditorDialog.instance) { + MaskEditorDialog.instance = new MaskEditorDialog(app); + } + + return MaskEditorDialog.instance; + } + + is_layout_created = false; + constructor() { super(); this.element = $el("div.comfy-modal", { parent: document.body }, [ $el("div.comfy-modal-content", [...this.createButtons()]), ]); - MaskEditorDialog.instance = this; } createButtons() { return []; } - clearMask(self) { - } - createButton(name, callback) { var button = document.createElement("button"); button.innerText = name; button.addEventListener("click", callback); return button; } + createLeftButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "left"; button.style.marginRight = "4px"; return button; } + createRightButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "right"; button.style.marginLeft = "4px"; return button; } + createLeftSlider(self, name, callback) { const divElement = document.createElement('div'); divElement.id = "maskeditor-slider"; @@ -164,7 +174,7 @@ class MaskEditorDialog extends ComfyDialog { brush.style.MozBorderRadius = "50%"; brush.style.WebkitBorderRadius = "50%"; brush.style.position = "absolute"; - brush.style.zIndex = 100; + brush.style.zIndex = 8889; brush.style.pointerEvents = "none"; this.brush = brush; this.element.appendChild(imgCanvas); @@ -187,7 +197,8 @@ class MaskEditorDialog extends ComfyDialog { document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.close(); }); - var saveButton = this.createRightButton("Save", () => { + + this.saveButton = this.createRightButton("Save", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.save(); @@ -199,11 +210,10 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(bottom_panel); bottom_panel.appendChild(clearButton); - bottom_panel.appendChild(saveButton); + bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(brush_size_slider); - this.element.style.display = "block"; imgCanvas.style.position = "relative"; imgCanvas.style.top = "200"; imgCanvas.style.left = "0"; @@ -212,25 +222,63 @@ class MaskEditorDialog extends ComfyDialog { } show() { - // layout - const imgCanvas = document.createElement('canvas'); - const maskCanvas = document.createElement('canvas'); - const backupCanvas = document.createElement('canvas'); + if(!this.is_layout_created) { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); - imgCanvas.id = "imageCanvas"; - maskCanvas.id = "maskCanvas"; - backupCanvas.id = "backupCanvas"; + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; - this.setlayout(imgCanvas, maskCanvas); + this.setlayout(imgCanvas, maskCanvas); - // prepare content - this.maskCanvas = maskCanvas; - this.backupCanvas = backupCanvas; - this.maskCtx = maskCanvas.getContext('2d'); - this.backupCtx = backupCanvas.getContext('2d'); + // prepare content + this.imgCanvas = imgCanvas; + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); - this.setImages(imgCanvas, backupCanvas); - this.setEventHandler(maskCanvas); + this.setEventHandler(maskCanvas); + + this.is_layout_created = true; + + // replacement of onClose hook since close is not real close + const self = this; + const observer = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if (mutation.type === 'attributes' && mutation.attributeName === 'style') { + if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { + ComfyApp.onClipspaceEditorClosed(); + } + + self.last_display_style = self.element.style.display; + } + }); + }); + + const config = { attributes: true }; + observer.observe(this.element, config); + } + + this.setImages(this.imgCanvas, this.backupCanvas); + + if(ComfyApp.clipspace_return_node) { + this.saveButton.innerText = "Save to node"; + } + else { + this.saveButton.innerText = "Save"; + } + this.saveButton.disabled = false; + + this.element.style.display = "block"; + this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. + } + + isOpened() { + return this.element.style.display == "block"; } setImages(imgCanvas, backupCanvas) { @@ -239,6 +287,10 @@ class MaskEditorDialog extends ComfyDialog { const maskCtx = this.maskCtx; const maskCanvas = this.maskCanvas; + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); + maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); + // image load const orig_image = new Image(); window.addEventListener("resize", () => { @@ -296,8 +348,7 @@ class MaskEditorDialog extends ComfyDialog { rgb_url.searchParams.set('channel', 'rgb'); orig_image.src = rgb_url; this.image = orig_image; - }g - + } setEventHandler(maskCanvas) { maskCanvas.addEventListener("contextmenu", (event) => { @@ -327,6 +378,8 @@ class MaskEditorDialog extends ComfyDialog { self.brush_size = Math.min(self.brush_size+2, 100); } else if (event.key === '[') { self.brush_size = Math.max(self.brush_size-2, 1); + } else if(event.key === 'Enter') { + self.save(); } self.updateBrushPreview(self); @@ -514,7 +567,7 @@ class MaskEditorDialog extends ComfyDialog { } } - save() { + async save() { const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); @@ -570,7 +623,10 @@ class MaskEditorDialog extends ComfyDialog { formData.append('type', "input"); formData.append('subfolder', "clipspace"); - uploadMask(item, formData); + this.saveButton.innerText = "Saving..."; + this.saveButton.disabled = true; + await uploadMask(item, formData); + ComfyApp.onClipspaceEditorSave(); this.close(); } } @@ -578,13 +634,15 @@ class MaskEditorDialog extends ComfyDialog { app.registerExtension({ name: "Comfy.MaskEditor", init(app) { - const callback = + ComfyApp.open_maskeditor = function () { - let dlg = new MaskEditorDialog(app); - dlg.show(); + const dlg = MaskEditorDialog.getInstance(); + if(!dlg.isOpened()) { + dlg.show(); + } }; const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 - ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); + ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor); } }); \ No newline at end of file diff --git a/web/scripts/app.js b/web/scripts/app.js index 00d3c9746..87c5e30ca 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -26,6 +26,8 @@ export class ComfyApp { */ static clipspace = null; static clipspace_invalidate_handler = null; + static open_maskeditor = null; + static clipspace_return_node = null; constructor() { this.ui = new ComfyUI(this); @@ -49,6 +51,114 @@ export class ComfyApp { this.shiftDown = false; } + static isImageNode(node) { + return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0); + } + + static onClipspaceEditorSave() { + if(ComfyApp.clipspace_return_node) { + ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node); + } + } + + static onClipspaceEditorClosed() { + ComfyApp.clipspace_return_node = null; + } + + static copyToClipspace(node) { + var widgets = null; + if(node.widgets) { + widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value })); + } + + var imgs = undefined; + var orig_imgs = undefined; + if(node.imgs != undefined) { + imgs = []; + orig_imgs = []; + + for (let i = 0; i < node.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = node.imgs[i].src; + orig_imgs[i] = imgs[i]; + } + } + + var selectedIndex = 0; + if(node.imageIndex) { + selectedIndex = node.imageIndex; + } + + ComfyApp.clipspace = { + 'widgets': widgets, + 'imgs': imgs, + 'original_imgs': orig_imgs, + 'images': node.images, + 'selectedIndex': selectedIndex, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action + }; + + ComfyApp.clipspace_return_node = null; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } + } + + static pasteFromClipspace(node) { + if(ComfyApp.clipspace) { + // image paste + if(ComfyApp.clipspace.imgs && node.imgs) { + if(node.images && ComfyApp.clipspace.images) { + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + } + else + app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images; + } + + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + node.imgs = [img]; + node.imageIndex = 0; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); + if(index >= 0) { + node.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } + } + + app.graph.setDirtyCanvas(true); + } + } + /** * Invoke an extension callback * @param {keyof ComfyExtension} method The extension callback to execute @@ -138,102 +248,30 @@ export class ComfyApp { } } - options.push( - { - content: "Copy (Clipspace)", - callback: (obj) => { - var widgets = null; - if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); - } - - var imgs = undefined; - var orig_imgs = undefined; - if(this.imgs != undefined) { - imgs = []; - orig_imgs = []; + // prevent conflict of clipspace content + if(!ComfyApp.clipspace_return_node) { + options.push({ + content: "Copy (Clipspace)", + callback: (obj) => { ComfyApp.copyToClipspace(this); } + }); - for (let i = 0; i < this.imgs.length; i++) { - imgs[i] = new Image(); - imgs[i].src = this.imgs[i].src; - orig_imgs[i] = imgs[i]; + if(ComfyApp.clipspace != null) { + options.push({ + content: "Paste (Clipspace)", + callback: () => { ComfyApp.pasteFromClipspace(this); } + }); + } + + if(ComfyApp.isImageNode(this)) { + options.push({ + content: "Open in MaskEditor", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + ComfyApp.clipspace_return_node = this; + ComfyApp.open_maskeditor(); } - } - - ComfyApp.clipspace = { - 'widgets': widgets, - 'imgs': imgs, - 'original_imgs': orig_imgs, - 'images': this.images, - 'selectedIndex': 0, - 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action - }; - - if(ComfyApp.clipspace_invalidate_handler) { - ComfyApp.clipspace_invalidate_handler(); - } - } - }); - - if(ComfyApp.clipspace != null) { - options.push( - { - content: "Paste (Clipspace)", - callback: () => { - if(ComfyApp.clipspace) { - // image paste - if(ComfyApp.clipspace.imgs && this.imgs) { - if(this.images && ComfyApp.clipspace.images) { - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; - - } - else - app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images; - } - - if(ComfyApp.clipspace.imgs) { - // deep-copy to cut link with clipspace - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - const img = new Image(); - img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; - this.imgs = [img]; - } - else { - const imgs = []; - for(let i=0; i obj.name === 'image'); - if(index >= 0) { - this.widgets[index].value = clip_image; - } - } - if(ComfyApp.clipspace.widgets) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { - prop.value = value; - prop.callback(value); - } - }); - } - } - } - - app.graph.setDirtyCanvas(true); - } - } - ); + }); + } } }; } From acff543d669dba9b03fb500a10010f2da8739ff3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 12:50:21 -0400 Subject: [PATCH 47/85] Remove useless code. --- nodes.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/nodes.py b/nodes.py index c4aff1012..bc23e5c17 100644 --- a/nodes.py +++ b/nodes.py @@ -146,9 +146,6 @@ class ConditioningSetMask: return (c, ) class VAEDecode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -161,9 +158,6 @@ class VAEDecode: return (vae.decode(samples["samples"]), ) class VAEDecodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -176,9 +170,6 @@ class VAEDecodeTiled: return (vae.decode_tiled(samples["samples"]), ) class VAEEncode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -203,9 +194,6 @@ class VAEEncode: return ({"samples":t}, ) class VAEEncodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -220,9 +208,6 @@ class VAEEncodeTiled: return ({"samples":t}, ) class VAEEncodeForInpaint: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} From 587f89fe5a8e2bcb389fb4919dc33c330320fa41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 15:10:40 -0400 Subject: [PATCH 48/85] Enable safe loading for upscale models. --- comfy_extras/nodes_upscale_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ab5b0ccfc..f9252ea0b 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -17,7 +17,7 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) out = model_loading.load_state_dict(sd).eval() return (out, ) From 84ea21c815d426000c233e0c7b8c542764335cc8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 17:02:40 -0400 Subject: [PATCH 49/85] Update litegraph from upstream. --- web/lib/litegraph.core.js | 145 +++++++++++++++++++++++++++++++++++--- 1 file changed, 137 insertions(+), 8 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 2bc6af0c3..6c81c3ffd 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action) //when clicked on top of a node //and it is not interactive - if (node && this.allow_interaction && !skip_action && !this.read_only) { + if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) { if (!this.live_mode && !node.flags.pinned) { this.bringToFront(node); } //if it wasn't selected? //not dragging mouse to connect two slots - if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { + if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { //Search for corner for resize if ( !skip_action && node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) @@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action) } //double clicking - if (is_double_click && this.selected_nodes[node.id]) { + if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) { //double click node if (node.onDblClick) { node.onDblClick( e, pos, this ); @@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action) this.dirty_canvas = true; } + //get node over + var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); + if (this.dragging_rectangle) { this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0]; @@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action) this.ds.offset[1] += delta[1] / this.ds.scale; this.dirty_canvas = true; this.dirty_bgcanvas = true; - } else if (this.allow_interaction && !this.read_only) { + } else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) { if (this.connecting_node) { this.dirty_canvas = true; } - //get node over - var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); - //remove mouseover flag for (var i = 0, l = this.graph._nodes.length; i < l; ++i) { if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) { @@ -9911,7 +9911,7 @@ LGraphNode.prototype.executeAction = function(action) event, active_widget ) { - if (!node.widgets || !node.widgets.length) { + if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) { return null; } @@ -10300,6 +10300,119 @@ LGraphNode.prototype.executeAction = function(action) canvas.graph.add(group); }; + /** + * Determines the furthest nodes in each direction + * @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.getBoundaryNodes = function(nodes) { + let top = null; + let right = null; + let bottom = null; + let left = null; + for (const nID in nodes) { + const node = nodes[nID]; + const [x, y] = node.pos; + const [width, height] = node.size; + + if (top === null || y < top.pos[1]) { + top = node; + } + if (right === null || x + width > right.pos[0] + right.size[0]) { + right = node; + } + if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) { + bottom = node; + } + if (left === null || x < left.pos[0]) { + left = node; + } + } + + return { + "top": top, + "right": right, + "bottom": bottom, + "left": left + }; + } + /** + * Determines the furthest nodes in each direction for the currently selected nodes + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.prototype.boundaryNodesForSelection = function() { + return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes)); + } + + /** + * + * @param {LGraphNode[]} nodes a list of nodes + * @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes + * @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction) + */ + LGraphCanvas.alignNodes = function (nodes, direction, align_to) { + if (!nodes) { + return; + } + + const canvas = LGraphCanvas.active_canvas; + let boundaryNodes = [] + if (align_to === undefined) { + boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes) + } else { + boundaryNodes = { + "top": align_to, + "right": align_to, + "bottom": align_to, + "left": align_to + } + } + + for (const [_, node] of Object.entries(canvas.selected_nodes)) { + switch (direction) { + case "right": + node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0]; + break; + case "left": + node.pos[0] = boundaryNodes["left"].pos[0]; + break; + case "top": + node.pos[1] = boundaryNodes["top"].pos[1]; + break; + case "bottom": + node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1]; + break; + } + } + + canvas.dirty_canvas = true; + canvas.dirty_bgcanvas = true; + }; + + LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node); + } + } + + LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase()); + } + } + LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) { var canvas = LGraphCanvas.active_canvas; @@ -12900,6 +13013,14 @@ LGraphNode.prototype.executeAction = function(action) options.push({ content: "Options", callback: that.showShowGraphOptionsPanel }); }*/ + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align", + has_submenu: true, + callback: LGraphCanvas.onGroupAlign, + }) + } + if (this._graph_stack && this._graph_stack.length > 0) { options.push(null, { content: "Close subgraph", @@ -13014,6 +13135,14 @@ LGraphNode.prototype.executeAction = function(action) callback: LGraphCanvas.onMenuNodeToSubgraph }); + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align Selected To", + has_submenu: true, + callback: LGraphCanvas.onNodeAlign, + }) + } + options.push(null, { content: "Remove", disabled: !(node.removable !== false && !node.block_delete ), From 1dd846a7bad8cfab679a0976e201c722871c6917 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 00:27:28 -0400 Subject: [PATCH 50/85] Fix outputs gone from history. --- execution.py | 16 +++++++++++----- main.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/execution.py b/execution.py index dd88029bc..0e2cc15c1 100644 --- a/execution.py +++ b/execution.py @@ -102,7 +102,7 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -117,7 +117,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: @@ -128,6 +128,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) executed.add(unique_id) @@ -205,6 +206,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): self.outputs = {} + self.outputs_ui = {} self.old_prompt = {} self.server = server @@ -234,6 +236,11 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) + for x in list(self.outputs_ui.keys()): + if x not in current_outputs: + d = self.outputs_ui.pop(x) + del d + if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() @@ -247,7 +254,7 @@ class PromptExecutor: to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") @@ -413,8 +420,7 @@ class PromptQueue: prompt = self.currently_running.pop(item_id) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } for o in outputs: - if "ui" in outputs[o]: - self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"] + self.history[prompt[1]]["outputs"][o] = outputs[o] self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index 00cbf3c4a..50d3b9a62 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,7 @@ def prompt_worker(q, server): while True: item, item_id = q.get() e.execute(item[2], item[1], item[3], item[4]) - q.task_done(item_id, e.outputs) + q.task_done(item_id, e.outputs_ui) async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) From ef815ba1e24eef45041adec8a55ecd628b20476f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 00:29:56 -0400 Subject: [PATCH 51/85] Switch default scheduler to normal. --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index aa44fa82d..fccf254ec 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -495,7 +495,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): class KSampler: - SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] From c02a554bcf6ef50f8e252c89dc0a56c08d4955c0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 03:25:24 -0400 Subject: [PATCH 52/85] Make DiffusersLoader work with subfolders. --- nodes.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index bc23e5c17..797ad6c9c 100644 --- a/nodes.py +++ b/nodes.py @@ -282,7 +282,10 @@ class DiffusersLoader: paths = [] for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths += next(os.walk(search_path))[1] + for root, subdir, files in os.walk(search_path, followlinks=True): + if "model_index.json" in files: + paths.append(os.path.relpath(root, start=search_path)) + return {"required": {"model_path": (paths,), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -292,9 +295,9 @@ class DiffusersLoader: def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths = next(os.walk(search_path))[1] - if model_path in paths: - model_path = os.path.join(search_path, model_path) + path = os.path.join(search_path, model_path) + if os.path.exists(path): + model_path = path break return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) From 2ec6d1c6e364ab92e3d8149a83873ac47c797248 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 03:31:03 -0400 Subject: [PATCH 53/85] Don't import custom nodes when the folder ends with .disabled --- nodes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nodes.py b/nodes.py index 797ad6c9c..e8b36c24a 100644 --- a/nodes.py +++ b/nodes.py @@ -1326,6 +1326,7 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + if module_path.endswith(".disabled"): continue time_before = time.perf_counter() success = load_custom_node(module_path) node_import_times.append((time.perf_counter() - time_before, module_path, success)) From 5f7968f1fafb2cf5d15fe049fc53265ad0fc6696 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 01:12:44 -0400 Subject: [PATCH 54/85] Print the endpoint ip for localtunnel in the colab notebook. --- notebooks/comfyui_colab.ipynb | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index fecfa6707..c5a209eec 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -175,6 +175,8 @@ "import threading\n", "import time\n", "import socket\n", + "import urllib.request\n", + "\n", "def iframe_thread(port):\n", " while True:\n", " time.sleep(0.5)\n", @@ -183,7 +185,9 @@ " if result == 0:\n", " break\n", " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n", + " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n", + "\n", + " print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n", " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", " for line in p.stdout:\n", " print(line.decode(), end='')\n", From 13d94caf49b21bd129ec867b04641973e3a102da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 03:18:11 -0400 Subject: [PATCH 55/85] Add control_after_generate to combo primitive. --- web/extensions/core/widgetInputs.js | 2 +- web/scripts/widgets.js | 80 +++++++++++++++++++---------- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index df7d8f071..4fe0a6013 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -300,7 +300,7 @@ app.registerExtension({ } } - if (widget.type === "number") { + if (widget.type === "number" || widget.type === "combo") { addValueControlWidget(this, widget, "fixed"); } diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 65edc0392..3d1acc53e 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,35 +19,61 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - let min = targetWidget.options.min; - let max = targetWidget.options.max; - // limit to something that javascript can handle - max = Math.min(1125899906842624, max); - min = Math.max(-1125899906842624, min); - let range = (max - min) / (targetWidget.options.step / 10); + console.log(targetWidget); + if (targetWidget.type == "combo" && v !== "fixed") { + let current_index = targetWidget.options.values.indexOf(targetWidget.value); + let current_length = targetWidget.options.values.length; - //adjust values based on valueControl Behaviour - switch (v) { - case "fixed": - break; - case "increment": - targetWidget.value += targetWidget.options.step / 10; - break; - case "decrement": - targetWidget.value -= targetWidget.options.step / 10; - break; - case "randomize": - targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; - default: - break; + switch (v) { + case "increment": + current_index += 1; + break; + case "decrement": + current_index -= 1; + break; + case "randomize": + current_index = Math.floor(Math.random() * current_length); + default: + break; + } + current_index = Math.max(0, current_index); + current_index = Math.min(current_length - 1, current_index); + if (current_index >= 0) { + let value = targetWidget.options.values[current_index]; + targetWidget.value = value; + targetWidget.callback(value); + } + } else { //number + let min = targetWidget.options.min; + let max = targetWidget.options.max; + // limit to something that javascript can handle + max = Math.min(1125899906842624, max); + min = Math.max(-1125899906842624, min); + let range = (max - min) / (targetWidget.options.step / 10); + + //adjust values based on valueControl Behaviour + switch (v) { + case "fixed": + break; + case "increment": + targetWidget.value += targetWidget.options.step / 10; + break; + case "decrement": + targetWidget.value -= targetWidget.options.step / 10; + break; + case "randomize": + targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; + default: + break; + } + /*check if values are over or under their respective + * ranges and set them to min or max.*/ + if (targetWidget.value < min) + targetWidget.value = min; + + if (targetWidget.value > max) + targetWidget.value = max; } - /*check if values are over or under their respective - * ranges and set them to min or max.*/ - if (targetWidget.value < min) - targetWidget.value = min; - - if (targetWidget.value > max) - targetWidget.value = max; } return valueControl; }; From 7ada9e7d85f93495aa5006468a45220932f5e988 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Tue, 16 May 2023 22:55:00 +0900 Subject: [PATCH 56/85] allows touch drag --- web/scripts/app.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 87c5e30ca..ef3b44c83 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -902,7 +902,9 @@ export class ComfyApp { await this.#loadExtensions(); // Create and mount the LiteGraph in the DOM - const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" })); + const mainCanvas = document.createElement("canvas") + mainCanvas.style.touchAction = "none" + const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" })); canvasEl.tabIndex = "1"; document.body.prepend(canvasEl); From 11e7168d56e0987e52d0afb620189f08bda2b454 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 11:55:16 -0400 Subject: [PATCH 57/85] Remove print. --- web/scripts/widgets.js | 1 - 1 file changed, 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 3d1acc53e..94988d0f2 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,7 +19,6 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - console.log(targetWidget); if (targetWidget.type == "combo" && v !== "fixed") { let current_index = targetWidget.options.values.indexOf(targetWidget.value); let current_length = targetWidget.options.values.length; From 4088e61aa6b8943e28ee243c0b1265c41974ef67 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 15:35:07 -0400 Subject: [PATCH 58/85] Update litegraph from upstream. --- web/lib/litegraph.core.js | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 6c81c3ffd..95f4a2735 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action) if (show_text) { ctx.textAlign = "center"; ctx.fillStyle = text_color; - ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7); + ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7); } break; case "toggle": @@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); if (show_text) { ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = w.value ? text_color : secondary_text_color; ctx.textAlign = "right"; @@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.textAlign = "center"; ctx.fillStyle = text_color; ctx.fillText( - w.name + " " + Number(w.value).toFixed(3), + w.label || w.name + " " + Number(w.value).toFixed(3), widget_width * 0.5, y + H * 0.7 ); @@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); } ctx.fillStyle = secondary_text_color; - ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7); + ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7); ctx.fillStyle = text_color; ctx.textAlign = "right"; if (w.type == "number") { @@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action) //ctx.stroke(); ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = text_color; ctx.textAlign = "right"; From e7f2816c6f1da22e2018cf088bd45110ff265c79 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Thu, 18 May 2023 12:40:28 +0900 Subject: [PATCH 59/85] feat:Latent Save/Load (#662) * wip * latent dir * fix * fix * now working * mark todo * remove server.py changes to separate PRt --------- Co-authored-by: Lt.Dr.Data --- input/latents/_input_latents_will_be_put_here | 0 nodes.py | 90 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 input/latents/_input_latents_will_be_put_here diff --git a/input/latents/_input_latents_will_be_put_here b/input/latents/_input_latents_will_be_put_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index e8b36c24a..a2c7713aa 100644 --- a/nodes.py +++ b/nodes.py @@ -29,6 +29,8 @@ import importlib import folder_paths +import safetensors.torch as sft + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -246,6 +248,91 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class SaveLatent: + def __init__(self): + self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents") + self.type = "output" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = () + FUNCTION = "save" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(self.output_dir, subfolder) + + if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: + print("Saving latent outside the 'input/latents' folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + + # support save metadata for latent sharing + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"workflow": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + file = f"{filename}_{counter:05}_.latent" + file = os.path.join(full_output_folder, file) + + sft.save_file(samples, file, metadata=metadata) + + return {} + + +class LoadLatent: + input_dir = os.path.join(folder_paths.get_input_directory(), "latents") + + @classmethod + def INPUT_TYPES(s): + files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")] + return {"required": {"latent": [sorted(files), ]}, } + + CATEGORY = "_for_testing" + + RETURN_TYPES = ("LATENT", ) + FUNCTION = "load" + + def load(self, latent): + file = folder_paths.get_annotated_filepath(latent, self.input_dir) + + latent = sft.load_file(file, device="cpu") + + return (latent, ) + + class CheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -1235,6 +1322,9 @@ NODE_CLASS_MAPPINGS = { "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, + + "LoadLatent": LoadLatent, + "SaveLatent": SaveLatent } NODE_DISPLAY_NAME_MAPPINGS = { From a7375103b9c80bb7607f85faa4afbf11ab5a5685 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 May 2023 23:04:40 -0400 Subject: [PATCH 60/85] Some small changes to Load/SaveLatent. --- nodes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index a2c7713aa..7255621d7 100644 --- a/nodes.py +++ b/nodes.py @@ -11,6 +11,7 @@ import time from PIL import Image from PIL.PngImagePlugin import PngInfo import numpy as np +import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -29,7 +30,6 @@ import importlib import folder_paths -import safetensors.torch as sft def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -307,7 +307,10 @@ class SaveLatent: file = f"{filename}_{counter:05}_.latent" file = os.path.join(full_output_folder, file) - sft.save_file(samples, file, metadata=metadata) + output = {} + output["latent_tensor"] = samples["samples"] + + safetensors.torch.save_file(output, file, metadata=metadata) return {} @@ -328,9 +331,10 @@ class LoadLatent: def load(self, latent): file = folder_paths.get_annotated_filepath(latent, self.input_dir) - latent = sft.load_file(file, device="cpu") + latent = safetensors.torch.load_file(file, device="cpu") + samples = {"samples": latent["latent_tensor"]} - return (latent, ) + return (samples, ) class CheckpointLoader: From faf899ad5ae32f770f0dae6a9df457e81d2b5c38 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 May 2023 23:43:59 -0400 Subject: [PATCH 61/85] LoadLatent and SaveLatent should behave like the LoadImage and SaveImage. --- folder_paths.py | 33 +++++++ input/latents/_input_latents_will_be_put_here | 0 nodes.py | 90 +++++-------------- 3 files changed, 55 insertions(+), 68 deletions(-) delete mode 100644 input/latents/_input_latents_will_be_put_here diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..28f117824 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -147,4 +147,37 @@ def get_filename_list(folder_name): output_list.update(filter_files_extensions(recursive_search(x), folders[1])) return sorted(list(output_list)) +def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + def compute_vars(input, image_width, image_height): + input = input.replace("%width%", str(image_width)) + input = input.replace("%height%", str(image_height)) + return input + + filename_prefix = compute_vars(filename_prefix, image_width, image_height) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(output_dir, subfolder) + + if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: + print("Saving image outside the output folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + return full_output_folder, filename, counter, subfolder, filename_prefix diff --git a/input/latents/_input_latents_will_be_put_here b/input/latents/_input_latents_will_be_put_here deleted file mode 100644 index e69de29bb..000000000 diff --git a/nodes.py b/nodes.py index 7255621d7..7b450df38 100644 --- a/nodes.py +++ b/nodes.py @@ -251,13 +251,12 @@ class VAEEncodeForInpaint: class SaveLatent: def __init__(self): - self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents") - self.type = "output" + self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } RETURN_TYPES = () @@ -268,31 +267,7 @@ class SaveLatent: CATEGORY = "_for_testing" def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving latent outside the 'input/latents' folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) # support save metadata for latent sharing prompt_info = "" @@ -316,11 +291,10 @@ class SaveLatent: class LoadLatent: - input_dir = os.path.join(folder_paths.get_input_directory(), "latents") - @classmethod def INPUT_TYPES(s): - files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")] + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] return {"required": {"latent": [sorted(files), ]}, } CATEGORY = "_for_testing" @@ -329,13 +303,25 @@ class LoadLatent: FUNCTION = "load" def load(self, latent): - file = folder_paths.get_annotated_filepath(latent, self.input_dir) - - latent = safetensors.torch.load_file(file, device="cpu") + latent_path = folder_paths.get_annotated_filepath(latent) + latent = safetensors.torch.load_file(latent_path, device="cpu") samples = {"samples": latent["latent_tensor"]} - return (samples, ) + @classmethod + def IS_CHANGED(s, latent): + image_path = folder_paths.get_annotated_filepath(latent) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, latent): + if not folder_paths.exists_annotated_filepath(latent): + return "Invalid latent file: {}".format(latent) + return True + class CheckpointLoader: @classmethod @@ -1020,39 +1006,7 @@ class SaveImage: CATEGORY = "image" def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - def compute_vars(input): - input = input.replace("%width%", str(images[0].shape[1])) - input = input.replace("%height%", str(images[0].shape[0])) - return input - - filename_prefix = compute_vars(filename_prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving image outside the output folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 - + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() for image in images: i = 255. * image.cpu().numpy() From 62a371e12b4763bf6f9aeb42ff4928138df6ae26 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 18 May 2023 02:41:21 -0400 Subject: [PATCH 62/85] Load workflow from latent file. --- nodes.py | 2 +- web/scripts/app.js | 7 ++++++- web/scripts/pnginfo.js | 16 ++++++++++++++++ web/scripts/ui.js | 2 +- 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 7b450df38..3c61cd2ec 100644 --- a/nodes.py +++ b/nodes.py @@ -274,7 +274,7 @@ class SaveLatent: if prompt is not None: prompt_info = json.dumps(prompt) - metadata = {"workflow": prompt_info} + metadata = {"prompt": prompt_info} if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) diff --git a/web/scripts/app.js b/web/scripts/app.js index ef3b44c83..97b7c8d31 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; -import { getPngMetadata, importA1111 } from "./pnginfo.js"; +import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; /** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension @@ -1308,6 +1308,11 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); + } else if (file.name?.endsWith(".latent")) { + const info = await getLatentMetadata(file); + if (info.workflow) { + this.loadGraphData(JSON.parse(info.workflow)); + } } } diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 209b562a6..8ddb7a1c5 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -47,6 +47,22 @@ export function getPngMetadata(file) { }); } +export function getLatentMetadata(file) { + return new Promise((r) => { + const reader = new FileReader(); + reader.onload = (event) => { + const safetensorsData = new Uint8Array(event.target.result); + const dataView = new DataView(safetensorsData.buffer); + let header_size = dataView.getUint32(0, true); + let offset = 8; + let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + r(header.__metadata__); + }; + + reader.readAsArrayBuffer(file); + }); +} + export async function importA1111(graph, parameters) { const p = parameters.lastIndexOf("\nSteps:"); if (p > -1) { diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 77517aec1..2c9043d00 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -465,7 +465,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png", + accept: ".json,image/png,.latent", style: { display: "none" }, parent: document.body, onchange: () => { From 8bbd9815a976ef43e2665d45c5afb4a21c06c831 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 May 2023 02:15:32 -0400 Subject: [PATCH 63/85] Support loading fp16 latent files. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 3c61cd2ec..878e0b955 100644 --- a/nodes.py +++ b/nodes.py @@ -305,7 +305,7 @@ class LoadLatent: def load(self, latent): latent_path = folder_paths.get_annotated_filepath(latent) latent = safetensors.torch.load_file(latent_path, device="cpu") - samples = {"samples": latent["latent_tensor"]} + samples = {"samples": latent["latent_tensor"].float()} return (samples, ) @classmethod From 2998e232cb26b66e7ba42a53ada3a8285fcb2c15 Mon Sep 17 00:00:00 2001 From: malern <701073+malern@users.noreply.github.com> Date: Fri, 19 May 2023 19:57:15 +0100 Subject: [PATCH 64/85] Make multiline widget work with different canvas dimensions. It now scales the textarea positioning using the canvas height/width. --- web/scripts/widgets.js | 20 +++++++++++++------- web/style.css | 2 ++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 94988d0f2..82168b08b 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -155,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) { computeSize(node.size); } const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext"; - const t = ctx.getTransform(); const margin = 10; + const elRect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + Object.assign(this.inputEl.style, { - left: `${t.a * margin + t.e}px`, - top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, - width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, - background: (!node.color)?'':node.color, - height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, + transformOrigin: "0 0", + transform: transform, + left: "0px", + top: "0px", + width: `${widgetWidth - (margin * 2)}px`, + height: `${this.parent.inputHeight - (margin * 2)}px`, position: "absolute", + background: (!node.color)?'':node.color, color: (!node.color)?'':'white', zIndex: app.graph._nodes.indexOf(node), - fontSize: `${t.d * 10.0}px`, }); this.inputEl.hidden = !visible; }, diff --git a/web/style.css b/web/style.css index df220cc02..87f096e14 100644 --- a/web/style.css +++ b/web/style.css @@ -39,6 +39,8 @@ body { padding: 2px; resize: none; border: none; + box-sizing: border-box; + font-size: 10px; } .comfy-modal { From e6e1999f96adbe0f4b041d265837e00bde9283ab Mon Sep 17 00:00:00 2001 From: malern <701073+malern@users.noreply.github.com> Date: Fri, 19 May 2023 20:04:36 +0100 Subject: [PATCH 65/85] Render UI at a higher resolution when viewing with a higher pixel ratio --- web/scripts/app.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 97b7c8d31..514ca3958 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -921,8 +921,9 @@ export class ComfyApp { this.graph.start(); function resizeCanvas() { - canvasEl.width = canvasEl.offsetWidth; - canvasEl.height = canvasEl.offsetHeight; + canvasEl.width = canvasEl.offsetWidth * window.devicePixelRatio; + canvasEl.height = canvasEl.offsetHeight * window.devicePixelRatio; + canvasEl.getContext("2d").scale(window.devicePixelRatio, window.devicePixelRatio); canvas.draw(true, true); } From b9daf4e30f32e76a00628add0849d84b5ec2fe76 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 May 2023 22:40:28 -0400 Subject: [PATCH 66/85] Add a /object_info/{node_class} route to get only the info of one node. --- server.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/server.py b/server.py index f52117f10..18ce54306 100644 --- a/server.py +++ b/server.py @@ -261,23 +261,34 @@ class PromptServer(): async def get_prompt(request): return web.json_response(self.get_queue_info()) + def node_info(node_class): + obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + info = {} + info['input'] = obj_class.INPUT_TYPES() + info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] + info['name'] = node_class + info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['description'] = '' + info['category'] = 'sd' + if hasattr(obj_class, 'CATEGORY'): + info['category'] = obj_class.CATEGORY + return info + @routes.get("/object_info") async def get_object_info(request): out = {} for x in nodes.NODE_CLASS_MAPPINGS: - obj_class = nodes.NODE_CLASS_MAPPINGS[x] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = x - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x - info['description'] = '' - info['category'] = 'sd' - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - out[x] = info + out[x] = node_info(x) + return web.json_response(out) + + @routes.get("/object_info/{node_class}") + async def get_object_info_node(request): + node_class = request.match_info.get("node_class", None) + out = {} + if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): + out[node_class] = node_info(node_class) return web.json_response(out) @routes.get("/history") From 36af98d75580292d4afbeafb5e7ba7f010145436 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 20 May 2023 15:23:28 +0200 Subject: [PATCH 67/85] improve sharpen and blur nodes --- comfy_extras/nodes_post_processing.py | 35 ++++++++++++++++----------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ba699e2b8..37c824bde 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -59,6 +59,12 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) +def gaussian_kernel(kernel_size: int, sigma: float): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + class Blur: def __init__(self): pass @@ -88,12 +94,6 @@ class Blur: CATEGORY = "image/postprocessing" - def gaussian_kernel(self, kernel_size: int, sigma: float): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") - d = torch.sqrt(x * x + y * y) - g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) - return g / g.sum() - def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): if blur_radius == 0: return (image,) @@ -101,10 +101,11 @@ class Blur: batch_size, height, width, channels = image.shape kernel_size = blur_radius * 2 + 1 - kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) + padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') + blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) return (blurred,) @@ -167,9 +168,15 @@ class Sharpen: "max": 31, "step": 1 }), - "alpha": ("FLOAT", { + "sigma": ("FLOAT", { "default": 1.0, "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + "alpha": ("FLOAT", { + "default": 1.0, + "min": 0.0, "max": 5.0, "step": 0.1 }), @@ -181,21 +188,21 @@ class Sharpen: CATEGORY = "image/postprocessing" - def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): + def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): if sharpen_radius == 0: return (image,) batch_size, height, width, channels = image.shape kernel_size = sharpen_radius * 2 + 1 - kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 + kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) center = kernel_size // 2 - kernel[center, center] = kernel_size**2 - kernel *= alpha + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) + tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect') + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] sharpened = sharpened.permute(0, 2, 3, 1) result = torch.clamp(sharpened, 0, 1) From 71666f248f769af073408a3475dd7a82a29d8247 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 10:08:47 -0400 Subject: [PATCH 68/85] Fix padding in Blur. --- comfy_extras/nodes_post_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 37c824bde..3be141dfe 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -105,7 +105,7 @@ class Blur: image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') - blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] + blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) return (blurred,) From 797c4e8d3b56559bb205ff8aab97d97dca424b9a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 15:07:21 -0400 Subject: [PATCH 69/85] Simplify and improve some vae attention code. --- comfy/ldm/modules/diffusionmodules/model.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5e4d2b60f..05caf7312 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -331,25 +331,13 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), + lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = out.transpose(2, 3).reshape(B, C, H, W) out = self.proj_out(out) return x+out From b8636a44aacd83ec6a9a19a6d3d3f5b76fc863c9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 15:43:39 -0400 Subject: [PATCH 70/85] Make scaled_dot_product switch to sliced attention on OOM. --- comfy/ldm/modules/diffusionmodules/model.py | 79 +++++++++++---------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 05caf7312..91e7d60ec 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -146,6 +146,41 @@ class ResnetBlock(nn.Module): return x+h +def slice_attention(q, k, v): + r1 = torch.zeros_like(k, device=q.device) + scale = (int(q.shape[-1])**(-0.5)) + + mem_free_total = model_management.get_free_memory(q.device) + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + while True: + try: + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = torch.bmm(q[:, i:end], k) * scale + + s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) + del s1 + + r1[:, :, i:end] = torch.bmm(v, s2) + del s2 + break + except model_management.OOM_EXCEPTION as e: + steps *= 2 + if steps > 128: + raise e + print("out of memory error, increasing steps and trying again", steps) + + return r1 class AttnBlock(nn.Module): def __init__(self, in_channels): @@ -183,48 +218,15 @@ class AttnBlock(nn.Module): # compute attention b,c,h,w = q.shape - scale = (int(c)**(-0.5)) q = q.reshape(b,c,h*w) q = q.permute(0,2,1) # b,hw,c k = k.reshape(b,c,h*w) # b,c,hw v = v.reshape(b,c,h*w) - r1 = torch.zeros_like(k, device=q.device) - - mem_free_total = model_management.get_free_memory(q.device) - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - while True: - try: - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = torch.bmm(q[:, i:end], k) * scale - - s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) - del s1 - - r1[:, :, i:end] = torch.bmm(v, s2) - del s2 - break - except model_management.OOM_EXCEPTION as e: - steps *= 2 - if steps > 128: - raise e - print("out of memory error, increasing steps and trying again", steps) - + r1 = slice_attention(q, k, v) h_ = r1.reshape(b,c,h,w) del r1 - h_ = self.proj_out(h_) return x+h_ @@ -335,9 +337,14 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = out.transpose(2, 3).reshape(B, C, H, W) + try: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(2, 3).reshape(B, C, H, W) + except model_management.OOM_EXCEPTION as e: + print("scaled_dot_product_attention OOMed: switched to slice attention") + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + out = self.proj_out(out) return x+out From 3c76f43057f140c583962327c18c7d5257e7495c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 23:06:33 -0400 Subject: [PATCH 71/85] Cleaner code. --- server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 18ce54306..701c0e7a7 100644 --- a/server.py +++ b/server.py @@ -331,7 +331,8 @@ class PromptServer(): extra_data["client_id"] = json_data["client_id"] if valid[0]: prompt_id = str(uuid.uuid4()) - self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) + outputs_to_execute = valid[2] + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) return web.json_response({"prompt_id": prompt_id}) else: print("invalid prompt:", valid[1]) From 516119ad835841bd176055cbc888843b418b8004 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 00:24:28 -0400 Subject: [PATCH 72/85] Print min and max values in validation error message. --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 0e2cc15c1..35f044346 100644 --- a/execution.py +++ b/execution.py @@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value smaller than min. {}, {}".format(class_type, x)) + return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x)) if "max" in info[1] and val > info[1]["max"]: - return (False, "Value bigger than max. {}, {}".format(class_type, x)) + return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x)) if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) From 069657fbf3d8d977ead39ab206d8c917bbcc4997 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 01:35:08 -0400 Subject: [PATCH 73/85] Add DPM-Solver++(2M) SDE and exponential scheduler. exponential scheduler is the one recommended with this sampler. --- comfy/k_diffusion/sampling.py | 43 +++++++++++++++++++++++++++++++++++ comfy/samplers.py | 6 +++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c809d39fb..94d7a5762 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -605,3 +605,46 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d old_denoised = denoised return x + +@torch.no_grad() +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + """DPM-Solver++(2M) SDE.""" + + if solver_type not in {'heun', 'midpoint'}: + raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index fccf254ec..1fb928f8d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -495,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device): class KSampler: - SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", - "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -532,6 +532,8 @@ class KSampler: if self.scheduler == "karras": sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + elif self.scheduler == "exponential": + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) elif self.scheduler == "normal": sigmas = self.model_wrap.get_sigmas(steps) elif self.scheduler == "simple": From 4796e615dd7faad38429fc8e716e3a817a28c526 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 10:34:26 -0400 Subject: [PATCH 74/85] Revert DPI fix since it caused more issues than it solved. --- web/scripts/app.js | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 514ca3958..97b7c8d31 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -921,9 +921,8 @@ export class ComfyApp { this.graph.start(); function resizeCanvas() { - canvasEl.width = canvasEl.offsetWidth * window.devicePixelRatio; - canvasEl.height = canvasEl.offsetHeight * window.devicePixelRatio; - canvasEl.getContext("2d").scale(window.devicePixelRatio, window.devicePixelRatio); + canvasEl.width = canvasEl.offsetWidth; + canvasEl.height = canvasEl.offsetHeight; canvas.draw(true, true); } From dc198650c0d2d281c9d87b23a8917b457a94d837 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 11:34:29 -0400 Subject: [PATCH 75/85] sample_dpmpp_2m_sde no longer crashes when step == 1. --- comfy/k_diffusion/sampling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 94d7a5762..c540d7411 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -628,6 +628,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if sigmas[i + 1] == 0: # Denoising step x = denoised + h = None else: # DPM-Solver++(2M) SDE t, s = -sigmas[i].log(), -sigmas[i + 1].log() From 6cc450579b3314fe314e64af22c4be81afd1f87d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 00:22:24 -0400 Subject: [PATCH 76/85] Auto transpose images from exif data. --- comfy/k_diffusion/sampling.py | 2 +- nodes.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c540d7411..26930428f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -620,6 +620,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl old_denoised = None h_last = None + h = None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -628,7 +629,6 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if sigmas[i + 1] == 0: # Denoising step x = denoised - h = None else: # DPM-Solver++(2M) SDE t, s = -sigmas[i].log(), -sigmas[i + 1].log() diff --git a/nodes.py b/nodes.py index 878e0b955..bae330bc9 100644 --- a/nodes.py +++ b/nodes.py @@ -8,7 +8,7 @@ import traceback import math import time -from PIL import Image +from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch @@ -1057,6 +1057,7 @@ class LoadImage: def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] @@ -1100,6 +1101,7 @@ class LoadImageMask: def load_image(self, image, channel): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") mask = None From ffc56c53c9cccfcc21c92fe14cb095bb32ea2744 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 13:22:38 -0400 Subject: [PATCH 77/85] Add a node_errors to the /prompt error json response. "node_errors" contains a dict keyed by node ids. The contents are a message and a list of dependent outputs. --- execution.py | 27 ++++++++++++++++----------- server.py | 4 ++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/execution.py b/execution.py index 35f044346..212e789ca 100644 --- a/execution.py +++ b/execution.py @@ -299,18 +299,18 @@ def validate_inputs(prompt, item, validated): required_inputs = class_inputs['required'] for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x)) + return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x)) + return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) + return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) r = validate_inputs(prompt, o_id, validated) if r[0] == False: validated[o_id] = r @@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x)) + return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x)) + return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) @@ -338,13 +338,13 @@ def validate_inputs(prompt, item, validated): ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") for r in ret: if r != True: - return (False, "{}, {}".format(class_type, r)) + return (False, "{}, {}".format(class_type, r), unique_id) else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) - ret = (True, "") + ret = (True, "", unique_id) validated[unique_id] = ret return ret @@ -356,10 +356,11 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs") + return (False, "Prompt has no outputs", [], []) good_outputs = set() errors = [] + node_errors = {} validated = {} for o in outputs: valid = False @@ -368,6 +369,7 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] + node_id = m[2] except Exception as e: print(traceback.format_exc()) valid = False @@ -379,12 +381,15 @@ def validate_prompt(prompt): print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") errors += [(o, reason)] + if node_id not in node_errors: + node_errors[node_id] = {"message": reason, "dependent_outputs": []} + node_errors[node_id]["dependent_outputs"].append(o) if len(good_outputs) == 0: errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) + return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) - return (True, "", list(good_outputs)) + return (True, "", list(good_outputs), node_errors) class PromptQueue: diff --git a/server.py b/server.py index 701c0e7a7..8429a63fb 100644 --- a/server.py +++ b/server.py @@ -336,9 +336,9 @@ class PromptServer(): return web.json_response({"prompt_id": prompt_id}) else: print("invalid prompt:", valid[1]) - return web.json_response({"error": valid[1]}, status=400) + return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: - return web.json_response({"error": "no prompt"}, status=400) + return web.json_response({"error": "no prompt", "node_errors": []}, status=400) @routes.post("/queue") async def post_queue(request): From db27b0405a31983916d6801cf84f7f1fc4503e6a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 13:25:50 -0400 Subject: [PATCH 78/85] object_info now returns if node is an output_node or not. --- server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server.py b/server.py index 8429a63fb..c0f79cbd5 100644 --- a/server.py +++ b/server.py @@ -272,6 +272,11 @@ class PromptServer(): info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class info['description'] = '' info['category'] = 'sd' + if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: + info['output_node'] = True + else: + info['output_node'] = False + if hasattr(obj_class, 'CATEGORY'): info['category'] = obj_class.CATEGORY return info From bfb13f5eee48545f1c4b0b8a377de80be84bb100 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 17:05:23 -0400 Subject: [PATCH 79/85] Remove useless call to /object_info --- web/extensions/core/colorPalette.js | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 2f2238a2b..bfcd847a3 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -174,7 +174,7 @@ const els = {} // const ctxMenu = LiteGraph.ContextMenu; app.registerExtension({ name: id, - init() { + addCustomNodeDefs(node_defs) { const sortObjectKeys = (unordered) => { return Object.keys(unordered).sort().reduce((obj, key) => { obj[key] = unordered[key]; @@ -182,10 +182,10 @@ app.registerExtension({ }, {}); }; - const getSlotTypes = async () => { + function getSlotTypes() { var types = []; - const defs = await api.getNodeDefs(); + const defs = node_defs; for (const nodeId in defs) { const nodeData = defs[nodeId]; @@ -212,8 +212,8 @@ app.registerExtension({ return types; }; - const completeColorPalette = async (colorPalette) => { - var types = await getSlotTypes(); + function completeColorPalette(colorPalette) { + var types = getSlotTypes(); for (const type of types) { if (!colorPalette.colors.node_slot[type]) { From 48fcc5b777b3a1ab5d6dc5fec6adaebeb32c2c93 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 20:51:30 -0400 Subject: [PATCH 80/85] Parsing error crash. --- execution.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/execution.py b/execution.py index 212e789ca..25f2fcacd 100644 --- a/execution.py +++ b/execution.py @@ -374,6 +374,7 @@ def validate_prompt(prompt): print(traceback.format_exc()) valid = False reason = "Parsing error" + node_id = None if valid == True: good_outputs.add(o) @@ -381,9 +382,10 @@ def validate_prompt(prompt): print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") errors += [(o, reason)] - if node_id not in node_errors: - node_errors[node_id] = {"message": reason, "dependent_outputs": []} - node_errors[node_id]["dependent_outputs"].append(o) + if node_id is not None: + if node_id not in node_errors: + node_errors[node_id] = {"message": reason, "dependent_outputs": []} + node_errors[node_id]["dependent_outputs"].append(o) if len(good_outputs) == 0: errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) From 34887b888546716b5c5507606289ca2728bf3123 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 03:12:56 -0400 Subject: [PATCH 81/85] Add experimental bislerp algorithm for latent upscaling. It's like bilinear but with slerp. --- comfy/utils.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++- nodes.py | 2 +- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 09e05d4ed..0f7b34503 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -46,6 +46,65 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +#slow and inefficient, should be optimized +def bislerp(samples, width, height): + shape = list(samples.shape) + width_scale = (shape[3]) / (width ) + height_scale = (shape[2]) / (height ) + + shape[3] = width + shape[2] = height + out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + + def algorithm(in1, w1, in2, w2): + dims = in1.shape + val = w2 + + #flatten to batches + low = in1.reshape(dims[0], -1) + high = in2.reshape(dims[0], -1) + + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + + # in case we divide by zero + low_norm[low_norm != low_norm] = 0.0 + high_norm[high_norm != high_norm] = 0.0 + + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res.reshape(dims) + + for x_dest in range(shape[3]): + for y_dest in range(shape[2]): + y = (y_dest) * height_scale + x = (x_dest) * width_scale + + x1 = max(math.floor(x), 0) + x2 = min(x1 + 1, samples.shape[3] - 1) + y1 = max(math.floor(y), 0) + y2 = min(y1 + 1, samples.shape[2] - 1) + + in1 = samples[:,:,y1,x1] + in2 = samples[:,:,y1,x2] + in3 = samples[:,:,y2,x1] + in4 = samples[:,:,y2,x2] + + if (x1 == x2) and (y1 == y2): + out_value = in1 + elif (x1 == x2): + out_value = algorithm(in1, (y2 - y), in3, (y - y1)) + elif (y1 == y2): + out_value = algorithm(in1, (x2 - x), in2, (x - x1)) + else: + o1 = algorithm(in1, (x2 - x), in2, (x - x1)) + o2 = algorithm(in3, (x2 - x), in4, (x - x1)) + out_value = algorithm(o1, (y2 - y), o2, (y - y1)) + + out1[:,:,y_dest,x_dest] = out_value + return out1 + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] @@ -61,7 +120,11 @@ def common_upscale(samples, width, height, upscale_method, crop): s = samples[:,:,y:old_height-y,x:old_width-x] else: s = samples - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + + if upscale_method == "bislerp": + return bislerp(s, width, height) + else: + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) diff --git a/nodes.py b/nodes.py index bae330bc9..e5cec2632 100644 --- a/nodes.py +++ b/nodes.py @@ -749,7 +749,7 @@ class RepeatLatentBatch: return (s,) class LatentUpscale: - upscale_methods = ["nearest-exact", "bilinear", "area"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] crop_methods = ["disabled", "center"] @classmethod From 451fb4169ad900e5d33b540f039f56ced9a76157 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 11:35:32 -0400 Subject: [PATCH 82/85] Fix 'git pull' not working on the standalones. --- .github/workflows/windows_release_cu118_package.yml | 1 + .github/workflows/windows_release_nightly_pytorch.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/windows_release_cu118_package.yml b/.github/workflows/windows_release_cu118_package.yml index 15322c86a..2d6048a23 100644 --- a/.github/workflows/windows_release_cu118_package.yml +++ b/.github/workflows/windows_release_cu118_package.yml @@ -30,6 +30,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - shell: bash run: | cd .. diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index b6a18ec0a..767a7216b 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -17,6 +17,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - uses: actions/setup-python@v4 with: python-version: '3.11.3' From b8ccbec6d893d34dab90d2418a3fe00969251fa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 11:40:24 -0400 Subject: [PATCH 83/85] Various improvements to bislerp. --- comfy/utils.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 0f7b34503..300eda6aa 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -56,35 +56,42 @@ def bislerp(samples, width, height): shape[2] = height out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) - def algorithm(in1, w1, in2, w2): + def algorithm(in1, in2, t): dims = in1.shape - val = w2 + val = t #flatten to batches low = in1.reshape(dims[0], -1) high = in2.reshape(dims[0], -1) - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) + low_weight = torch.norm(low, dim=1, keepdim=True) + low_weight[low_weight == 0] = 0.0000000001 + low_norm = low/low_weight + high_weight = torch.norm(high, dim=1, keepdim=True) + high_weight[high_weight == 0] = 0.0000000001 + high_norm = high/high_weight - # in case we divide by zero - low_norm[low_norm != low_norm] = 0.0 - high_norm[high_norm != high_norm] = 0.0 - - omega = torch.acos((low_norm*high_norm).sum(1)) + dot_prod = (low_norm*high_norm).sum(1) + dot_prod[dot_prod > 0.9995] = 0.9995 + dot_prod[dot_prod < -0.9995] = -0.9995 + omega = torch.acos(dot_prod) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm + res *= (low_weight * (1.0-val) + high_weight * val) return res.reshape(dims) for x_dest in range(shape[3]): for y_dest in range(shape[2]): - y = (y_dest) * height_scale - x = (x_dest) * width_scale + y = (y_dest + 0.5) * height_scale - 0.5 + x = (x_dest + 0.5) * width_scale - 0.5 x1 = max(math.floor(x), 0) x2 = min(x1 + 1, samples.shape[3] - 1) + wx = x - math.floor(x) + y1 = max(math.floor(y), 0) y2 = min(y1 + 1, samples.shape[2] - 1) + wy = y - math.floor(y) in1 = samples[:,:,y1,x1] in2 = samples[:,:,y1,x2] @@ -94,13 +101,13 @@ def bislerp(samples, width, height): if (x1 == x2) and (y1 == y2): out_value = in1 elif (x1 == x2): - out_value = algorithm(in1, (y2 - y), in3, (y - y1)) + out_value = algorithm(in1, in3, wy) elif (y1 == y2): - out_value = algorithm(in1, (x2 - x), in2, (x - x1)) + out_value = algorithm(in1, in2, wx) else: - o1 = algorithm(in1, (x2 - x), in2, (x - x1)) - o2 = algorithm(in3, (x2 - x), in4, (x - x1)) - out_value = algorithm(o1, (y2 - y), o2, (y - y1)) + o1 = algorithm(in1, in2, wx) + o2 = algorithm(in3, in4, wx) + out_value = algorithm(o1, o2, wy) out1[:,:,y_dest,x_dest] = out_value return out1 From c00bb1a0b78f0d2cf2e4ec2dd9ae7d61cb07a637 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 12:53:38 -0400 Subject: [PATCH 84/85] Add a latent upscale by node. --- nodes.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nodes.py b/nodes.py index e5cec2632..f0a93ebd5 100644 --- a/nodes.py +++ b/nodes.py @@ -768,6 +768,25 @@ class LatentUpscale: s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) +class LatentUpscaleBy: + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "upscale" + + CATEGORY = "latent" + + def upscale(self, samples, upscale_method, scale_by): + s = samples.copy() + width = round(samples["samples"].shape[3] * scale_by) + height = round(samples["samples"].shape[2] * scale_by) + s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled") + return (s,) + class LatentRotate: @classmethod def INPUT_TYPES(s): @@ -1244,6 +1263,7 @@ NODE_CLASS_MAPPINGS = { "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, + "LatentUpscaleBy": LatentUpscaleBy, "LatentFromBatch": LatentFromBatch, "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, @@ -1322,6 +1342,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentCrop": "Crop Latent", "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", + "LatentUpscaleBy": "Upscale Latent By", "LatentComposite": "Latent Composite", "LatentFromBatch" : "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", From 7310290f17aad79480edb92f22cd58f0997db964 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 22:26:50 -0400 Subject: [PATCH 85/85] Pull in latest upscale model code from chainner. --- .../architecture/OmniSR/ChannelAttention.py | 110 ++++ .../architecture/OmniSR/LICENSE | 201 ++++++ .../architecture/OmniSR/OSA.py | 577 ++++++++++++++++++ .../architecture/OmniSR/OSAG.py | 60 ++ .../architecture/OmniSR/OmniSR.py | 133 ++++ .../architecture/OmniSR/esa.py | 294 +++++++++ .../architecture/OmniSR/layernorm.py | 70 +++ .../architecture/OmniSR/pixelshuffle.py | 31 + .../chainner_models/architecture/RRDB.py | 17 +- .../chainner_models/architecture/block.py | 30 + comfy_extras/chainner_models/model_loading.py | 5 + comfy_extras/chainner_models/types.py | 4 +- 12 files changed, 1530 insertions(+), 2 deletions(-) create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/LICENSE create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OSA.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OSAG.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/esa.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/layernorm.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py diff --git a/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py new file mode 100644 index 000000000..f4d52aa1e --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py @@ -0,0 +1,110 @@ +import math + +import torch.nn as nn + + +class CA_layer(nn.Module): + def __init__(self, channel, reduction=16): + super(CA_layer, self).__init__() + # global average pooling + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False), + nn.GELU(), + nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False), + # nn.Sigmoid() + ) + + def forward(self, x): + y = self.fc(self.gap(x)) + return x * y.expand_as(x) + + +class Simple_CA_layer(nn.Module): + def __init__(self, channel): + super(Simple_CA_layer, self).__init__() + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=channel, + out_channels=channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return x * self.fc(self.gap(x)) + + +class ECA_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) + + +class ECA_MaxPool_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_MaxPool_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.max_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/LICENSE b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSA.py b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py new file mode 100644 index 000000000..d7a129696 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSA.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:07:42 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce +from torch import einsum, nn + +from .layernorm import LayerNorm2d + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, length=1): + return val if isinstance(val, tuple) else ((val,) * length) + + +# helper classes + + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class Conv_PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = LayerNorm2d(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Conv2d(dim, inner_dim, 1, 1, 0), + nn.GELU(), + nn.Dropout(dropout), + nn.Conv2d(inner_dim, dim, 1, 1, 0), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Gated_Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=1, bias=False, dropout=0.0): + super().__init__() + + hidden_features = int(dim * mult) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + bias=bias, + ) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +# MBConv + + +class SqueezeExcitation(nn.Module): + def __init__(self, dim, shrinkage_rate=0.25): + super().__init__() + hidden_dim = int(dim * shrinkage_rate) + + self.gate = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(dim, hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim, dim, bias=False), + nn.Sigmoid(), + Rearrange("b c -> b c 1 1"), + ) + + def forward(self, x): + return x * self.gate(x) + + +class MBConvResidual(nn.Module): + def __init__(self, fn, dropout=0.0): + super().__init__() + self.fn = fn + self.dropsample = Dropsample(dropout) + + def forward(self, x): + out = self.fn(x) + out = self.dropsample(out) + return out + x + + +class Dropsample(nn.Module): + def __init__(self, prob=0): + super().__init__() + self.prob = prob + + def forward(self, x): + device = x.device + + if self.prob == 0.0 or (not self.training): + return x + + keep_mask = ( + torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() + > self.prob + ) + return x * keep_mask / (1 - self.prob) + + +def MBConv( + dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0 +): + hidden_dim = int(expansion_rate * dim_out) + stride = 2 if downsample else 1 + + net = nn.Sequential( + nn.Conv2d(dim_in, hidden_dim, 1), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + nn.Conv2d( + hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim + ), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate), + nn.Conv2d(hidden_dim, dim_out, 1), + # nn.BatchNorm2d(dim_out) + ) + + if dim_in == dim_out and not downsample: + net = MBConvResidual(net, dropout=dropout) + + return net + + +# attention related classes +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Sequential( + nn.Linear(dim, dim, bias=False), nn.Dropout(dropout) + ) + + # relative positional bias + if self.with_pe: + self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) + + pos = torch.arange(window_size) + grid = torch.stack(torch.meshgrid(pos, pos)) + grid = rearrange(grid, "c i j -> (i j) c") + rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange( + grid, "j ... -> 1 j ..." + ) + rel_pos += window_size - 1 + rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum( + dim=-1 + ) + + self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False) + + def forward(self, x): + batch, height, width, window_height, window_width, _, device, h = ( + *x.shape, + x.device, + self.heads, + ) + + # flatten + + x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d") + + # project for queries, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + # split heads + + q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v)) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # add positional bias + if self.with_pe: + bias = self.rel_pos_bias(self.rel_pos_indices) + sim = sim + rearrange(bias, "i j h -> h i j") + + # attention + + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + + out = rearrange( + out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width + ) + + # combine heads out + + out = self.to_out(out) + return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width) + + +class Block_Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + bias=False, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.ps = window_size + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + # project for queries, keys, values + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + # split heads + + q, k, v = map( + lambda t: rearrange( + t, + "b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d", + h=self.heads, + w1=self.ps, + w2=self.ps, + ), + (q, k, v), + ) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # attention + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + out = rearrange( + out, + "(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)", + x=h // self.ps, + y=w // self.ps, + head=self.heads, + w1=self.ps, + w2=self.ps, + ) + + out = self.to_out(out) + return out + + +class Channel_Attention(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class Channel_Attention_grid(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention_grid, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class OSA_Block(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + ffn_bias=True, + window_size=8, + with_pe=False, + dropout=0.0, + ): + super(OSA_Block, self).__init__() + + w = window_size + + self.layer = nn.Sequential( + MBConv( + channel_num, + channel_num, + downsample=False, + expansion_rate=1, + shrinkage_rate=0.25, + ), + Rearrange( + "b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w + ), # block-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + Rearrange( + "b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w + ), # grid-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention_grid( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + ) + + def forward(self, x): + out = self.layer(x) + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py new file mode 100644 index 000000000..477e81f9d --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSAG.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:08:49 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + + +import torch.nn as nn + +from .esa import ESA +from .OSA import OSA_Block + + +class OSAG(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + block_num=4, + ffn_bias=False, + window_size=0, + pe=False, + ): + super(OSAG, self).__init__() + + # print("window_size: %d" % (window_size)) + # print("with_pe", pe) + # print("ffn_bias: %d" % (ffn_bias)) + + # block_script_name = kwargs.get("block_script_name", "OSA") + # block_class_name = kwargs.get("block_class_name", "OSA_Block") + + # script_name = "." + block_script_name + # package = __import__(script_name, fromlist=True) + block_class = OSA_Block # getattr(package, block_class_name) + group_list = [] + for _ in range(block_num): + temp_res = block_class( + channel_num, + bias, + ffn_bias=ffn_bias, + window_size=window_size, + with_pe=pe, + ) + group_list.append(temp_res) + group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias)) + self.residual_layer = nn.Sequential(*group_list) + esa_channel = max(channel_num // 4, 16) + self.esa = ESA(esa_channel, channel_num) + + def forward(self, x): + out = self.residual_layer(x) + out = out + x + return self.esa(out) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py new file mode 100644 index 000000000..dec169520 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OmniSR.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:06:36 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .OSAG import OSAG +from .pixelshuffle import pixelshuffle_block + + +class OmniSR(nn.Module): + def __init__( + self, + state_dict, + **kwargs, + ): + super(OmniSR, self).__init__() + self.state = state_dict + + bias = True # Fine to assume this for now + block_num = 1 # Fine to assume this for now + ffn_bias = True + pe = True + + num_feat = state_dict["input.weight"].shape[0] or 64 + num_in_ch = state_dict["input.weight"].shape[1] or 3 + num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh + + pixelshuffle_shape = state_dict["up.0.weight"].shape[0] + up_scale = math.sqrt(pixelshuffle_shape / num_out_ch) + if up_scale - int(up_scale) > 0: + print( + "out_nc is probably different than in_nc, scale calculation might be wrong" + ) + up_scale = int(up_scale) + res_num = 0 + for key in state_dict.keys(): + if "residual_layer" in key: + temp_res_num = int(key.split(".")[1]) + if temp_res_num > res_num: + res_num = temp_res_num + res_num = res_num + 1 # zero-indexed + + residual_layer = [] + self.res_num = res_num + + self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer) + self.up_scale = up_scale + + for _ in range(res_num): + temp_res = OSAG( + channel_num=num_feat, + bias=bias, + block_num=block_num, + ffn_bias=ffn_bias, + window_size=self.window_size, + pe=pe, + ) + residual_layer.append(temp_res) + self.residual_layer = nn.Sequential(*residual_layer) + self.input = nn.Conv2d( + in_channels=num_in_ch, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.output = nn.Conv2d( + in_channels=num_feat, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias) + + # self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, sqrt(2. / n)) + + # chaiNNer specific stuff + self.model_arch = "OmniSR" + self.sub_type = "SR" + self.in_nc = num_in_ch + self.out_nc = num_out_ch + self.num_feat = num_feat + self.scale = up_scale + + self.supports_fp16 = True # TODO: Test this + self.supports_bfp16 = True + self.min_size_restriction = 16 + + self.load_state_dict(state_dict, strict=False) + + def check_image_size(self, x): + _, _, h, w = x.size() + # import pdb; pdb.set_trace() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + # x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0) + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + residual = self.input(x) + out = self.residual_layer(residual) + + # origin + out = torch.add(self.output(out), residual) + out = self.up(out) + + out = out[:, :, : H * self.up_scale, : W * self.up_scale] + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/esa.py b/comfy_extras/chainner_models/architecture/OmniSR/esa.py new file mode 100644 index 000000000..f9ce7f7a6 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/esa.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: esa.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:06 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .layernorm import LayerNorm2d + + +def moment(x, dim=(2, 3), k=2): + assert len(x.size()) == 4 + mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1) + mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim) + return mk + + +class ESA(nn.Module): + """ + Modification of Enhanced Spatial Attention (ESA), which is proposed by + `Residual Feature Aggregation Network for Image Super-Resolution` + Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes + are deleted. + """ + + def __init__(self, esa_channels, n_feats, conv=nn.Conv2d): + super(ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) + self.conv3 = conv(f, f, kernel_size=3, padding=1) + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + c1 = self.conv2(c1_) + v_max = F.max_pool2d(c1, kernel_size=7, stride=3) + c3 = self.conv3(v_max) + c3 = F.interpolate( + c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False + ) + cf = self.conv_f(c1_) + c4 = self.conv4(c3 + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA_LN(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA_LN, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.norm = LayerNorm2d(n_feats) + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.norm(x) + c1_ = self.conv1(c1_) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class AdaGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaGuidedFilter, self).__init__() + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=n_feats, + out_channels=1, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + self.r = 5 + + def box_filter(self, x, r): + channel = x.shape[1] + kernel_size = 2 * r + 1 + weight = 1.0 / (kernel_size**2) + box_kernel = weight * torch.ones( + (channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device + ) + output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) + return output + + def forward(self, x): + _, _, H, W = x.shape + N = self.box_filter( + torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r + ) + + # epsilon = self.fc(self.gap(x)) + # epsilon = torch.pow(epsilon, 2) + epsilon = 1e-2 + + mean_x = self.box_filter(x, self.r) / N + var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x + + A = var_x / (var_x + epsilon) + b = (1 - A) * mean_x + m = A * x + b + + # mean_A = self.box_filter(A, self.r) / N + # mean_b = self.box_filter(b, self.r) / N + # m = mean_A * x + mean_b + return x * m + + +class AdaConvGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaConvGuidedFilter, self).__init__() + f = esa_channels + + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=f, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=f, + bias=bias, + ) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + y = self.vec_conv(x) + y = self.hor_conv(y) + + sigma = torch.pow(y, 2) + epsilon = self.fc(self.gap(y)) + + weight = sigma / (sigma + epsilon) + + m = weight * x + (1 - weight) + + return x * m diff --git a/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py new file mode 100644 index 000000000..731a25f75 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: layernorm.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:20 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn + + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class GRN(nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) + Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x diff --git a/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py new file mode 100644 index 000000000..4260fb7c9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: pixelshuffle.py +# Created Date: Friday July 1st 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 1st July 2022 10:18:39 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import torch.nn as nn + + +def pixelshuffle_block( + in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False +): + """ + Upsample features according to `upscale_factor`. + """ + padding = kernel_size // 2 + conv = nn.Conv2d( + in_channels, + out_channels * (upscale_factor**2), + kernel_size, + padding=1, + bias=bias, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + return nn.Sequential(*[conv, pixel_shuffle]) diff --git a/comfy_extras/chainner_models/architecture/RRDB.py b/comfy_extras/chainner_models/architecture/RRDB.py index 4d52f05dd..b50db7c24 100644 --- a/comfy_extras/chainner_models/architecture/RRDB.py +++ b/comfy_extras/chainner_models/architecture/RRDB.py @@ -79,6 +79,12 @@ class RRDBNet(nn.Module): self.scale: int = self.get_scale() self.num_filters: int = self.state[self.key_arr[0]].shape[0] + c2x2 = False + if self.state["model.0.weight"].shape[-2] == 2: + c2x2 = True + self.scale = round(math.sqrt(self.scale / 4)) + self.model_arch = "ESRGAN-2c2" + self.supports_fp16 = True self.supports_bfp16 = True self.min_size_restriction = None @@ -105,11 +111,15 @@ class RRDBNet(nn.Module): out_nc=self.num_filters, upscale_factor=3, act_type=self.act, + c2x2=c2x2, ) else: upsample_blocks = [ upsample_block( - in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act + in_nc=self.num_filters, + out_nc=self.num_filters, + act_type=self.act, + c2x2=c2x2, ) for _ in range(int(math.log(self.scale, 2))) ] @@ -122,6 +132,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), B.ShortcutBlock( B.sequential( @@ -138,6 +149,7 @@ class RRDBNet(nn.Module): act_type=self.act, mode="CNA", plus=self.plus, + c2x2=c2x2, ) for _ in range(self.num_blocks) ], @@ -149,6 +161,7 @@ class RRDBNet(nn.Module): norm_type=self.norm, act_type=None, mode=self.mode, + c2x2=c2x2, ), ) ), @@ -160,6 +173,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=self.act, + c2x2=c2x2, ), # hr_conv1 B.conv_block( @@ -168,6 +182,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), ) diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py index 214642cc4..d7bc5d227 100644 --- a/comfy_extras/chainner_models/architecture/block.py +++ b/comfy_extras/chainner_models/architecture/block.py @@ -141,6 +141,19 @@ def sequential(*args): ConvMode = Literal["CNA", "NAC", "CNAC"] +# 2x2x2 Conv Block +def conv_block_2c2( + in_nc, + out_nc, + act_type="relu", +): + return sequential( + nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), + nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), + act(act_type) if act_type else None, + ) + + def conv_block( in_nc: int, out_nc: int, @@ -153,12 +166,17 @@ def conv_block( norm_type: str | None = None, act_type: str | None = "relu", mode: ConvMode = "CNA", + c2x2=False, ): """ Conv layer with padding, normalization, activation mode: CNA --> Conv -> Norm -> Act NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) """ + + if c2x2: + return conv_block_2c2(in_nc, out_nc, act_type=act_type) + assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) padding = get_valid_padding(kernel_size, dilation) p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None @@ -285,6 +303,7 @@ class RRDB(nn.Module): _convtype="Conv2D", _spectral_norm=False, plus=False, + c2x2=False, ): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C( @@ -298,6 +317,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB2 = ResidualDenseBlock_5C( nf, @@ -310,6 +330,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB3 = ResidualDenseBlock_5C( nf, @@ -322,6 +343,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) def forward(self, x): @@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module): act_type="leakyrelu", mode: ConvMode = "CNA", plus=False, + c2x2=False, ): super(ResidualDenseBlock_5C, self).__init__() @@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv2 = conv_block( nf + gc, @@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv3 = conv_block( nf + 2 * gc, @@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv4 = conv_block( nf + 3 * gc, @@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) if mode == "CNA": last_act = None @@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=last_act, mode=mode, + c2x2=c2x2, ) def forward(self, x): @@ -499,6 +527,7 @@ def upconv_block( norm_type: str | None = None, act_type="relu", mode="nearest", + c2x2=False, ): # Up conv # described in https://distill.pub/2016/deconv-checkerboard/ @@ -512,5 +541,6 @@ def upconv_block( pad_type=pad_type, norm_type=norm_type, act_type=act_type, + c2x2=c2x2, ) return sequential(upsample, conv) diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py index 8234ac5d1..2e66e6247 100644 --- a/comfy_extras/chainner_models/model_loading.py +++ b/comfy_extras/chainner_models/model_loading.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel: state_dict = state_dict["params"] state_dict_keys = list(state_dict.keys()) + # SRVGGNet Real-ESRGAN (v2) if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys: model = RealESRGANv2(state_dict) @@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel: # MAT elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys: model = MAT(state_dict) + # Omni-SR + elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys: + model = OmniSR(state_dict) # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1 else: try: diff --git a/comfy_extras/chainner_models/types.py b/comfy_extras/chainner_models/types.py index 8e2bef47a..1906c0c7f 100644 --- a/comfy_extras/chainner_models/types.py +++ b/comfy_extras/chainner_models/types.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN from .architecture.Swin2SR import Swin2SR from .architecture.SwinIR import SwinIR -PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT) +PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR) PyTorchSRModel = Union[ RealESRGANv2, SPSR, @@ -22,6 +23,7 @@ PyTorchSRModel = Union[ SwinIR, Swin2SR, HAT, + OmniSR, ]