From 472b1cc0d881c4009e5a89e0893c5835f3a4c47d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 18 Apr 2023 19:34:07 -0400 Subject: [PATCH 001/208] Add a github action to use pip xformers package for dependencies. --- .../windows_release_cu118_dependencies_2.yml | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/workflows/windows_release_cu118_dependencies_2.yml diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml new file mode 100644 index 000000000..a88449527 --- /dev/null +++ b/.github/workflows/windows_release_cu118_dependencies_2.yml @@ -0,0 +1,30 @@ +name: "Windows Release cu118 dependencies 2" + +on: + workflow_dispatch: +# push: +# branches: +# - master + +jobs: + build_dependencies: + runs-on: windows-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10.9' + + - shell: bash + run: | + python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip install --no-cache-dir ./temp_wheel_dir/* + echo installed basic + ls -lah temp_wheel_dir + mv temp_wheel_dir cu118_python_deps + tar cf cu118_python_deps.tar cu118_python_deps + + - uses: actions/cache/save@v3 + with: + path: cu118_python_deps.tar + key: ${{ runner.os }}-build-cu118 From 3696d1699a6fece2485c063317cf65abbcddb79b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 09:36:19 -0400 Subject: [PATCH 002/208] Add support for GLIGEN textbox model. --- comfy/gligen.py | 343 ++++++++++++++++++ comfy/ldm/modules/attention.py | 16 + .../modules/diffusionmodules/openaimodel.py | 2 + comfy/model_management.py | 6 +- comfy/samplers.py | 57 ++- comfy/sd.py | 22 +- folder_paths.py | 2 + models/gligen/put_gligen_models_here | 0 nodes.py | 71 +++- 9 files changed, 491 insertions(+), 28 deletions(-) create mode 100644 comfy/gligen.py create mode 100644 models/gligen/put_gligen_models_here diff --git a/comfy/gligen.py b/comfy/gligen.py new file mode 100644 index 000000000..8770383e5 --- /dev/null +++ b/comfy/gligen.py @@ -0,0 +1,343 @@ +import torch +from torch import nn, einsum +from ldm.modules.attention import CrossAttention +from inspect import isfunction + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * torch.nn.functional.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class GatedCrossAttentionDense(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + self.attn = CrossAttention( + query_dim=query_dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as + # original one + self.scale = 1 + + def forward(self, x, objs): + + x = x + self.scale * \ + torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs) + x = x + self.scale * \ + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) + + return x + + +class GatedSelfAttentionDense(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj + # feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = CrossAttention( + query_dim=query_dim, + context_dim=query_dim, + heads=n_heads, + dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as + # original one + self.scale = 1 + + def forward(self, x, objs): + + N_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn( + self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :] + x = x + self.scale * \ + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) + + return x + + +class GatedSelfAttentionDense2(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, d_head): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj + # feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = CrossAttention( + query_dim=query_dim, context_dim=query_dim, dim_head=d_head) + self.ff = FeedForward(query_dim, glu=True) + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) + self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) + + # this can be useful: we can externally change magnitude of tanh(alpha) + # for example, when it is set to 0, then the entire model is same as + # original one + self.scale = 1 + + def forward(self, x, objs): + + B, N_visual, _ = x.shape + B, N_ground, _ = objs.shape + + objs = self.linear(objs) + + # sanity check + size_v = math.sqrt(N_visual) + size_g = math.sqrt(N_ground) + assert int(size_v) == size_v, "Visual tokens must be square rootable" + assert int(size_g) == size_g, "Grounding tokens must be square rootable" + size_v = int(size_v) + size_g = int(size_g) + + # select grounding token and resize it to visual token size as residual + out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[ + :, N_visual:, :] + out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g) + out = torch.nn.functional.interpolate( + out, (size_v, size_v), mode='bicubic') + residual = out.reshape(B, -1, N_visual).permute(0, 2, 1) + + # add residual to visual feature + x = x + self.scale * torch.tanh(self.alpha_attn) * residual + x = x + self.scale * \ + torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) + + return x + + +class FourierEmbedder(): + def __init__(self, num_freqs=64, temperature=100): + + self.num_freqs = num_freqs + self.temperature = temperature + self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) + + @torch.no_grad() + def __call__(self, x, cat_dim=-1): + "x: arbitrary shape of tensor. dim: cat dim" + out = [] + for freq in self.freq_bands: + out.append(torch.sin(freq * x)) + out.append(torch.cos(freq * x)) + return torch.cat(out, cat_dim) + + +class PositionNet(nn.Module): + def __init__(self, in_dim, out_dim, fourier_freqs=8): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) + self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy + + self.linears = nn.Sequential( + nn.Linear(self.in_dim + self.position_dim, 512), + nn.SiLU(), + nn.Linear(512, 512), + nn.SiLU(), + nn.Linear(512, out_dim), + ) + + self.null_positive_feature = torch.nn.Parameter( + torch.zeros([self.in_dim])) + self.null_position_feature = torch.nn.Parameter( + torch.zeros([self.position_dim])) + + def forward(self, boxes, masks, positive_embeddings): + B, N, _ = boxes.shape + masks = masks.unsqueeze(-1) + + # embedding position (it may includes padding as placeholder) + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C + + # learnable null embedding + positive_null = self.null_positive_feature.view(1, 1, -1) + xyxy_null = self.null_position_feature.view(1, 1, -1) + + # replace padding with learnable null embedding + positive_embeddings = positive_embeddings * \ + masks + (1 - masks) * positive_null + xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null + + objs = self.linears( + torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) + assert objs.shape == torch.Size([B, N, self.out_dim]) + return objs + + +class Gligen(nn.Module): + def __init__(self, modules, position_net, key_dim): + super().__init__() + self.module_list = nn.ModuleList(modules) + self.position_net = position_net + self.key_dim = key_dim + self.max_objs = 30 + + def _set_position(self, boxes, masks, positive_embeddings): + objs = self.position_net(boxes, masks, positive_embeddings) + + def func(key, x): + module = self.module_list[key] + return module(x, objs) + return func + + def set_position(self, latent_image_shape, position_params, device): + batch, c, h, w = latent_image_shape + masks = torch.zeros([self.max_objs], device="cpu") + boxes = [] + positive_embeddings = [] + for p in position_params: + x1 = (p[4]) / w + y1 = (p[3]) / h + x2 = (p[4] + p[2]) / w + y2 = (p[3] + p[1]) / h + masks[len(boxes)] = 1.0 + boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)] + positive_embeddings += [p[0]] + append_boxes = [] + append_conds = [] + if len(boxes) < self.max_objs: + append_boxes = [torch.zeros( + [self.max_objs - len(boxes), 4], device="cpu")] + append_conds = [torch.zeros( + [self.max_objs - len(boxes), self.key_dim], device="cpu")] + + box_out = torch.cat( + boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1) + masks = masks.unsqueeze(0).repeat(batch, 1) + conds = torch.cat(positive_embeddings + + append_conds).unsqueeze(0).repeat(batch, 1, 1) + return self._set_position( + box_out.to(device), + masks.to(device), + conds.to(device)) + + def set_empty(self, latent_image_shape, device): + batch, c, h, w = latent_image_shape + masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1) + box_out = torch.zeros([self.max_objs, 4], + device="cpu").repeat(batch, 1, 1) + conds = torch.zeros([self.max_objs, self.key_dim], + device="cpu").repeat(batch, 1, 1) + return self._set_position( + box_out.to(device), + masks.to(device), + conds.to(device)) + + def cleanup(self): + pass + + def get_models(self): + return [self] + +def load_gligen(sd): + sd_k = sd.keys() + output_list = [] + key_dim = 768 + for a in ["input_blocks", "middle_block", "output_blocks"]: + for b in range(20): + k_temp = filter(lambda k: "{}.{}.".format(a, b) + in k and ".fuser." in k, sd_k) + k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp) + + n_sd = {} + for k in k_temp: + n_sd[k[1]] = sd[k[0]] + if len(n_sd) > 0: + query_dim = n_sd["linear.weight"].shape[0] + key_dim = n_sd["linear.weight"].shape[1] + + if key_dim == 768: # SD1.x + n_heads = 8 + d_head = query_dim // n_heads + else: + d_head = 64 + n_heads = query_dim // d_head + + gated = GatedSelfAttentionDense( + query_dim, key_dim, n_heads, d_head) + gated.load_state_dict(n_sd, strict=False) + output_list.append(gated) + + if "position_net.null_positive_feature" in sd_k: + in_dim = sd["position_net.null_positive_feature"].shape[0] + out_dim = sd["position_net.linears.4.weight"].shape[0] + + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.position_net = PositionNet(in_dim, out_dim) + w.load_state_dict(sd, strict=False) + + gligen = Gligen(output_list, w.position_net, key_dim) + return gligen diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c83387348..98dbda635 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -510,6 +510,14 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) def _forward(self, x, context=None, transformer_options={}): + current_index = None + if "current_index" in transformer_options: + current_index = transformer_options["current_index"] + if "patches" in transformer_options: + transformer_patches = transformer_options["patches"] + else: + transformer_patches = {} + n = self.norm1(x) if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) @@ -518,11 +526,19 @@ class BasicTransformerBlock(nn.Module): n = self.attn1(n, context=context if self.disable_self_attn else None) x += n + if "middle_patch" in transformer_patches: + patch = transformer_patches["middle_patch"] + for p in patch: + x = p(current_index, x) + n = self.norm2(x) n = self.attn2(n, context=context) x += n x = self.ff(self.norm3(x)) + x + + if current_index is not None: + transformer_options["current_index"] += 1 return x diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 8a4e8b3e1..4c69c8567 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -782,6 +782,8 @@ class UNetModel(nn.Module): :return: an [N x C x ...] Tensor of outputs. """ transformer_options["original_shape"] = list(x.shape) + transformer_options["current_index"] = 0 + assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" diff --git a/comfy/model_management.py b/comfy/model_management.py index 76455e4a2..a0d1313d2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -176,7 +176,7 @@ def load_model_gpu(model): model_accelerated = True return current_loaded_model -def load_controlnet_gpu(models): +def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state if vram_state == VRAMState.CPU: @@ -186,6 +186,10 @@ def load_controlnet_gpu(models): #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return + models = [] + for m in control_models: + models += m.get_models() + for m in current_gpu_controlnets: if m not in models: m.cpu() diff --git a/comfy/samplers.py b/comfy/samplers.py index 05af6fe88..31968e185 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -70,7 +70,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con control = None if 'control' in cond[1]: control = cond[1]['control'] - return (input_x, mult, conditionning, area, control) + + patches = None + if 'gligen' in cond[1]: + gligen = cond[1]['gligen'] + patches = {} + gligen_type = gligen[0] + gligen_model = gligen[1] + if gligen_type == "position": + gligen_patch = gligen_model.set_position(input_x.shape, gligen[2], input_x.device) + else: + gligen_patch = gligen_model.set_empty(input_x.shape, input_x.device) + + patches['middle_patch'] = [gligen_patch] + + return (input_x, mult, conditionning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -91,12 +105,21 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con def can_concat_cond(c1, c2): if c1[0].shape != c2[0].shape: return False + + #control if (c1[4] is None) != (c2[4] is None): return False if c1[4] is not None: if c1[4] is not c2[4]: return False + #patches + if (c1[5] is None) != (c2[5] is None): + return False + if (c1[5] is not None): + if c1[5] is not c2[5]: + return False + return cond_equal_size(c1[2], c2[2]) def cond_cat(c_list): @@ -166,6 +189,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con cond_or_uncond = [] area = [] control = None + patches = None for x in to_batch: o = to_run.pop(x) p = o[0] @@ -175,6 +199,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con area += [p[3]] cond_or_uncond += [o[1]] control = p[4] + patches = p[5] batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) @@ -184,8 +209,14 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if control is not None: c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) + transformer_options = {} if 'transformer_options' in model_options: - c['transformer_options'] = model_options['transformer_options'] + transformer_options = model_options['transformer_options'].copy() + + if patches is not None: + transformer_options["patches"] = patches + + c['transformer_options'] = transformer_options output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x @@ -309,8 +340,7 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] - -def apply_control_net_to_equal_area(conds, uncond): +def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] cond_other = [] uncond_cnets = [] @@ -318,15 +348,15 @@ def apply_control_net_to_equal_area(conds, uncond): for t in range(len(conds)): x = conds[t] if 'area' not in x[1]: - if 'control' in x[1] and x[1]['control'] is not None: - cond_cnets.append(x[1]['control']) + if name in x[1] and x[1][name] is not None: + cond_cnets.append(x[1][name]) else: cond_other.append((x, t)) for t in range(len(uncond)): x = uncond[t] if 'area' not in x[1]: - if 'control' in x[1] and x[1]['control'] is not None: - uncond_cnets.append(x[1]['control']) + if name in x[1] and x[1][name] is not None: + uncond_cnets.append(x[1][name]) else: uncond_other.append((x, t)) @@ -336,15 +366,16 @@ def apply_control_net_to_equal_area(conds, uncond): for x in range(len(cond_cnets)): temp = uncond_other[x % len(uncond_other)] o = temp[0] - if 'control' in o[1] and o[1]['control'] is not None: + if name in o[1] and o[1][name] is not None: n = o[1].copy() - n['control'] = cond_cnets[x] + n[name] = uncond_fill_func(cond_cnets, x) uncond += [[o[0], n]] else: n = o[1].copy() - n['control'] = cond_cnets[x] + n[name] = uncond_fill_func(cond_cnets, x) uncond[temp[1]] = [o[0], n] + def encode_adm(noise_augmentor, conds, batch_size, device): for t in range(len(conds)): x = conds[t] @@ -378,6 +409,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): return conds + class KSampler: SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", @@ -466,7 +498,8 @@ class KSampler: for c in negative: create_cond_with_same_area_if_none(positive, c) - apply_control_net_to_equal_area(positive, negative) + apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if self.model.model.diffusion_model.dtype == torch.float16: precision_scope = torch.autocast diff --git a/comfy/sd.py b/comfy/sd.py index 1d7774742..211acd70e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -13,6 +13,7 @@ from .t2i_adapter import adapter from . import utils from . import clip_vision +from . import gligen def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) @@ -378,7 +379,7 @@ class CLIP: def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def encode_from_tokens(self, tokens): + def encode_from_tokens(self, tokens, return_pooled=False): if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) try: @@ -388,6 +389,10 @@ class CLIP: except Exception as e: self.patcher.unpatch_model() raise e + if return_pooled: + eos_token_index = max(range(len(tokens[0])), key=tokens[0].__getitem__) + pooled = cond[:, eos_token_index] + return cond, pooled return cond def encode(self, text): @@ -564,10 +569,10 @@ class ControlNet: c.strength = self.strength return c - def get_control_models(self): + def get_models(self): out = [] if self.previous_controlnet is not None: - out += self.previous_controlnet.get_control_models() + out += self.previous_controlnet.get_models() out.append(self.control_model) return out @@ -737,10 +742,10 @@ class T2IAdapter: del self.cond_hint self.cond_hint = None - def get_control_models(self): + def get_models(self): out = [] if self.previous_controlnet is not None: - out += self.previous_controlnet.get_control_models() + out += self.previous_controlnet.get_models() return out def load_t2i_adapter(t2i_data): @@ -787,6 +792,13 @@ def load_clip(ckpt_path, embedding_directory=None): clip.load_from_state_dict(clip_data) return clip +def load_gligen(ckpt_path): + data = utils.load_torch_file(ckpt_path) + model = gligen.load_gligen(data) + if model_management.should_use_fp16(): + model = model.half() + return model + def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): with open(config_path, 'r') as stream: config = yaml.safe_load(stream) diff --git a/folder_paths.py b/folder_paths.py index 61f446c96..3c4ad3711 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -26,6 +26,8 @@ folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")] folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) +folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) + folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions) folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) diff --git a/models/gligen/put_gligen_models_here b/models/gligen/put_gligen_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 06b69f453..8555f272a 100644 --- a/nodes.py +++ b/nodes.py @@ -490,6 +490,51 @@ class unCLIPConditioning: c.append(n) return (c, ) +class GLIGENLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"), )}} + + RETURN_TYPES = ("GLIGEN",) + FUNCTION = "load_gligen" + + CATEGORY = "_for_testing/gligen" + + def load_gligen(self, gligen_name): + gligen_path = folder_paths.get_full_path("gligen", gligen_name) + gligen = comfy.sd.load_gligen(gligen_path) + return (gligen,) + +class GLIGENTextBoxApply: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning_to": ("CONDITIONING", ), + "clip": ("CLIP", ), + "gligen_textbox_model": ("GLIGEN", ), + "text": ("STRING", {"multiline": True}), + "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "_for_testing/gligen" + + def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): + c = [] + cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True) + for t in conditioning_to: + n = [t[0], t[1].copy()] + position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)] + prev = [] + if "gligen" in n[1]: + prev = n[1]['gligen'][2] + + n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params) + c.append(n) + return (c, ) class EmptyLatentImage: def __init__(self, device="cpu"): @@ -731,27 +776,30 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative_copy = [] control_nets = [] + def get_models(cond): + models = [] + for c in cond: + if 'control' in c[1]: + models += [c[1]['control']] + if 'gligen' in c[1]: + models += [c[1]['gligen'][1]] + return models + for p in positive: t = p[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) - if 'control' in p[1]: - control_nets += [p[1]['control']] positive_copy += [[t] + p[1:]] for n in negative: t = n[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) - if 'control' in n[1]: - control_nets += [n[1]['control']] negative_copy += [[t] + n[1:]] - control_net_models = [] - for x in control_nets: - control_net_models += x.get_control_models() - comfy.model_management.load_controlnet_gpu(control_net_models) + models = get_models(positive) + get_models(negative) + comfy.model_management.load_controlnet_gpu(models) if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) @@ -761,8 +809,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() - for c in control_nets: - c.cleanup() + for m in models: + m.cleanup() out = latent.copy() out["samples"] = samples @@ -1128,6 +1176,9 @@ NODE_CLASS_MAPPINGS = { "VAEEncodeTiled": VAEEncodeTiled, "TomePatchModel": TomePatchModel, "unCLIPCheckpointLoader": unCLIPCheckpointLoader, + "GLIGENLoader": GLIGENLoader, + "GLIGENTextBoxApply": GLIGENTextBoxApply, + "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, } From 781b724ac667e42900c331988f356a85670c0ec5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 11:30:18 -0400 Subject: [PATCH 003/208] Add GLIGEN model link to colab. --- notebooks/comfyui_colab.ipynb | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c088de89c..c1982d8be 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -138,6 +138,11 @@ "# Controlnet Preprocessor nodes by Fannovel16\n", "#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n", "\n", + "\n", + "# GLIGEN\n", + "#!wget -c https://huggingface.co/comfyanonymous/GLIGEN_pruned_safetensors/resolve/main/gligen_sd14_textbox_pruned_fp16.safetensors -P ./models/gligen/\n", + "\n", + "\n", "# ESRGAN upscale model\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth -P ./models/upscale_models/\n", "#!wget -c https://huggingface.co/sberbank-ai/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth -P ./models/upscale_models/\n", From 2d546d510d1f7919bbae3ac08108e0d05e9c0bae Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 11:47:49 -0400 Subject: [PATCH 004/208] Add gligen entry to extra_model_paths example. --- extra_model_paths.yaml.example | 1 + 1 file changed, 1 insertion(+) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index f421f54dc..ac1ffe9d2 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -18,6 +18,7 @@ a111: #other_ui: # base_path: path/to/ui # checkpoints: models/checkpoints +# gligen: models/gligen # custom_nodes: path/custom_nodes From 96b57a9ad6447b95921b91e5f52fb3684f73514f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 19 Apr 2023 21:11:38 -0400 Subject: [PATCH 005/208] Don't pass adm to model when it doesn't support it. --- comfy/samplers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 31968e185..19ebc97d9 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -36,8 +36,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con strength = cond[1]['strength'] adm_cond = None - if 'adm' in cond[1]: - adm_cond = cond[1]['adm'] + if 'adm_encoded' in cond[1]: + adm_cond = cond[1]['adm_encoded'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] mult = torch.ones_like(input_x) * strength @@ -405,7 +405,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): else: adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) x[1] = x[1].copy() - x[1]["adm"] = torch.cat([adm_out] * batch_size) + x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) return conds From 94e9798a4b614627805be197aa3da415a1de7ee4 Mon Sep 17 00:00:00 2001 From: omar92 Date: Thu, 20 Apr 2023 06:19:56 +0200 Subject: [PATCH 006/208] when drag from node input or output show all possible nodes that you can connect --- web/extensions/core/slotDefaults.js | 50 ++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 0b6a0a150..3ff5fdb06 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -1,21 +1,49 @@ import { app } from "/scripts/app.js"; - +import { ComfyWidgets } from "/scripts/widgets.js"; // Adds defaults for quickly adding nodes with middle click on the input/output app.registerExtension({ name: "Comfy.SlotDefaults", init() { LiteGraph.middle_click_slot_add_default_node = true; - LiteGraph.slot_types_default_in = { - MODEL: "CheckpointLoaderSimple", - LATENT: "EmptyLatentImage", - VAE: "VAELoader", - }; + }, + async beforeRegisterNodeDef(nodeType, nodeData, app) { + var nodeId = nodeData.name; + var inputs = []; + //if (nodeData["input"]["optional"] != undefined) { + // inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]); + //} else { + inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logica to create node with optional inputs + //} + for (const inputKey in inputs) { + var input = (inputs[inputKey]); + //make sure input[0] is a string + if (typeof input[0] !== "string") continue; + + // for (const slotKey in inputs[inputKey]) { + var type = input[0] + if (type in ComfyWidgets) { + var customProperties = input[1] + //console.log(customProperties) + if (!(customProperties?.forceInput)) continue; //ignore widgets that don't force input + } + + if (!(type in LiteGraph.slot_types_default_out)) { + LiteGraph.slot_types_default_out[type] = ["Reroute"]; + } + if (LiteGraph.slot_types_default_out[type].includes(nodeId)) continue; + LiteGraph.slot_types_default_out[type].push(nodeId); + // } + } + + var outputs = nodeData["output"]; + for (const key in outputs) { + var type = outputs[key]; + if (!(type in LiteGraph.slot_types_default_in)) { + LiteGraph.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'() + } + LiteGraph.slot_types_default_in[type].push(nodeId); + } - LiteGraph.slot_types_default_out = { - LATENT: "VAEDecode", - IMAGE: "SaveImage", - CLIP: "CLIPTextEncode", - }; }, }); From 5229c1f972b4130d5d0ddc19362604c6ec57d1fd Mon Sep 17 00:00:00 2001 From: omar92 Date: Thu, 20 Apr 2023 21:13:14 +0200 Subject: [PATCH 007/208] add option on the settings to change the number of the suggestions --- web/extensions/core/slotDefaults.js | 61 ++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 3ff5fdb06..04baadc6a 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -4,46 +4,69 @@ import { ComfyWidgets } from "/scripts/widgets.js"; app.registerExtension({ name: "Comfy.SlotDefaults", + suggestionsNumber: null, init() { LiteGraph.middle_click_slot_add_default_node = true; + this.suggestionsNumber = app.ui.settings.addSetting({ + id: "Comfy.NodeSuggestions.number", + name: "number of nodes suggestions", + type: "slider", + attrs: { + min: 1, + max: 100, + step: 1, + }, + defaultValue: 5, + onChange: (newVal, oldVal) => { + this.setDefaults(newVal); + } + }); }, + slot_types_default_out: {}, + slot_types_default_in: {}, async beforeRegisterNodeDef(nodeType, nodeData, app) { - var nodeId = nodeData.name; + var nodeId = nodeData.name; var inputs = []; - //if (nodeData["input"]["optional"] != undefined) { - // inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]); - //} else { - inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logica to create node with optional inputs - //} + inputs = nodeData["input"]["required"]; //only show required inputs to reduce the mess also not logical to create node with optional inputs for (const inputKey in inputs) { var input = (inputs[inputKey]); - //make sure input[0] is a string if (typeof input[0] !== "string") continue; - // for (const slotKey in inputs[inputKey]) { var type = input[0] if (type in ComfyWidgets) { var customProperties = input[1] - //console.log(customProperties) if (!(customProperties?.forceInput)) continue; //ignore widgets that don't force input } - if (!(type in LiteGraph.slot_types_default_out)) { - LiteGraph.slot_types_default_out[type] = ["Reroute"]; + if (!(type in this.slot_types_default_out)) { + this.slot_types_default_out[type] = ["Reroute"]; } - if (LiteGraph.slot_types_default_out[type].includes(nodeId)) continue; - LiteGraph.slot_types_default_out[type].push(nodeId); - // } - } + if (this.slot_types_default_out[type].includes(nodeId)) continue; + this.slot_types_default_out[type].push(nodeId); + } var outputs = nodeData["output"]; for (const key in outputs) { var type = outputs[key]; - if (!(type in LiteGraph.slot_types_default_in)) { - LiteGraph.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'() + if (!(type in this.slot_types_default_in)) { + this.slot_types_default_in[type] = ["Reroute"];// ["Reroute", "Primitive"]; primitive doesn't always work :'() } - LiteGraph.slot_types_default_in[type].push(nodeId); - } + this.slot_types_default_in[type].push(nodeId); + } + var maxNum = this.suggestionsNumber ? this.suggestionsNumber.value : 5; + this.setDefaults(maxNum); }, + setDefaults(maxNum) { + + LiteGraph.slot_types_default_out = {}; + LiteGraph.slot_types_default_in = {}; + + for (const type in this.slot_types_default_out) { + LiteGraph.slot_types_default_out[type] = this.slot_types_default_out[type].slice(0, maxNum); + } + for (const type in this.slot_types_default_in) { + LiteGraph.slot_types_default_in[type] = this.slot_types_default_in[type].slice(0, maxNum); + } + } }); From 31e60adb2802874a5889623a83149faa32924a98 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 20 Apr 2023 17:30:10 -0400 Subject: [PATCH 008/208] Add GLIGEN example to README. --- README.md | 1 + nodes.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index be2cb8ec5..bf16006bf 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [ControlNet and T2I-Adapter](https://comfyanonymous.github.io/ComfyUI_examples/controlnet/) - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) +- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - Starts up very fast. - Works fully offline: will never download anything. - [Config file](extra_model_paths.yaml.example) to set the search paths for models. diff --git a/nodes.py b/nodes.py index 8555f272a..48c3ee9c3 100644 --- a/nodes.py +++ b/nodes.py @@ -498,7 +498,7 @@ class GLIGENLoader: RETURN_TYPES = ("GLIGEN",) FUNCTION = "load_gligen" - CATEGORY = "_for_testing/gligen" + CATEGORY = "loaders" def load_gligen(self, gligen_name): gligen_path = folder_paths.get_full_path("gligen", gligen_name) @@ -520,7 +520,7 @@ class GLIGENTextBoxApply: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "_for_testing/gligen" + CATEGORY = "conditioning/gligen" def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): c = [] From d2ef3465ca838e528008cb5e20b40d25079d5176 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 20 Apr 2023 18:23:51 -0600 Subject: [PATCH 009/208] Improve current word selection --- web/extensions/core/editAttention.js | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index bebc80b12..cc51a04e5 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -89,24 +89,17 @@ app.registerExtension({ end = nearestEnclosure.end; selectedText = inputField.value.substring(start, end); } else { - // Select the current word, find the start and end of the word (first space before and after) - const wordStart = inputField.value.substring(0, start).lastIndexOf(" ") + 1; - const wordEnd = inputField.value.substring(end).indexOf(" "); - // If there is no space after the word, select to the end of the string - if (wordEnd === -1) { - end = inputField.value.length; - } else { - end += wordEnd; + // Select the current word, find the start and end of the word + const delimiters = " .,\\/!?%^*;:{}=-_`~()\r\n\t"; + + while (!delimiters.includes(inputField.value[start - 1]) && start > 0) { + start--; + } + + while (!delimiters.includes(inputField.value[end]) && end < inputField.value.length) { + end++; } - start = wordStart; - // Remove all punctuation at the end and beginning of the word - while (inputField.value[start].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { - start++; - } - while (inputField.value[end - 1].match(/[.,\/#!$%\^&\*;:{}=\-_`~()]/)) { - end--; - } selectedText = inputField.value.substring(start, end); if (!selectedText) return; } From 907010e0824eeab12c5948e5afa4df6d0934be9a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 20 Apr 2023 23:58:25 -0400 Subject: [PATCH 010/208] Remove some useless code. --- comfy/samplers.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 19ebc97d9..15527224e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -7,23 +7,6 @@ from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps -class CFGDenoiser(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale): - if len(uncond[0]) == len(cond[0]) and x.shape[0] * x.shape[2] * x.shape[3] < (96 * 96): #TODO check memory instead - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - else: - cond = self.inner_model(x, sigma, cond=cond) - uncond = self.inner_model(x, sigma, cond=uncond) - return uncond + (cond - uncond) * cond_scale - - #The main sampling function shared by all the samplers #Returns predicted noise def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None, model_options={}): From 98ae4bbfdee1ea9da62e3d22a3c6428032a78398 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Thu, 20 Apr 2023 23:55:20 -0600 Subject: [PATCH 011/208] Remove brackets if weight == 1 --- web/extensions/core/editAttention.js | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/editAttention.js b/web/extensions/core/editAttention.js index cc51a04e5..b937bb103 100644 --- a/web/extensions/core/editAttention.js +++ b/web/extensions/core/editAttention.js @@ -128,8 +128,13 @@ app.registerExtension({ // Increment the weight const weightDelta = event.key === "ArrowUp" ? delta : -delta; - const updatedText = selectedText.replace(/(.*:)(\d+(\.\d+)?)(.*)/, (match, prefix, weight, _, suffix) => { - return prefix + incrementWeight(weight, weightDelta) + suffix; + const updatedText = selectedText.replace(/\((.*):(\d+(?:\.\d+)?)\)/, (match, text, weight) => { + weight = incrementWeight(weight, weightDelta); + if (weight == 1) { + return text; + } else { + return `(${text}:${weight})`; + } }); inputField.setRangeText(updatedText, start, end, "select"); From 989acd769a6b5f3a5d6e3cd03fafbd9668c2dbdf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 21 Apr 2023 23:43:38 -0400 Subject: [PATCH 012/208] Cleanup. --- web/extensions/core/slotDefaults.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 04baadc6a..3ec605900 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -54,7 +54,7 @@ app.registerExtension({ this.slot_types_default_in[type].push(nodeId); } - var maxNum = this.suggestionsNumber ? this.suggestionsNumber.value : 5; + var maxNum = this.suggestionsNumber.value; this.setDefaults(maxNum); }, setDefaults(maxNum) { From 6908f9c94992b32fbb96be0f6cd8c5b362d72a77 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 22 Apr 2023 14:30:39 -0400 Subject: [PATCH 013/208] This makes pytorch2.0 attention perform a bit faster. --- comfy/ldm/modules/attention.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 98dbda635..c27d032a3 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -455,11 +455,7 @@ class CrossAttentionPytorch(nn.Module): b, _, _ = q.shape q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), + lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2), (q, k, v), ) @@ -468,10 +464,7 @@ class CrossAttentionPytorch(nn.Module): if exists(mask): raise NotImplementedError out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) + out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head) ) return self.to_out(out) From ee030d281bbd25d385ba9ca10badb66b487cca21 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 22 Apr 2023 16:02:26 -0700 Subject: [PATCH 014/208] Add support for multiple unique inpainting masks This enables workflows like "Inpaint at full resolution" when using batch sizes greater than 1. --- nodes.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/nodes.py b/nodes.py index 48c3ee9c3..9335d5243 100644 --- a/nodes.py +++ b/nodes.py @@ -171,24 +171,28 @@ class VAEEncodeForInpaint: def encode(self, vae, pixels, mask): x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 - mask = torch.nn.functional.interpolate(mask[None,None,], size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")[0][0] + if len(mask.shape) < 3: + mask = mask.unsqueeze(0).unsqueeze(0) + elif len(mask.shape) < 4: + mask = mask.unsqueeze(1) + mask = torch.nn.functional.interpolate(mask, size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: pixels = pixels[:,:x,:y,:] - mask = mask[:x,:y] + mask = mask[:,:x,:y,:] #grow mask by a few pixels to keep things seamless in latent space kernel_tensor = torch.ones((1, 1, 6, 6)) - mask_erosion = torch.clamp(torch.nn.functional.conv2d((mask.round())[None], kernel_tensor, padding=3), 0, 1) - m = (1.0 - mask.round()) + mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1) + m = (1.0 - mask.round()).squeeze(1) for i in range(3): pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] *= m pixels[:,:,:,i] += 0.5 t = vae.encode(pixels) - return ({"samples":t, "noise_mask": (mask_erosion[0][:x,:y].round())}, ) + return ({"samples":t, "noise_mask": (mask_erosion[:,:x,:y,:].round())}, ) class CheckpointLoader: @classmethod @@ -759,10 +763,15 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent['noise_mask'] - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + if len(noise_mask.shape) < 3: + noise_mask = noise_mask.unsqueeze(0).unsqueeze(0) + elif len(noise_mask.shape) < 4: + noise_mask = noise_mask.unsqueeze(1) + noise_mask = torch.nn.functional.interpolate(noise_mask, size=(noise.shape[2], noise.shape[3]), mode="bilinear") noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) + if noise_mask.shape[0] < latent_image.shape[0]: + noise_mask = noise_mask.repeat(latent_image.shape[0] // noise_mask.shape[0], 1, 1, 1) noise_mask = noise_mask.to(device) real_model = None From c8355ed39ff39a10eb7a3d262f278dc99ad2e73b Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 23 Apr 2023 10:31:21 +0100 Subject: [PATCH 015/208] use window.name instead of session storage - prevents duplicate stealing session id --- web/scripts/api.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/scripts/api.js b/web/scripts/api.js index 2b90c2abc..d29faa5ba 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -35,7 +35,7 @@ class ComfyApi extends EventTarget { } let opened = false; - let existingSession = sessionStorage["Comfy.SessionId"] || ""; + let existingSession = window.name; if (existingSession) { existingSession = "?clientId=" + existingSession; } @@ -75,7 +75,7 @@ class ComfyApi extends EventTarget { case "status": if (msg.data.sid) { this.clientId = msg.data.sid; - sessionStorage["Comfy.SessionId"] = this.clientId; + window.name = this.clientId; } this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); break; From 5282f5643476ba0f55197c3ca8b72ce43525b025 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 23 Apr 2023 12:35:25 -0400 Subject: [PATCH 016/208] Implement Linear hypernetworks. Add a HypernetworkLoader node to use hypernetworks. --- comfy/ldm/modules/attention.py | 69 +++++++++++++--- comfy/model_management.py | 3 + comfy/samplers.py | 10 ++- comfy/sd.py | 23 ++++++ comfy/utils.py | 7 +- comfy_extras/nodes_hypernetwork.py | 87 +++++++++++++++++++++ folder_paths.py | 1 + models/hypernetworks/put_hypernetworks_here | 0 nodes.py | 1 + 9 files changed, 185 insertions(+), 16 deletions(-) create mode 100644 comfy_extras/nodes_hypernetwork.py create mode 100644 models/hypernetworks/put_hypernetworks_here diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c27d032a3..ce7180d91 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads query = self.to_q(x) context = default(context, x) key = self.to_k(context) - value = self.to_v(context) + if value is not None: + value = self.to_v(value) + else: + value = self.to_v(context) + del context, x query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) @@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) k_in = self.to_k(context) - v_in = self.to_v(context) + if value is not None: + v_in = self.to_v(value) + del value + else: + v_in = self.to_v(context) del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) @@ -350,13 +358,17 @@ class CrossAttention(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) @@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( @@ -447,11 +463,15 @@ class CrossAttentionPytorch(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( @@ -512,11 +532,25 @@ class BasicTransformerBlock(nn.Module): transformer_patches = {} n = self.norm1(x) + if self.disable_self_attn: + context_attn1 = context + else: + context_attn1 = None + value_attn1 = None + + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + if context_attn1 is None: + context_attn1 = n + value_attn1 = context_attn1 + for p in patch: + n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) + if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) - n = u(self.attn1(m(n), context=context if self.disable_self_attn else None)) + n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) else: - n = self.attn1(n, context=context if self.disable_self_attn else None) + n = self.attn1(n, context=context_attn1, value=value_attn1) x += n if "middle_patch" in transformer_patches: @@ -525,7 +559,16 @@ class BasicTransformerBlock(nn.Module): x = p(current_index, x) n = self.norm2(x) - n = self.attn2(n, context=context) + + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] + value_attn2 = context_attn2 + for p in patch: + n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2) + + n = self.attn2(n, context=context_attn2, value=value_attn2) x += n x = self.ff(self.norm3(x)) + x diff --git a/comfy/model_management.py b/comfy/model_management.py index a0d1313d2..6e3a03530 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -133,6 +133,7 @@ def unload_model(): #never unload models from GPU on high vram if vram_state != VRAMState.HIGH_VRAM: current_loaded_model.model.cpu() + current_loaded_model.model_patches_to("cpu") current_loaded_model.unpatch_model() current_loaded_model = None @@ -156,6 +157,8 @@ def load_model_gpu(model): except Exception as e: model.unpatch_model() raise e + + model.model_patches_to(get_torch_device()) current_loaded_model = model if vram_state == VRAMState.CPU: pass diff --git a/comfy/samplers.py b/comfy/samplers.py index 15527224e..b860f25f1 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -197,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con transformer_options = model_options['transformer_options'].copy() if patches is not None: - transformer_options["patches"] = patches + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + else: + transformer_options["patches"] = patches c['transformer_options'] = transformer_options diff --git a/comfy/sd.py b/comfy/sd.py index 211acd70e..92dbb931d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -254,6 +254,29 @@ class ModelPatcher: def set_model_sampler_cfg_function(self, sampler_cfg_function): self.model_options["sampler_cfg_function"] = sampler_cfg_function + + def set_model_patch(self, patch, name): + to = self.model_options["transformer_options"] + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + + def set_model_attn1_patch(self, patch): + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch): + self.set_model_patch(patch, "attn2_patch") + + def model_patches_to(self, device): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + patch_list[i] = patch_list[i].to(device) + def model_dtype(self): return self.model.diffusion_model.dtype diff --git a/comfy/utils.py b/comfy/utils.py index 0380b91dd..68f93403c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,11 +1,14 @@ import torch -def load_torch_file(ckpt): +def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: - pl_sd = torch.load(ckpt, map_location="cpu") + if safe_load: + pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) + else: + pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py new file mode 100644 index 000000000..db2f8695c --- /dev/null +++ b/comfy_extras/nodes_hypernetwork.py @@ -0,0 +1,87 @@ +import comfy.utils +import folder_paths +import torch + +def load_hypernetwork_patch(path, strength): + sd = comfy.utils.load_torch_file(path, safe_load=True) + activation_func = sd.get('activation_func', 'linear') + is_layer_norm = sd.get('is_layer_norm', False) + use_dropout = sd.get('use_dropout', False) + activate_output = sd.get('activate_output', False) + last_layer_dropout = sd.get('last_layer_dropout', False) + + if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False: + print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) + return None + + out = {} + + for d in sd: + try: + dim = int(d) + except: + continue + + output = [] + for index in [0, 1]: + attn_weights = sd[dim][index] + keys = attn_weights.keys() + + linears = filter(lambda a: a.endswith(".weight"), keys) + linears = sorted(list(map(lambda a: a[:-len(".weight")], linears))) + layers = [] + + for lin_name in linears: + lin_weight = attn_weights['{}.weight'.format(lin_name)] + lin_bias = attn_weights['{}.bias'.format(lin_name)] + layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) + layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) + layers += [layer] + + output.append(torch.nn.Sequential(*layers)) + out[dim] = torch.nn.ModuleList(output) + + class hypernetwork_patch: + def __init__(self, hypernet, strength): + self.hypernet = hypernet + self.strength = strength + def __call__(self, current_index, q, k, v): + dim = k.shape[-1] + if dim in self.hypernet: + hn = self.hypernet[dim] + k = k + hn[0](k) * self.strength + v = v + hn[1](v) * self.strength + + return q, k, v + + def to(self, device): + for d in self.hypernet.keys(): + self.hypernet[d] = self.hypernet[d].to(device) + return self + + return hypernetwork_patch(out, strength) + +class HypernetworkLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_hypernetwork" + + CATEGORY = "_for_testing" + + def load_hypernetwork(self, model, hypernetwork_name, strength): + hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) + model_hypernetwork = model.clone() + patch = load_hypernetwork_patch(hypernetwork_path, strength) + if patch is not None: + model_hypernetwork.set_model_attn1_patch(patch) + model_hypernetwork.set_model_attn2_patch(patch) + return (model_hypernetwork,) + +NODE_CLASS_MAPPINGS = { + "HypernetworkLoader": HypernetworkLoader +} diff --git a/folder_paths.py b/folder_paths.py index 3c4ad3711..bb0d65524 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) +folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") diff --git a/models/hypernetworks/put_hypernetworks_here b/models/hypernetworks/put_hypernetworks_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 48c3ee9c3..6ca73fa0c 100644 --- a/nodes.py +++ b/nodes.py @@ -1268,6 +1268,7 @@ def load_custom_nodes(): def init_custom_nodes(): load_custom_nodes() + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) From 2a09e2aa27620c492f694b66cc10c5f41b101c12 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sun, 23 Apr 2023 20:02:08 +0200 Subject: [PATCH 017/208] refactor/split various bits of code for sampling --- comfy/sample.py | 62 +++++++++++++++++++++++++++++++++++++++++++++ comfy/samplers.py | 64 +++++++++++++++++++++++++++-------------------- nodes.py | 60 +++++++------------------------------------- 3 files changed, 108 insertions(+), 78 deletions(-) create mode 100644 comfy/sample.py diff --git a/comfy/sample.py b/comfy/sample.py new file mode 100644 index 000000000..ede89890b --- /dev/null +++ b/comfy/sample.py @@ -0,0 +1,62 @@ +import torch +import comfy.model_management + + +def prepare_noise(latent, seed, disable_noise): + latent_image = latent["samples"] + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + batch_index = 0 + if "batch_index" in latent: + batch_index = latent["batch_index"] + + generator = torch.manual_seed(seed) + for i in range(batch_index): + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + return noise + +def create_mask(latent, noise): + noise_mask = None + device = comfy.model_management.get_torch_device() + if "noise_mask" in latent: + noise_mask = latent['noise_mask'] + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = noise_mask.round() + noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = noise_mask.to(device) + return noise_mask + +def broadcast_cond(cond, noise): + device = comfy.model_management.get_torch_device() + copy = [] + for p in cond: + t = p[0] + if t.shape[0] < noise.shape[0]: + t = torch.cat([t] * noise.shape[0]) + t = t.to(device) + copy += [[t] + p[1:]] + return copy + +def load_c_nets(positive, negative): + def get_models(cond): + models = [] + for c in cond: + if 'control' in c[1]: + models += [c[1]['control']] + if 'gligen' in c[1]: + models += [c[1]['gligen'][1]] + return models + + return get_models(positive) + get_models(negative) + +def load_additional_models(positive, negative): + models = load_c_nets(positive, negative) + comfy.model_management.load_controlnet_gpu(models) + return models + +def cleanup_additional_models(models): + for m in models: + m.cleanup() \ No newline at end of file diff --git a/comfy/samplers.py b/comfy/samplers.py index 15527224e..541a8db8d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -392,6 +392,38 @@ def encode_adm(noise_augmentor, conds, batch_size, device): return conds +def calculate_sigmas(model, steps, scheduler, sampler): + """ + Returns a tensor containing the sigmas corresponding to the given model, number of steps, scheduler type and sample technique + """ + if not (isinstance(model, CompVisVDenoiser) or isinstance(model, k_diffusion_external.CompVisDenoiser)): + model = CFGNoisePredictor(model) + if model.inner_model.parameterization == "v": + model = CompVisVDenoiser(model, quantize=True) + else: + model = k_diffusion_external.CompVisDenoiser(model, quantize=True) + + sigmas = None + + discard_penultimate_sigma = False + if sampler in ['dpm_2', 'dpm_2_ancestral']: + steps += 1 + discard_penultimate_sigma = True + + if scheduler == "karras": + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.sigma_min), sigma_max=float(model.sigma_max)) + elif scheduler == "normal": + sigmas = model.get_sigmas(steps) + elif scheduler == "simple": + sigmas = simple_scheduler(model, steps) + elif scheduler == "ddim_uniform": + sigmas = ddim_scheduler(model, steps) + else: + print("error invalid scheduler", scheduler) + + if discard_penultimate_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + return sigmas class KSampler: SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] @@ -421,41 +453,19 @@ class KSampler: self.denoise = denoise self.model_options = model_options - def _calculate_sigmas(self, steps): - sigmas = None - - discard_penultimate_sigma = False - if self.sampler in ['dpm_2', 'dpm_2_ancestral']: - steps += 1 - discard_penultimate_sigma = True - - if self.scheduler == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device) - elif self.scheduler == "normal": - sigmas = self.model_wrap.get_sigmas(steps).to(self.device) - elif self.scheduler == "simple": - sigmas = simple_scheduler(self.model_wrap, steps).to(self.device) - elif self.scheduler == "ddim_uniform": - sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device) - else: - print("error invalid scheduler", self.scheduler) - - if discard_penultimate_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - return sigmas - def set_steps(self, steps, denoise=None): self.steps = steps if denoise is None or denoise > 0.9999: - self.sigmas = self._calculate_sigmas(steps) + self.sigmas = calculate_sigmas(self.model_wrap, steps, self.scheduler, self.sampler).to(self.device) else: new_steps = int(steps/denoise) - sigmas = self._calculate_sigmas(new_steps) + sigmas = calculate_sigmas(self.model_wrap, new_steps, self.scheduler, self.sampler).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None): - sigmas = self.sigmas + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None): + if sigmas is None: + sigmas = self.sigmas sigma_min = self.sigma_min if last_step is not None and last_step < (len(sigmas) - 1): diff --git a/nodes.py b/nodes.py index 48c3ee9c3..601661864 100644 --- a/nodes.py +++ b/nodes.py @@ -16,6 +16,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "co import comfy.diffusers_convert import comfy.samplers +import comfy.sample import comfy.sd import comfy.utils @@ -739,31 +740,12 @@ class SetLatentNoiseMask: s["noise_mask"] = mask return (s,) - def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): - latent_image = latent["samples"] - noise_mask = None device = comfy.model_management.get_torch_device() + latent_image = latent["samples"] - if disable_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") - else: - batch_index = 0 - if "batch_index" in latent: - batch_index = latent["batch_index"] - - generator = torch.manual_seed(seed) - for i in range(batch_index): - noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - - if "noise_mask" in latent: - noise_mask = latent['noise_mask'] - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") - noise_mask = noise_mask.round() - noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) - noise_mask = noise_mask.to(device) + noise = comfy.sample.prepare_noise(latent, seed, disable_noise) + noise_mask = comfy.sample.create_mask(latent, noise) real_model = None comfy.model_management.load_model_gpu(model) @@ -772,34 +754,10 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise = noise.to(device) latent_image = latent_image.to(device) - positive_copy = [] - negative_copy = [] + positive_copy = comfy.sample.broadcast_cond(positive, noise) + negative_copy = comfy.sample.broadcast_cond(negative, noise) - control_nets = [] - def get_models(cond): - models = [] - for c in cond: - if 'control' in c[1]: - models += [c[1]['control']] - if 'gligen' in c[1]: - models += [c[1]['gligen'][1]] - return models - - for p in positive: - t = p[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - positive_copy += [[t] + p[1:]] - for n in negative: - t = n[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(device) - negative_copy += [[t] + n[1:]] - - models = get_models(positive) + get_models(negative) - comfy.model_management.load_controlnet_gpu(models) + models = comfy.sample.load_additional_models(positive, negative) if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) @@ -809,8 +767,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() - for m in models: - m.cleanup() + + comfy.sample.cleanup_additional_models(models) out = latent.copy() out["samples"] = samples From 5818539743bd390a282a19d7e480177c31bc222b Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sun, 23 Apr 2023 20:09:09 +0200 Subject: [PATCH 018/208] add docstrings --- comfy/sample.py | 25 ++++++++++++++----------- nodes.py | 6 +++++- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index ede89890b..981781b53 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,22 +2,21 @@ import torch import comfy.model_management -def prepare_noise(latent, seed, disable_noise): +def prepare_noise(latent, seed): + """creates random noise given a LATENT and a seed""" latent_image = latent["samples"] - if disable_noise: - noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") - else: - batch_index = 0 - if "batch_index" in latent: - batch_index = latent["batch_index"] + batch_index = 0 + if "batch_index" in latent: + batch_index = latent["batch_index"] - generator = torch.manual_seed(seed) - for i in range(batch_index): - noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + generator = torch.manual_seed(seed) + for i in range(batch_index): + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") return noise def create_mask(latent, noise): + """creates a mask for a given LATENT and noise""" noise_mask = None device = comfy.model_management.get_torch_device() if "noise_mask" in latent: @@ -30,6 +29,7 @@ def create_mask(latent, noise): return noise_mask def broadcast_cond(cond, noise): + """broadcasts conditioning to the noise batch size""" device = comfy.model_management.get_torch_device() copy = [] for p in cond: @@ -41,6 +41,7 @@ def broadcast_cond(cond, noise): return copy def load_c_nets(positive, negative): + """loads control nets in positive and negative conditioning""" def get_models(cond): models = [] for c in cond: @@ -53,10 +54,12 @@ def load_c_nets(positive, negative): return get_models(positive) + get_models(negative) def load_additional_models(positive, negative): + """loads additional models in positive and negative conditioning""" models = load_c_nets(positive, negative) comfy.model_management.load_controlnet_gpu(models) return models def cleanup_additional_models(models): + """cleanup additional models that were loaded""" for m in models: m.cleanup() \ No newline at end of file diff --git a/nodes.py b/nodes.py index a70668fd7..b8c6d350f 100644 --- a/nodes.py +++ b/nodes.py @@ -744,7 +744,11 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, device = comfy.model_management.get_torch_device() latent_image = latent["samples"] - noise = comfy.sample.prepare_noise(latent, seed, disable_noise) + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + noise = comfy.sample.prepare_noise(latent, seed) + noise_mask = comfy.sample.create_mask(latent, noise) real_model = None From f7a821881476cbd52a513877a9ffe35e6702b850 Mon Sep 17 00:00:00 2001 From: ltdrdata <128333288+ltdrdata@users.noreply.github.com> Date: Mon, 24 Apr 2023 04:58:55 +0900 Subject: [PATCH 019/208] Add clipspace feature. (#541) * Add clipspace feature. * feat: copy content to clipspace * feat: paste content from clipspace Extend validation to allow for validating annotated_path in addition to other parameters. Add support for annotated_filepath in folder_paths function. Generalize the '/upload/image' API to allow for uploading images to the 'input', 'temp', or 'output' directories. * rename contentClipboard -> clipspace * Do deep copy for imgs on copy to clipspace. * add original_imgs into clipspace * Preserve the original image when 'imgs' are modified * robust patch & refactoring folder_paths about annotated_filepath * Only show the Paste menu if the ComfyApp.clipspace is not empty * instant refresh on paste force triggering 'changed' on paste action * subfolder fix on paste logic attach subfolder if subfolder isn't empty --------- Co-authored-by: Lt.Dr.Data --- execution.py | 8 ++++- folder_paths.py | 40 ++++++++++++++++++++++ nodes.py | 8 ++--- server.py | 15 ++++++--- web/scripts/app.js | 83 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 145 insertions(+), 9 deletions(-) diff --git a/execution.py b/execution.py index 73be6db03..b062deeb1 100644 --- a/execution.py +++ b/execution.py @@ -11,6 +11,7 @@ import torch import nodes import comfy.model_management +import folder_paths def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() @@ -250,7 +251,12 @@ def validate_inputs(prompt, item): return (False, "Value bigger than max. {}, {}".format(class_type, x)) if isinstance(type_input, list): - if val not in type_input: + is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]") + if is_annotated_path: + if not folder_paths.exists_annotated_filepath(val): + return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val)) + + elif val not in type_input: return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (True, "") diff --git a/folder_paths.py b/folder_paths.py index bb0d65524..99a016695 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -69,6 +69,46 @@ def get_directory_by_type(type_name): return None +# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format +# otherwise use default_path as base_dir +def touch_annotated_filepath(name): + if name.endswith("[output]"): + base_dir = get_output_directory() + name = name[:-9] + elif name.endswith("[input]"): + base_dir = get_input_directory() + name = name[:-8] + elif name.endswith("[temp]"): + base_dir = get_temp_directory() + name = name[:-7] + else: + return name, None + + return name, base_dir + + +def get_annotated_filepath(name, default_dir=None): + name, base_dir = touch_annotated_filepath(name) + + if base_dir is None: + if default_dir is not None: + base_dir = default_dir + else: + base_dir = get_input_directory() # fallback path + + return os.path.join(base_dir, name) + + +def exists_annotated_filepath(name): + name, base_dir = touch_annotated_filepath(name) + + if base_dir is None: + base_dir = get_input_directory() # fallback path + + filepath = os.path.join(base_dir, name) + return os.path.exists(filepath) + + def add_model_folder_path(folder_name, full_folder_path): global folder_names_and_paths if folder_name in folder_names_and_paths: diff --git a/nodes.py b/nodes.py index 6ca73fa0c..b8b6280d6 100644 --- a/nodes.py +++ b/nodes.py @@ -975,7 +975,7 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image, input_dir) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -990,7 +990,7 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image, input_dir) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) @@ -1011,7 +1011,7 @@ class LoadImageMask: FUNCTION = "load_image" def load_image(self, image, channel): input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image, input_dir) i = Image.open(image_path) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") @@ -1029,7 +1029,7 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): input_dir = folder_paths.get_input_directory() - image_path = os.path.join(input_dir, image) + image_path = folder_paths.get_annotated_filepath(image, input_dir) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) diff --git a/server.py b/server.py index b5403670f..1c5c17916 100644 --- a/server.py +++ b/server.py @@ -112,13 +112,20 @@ class PromptServer(): @routes.post("/upload/image") async def upload_image(request): - upload_dir = folder_paths.get_input_directory() + post = await request.post() + image = post.get("image") + + if post.get("type") is None: + upload_dir = folder_paths.get_input_directory() + elif post.get("type") == "input": + upload_dir = folder_paths.get_input_directory() + elif post.get("type") == "temp": + upload_dir = folder_paths.get_temp_directory() + elif post.get("type") == "output": + upload_dir = folder_paths.get_output_directory() if not os.path.exists(upload_dir): os.makedirs(upload_dir) - - post = await request.post() - image = post.get("image") if image and image.file: filename = image.filename diff --git a/web/scripts/app.js b/web/scripts/app.js index f158f3457..b3e88d46f 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -20,6 +20,12 @@ export class ComfyApp { */ #processingQueue = false; + /** + * Content Clipboard + * @type {serialized node object} + */ + static clipspace = null; + constructor() { this.ui = new ComfyUI(this); @@ -130,6 +136,83 @@ export class ComfyApp { ); } } + + options.push( + { + content: "Copy (Clipspace)", + callback: (obj) => { + var widgets = null; + if(this.widgets) { + widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); + } + + let img = new Image(); + var imgs = undefined; + if(this.imgs != undefined) { + img.src = this.imgs[0].src; + imgs = [img]; + } + + ComfyApp.clipspace = { + 'widgets': widgets, + 'imgs': imgs, + 'original_imgs': imgs, + 'images': this.images + }; + } + }); + + if(ComfyApp.clipspace != null) { + options.push( + { + content: "Paste (Clipspace)", + callback: () => { + if(ComfyApp.clipspace != null) { + if(ComfyApp.clipspace.widgets != null && this.widgets != null) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop) { + prop.value = value; + } + }); + } + + // image paste + if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) { + var filename = ""; + if(this.images && ComfyApp.clipspace.images) { + this.images = ComfyApp.clipspace.images; + } + + if(ComfyApp.clipspace.images != undefined) { + const clip_image = ComfyApp.clipspace.images[0]; + if(clip_image.subfolder != '') + filename = `${clip_image.subfolder}/`; + filename += `${clip_image.filename} [${clip_image.type}]`; + } + else if(ComfyApp.clipspace.widgets != undefined) { + const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); + if(index_in_clip >= 0) { + filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`; + } + } + + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) { + this.imgs = ComfyApp.clipspace.imgs; + + this.widgets[index].value = filename; + if(this.widgets_values != undefined) { + this.widgets_values[index] = filename; + } + } + } + this.trigger('changed'); + } + } + } + ); + } }; } From ccad603b2e6862a4a719bc34dc6bd32e65a539ad Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 23 Apr 2023 16:03:26 -0400 Subject: [PATCH 020/208] Add a way for nodes to validate their own inputs. --- execution.py | 21 +++++++++++---------- folder_paths.py | 6 +++--- nodes.py | 32 +++++++++++++++++++++++--------- web/scripts/app.js | 2 +- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/execution.py b/execution.py index b062deeb1..115efcbda 100644 --- a/execution.py +++ b/execution.py @@ -11,7 +11,6 @@ import torch import nodes import comfy.model_management -import folder_paths def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() @@ -250,14 +249,15 @@ def validate_inputs(prompt, item): if "max" in info[1] and val > info[1]["max"]: return (False, "Value bigger than max. {}, {}".format(class_type, x)) - if isinstance(type_input, list): - is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]") - if is_annotated_path: - if not folder_paths.exists_annotated_filepath(val): - return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val)) - - elif val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + if hasattr(obj_class, "VALIDATE_INPUTS"): + input_data_all = get_input_data(inputs, obj_class, unique_id) + ret = obj_class.VALIDATE_INPUTS(**input_data_all) + if ret != True: + return (False, "{}, {}".format(class_type, ret)) + else: + if isinstance(type_input, list): + if val not in type_input: + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (True, "") def validate_prompt(prompt): @@ -279,7 +279,8 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o) valid = m[0] reason = m[1] - except: + except Exception as e: + print(traceback.format_exc()) valid = False reason = "Parsing error" diff --git a/folder_paths.py b/folder_paths.py index 99a016695..e5b89492c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -71,7 +71,7 @@ def get_directory_by_type(type_name): # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # otherwise use default_path as base_dir -def touch_annotated_filepath(name): +def annotated_filepath(name): if name.endswith("[output]"): base_dir = get_output_directory() name = name[:-9] @@ -88,7 +88,7 @@ def touch_annotated_filepath(name): def get_annotated_filepath(name, default_dir=None): - name, base_dir = touch_annotated_filepath(name) + name, base_dir = annotated_filepath(name) if base_dir is None: if default_dir is not None: @@ -100,7 +100,7 @@ def get_annotated_filepath(name, default_dir=None): def exists_annotated_filepath(name): - name, base_dir = touch_annotated_filepath(name) + name, base_dir = annotated_filepath(name) if base_dir is None: base_dir = get_input_directory() # fallback path diff --git a/nodes.py b/nodes.py index b8b6280d6..d1133d1d8 100644 --- a/nodes.py +++ b/nodes.py @@ -974,8 +974,7 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -989,20 +988,27 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + return True + class LoadImageMask: + _color_channels = ["alpha", "red", "green", "blue"] @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() return {"required": {"image": (sorted(os.listdir(input_dir)), ), - "channel": (["alpha", "red", "green", "blue"], ),} + "channel": (s._color_channels, ),} } CATEGORY = "mask" @@ -1010,8 +1016,7 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") @@ -1028,13 +1033,22 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image, channel): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + if channel not in s._color_channels: + return "Invalid color channel: {}".format(channel) + + return True + class ImageScale: upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] diff --git a/web/scripts/app.js b/web/scripts/app.js index b3e88d46f..a161bf40e 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -172,7 +172,7 @@ export class ComfyApp { ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); if (prop) { - prop.value = value; + prop.callback(value); } }); } From 0ac319fd81bcecea2aa35743da28088832e44707 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 23 Apr 2023 22:44:38 -0400 Subject: [PATCH 021/208] Don't delete all outputs when execution gets interrupted. --- execution.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/execution.py b/execution.py index 115efcbda..31a208e78 100644 --- a/execution.py +++ b/execution.py @@ -40,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = unique_id return input_data_all -def recursive_execute(server, prompt, outputs, current_item, extra_data={}): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return [] - - executed = [] + return for x in inputs: input_data = inputs[x] @@ -57,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: @@ -72,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) if "result" in outputs[unique_id]: outputs[unique_id] = outputs[unique_id]["result"] - return executed + [unique_id] + executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item @@ -158,7 +156,7 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) - executed = [] + executed = set() try: to_execute = [] for x in prompt: @@ -181,12 +179,12 @@ class PromptExecutor: except: valid = False if valid: - executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: print(traceback.format_exc()) to_delete = [] for o in self.outputs: - if o not in current_outputs: + if (o not in current_outputs) and (o not in executed): to_delete += [o] if o in self.old_prompt: d = self.old_prompt.pop(o) @@ -194,11 +192,9 @@ class PromptExecutor: for o in to_delete: d = self.outputs.pop(o) del d - else: - executed = set(executed) + finally: for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) - finally: self.server.last_node_id = None if self.server.client_id is not None: self.server.send_sync("executing", { "node": None }, self.server.client_id) From f1b87f50fa9c274f2dd9dbe24b082aa83ef0b028 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 01:50:56 -0400 Subject: [PATCH 022/208] Add hypernetworks path config to extra_model_paths.yaml.example --- extra_model_paths.yaml.example | 1 + 1 file changed, 1 insertion(+) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index ac1ffe9d2..fa5418a68 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -13,6 +13,7 @@ a111: models/ESRGAN models/SwinIR embeddings: embeddings + hypernetworks: models/hypernetworks controlnet: models/ControlNet #other_ui: From 4e345b31f692d5fb89009bf3352c922c2abe30e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 02:36:06 -0400 Subject: [PATCH 023/208] Support all known hypernetworks. --- comfy_extras/nodes_hypernetwork.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index db2f8695c..c08c2c811 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -10,7 +10,17 @@ def load_hypernetwork_patch(path, strength): activate_output = sd.get('activate_output', False) last_layer_dropout = sd.get('last_layer_dropout', False) - if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False: + valid_activation = { + "linear": torch.nn.Identity, + "relu": torch.nn.ReLU, + "leakyrelu": torch.nn.LeakyReLU, + "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid, + } + + if activation_func not in valid_activation: print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) return None @@ -28,15 +38,27 @@ def load_hypernetwork_patch(path, strength): keys = attn_weights.keys() linears = filter(lambda a: a.endswith(".weight"), keys) - linears = sorted(list(map(lambda a: a[:-len(".weight")], linears))) + linears = list(map(lambda a: a[:-len(".weight")], linears)) layers = [] - for lin_name in linears: + for i in range(len(linears)): + lin_name = linears[i] + last_layer = (i == (len(linears) - 1)) + penultimate_layer = (i == (len(linears) - 2)) + lin_weight = attn_weights['{}.weight'.format(lin_name)] lin_bias = attn_weights['{}.bias'.format(lin_name)] layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) - layers += [layer] + layers.append(layer) + if activation_func != "linear": + if (not last_layer) or (activate_output): + layers.append(valid_activation[activation_func]()) + if is_layer_norm: + layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) + if use_dropout: + if (not last_layer) and (not penultimate_layer or last_layer_dropout): + layers.append(torch.nn.Dropout(p=0.3)) output.append(torch.nn.Sequential(*layers)) out[dim] = torch.nn.ModuleList(output) From 463bde66a1d22b02858ac6f148d7fa3e6d9c4322 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 03:08:51 -0400 Subject: [PATCH 024/208] Add hypernetwork example link to readme. Move hypernetwork loader node to loaders. --- README.md | 1 + comfy_extras/nodes_hypernetwork.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bf16006bf..5b6346a67 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. - Embeddings/Textual inversion - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/) - Loading full workflows (with seeds) from generated PNG files. - Saving/Loading workflows as Json files. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index c08c2c811..0c7250e43 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -93,7 +93,7 @@ class HypernetworkLoader: RETURN_TYPES = ("MODEL",) FUNCTION = "load_hypernetwork" - CATEGORY = "_for_testing" + CATEGORY = "loaders" def load_hypernetwork(self, model, hypernetwork_name, strength): hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) From d9b1595f8552384dd08374d34c4d4127e0b1a4e6 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:53:10 +0200 Subject: [PATCH 025/208] made sample functions more explicit --- comfy/sample.py | 55 +++++++++++++++++++++---------------------------- nodes.py | 7 +++++-- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 981781b53..84eefcb7b 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,30 +2,25 @@ import torch import comfy.model_management -def prepare_noise(latent, seed): - """creates random noise given a LATENT and a seed""" - latent_image = latent["samples"] - batch_index = 0 - if "batch_index" in latent: - batch_index = latent["batch_index"] - +def prepare_noise(latent_image, seed, skip=0): + """ + creates random noise given a latent image and a seed. + optional arg skip can be used to skip and discard x number of noise generations for a given seed + """ generator = torch.manual_seed(seed) - for i in range(batch_index): + for _ in range(skip): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") return noise -def create_mask(latent, noise): - """creates a mask for a given LATENT and noise""" - noise_mask = None +def prepare_mask(noise_mask, noise): + """ensures noise mask is of proper dimensions""" device = comfy.model_management.get_torch_device() - if "noise_mask" in latent: - noise_mask = latent['noise_mask'] - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") - noise_mask = noise_mask.round() - noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) - noise_mask = noise_mask.to(device) + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = noise_mask.round() + noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = noise_mask.to(device) return noise_mask def broadcast_cond(cond, noise): @@ -40,22 +35,20 @@ def broadcast_cond(cond, noise): copy += [[t] + p[1:]] return copy -def load_c_nets(positive, negative): - """loads control nets in positive and negative conditioning""" - def get_models(cond): - models = [] - for c in cond: - if 'control' in c[1]: - models += [c[1]['control']] - if 'gligen' in c[1]: - models += [c[1]['gligen'][1]] - return models - - return get_models(positive) + get_models(negative) +def get_models_from_cond(cond, model_type): + models = [] + for c in cond: + if model_type in c[1]: + models += [c[1][model_type]] + return models def load_additional_models(positive, negative): """loads additional models in positive and negative conditioning""" - models = load_c_nets(positive, negative) + models = [] + models += get_models_from_cond(positive, "control") + models += get_models_from_cond(negative, "control") + models += get_models_from_cond(positive, "gligen") + models += get_models_from_cond(negative, "gligen") comfy.model_management.load_controlnet_gpu(models) return models diff --git a/nodes.py b/nodes.py index b8c6d350f..f9bedc97e 100644 --- a/nodes.py +++ b/nodes.py @@ -747,9 +747,12 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - noise = comfy.sample.prepare_noise(latent, seed) + skip = latent["batch_index"] if "batch_index" in latent else 0 + noise = comfy.sample.prepare_noise(latent_image, seed, skip) - noise_mask = comfy.sample.create_mask(latent, noise) + noise_mask = None + if "noise_mask" in latent: + noise_mask = comfy.sample.prepare_mask(latent["noise_mask"], noise) real_model = None comfy.model_management.load_model_gpu(model) From c8c9926eeb0b25dba86f3d9e574e8527c090fc37 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 24 Apr 2023 11:55:44 +0100 Subject: [PATCH 026/208] Add progress to vae decode tiled --- comfy/sd.py | 12 +++++++++--- comfy/utils.py | 4 +++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 92dbb931d..2aadefadc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,6 +1,7 @@ import torch import contextlib import copy +from tqdm.auto import tqdm import sd1_clip import sd2_clip @@ -437,11 +438,16 @@ class VAE: self.device = device def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): + it_1 = -(samples.shape[2] // -(tile_y * 2 - overlap)) * -(samples.shape[3] // -(tile_x // 2 - overlap)) + it_2 = -(samples.shape[2] // -(tile_y // 2 - overlap)) * -(samples.shape[3] // -(tile_x * 2 - overlap)) + it_3 = -(samples.shape[2] // -(tile_y - overlap)) * -(samples.shape[3] // -(tile_x - overlap)) + pbar = tqdm(total=samples.shape[0] * (it_1 + it_2 + it_3)) + decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = torch.clamp(( - (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) + - utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) + - utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8)) + (utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + + utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + + utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output diff --git a/comfy/utils.py b/comfy/utils.py index 68f93403c..c7c6a08c5 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -63,7 +63,7 @@ def common_upscale(samples, width, height, upscale_method, crop): return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) @torch.inference_mode() -def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3): +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") for b in range(samples.shape[0]): s = samples[b:b+1] @@ -83,6 +83,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) out[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += ps * mask out_div[:,:,round(y*upscale_amount):round((y+tile_y)*upscale_amount),round(x*upscale_amount):round((x+tile_x)*upscale_amount)] += mask + if pbar is not None: + pbar.update(1) output[b:b+1] = out/out_div return output From 0b07b2cc0f94fc2b8ebe656dfb3768c6f67866f1 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Mon, 24 Apr 2023 21:47:57 +0200 Subject: [PATCH 027/208] gligen tuple --- comfy/sample.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 84eefcb7b..09ab20cd2 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -44,11 +44,10 @@ def get_models_from_cond(cond, model_type): def load_additional_models(positive, negative): """loads additional models in positive and negative conditioning""" - models = [] - models += get_models_from_cond(positive, "control") - models += get_models_from_cond(negative, "control") - models += get_models_from_cond(positive, "gligen") - models += get_models_from_cond(negative, "gligen") + control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control") + gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen") + gligen = [x[1] for x in gligen] + models = control_nets + gligen comfy.model_management.load_controlnet_gpu(models) return models From 36acce58e71bbe1bf835c2ec380dc7ac0c5b4752 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 18:13:18 -0400 Subject: [PATCH 028/208] Auto increase the size of the image upload widget when there's an image. --- web/scripts/widgets.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 2acc5f2c0..238ad59dd 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -270,6 +270,9 @@ export const ComfyWidgets = { app.graph.setDirtyCanvas(true); }; img.src = `/view?filename=${name}&type=input`; + if ((node.size[1] - node.imageOffset) < 100) { + node.size[1] = 250 + node.imageOffset; + } } // Add our own callback to the combo widget to render an image when it changes From 7983b3a975c26b93601c8b6fa9a0a333b35794bd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 22:45:35 -0400 Subject: [PATCH 029/208] This is cleaner this way. --- comfy/samplers.py | 59 ++++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 46bdb82a0..26597ebba 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -400,38 +400,6 @@ def encode_adm(noise_augmentor, conds, batch_size, device): return conds -def calculate_sigmas(model, steps, scheduler, sampler): - """ - Returns a tensor containing the sigmas corresponding to the given model, number of steps, scheduler type and sample technique - """ - if not (isinstance(model, CompVisVDenoiser) or isinstance(model, k_diffusion_external.CompVisDenoiser)): - model = CFGNoisePredictor(model) - if model.inner_model.parameterization == "v": - model = CompVisVDenoiser(model, quantize=True) - else: - model = k_diffusion_external.CompVisDenoiser(model, quantize=True) - - sigmas = None - - discard_penultimate_sigma = False - if sampler in ['dpm_2', 'dpm_2_ancestral']: - steps += 1 - discard_penultimate_sigma = True - - if scheduler == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.sigma_min), sigma_max=float(model.sigma_max)) - elif scheduler == "normal": - sigmas = model.get_sigmas(steps) - elif scheduler == "simple": - sigmas = simple_scheduler(model, steps) - elif scheduler == "ddim_uniform": - sigmas = ddim_scheduler(model, steps) - else: - print("error invalid scheduler", scheduler) - - if discard_penultimate_sigma: - sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) - return sigmas class KSampler: SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] @@ -461,13 +429,36 @@ class KSampler: self.denoise = denoise self.model_options = model_options + def calculate_sigmas(self, steps): + sigmas = None + + discard_penultimate_sigma = False + if self.sampler in ['dpm_2', 'dpm_2_ancestral']: + steps += 1 + discard_penultimate_sigma = True + + if self.scheduler == "karras": + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + elif self.scheduler == "normal": + sigmas = self.model_wrap.get_sigmas(steps) + elif self.scheduler == "simple": + sigmas = simple_scheduler(self.model_wrap, steps) + elif self.scheduler == "ddim_uniform": + sigmas = ddim_scheduler(self.model_wrap, steps) + else: + print("error invalid scheduler", self.scheduler) + + if discard_penultimate_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + return sigmas + def set_steps(self, steps, denoise=None): self.steps = steps if denoise is None or denoise > 0.9999: - self.sigmas = calculate_sigmas(self.model_wrap, steps, self.scheduler, self.sampler).to(self.device) + self.sigmas = self.calculate_sigmas(steps).to(self.device) else: new_steps = int(steps/denoise) - sigmas = calculate_sigmas(self.model_wrap, new_steps, self.scheduler, self.sampler).to(self.device) + sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] From c50208a703c6eba2363b08c4cb62e903a3012710 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 24 Apr 2023 23:25:51 -0400 Subject: [PATCH 030/208] Refactor more code to sample.py --- comfy/sample.py | 47 ++++++++++++++++++++++++++++++++++++----------- nodes.py | 28 ++++------------------------ 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 09ab20cd2..d6848f9d5 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,5 +1,6 @@ import torch import comfy.model_management +import comfy.samplers def prepare_noise(latent_image, seed, skip=0): @@ -13,24 +14,22 @@ def prepare_noise(latent_image, seed, skip=0): noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") return noise -def prepare_mask(noise_mask, noise): +def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" - device = comfy.model_management.get_torch_device() - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(shape[2], shape[3]), mode="bilinear") noise_mask = noise_mask.round() - noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = torch.cat([noise_mask] * shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * shape[0]) noise_mask = noise_mask.to(device) return noise_mask -def broadcast_cond(cond, noise): - """broadcasts conditioning to the noise batch size""" - device = comfy.model_management.get_torch_device() +def broadcast_cond(cond, batch, device): + """broadcasts conditioning to the batch size""" copy = [] for p in cond: t = p[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) + if t.shape[0] < batch: + t = torch.cat([t] * batch) t = t.to(device) copy += [[t] + p[1:]] return copy @@ -54,4 +53,30 @@ def load_additional_models(positive, negative): def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: - m.cleanup() \ No newline at end of file + m.cleanup() + +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None): + device = comfy.model_management.get_torch_device() + + if noise_mask is not None: + noise_mask = prepare_mask(noise_mask, noise.shape, device) + + real_model = None + comfy.model_management.load_model_gpu(model) + real_model = model.model + + noise = noise.to(device) + latent_image = latent_image.to(device) + + positive_copy = broadcast_cond(positive, noise.shape[0], device) + negative_copy = broadcast_cond(negative, noise.shape[0], device) + + models = load_additional_models(positive, negative) + + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) + + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas) + samples = samples.cpu() + + cleanup_additional_models(models) + return samples diff --git a/nodes.py b/nodes.py index f787fcf8a..0083f6ef8 100644 --- a/nodes.py +++ b/nodes.py @@ -752,31 +752,11 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = None if "noise_mask" in latent: - noise_mask = comfy.sample.prepare_mask(latent["noise_mask"], noise) - - real_model = None - comfy.model_management.load_model_gpu(model) - real_model = model.model - - noise = noise.to(device) - latent_image = latent_image.to(device) - - positive_copy = comfy.sample.broadcast_cond(positive, noise) - negative_copy = comfy.sample.broadcast_cond(negative, noise) - - models = comfy.sample.load_additional_models(positive, negative) - - if sampler_name in comfy.samplers.KSampler.SAMPLERS: - sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - else: - #other samplers - pass - - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) - samples = samples.cpu() - - comfy.sample.cleanup_additional_models(models) + noise_mask = latent["noise_mask"] + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, + force_full_denoise=force_full_denoise, noise_mask=noise_mask) out = latent.copy() out["samples"] = samples return (out, ) From aa57136dae83887e005ab6b0222dce4667b61bee Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Apr 2023 01:12:40 -0400 Subject: [PATCH 031/208] Some fixes to the batch masks PR. --- comfy/sample.py | 7 ++++--- nodes.py | 10 +++------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index d6848f9d5..5e4d26142 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,7 +1,7 @@ import torch import comfy.model_management import comfy.samplers - +import math def prepare_noise(latent_image, seed, skip=0): """ @@ -16,10 +16,11 @@ def prepare_noise(latent_image, seed, skip=0): def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(shape[2], shape[3]), mode="bilinear") + noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * shape[1], dim=1) - noise_mask = torch.cat([noise_mask] * shape[0]) + if noise_mask.shape[0] < shape[0]: + noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]] noise_mask = noise_mask.to(device) return noise_mask diff --git a/nodes.py b/nodes.py index b0b61d676..0a9513bed 100644 --- a/nodes.py +++ b/nodes.py @@ -172,16 +172,12 @@ class VAEEncodeForInpaint: def encode(self, vae, pixels, mask): x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 - if len(mask.shape) < 3: - mask = mask.unsqueeze(0).unsqueeze(0) - elif len(mask.shape) < 4: - mask = mask.unsqueeze(1) - mask = torch.nn.functional.interpolate(mask, size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: pixels = pixels[:,:x,:y,:] - mask = mask[:,:x,:y,:] + mask = mask[:,:,:x,:y] #grow mask by a few pixels to keep things seamless in latent space kernel_tensor = torch.ones((1, 1, 6, 6)) @@ -193,7 +189,7 @@ class VAEEncodeForInpaint: pixels[:,:,:,i] += 0.5 t = vae.encode(pixels) - return ({"samples":t, "noise_mask": (mask_erosion[:,:x,:y,:].round())}, ) + return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) class CheckpointLoader: @classmethod From 07194297fd41729f8b95352a710b9039ca2c99e8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Apr 2023 14:02:17 -0400 Subject: [PATCH 032/208] Python 3.7 support. --- comfy_extras/chainner_models/architecture/block.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py index 1abe1ed8f..214642cc4 100644 --- a/comfy_extras/chainner_models/architecture/block.py +++ b/comfy_extras/chainner_models/architecture/block.py @@ -4,7 +4,10 @@ from __future__ import annotations from collections import OrderedDict -from typing import Literal +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal import torch import torch.nn as nn From ee3a12d283d76212f6771a9cace21d4a469c1ee8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Apr 2023 19:18:50 -0400 Subject: [PATCH 033/208] Update litegraph from upstream. --- web/lib/litegraph.core.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 4189a48c0..20ec35476 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -9953,11 +9953,11 @@ LGraphNode.prototype.executeAction = function(action) } break; case "slider": - var range = w.options.max - w.options.min; + var old_value = w.value; var nvalue = Math.clamp((x - 15) / (widget_width - 30), 0, 1); if(w.options.read_only) break; w.value = w.options.min + (w.options.max - w.options.min) * nvalue; - if (w.callback) { + if (old_value != w.value) { setTimeout(function() { inner_value_change(w, w.value); }, 20); @@ -10044,7 +10044,7 @@ LGraphNode.prototype.executeAction = function(action) if (event.click_time < 200 && delta == 0) { this.prompt("Value",w.value,function(v) { // check if v is a valid equation or a number - if (/^[0-9+\-*/()\s]+$/.test(v)) { + if (/^[0-9+\-*/()\s]+|\d+\.\d+$/.test(v)) { try {//solve the equation if possible v = eval(v); } catch (e) { } From 54251ad85e484d4e36df849dcd529837c775d690 Mon Sep 17 00:00:00 2001 From: Jake D <122334950+jwd-dev@users.noreply.github.com> Date: Wed, 26 Apr 2023 01:22:36 -0400 Subject: [PATCH 034/208] Colored MultilineWidget (#524) * fixes colors and z-index * light mode fix * Update widgets.js --- web/scripts/widgets.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 238ad59dd..c0e73ffa1 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -136,9 +136,11 @@ function addMultilineWidget(node, name, opts, app) { left: `${t.a * margin + t.e}px`, top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, + background: (!node.color)?'':node.color, height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, position: "absolute", - zIndex: 1, + color: (!node.color)?'':'white', + zIndex: app.graph._nodes.indexOf(node), fontSize: `${t.d * 10.0}px`, }); this.inputEl.hidden = !visible; From 951c0c2bbe11e48956a7c619faf0c2cc6e3abff5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Apr 2023 02:05:57 -0400 Subject: [PATCH 035/208] Don't keep cached outputs for removed nodes. --- execution.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/execution.py b/execution.py index 31a208e78..2c97e70d2 100644 --- a/execution.py +++ b/execution.py @@ -152,6 +152,15 @@ class PromptExecutor: self.server.client_id = None with torch.inference_mode(): + #delete cached outputs if nodes don't exist for them + to_delete = [] + for o in self.outputs: + if o not in prompt: + to_delete += [o] + for o in to_delete: + d = self.outputs.pop(o) + del d + for x in prompt: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) From 3a1f9dba20c89038b71d6ff74d4e600d375283b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Apr 2023 02:13:56 -0400 Subject: [PATCH 036/208] If IS_CHANGED returns exception delete the output instead of crashing. --- execution.py | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/execution.py b/execution.py index 2c97e70d2..c19c10bc6 100644 --- a/execution.py +++ b/execution.py @@ -97,40 +97,44 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item is_changed_old = '' is_changed = '' + to_delete = False if hasattr(class_def, 'IS_CHANGED'): if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: is_changed_old = old_prompt[unique_id]['is_changed'] if 'is_changed' not in prompt[unique_id]: input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: - is_changed = class_def.IS_CHANGED(**input_data_all) - prompt[unique_id]['is_changed'] = is_changed + try: + is_changed = class_def.IS_CHANGED(**input_data_all) + prompt[unique_id]['is_changed'] = is_changed + except: + to_delete = True else: is_changed = prompt[unique_id]['is_changed'] if unique_id not in outputs: return True - to_delete = False - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] + if not to_delete: + if is_changed != is_changed_old: + to_delete = True + elif unique_id not in old_prompt: + to_delete = True + elif inputs == old_prompt[unique_id]['inputs']: + for x in inputs: + input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True + if isinstance(input_data, list): + input_unique_id = input_data[0] + output_index = input_data[1] + if input_unique_id in outputs: + to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) + else: + to_delete = True + if to_delete: + break + else: + to_delete = True if to_delete: d = outputs.pop(unique_id) From 5a971cecdbacb849340f2ea7b3bcd80cc6032d1a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 27 Apr 2023 04:38:44 -0400 Subject: [PATCH 037/208] Add callback to sampler function. Callback format is: callback(step, x0, x) --- comfy/extra_samplers/uni_pc.py | 6 ++++-- comfy/sample.py | 4 ++-- comfy/samplers.py | 22 ++++++++++++++++------ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index e96cfc93a..2952be62d 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -712,7 +712,7 @@ class UniPC: def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, corrector=False, + atol=0.0078, rtol=0.05, corrector=False, callback=None ): t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start @@ -766,6 +766,8 @@ class UniPC: if model_x is None: model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x + if callback is not None: + callback(step_index, model_prev_list[-1], x) else: raise NotImplementedError() if denoise_to_zero: @@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex order = min(3, len(timesteps) - 1) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) - x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True) + x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback) if not to_zero: x /= ns.marginal_alpha(timesteps[-1]) return x diff --git a/comfy/sample.py b/comfy/sample.py index 5e4d26142..f4132bbed 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -56,7 +56,7 @@ def cleanup_additional_models(models): for m in models: m.cleanup() -def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None): +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None): device = comfy.model_management.get_torch_device() if noise_mask is not None: @@ -76,7 +76,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas) + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) samples = samples.cpu() cleanup_additional_models(models) diff --git a/comfy/samplers.py b/comfy/samplers.py index 26597ebba..fc19ddcfc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -462,7 +462,7 @@ class KSampler: self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): if sigmas is None: sigmas = self.sigmas sigma_min = self.sigma_min @@ -527,9 +527,9 @@ class KSampler: with precision_scope(model_management.get_autocast_device(self.device)): if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask) + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback) elif self.sampler == "uni_pc_bh2": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2') + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2') elif self.sampler == "ddim": timesteps = [] for s in range(sigmas.shape[0]): @@ -537,6 +537,11 @@ class KSampler: noise_mask = None if denoise_mask is not None: noise_mask = 1.0 - denoise_mask + + ddim_callback = None + if callback is not None: + ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None) + sampler = DDIMSampler(self.model, device=self.device) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) @@ -550,6 +555,7 @@ class KSampler: eta=0.0, x_T=z_enc, x0=latent_image, + img_callback=ddim_callback, denoise_function=sampling_function, extra_args=extra_args, mask=noise_mask, @@ -563,13 +569,17 @@ class KSampler: noise = noise * sigmas[0] + k_callback = None + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"]) + if latent_image is not None: noise += latent_image if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback) elif self.sampler == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args) + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args) + samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) return samples.to(torch.float32) From e958dfdd4d34ad160c50a32e01b5ce08c4e62a29 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 27 Apr 2023 10:59:47 -0400 Subject: [PATCH 038/208] Make notebook work on python3.7 --- notebooks/comfyui_colab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index c1982d8be..fecfa6707 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -47,7 +47,7 @@ " !git pull\n", "\n", "!echo -= Install dependencies =-\n", - "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118" + "!pip install xformers!=0.0.18 -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu117" ] }, { From e214c917ae889b278a05fa6e8b8c42d2cc8818fa Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Tue, 25 Apr 2023 00:15:25 -0700 Subject: [PATCH 039/208] Add Condition by Mask node This PR adds support for a Condition by Mask node. This node allows conditioning to be limited to a non-rectangle area. --- comfy/samplers.py | 88 +++++++++++++++++++++++++++++++++++++++-------- nodes.py | 28 +++++++++++++++ 2 files changed, 101 insertions(+), 15 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index fc19ddcfc..6fa754b90 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,7 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -23,21 +24,34 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con adm_cond = cond[1]['adm_encoded'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - mult = torch.ones_like(input_x) * strength + if 'mask' in cond[1]: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask = cond[1]['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if mask.shape[0] != input_x.shape[0]: + mask = mask.repeat(input_x.shape[0], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in cond[1]: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) conditionning = {} conditionning['c_crossattn'] = cond[0] if cond_concat_in is not None and len(cond_concat_in) > 0: @@ -301,6 +315,47 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def resolve_cond_masks(conditions, h, w, device): + # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. + # While we're doing this, we can also resolve the mask device and scaling for performance reasons + for i in range(len(conditions)): + c = conditions[i] + if 'mask' in c[1]: + mask = c[1]['mask'] + mask = mask.to(device=device) + modified = c[1].copy() + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if mask.shape[2] != h or mask.shape[3] != w: + mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) + + if 'area' not in modified: + bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) + if torch.max(bounds) == 0: + # Handle the edge-case of an all black mask (where masks_to_boxes would error) + area = (0, 0, 0, 0) + else: + box = masks_to_boxes(bounds)[0].type(torch.int) + H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) + # Make sure the height and width are divisible by 8 + if X % 8 != 0: + newx = X // 8 * 8 + W = W + (X - newx) + X = newx + if Y % 8 != 0: + newy = Y // 8 * 8 + H = H + (Y - newy) + Y = newy + if H % 8 != 0: + H = H + (8 - (H % 8)) + if W % 8 != 0: + W = W + (8 - (W % 8)) + area = (int(H), int(W), int(Y), (X)) + modified['area'] = area + + modified['mask'] = mask + conditions[i] = [c[0], modified] + def create_cond_with_same_area_if_none(conds, c): if 'area' not in c[1]: return @@ -461,7 +516,6 @@ class KSampler: sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): if sigmas is None: sigmas = self.sigmas @@ -484,6 +538,10 @@ class KSampler: positive = positive[:] negative = negative[:] + + resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device) + resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) diff --git a/nodes.py b/nodes.py index 0a9513bed..be02f4676 100644 --- a/nodes.py +++ b/nodes.py @@ -85,6 +85,32 @@ class ConditioningSetArea: c.append(n) return (c, ) +class ConditioningSetMask: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "mask": ("MASK", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + c = [] + if len(mask.shape) < 3: + mask = mask.unsqueeze(0) + for t in conditioning: + n = [t[0], t[1].copy()] + _, h, w = mask.shape + n[1]['mask'] = mask + n[1]['strength'] = strength + n[1]['min_sigma'] = min_sigma + n[1]['max_sigma'] = max_sigma + c.append(n) + return (c, ) + class VAEDecode: def __init__(self, device="cpu"): self.device = device @@ -1115,6 +1141,7 @@ NODE_CLASS_MAPPINGS = { "ImagePadForOutpaint": ImagePadForOutpaint, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, + "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, "LatentComposite": LatentComposite, @@ -1164,6 +1191,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CLIPSetLastLayer": "CLIP Set Last Layer", "ConditioningCombine": "Conditioning (Combine)", "ConditioningSetArea": "Conditioning (Set Area)", + "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", From 27bf9392ac1ef07776d31895b748c7ea84969115 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 08:35:20 -0400 Subject: [PATCH 040/208] Switch stable standalone dependencies to stable xformers. Switch nightly standalone to cu121. --- .github/workflows/windows_release_cu118_dependencies_2.yml | 2 +- .github/workflows/windows_release_nightly_pytorch.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/windows_release_cu118_dependencies_2.yml b/.github/workflows/windows_release_cu118_dependencies_2.yml index a88449527..42adee9e7 100644 --- a/.github/workflows/windows_release_cu118_dependencies_2.yml +++ b/.github/workflows/windows_release_cu118_dependencies_2.yml @@ -17,7 +17,7 @@ jobs: - shell: bash run: | - python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir + python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip install --no-cache-dir ./temp_wheel_dir/* echo installed basic ls -lah temp_wheel_dir diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 291d754e3..32d2f320b 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -30,7 +30,7 @@ jobs: echo 'import site' >> ./python310._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python310._pth From e543ecad6991fc7e71dd2042b439aefb9c0722de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 08:50:12 -0400 Subject: [PATCH 041/208] Fix the nightly build not being packaged correctly. --- .ci/nightly/update_windows/update.py | 65 ------------------- .ci/nightly/update_windows/update_comfyui.bat | 2 - ...update_comfyui_and_python_dependencies.bat | 2 +- .../README_VERY_IMPORTANT.txt | 27 -------- .ci/nightly/windows_base_files/run_cpu.bat | 2 - .../windows_release_nightly_pytorch.yml | 2 + 6 files changed, 3 insertions(+), 97 deletions(-) delete mode 100755 .ci/nightly/update_windows/update.py delete mode 100755 .ci/nightly/update_windows/update_comfyui.bat delete mode 100755 .ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt delete mode 100755 .ci/nightly/windows_base_files/run_cpu.bat diff --git a/.ci/nightly/update_windows/update.py b/.ci/nightly/update_windows/update.py deleted file mode 100755 index c09f29a80..000000000 --- a/.ci/nightly/update_windows/update.py +++ /dev/null @@ -1,65 +0,0 @@ -import pygit2 -from datetime import datetime -import sys - -def pull(repo, remote_name='origin', branch='master'): - for remote in repo.remotes: - if remote.name == remote_name: - remote.fetch() - remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target - merge_result, _ = repo.merge_analysis(remote_master_id) - # Up to date, do nothing - if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE: - return - # We can just fastforward - elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD: - repo.checkout_tree(repo.get(remote_master_id)) - try: - master_ref = repo.lookup_reference('refs/heads/%s' % (branch)) - master_ref.set_target(remote_master_id) - except KeyError: - repo.create_branch(branch, repo.get(remote_master_id)) - repo.head.set_target(remote_master_id) - elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL: - repo.merge(remote_master_id) - - if repo.index.conflicts is not None: - for conflict in repo.index.conflicts: - print('Conflicts found in:', conflict[0].path) - raise AssertionError('Conflicts, ahhhhh!!') - - user = repo.default_signature - tree = repo.index.write_tree() - commit = repo.create_commit('HEAD', - user, - user, - 'Merge!', - tree, - [repo.head.target, remote_master_id]) - # We need to do this or git CLI will think we are still merging. - repo.state_cleanup() - else: - raise AssertionError('Unknown merge analysis result') - - -repo = pygit2.Repository(str(sys.argv[1])) -ident = pygit2.Signature('comfyui', 'comfy@ui') -try: - print("stashing current changes") - repo.stash(ident) -except KeyError: - print("nothing to stash") -backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) -print("creating backup branch: {}".format(backup_branch_name)) -repo.branches.local.create(backup_branch_name, repo.head.peel()) - -print("checking out master branch") -branch = repo.lookup_branch('master') -ref = repo.lookup_reference(branch.name) -repo.checkout(ref) - -print("pulling latest changes") -pull(repo) - -print("Done!") - diff --git a/.ci/nightly/update_windows/update_comfyui.bat b/.ci/nightly/update_windows/update_comfyui.bat deleted file mode 100755 index 60d1e694f..000000000 --- a/.ci/nightly/update_windows/update_comfyui.bat +++ /dev/null @@ -1,2 +0,0 @@ -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -pause diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat index c5e0c6be7..c345a6992 100755 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2 +..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121 -r ../ComfyUI/requirements.txt pygit2 pause diff --git a/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt b/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt deleted file mode 100755 index 656b9db43..000000000 --- a/.ci/nightly/windows_base_files/README_VERY_IMPORTANT.txt +++ /dev/null @@ -1,27 +0,0 @@ -HOW TO RUN: - -if you have a NVIDIA gpu: - -run_nvidia_gpu.bat - - - -To run it in slow CPU mode: - -run_cpu.bat - - - -IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints - -You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt - - - -RECOMMENDED WAY TO UPDATE: -To update the ComfyUI code: update\update_comfyui.bat - - - -To update ComfyUI with the python dependencies: -update\update_comfyui_and_python_dependencies.bat diff --git a/.ci/nightly/windows_base_files/run_cpu.bat b/.ci/nightly/windows_base_files/run_cpu.bat deleted file mode 100755 index c3ba41721..000000000 --- a/.ci/nightly/windows_base_files/run_cpu.bat +++ /dev/null @@ -1,2 +0,0 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build -pause diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 32d2f320b..4d686ded8 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -46,6 +46,8 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/windows_base_files/* ./ + cp -r ComfyUI/.ci/nightly/update_windows/* ./update/ + cp -r ComfyUI/.ci/nightly/windows_base_files/* ./ cd .. From ab9a9deff48b5780bd105dfd6d19f5f8333ef608 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 09:03:39 -0400 Subject: [PATCH 042/208] Fix nightly CI builds. No cu121 builds for windows yet. --- .../update_windows/update_comfyui_and_python_dependencies.bat | 2 +- .github/workflows/windows_release_nightly_pytorch.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat index c345a6992..b4989534f 100755 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121 -r ../ComfyUI/requirements.txt pygit2 +..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 pause diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index 4d686ded8..f23cae6d5 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -30,7 +30,7 @@ jobs: echo 'import site' >> ./python310._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* sed -i '1i../ComfyUI' ./python310._pth From 3baded9892a6ac02f57caaf68053791ec0e14c5a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 14:28:57 -0400 Subject: [PATCH 043/208] Basic torch_directml support. Use --directml to use it. --- comfy/cli_args.py | 1 + comfy/model_management.py | 27 ++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index b24054ce0..05b9c5e08 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") +parser.add_argument("--directml", action="store_true", help="Use torch-directml.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 6e3a03530..339111c4d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -20,6 +20,13 @@ total_vram_available_mb = -1 accelerate_enabled = False xpu_available = False +directml_enabled = False +if args.directml: + import torch_directml + print("Using directml") + directml_enabled = True + # torch_directml.disable_tiled_resources(True) + try: import torch try: @@ -217,6 +224,9 @@ def unload_if_low_vram(model): def get_torch_device(): global xpu_available + global directml_enabled + if directml_enabled: + return torch_directml.device() if vram_state == VRAMState.MPS: return torch.device("mps") if vram_state == VRAMState.CPU: @@ -234,8 +244,14 @@ def get_autocast_device(dev): def xformers_enabled(): + global xpu_available + global directml_enabled if vram_state == VRAMState.CPU: return False + if xpu_available: + return False + if directml_enabled: + return False return XFORMERS_IS_AVAILABLE @@ -251,6 +267,7 @@ def pytorch_attention_enabled(): def get_free_memory(dev=None, torch_free_too=False): global xpu_available + global directml_enabled if dev is None: dev = get_torch_device() @@ -258,7 +275,10 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = psutil.virtual_memory().available mem_free_torch = mem_free_total else: - if xpu_available: + if directml_enabled: + mem_free_total = 1024 * 1024 * 1024 #TODO + mem_free_torch = mem_free_total + elif xpu_available: mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) mem_free_torch = mem_free_total else: @@ -293,9 +313,14 @@ def mps_mode(): def should_use_fp16(): global xpu_available + global directml_enabled + if FORCE_FP32: return False + if directml_enabled: + return False + if cpu_mode() or mps_mode() or xpu_available: return False #TODO ? From 0306371e54ddb7472622eb43ed2180a109be6e6b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 16:18:54 -0400 Subject: [PATCH 044/208] Add "Installing" link to top of readme. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 5b6346a67..00b228497 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend. This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) +### [Installing](#installing) + ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Fully supports SD1.x and SD2.x From cab80973d187903d9c415cfcc2575e4616befaa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 16:19:56 -0400 Subject: [PATCH 045/208] Fix Readme. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 00b228497..3b3824714 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A powerful and modular stable diffusion GUI and backend. This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) -### [Installing](#installing) +### [Installing ComfyUI](#installing) ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. From 2ca934f7d4df3e4fa5a74172e5bbc1dd5e1a2ff9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 28 Apr 2023 16:51:35 -0400 Subject: [PATCH 046/208] You can now select the device index with: --directml id Like this for example: --directml 1 --- comfy/cli_args.py | 2 +- comfy/model_management.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 05b9c5e08..764427165 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -10,7 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") -parser.add_argument("--directml", action="store_true", help="Use torch-directml.") +parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 339111c4d..9497ae7af 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -21,10 +21,15 @@ accelerate_enabled = False xpu_available = False directml_enabled = False -if args.directml: +if args.directml is not None: import torch_directml - print("Using directml") directml_enabled = True + device_index = args.directml + if device_index < 0: + directml_device = torch_directml.device() + else: + directml_device = torch_directml.device(device_index) + print("Using directml with device:", torch_directml.device_name(device_index)) # torch_directml.disable_tiled_resources(True) try: @@ -226,7 +231,8 @@ def get_torch_device(): global xpu_available global directml_enabled if directml_enabled: - return torch_directml.device() + global directml_device + return directml_device if vram_state == VRAMState.MPS: return torch.device("mps") if vram_state == VRAMState.CPU: From 056e5545ffafc7c396cd18d0737a9d5e40f81552 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 00:28:48 -0400 Subject: [PATCH 047/208] Don't try to get vram from xpu or cuda when directml is enabled. --- comfy/model_management.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9497ae7af..db5d368e1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -34,13 +34,16 @@ if args.directml is not None: try: import torch - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) + if directml_enabled: + total_vram = 4097 #TODO + else: + try: + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True + total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) + except: + total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) if not args.normalvram and not args.cpu: if total_vram <= 4096: From af02393c2a7134861df57e5843fc17498c65a795 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 29 Apr 2023 00:16:58 -0700 Subject: [PATCH 048/208] Default to sampling entire image By default, when applying a mask to a condition, the entire image will still be used for sampling. The new "set_area_to_bounds" option on the node will allow the user to automatically limit conditioning to the bounds of the mask. I've also removed the dependency on torchvision for calculating bounding boxes. I've taken the opportunity to fix some frustrating details in the other version: 1. An all-0 mask will no longer cause an error 2. Indices are returned as integers instead of floats so they can be used to index into tensors. --- comfy/samplers.py | 42 ++++++++++++++++++++++++++++++++---------- nodes.py | 4 +++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6fa754b90..f8701c879 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,7 +6,6 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps -from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -31,8 +30,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if mask.shape[0] != input_x.shape[0]: - mask = mask.repeat(input_x.shape[0], 1, 1) + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) mult = mask * strength @@ -315,6 +313,29 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def get_mask_aabb(masks): + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device, dtype=torch.int) + + b = masks.shape[0] + + bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int) + is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool) + for i in range(b): + mask = masks[i] + if mask.numel() == 0: + continue + if torch.max(mask != 0) == False: + is_empty[i] = True + continue + y, x = torch.where(mask) + bounding_boxes[i, 0] = torch.min(x) + bounding_boxes[i, 1] = torch.min(y) + bounding_boxes[i, 2] = torch.max(x) + bounding_boxes[i, 3] = torch.max(y) + + return bounding_boxes, is_empty + def resolve_cond_masks(conditions, h, w, device): # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. # While we're doing this, we can also resolve the mask device and scaling for performance reasons @@ -329,13 +350,14 @@ def resolve_cond_masks(conditions, h, w, device): if mask.shape[2] != h or mask.shape[3] != w: mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) - if 'area' not in modified: + if modified.get("set_area_to_bounds", False): bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) - if torch.max(bounds) == 0: - # Handle the edge-case of an all black mask (where masks_to_boxes would error) - area = (0, 0, 0, 0) + boxes, is_empty = get_mask_aabb(bounds) + if is_empty[0]: + # Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway) + modified['area'] = (8, 8, 0, 0) else: - box = masks_to_boxes(bounds)[0].type(torch.int) + box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) # Make sure the height and width are divisible by 8 if X % 8 != 0: @@ -350,8 +372,8 @@ def resolve_cond_masks(conditions, h, w, device): H = H + (8 - (H % 8)) if W % 8 != 0: W = W + (8 - (W % 8)) - area = (int(H), int(W), int(Y), (X)) - modified['area'] = area + area = (int(H), int(W), int(Y), int(X)) + modified['area'] = area modified['mask'] = mask conditions[i] = [c[0], modified] diff --git a/nodes.py b/nodes.py index be02f4676..12fa7e5a3 100644 --- a/nodes.py +++ b/nodes.py @@ -90,6 +90,7 @@ class ConditioningSetMask: def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), "mask": ("MASK", ), + "set_area_to_bounds": ([False, True],), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) @@ -97,7 +98,7 @@ class ConditioningSetMask: CATEGORY = "conditioning" - def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0): c = [] if len(mask.shape) < 3: mask = mask.unsqueeze(0) @@ -105,6 +106,7 @@ class ConditioningSetMask: n = [t[0], t[1].copy()] _, h, w = mask.shape n[1]['mask'] = mask + n[1]['set_area_to_bounds'] = set_area_to_bounds n[1]['strength'] = strength n[1]['min_sigma'] = min_sigma n[1]['max_sigma'] = max_sigma From ffd0f9f417d94bce03ea863131df9e6a86a89ada Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 29 Apr 2023 17:19:14 +0100 Subject: [PATCH 049/208] Search filter by type --- web/extensions/core/slotDefaults.js | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/web/extensions/core/slotDefaults.js b/web/extensions/core/slotDefaults.js index 3ec605900..9401678b0 100644 --- a/web/extensions/core/slotDefaults.js +++ b/web/extensions/core/slotDefaults.js @@ -6,6 +6,7 @@ app.registerExtension({ name: "Comfy.SlotDefaults", suggestionsNumber: null, init() { + LiteGraph.search_filter_enabled = true; LiteGraph.middle_click_slot_add_default_node = true; this.suggestionsNumber = app.ui.settings.addSetting({ id: "Comfy.NodeSuggestions.number", @@ -43,6 +44,14 @@ app.registerExtension({ } if (this.slot_types_default_out[type].includes(nodeId)) continue; this.slot_types_default_out[type].push(nodeId); + + // Input types have to be stored as lower case + // Store each node that can handle this input type + const lowerType = type.toLocaleLowerCase(); + if (!(lowerType in LiteGraph.registered_slot_in_types)) { + LiteGraph.registered_slot_in_types[lowerType] = { nodes: [] }; + } + LiteGraph.registered_slot_in_types[lowerType].nodes.push(nodeType.comfyClass); } var outputs = nodeData["output"]; @@ -53,6 +62,16 @@ app.registerExtension({ } this.slot_types_default_in[type].push(nodeId); + + // Store each node that can handle this output type + if (!(type in LiteGraph.registered_slot_out_types)) { + LiteGraph.registered_slot_out_types[type] = { nodes: [] }; + } + LiteGraph.registered_slot_out_types[type].nodes.push(nodeType.comfyClass); + + if(!LiteGraph.slot_types_out.includes(type)) { + LiteGraph.slot_types_out.push(type); + } } var maxNum = this.suggestionsNumber.value; this.setDefaults(maxNum); From 15a4c0db3b11c75350268950d8d0da175e72440d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 29 Apr 2023 17:29:07 +0100 Subject: [PATCH 050/208] - button hover style - ensure context menu is always above everything --- web/style.css | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/web/style.css b/web/style.css index 2cbf02c0c..eced33d29 100644 --- a/web/style.css +++ b/web/style.css @@ -120,7 +120,7 @@ body { .comfy-menu > button, .comfy-menu-btns button, .comfy-menu .comfy-list button, -.comfy-modal button{ +.comfy-modal button { color: var(--input-text); background-color: var(--comfy-input-bg); border-radius: 8px; @@ -129,6 +129,15 @@ body { margin-top: 2px; } +.comfy-menu > button:hover, +.comfy-menu-btns button:hover, +.comfy-menu .comfy-list button:hover, +.comfy-modal button:hover, +.comfy-settings-btn:hover { + filter: brightness(1.2); + cursor: pointer; +} + .comfy-menu span.drag-handle { width: 10px; height: 20px; @@ -284,4 +293,7 @@ button.comfy-queue-btn { top: 0; right: 2px; } - \ No newline at end of file + + .litecontextmenu { + z-index: 9999 !important; +} \ No newline at end of file From 071011aebed2b636865dacacf6213d6714d6d80c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 20:06:53 -0400 Subject: [PATCH 051/208] Mask strength should be separate from area strength. --- comfy/samplers.py | 5 ++++- nodes.py | 6 ++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f8701c879..10527fb1c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -26,10 +26,13 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if 'mask' in cond[1]: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process + mask_strength = 1.0 + if "mask_strength" in cond[1]: + mask_strength = cond[1]["mask_strength"] mask = cond[1]['mask'] assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) diff --git a/nodes.py b/nodes.py index 12fa7e5a3..b4069c836 100644 --- a/nodes.py +++ b/nodes.py @@ -98,7 +98,7 @@ class ConditioningSetMask: CATEGORY = "conditioning" - def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, mask, set_area_to_bounds, strength): c = [] if len(mask.shape) < 3: mask = mask.unsqueeze(0) @@ -107,9 +107,7 @@ class ConditioningSetMask: _, h, w = mask.shape n[1]['mask'] = mask n[1]['set_area_to_bounds'] = set_area_to_bounds - n[1]['strength'] = strength - n[1]['min_sigma'] = min_sigma - n[1]['max_sigma'] = max_sigma + n[1]['mask_strength'] = strength c.append(n) return (c, ) From c66db067630c57ec037b906b6b3f766d1153522b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 20:19:14 -0400 Subject: [PATCH 052/208] Make ConditioningSetMask area option a bit more clear. Make ConditioningSetArea override the set_area_to_bounds. --- nodes.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index b4069c836..c9d660738 100644 --- a/nodes.py +++ b/nodes.py @@ -80,6 +80,7 @@ class ConditioningSetArea: n = [t[0], t[1].copy()] n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) n[1]['strength'] = strength + n[1]['set_area_to_bounds'] = False n[1]['min_sigma'] = min_sigma n[1]['max_sigma'] = max_sigma c.append(n) @@ -90,16 +91,19 @@ class ConditioningSetMask: def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), "mask": ("MASK", ), - "set_area_to_bounds": ([False, True],), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "set_cond_area": (["default", "mask bounds"],), }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" CATEGORY = "conditioning" - def append(self, conditioning, mask, set_area_to_bounds, strength): + def append(self, conditioning, mask, set_cond_area, strength): c = [] + set_area_to_bounds = False + if set_cond_area != "default": + set_area_to_bounds = True if len(mask.shape) < 3: mask = mask.unsqueeze(0) for t in conditioning: From 4cea9aecdab6bbd7b5801c64c27368ee3203a9ad Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 29 Apr 2023 20:53:03 -0400 Subject: [PATCH 053/208] Make nodes easier to resize. --- web/lib/litegraph.core.js | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 20ec35476..d471c0f50 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -5880,10 +5880,10 @@ LGraphNode.prototype.executeAction = function(action) node.resizable !== false && isInsideRectangle( e.canvasX, e.canvasY, - node.pos[0] + node.size[0] - 5, - node.pos[1] + node.size[1] - 5, - 10, - 10 + node.pos[0] + node.size[0] - 15, + node.pos[1] + node.size[1] - 15, + 20, + 20 ) ) { this.graph.beforeChange(); @@ -6428,10 +6428,10 @@ LGraphNode.prototype.executeAction = function(action) isInsideRectangle( e.canvasX, e.canvasY, - node.pos[0] + node.size[0] - 5, - node.pos[1] + node.size[1] - 5, - 5, - 5 + node.pos[0] + node.size[0] - 15, + node.pos[1] + node.size[1] - 15, + 15, + 15 ) ) { this.canvas.style.cursor = "se-resize"; From a2e18b15046456c86b0d550d515c737f976d03d6 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sun, 30 Apr 2023 18:59:58 +0200 Subject: [PATCH 054/208] allow disabling of progress bar when sampling --- comfy/samplers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 10527fb1c..1b486f803 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -541,7 +541,7 @@ class KSampler: sigmas = self.calculate_sigmas(new_steps).to(self.device) self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False): if sigmas is None: sigmas = self.sigmas sigma_min = self.sigma_min @@ -610,9 +610,9 @@ class KSampler: with precision_scope(model_management.get_autocast_device(self.device)): if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback) + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) elif self.sampler == "uni_pc_bh2": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2') + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) elif self.sampler == "ddim": timesteps = [] for s in range(sigmas.shape[0]): @@ -659,10 +659,10 @@ class KSampler: if latent_image is not None: noise += latent_image if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) elif self.sampler == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback) + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback) + samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar) return samples.to(torch.float32) From 20123624933cd559dc903f0b7c97566113018a1b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Apr 2023 13:02:07 -0400 Subject: [PATCH 055/208] Adjust node resize area depending on outputs. --- web/lib/litegraph.core.js | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index d471c0f50..2bc6af0c3 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -3628,6 +3628,18 @@ return size; }; + LGraphNode.prototype.inResizeCorner = function(canvasX, canvasY) { + var rows = this.outputs ? this.outputs.length : 1; + var outputs_offset = (this.constructor.slot_start_y || 0) + rows * LiteGraph.NODE_SLOT_HEIGHT; + return isInsideRectangle(canvasX, + canvasY, + this.pos[0] + this.size[0] - 15, + this.pos[1] + Math.max(this.size[1] - 15, outputs_offset), + 20, + 20 + ); + } + /** * returns all the info available about a property of this node. * @@ -5877,14 +5889,7 @@ LGraphNode.prototype.executeAction = function(action) if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { //Search for corner for resize if ( !skip_action && - node.resizable !== false && - isInsideRectangle( e.canvasX, - e.canvasY, - node.pos[0] + node.size[0] - 15, - node.pos[1] + node.size[1] - 15, - 20, - 20 - ) + node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) ) { this.graph.beforeChange(); this.resizing_node = node; @@ -6424,16 +6429,7 @@ LGraphNode.prototype.executeAction = function(action) //Search for corner if (this.canvas) { - if ( - isInsideRectangle( - e.canvasX, - e.canvasY, - node.pos[0] + node.size[0] - 15, - node.pos[1] + node.size[1] - 15, - 15, - 15 - ) - ) { + if (node.inResizeCorner(e.canvasX, e.canvasY)) { this.canvas.style.cursor = "se-resize"; } else { this.canvas.style.cursor = "crosshair"; From 29c8f1a3442aad7d615430f8484b85de995c141c Mon Sep 17 00:00:00 2001 From: FizzleDorf <1fizzledorf@gmail.com> Date: Sun, 30 Apr 2023 17:33:15 -0400 Subject: [PATCH 056/208] Conditioning Average (#495) * first commit * fixed a bunch of things missing in initial commit. * parameters renamed for clarity * renamed node, attempted update cond list * to_strength removed, it is now normalized * removed comments and prints. Attempted to apply to every cond in list again but no luck * fixed repeating frames after batch using deepcopy * Revert "fixed repeating frames after batch using deepcopy" This reverts commit 1086d6a0e1f5c5c9247312872402ff8e60358fe1. * Rewrite addWeighted to use torch.mul iteratively. --------- Co-authored-by: City <125218114+city96@users.noreply.github.com> --- nodes.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/nodes.py b/nodes.py index c9d660738..fc3d2f183 100644 --- a/nodes.py +++ b/nodes.py @@ -59,6 +59,27 @@ class ConditioningCombine: def combine(self, conditioning_1, conditioning_2): return (conditioning_1 + conditioning_2, ) +class ConditioningAverage : + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning_from": ("CONDITIONING", ), "conditioning_to": ("CONDITIONING", ), + "conditioning_from_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "addWeighted" + + CATEGORY = "conditioning" + + def addWeighted(self, conditioning_from, conditioning_to, conditioning_from_strength): + out = [] + for i in range(min(len(conditioning_from),len(conditioning_to))): + t0 = conditioning_from[i] + t1 = conditioning_to[i] + tw = torch.mul(t0[0],(1-conditioning_from_strength)) + torch.mul(t1[0],conditioning_from_strength) + n = [tw, t0[1].copy()] + out.append(n) + return (out, ) + class ConditioningSetArea: @classmethod def INPUT_TYPES(s): @@ -1143,6 +1164,7 @@ NODE_CLASS_MAPPINGS = { "ImageScale": ImageScale, "ImageInvert": ImageInvert, "ImagePadForOutpaint": ImagePadForOutpaint, + "ConditioningAverage ": ConditioningAverage , "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetMask": ConditioningSetMask, @@ -1194,6 +1216,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CLIPTextEncode": "CLIP Text Encode (Prompt)", "CLIPSetLastLayer": "CLIP Set Last Layer", "ConditioningCombine": "Conditioning (Combine)", + "ConditioningAverage ": "Conditioning (Average)", "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetMask": "Conditioning (Set Mask)", "ControlNetApply": "Apply ControlNet", From 0aa667ed33aae800880153a91c283ac457d0b31c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Apr 2023 17:28:55 -0400 Subject: [PATCH 057/208] Fix ConditioningAverage. --- nodes.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/nodes.py b/nodes.py index fc3d2f183..53e0f74bf 100644 --- a/nodes.py +++ b/nodes.py @@ -62,21 +62,30 @@ class ConditioningCombine: class ConditioningAverage : @classmethod def INPUT_TYPES(s): - return {"required": {"conditioning_from": ("CONDITIONING", ), "conditioning_to": ("CONDITIONING", ), - "conditioning_from_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1}) + return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ), + "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) }} RETURN_TYPES = ("CONDITIONING",) FUNCTION = "addWeighted" CATEGORY = "conditioning" - def addWeighted(self, conditioning_from, conditioning_to, conditioning_from_strength): + def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength): out = [] - for i in range(min(len(conditioning_from),len(conditioning_to))): - t0 = conditioning_from[i] - t1 = conditioning_to[i] - tw = torch.mul(t0[0],(1-conditioning_from_strength)) + torch.mul(t1[0],conditioning_from_strength) - n = [tw, t0[1].copy()] + + if len(conditioning_from) > 1: + print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.") + + cond_from = conditioning_from[0][0] + + for i in range(len(conditioning_to)): + t1 = conditioning_to[i][0] + t0 = cond_from[:,:t1.shape[1]] + if t0.shape[1] < t1.shape[1]: + t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1) + + tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength)) + n = [tw, conditioning_to[i][1].copy()] out.append(n) return (out, ) From b04e16ef5a7cd9cbf80d272a455bd34e869a6ec8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Apr 2023 18:19:03 -0400 Subject: [PATCH 058/208] Make default workflow use an existing checkpoint if no SD1.5 checkpoint. --- web/scripts/app.js | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index a161bf40e..ada1708dc 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -971,8 +971,10 @@ export class ComfyApp { loadGraphData(graphData) { this.clean(); + let reset_invalid_values = false; if (!graphData) { graphData = structuredClone(defaultGraph); + reset_invalid_values = true; } const missingNodeTypes = []; @@ -1058,6 +1060,13 @@ export class ComfyApp { } } } + if (reset_invalid_values) { + if (widget.type == "combo") { + if (!widget.options.values.includes(widget.value) && widget.options.values.length > 0) { + widget.value = widget.options.values[0]; + } + } + } } } From 6aae1f497f680355b0e51242c4195cf75803056d Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 1 May 2023 13:16:19 -0400 Subject: [PATCH 059/208] style context menu fix graphdialog background, and palette template --- web/extensions/core/colorPalette.js | 17 +++++++++++++++ web/style.css | 34 ++++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 41541a8d8..2f2238a2b 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -232,10 +232,27 @@ app.registerExtension({ "name": "My Color Palette", "colors": { "node_slot": { + }, + "litegraph_base": { + }, + "comfy_base": { } } }; + // Copy over missing keys from default color palette + const defaultColorPalette = colorPalettes[defaultColorPaletteId]; + for (const key in defaultColorPalette.colors.litegraph_base) { + if (!colorPalette.colors.litegraph_base[key]) { + colorPalette.colors.litegraph_base[key] = ""; + } + } + for (const key in defaultColorPalette.colors.comfy_base) { + if (!colorPalette.colors.comfy_base[key]) { + colorPalette.colors.comfy_base[key] = ""; + } + } + return completeColorPalette(colorPalette); }; diff --git a/web/style.css b/web/style.css index eced33d29..6ef3a4c21 100644 --- a/web/style.css +++ b/web/style.css @@ -257,8 +257,11 @@ button.comfy-queue-btn { } } +/* Input popup */ + .graphdialog { min-height: 1em; + background-color: var(--comfy-menu-bg); } .graphdialog .name { @@ -282,18 +285,37 @@ button.comfy-queue-btn { border-radius: 12px 0 0 12px; } +/* Context menu */ + .litegraph .litemenu-entry.has_submenu { position: relative; padding-right: 20px; - } +} - .litemenu-entry.has_submenu::after { +.litemenu-entry.has_submenu::after { content: ">"; position: absolute; top: 0; right: 2px; - } - - .litecontextmenu { +} + +.litecontextmenu { z-index: 9999 !important; -} \ No newline at end of file +} + +.litegraph.litecontextmenu { + background-color: var(--comfy-menu-bg) !important; + filter: brightness(95%); + color: var(--input-text) !important; +} + +.litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) { + background-color: var(--comfy-menu-bg) !important; + filter: brightness(155%); + color: var(--input-text) !important; +} + +.litegraph.litecontextmenu .litemenu-entry.submenu { + background-color: var(--comfy-menu-bg) !important; + color: var(--input-text) !important; +} From d3293c833947928456cd69a67c5e7d602216f997 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 1 May 2023 15:47:10 -0400 Subject: [PATCH 060/208] Properly disable all progress bars when disable_pbar=True --- comfy/extra_samplers/uni_pc.py | 8 ++++---- comfy/ldm/models/diffusion/ddim.py | 8 +++++--- comfy/sample.py | 4 ++-- comfy/samplers.py | 3 ++- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 2952be62d..78bab5936 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -712,7 +712,7 @@ class UniPC: def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, corrector=False, callback=None + atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False ): t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_T = self.noise_schedule.T if t_start is None else t_start @@ -723,7 +723,7 @@ class UniPC: # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): - for step_index in trange(steps): + for step_index in trange(steps, disable=disable_pbar): if self.noise_mask is not None: x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) if step_index == 0: @@ -835,7 +835,7 @@ def expand_dims(v, dims): -def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'): +def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): to_zero = False if sigmas[-1] == 0: timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] @@ -879,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex order = min(3, len(timesteps) - 1) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) - x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback) + x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) if not to_zero: x /= ns.marginal_alpha(timesteps[-1]) return x diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index e00ffd3f5..deab76f21 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -81,6 +81,7 @@ class DDIMSampler(object): extra_args=None, to_zero=True, end_step=None, + disable_pbar=False, **kwargs ): self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose) @@ -103,7 +104,8 @@ class DDIMSampler(object): denoise_function=denoise_function, extra_args=extra_args, to_zero=to_zero, - end_step=end_step + end_step=end_step, + disable_pbar=disable_pbar ) return samples, intermediates @@ -185,7 +187,7 @@ class DDIMSampler(object): mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None): + ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False): device = self.model.betas.device b = shape[0] if x_T is None: @@ -204,7 +206,7 @@ class DDIMSampler(object): total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] # print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step) + iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar) for i, step in enumerate(iterator): index = total_steps - i - 1 diff --git a/comfy/sample.py b/comfy/sample.py index f4132bbed..bd38585ac 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -56,7 +56,7 @@ def cleanup_additional_models(models): for m in models: m.cleanup() -def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None): +def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False): device = comfy.model_management.get_torch_device() if noise_mask is not None: @@ -76,7 +76,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback) + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar) samples = samples.cpu() cleanup_additional_models(models) diff --git a/comfy/samplers.py b/comfy/samplers.py index 1b486f803..b30fc3d9b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -643,7 +643,8 @@ class KSampler: extra_args=extra_args, mask=noise_mask, to_zero=sigmas[-1]==0, - end_step=sigmas.shape[0] - 1) + end_step=sigmas.shape[0] - 1, + disable_pbar=disable_pbar) else: extra_args["denoise_mask"] = denoise_mask From 81bee39ca0540aa7bbab275bb6bb9f156e72addd Mon Sep 17 00:00:00 2001 From: EllangoK Date: Mon, 1 May 2023 15:57:10 -0400 Subject: [PATCH 061/208] style everything styles searchbox, should be actually everything --- web/style.css | 43 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/web/style.css b/web/style.css index 6ef3a4c21..df220cc02 100644 --- a/web/style.css +++ b/web/style.css @@ -299,23 +299,52 @@ button.comfy-queue-btn { right: 2px; } -.litecontextmenu { +.litegraph.litecontextmenu, +.litegraph.litecontextmenu.dark { z-index: 9999 !important; -} - -.litegraph.litecontextmenu { background-color: var(--comfy-menu-bg) !important; filter: brightness(95%); - color: var(--input-text) !important; } .litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) { background-color: var(--comfy-menu-bg) !important; filter: brightness(155%); + color: var(--input-text); +} + +.litegraph.litecontextmenu .litemenu-entry.submenu, +.litegraph.litecontextmenu.dark .litemenu-entry.submenu { + background-color: var(--comfy-menu-bg) !important; + color: var(--input-text); +} + +.litegraph.litecontextmenu input { + background-color: var(--comfy-input-bg) !important; color: var(--input-text) !important; } -.litegraph.litecontextmenu .litemenu-entry.submenu { +/* Search box */ + +.litegraph.litesearchbox { + z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; - color: var(--input-text) !important; + overflow: hidden; +} + +.litegraph.litesearchbox input, +.litegraph.litesearchbox select { + background-color: var(--comfy-input-bg) !important; + color: var(--input-text); +} + +.litegraph.lite-search-item { + color: var(--input-text); + background-color: var(--comfy-input-bg); + filter: brightness(80%); + padding-left: 0.2em; +} + +.litegraph.lite-search-item.generic_type { + color: var(--input-text); + filter: brightness(50%); } From 9c335a553fd9f8d4c3c97eeaec5dca89a2a900f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 1 May 2023 18:11:58 -0400 Subject: [PATCH 062/208] LoKR support. --- comfy/sd.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index 92dbb931d..3eb50cc95 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -111,6 +111,8 @@ def load_lora(path, to_load): loaded_keys.add(A_name) loaded_keys.add(B_name) + + ######## loha hada_w1_a_name = "{}.hada_w1_a".format(x) hada_w1_b_name = "{}.hada_w1_b".format(x) hada_w2_a_name = "{}.hada_w2_a".format(x) @@ -132,6 +134,54 @@ def load_lora(path, to_load): loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_b_name) + + ######## lokr + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) + for x in lora.keys(): if x not in loaded_keys: print("lora key not loaded", x) @@ -315,6 +365,33 @@ class ModelPatcher: final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) + elif len(v) == 8: #lokr + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(w1_a.float(), w1_b.float()) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(w2_a.float(), w2_b.float()) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float()) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha *= v[2] / dim + + weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device) else: #loha w1a = v[0] w1b = v[1] From 35f636b6c741045d25d645ecb95a6e8e2c04d6eb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 00:53:15 -0400 Subject: [PATCH 063/208] Expose grow_mask_by in VAEEncodeForInpaint. The mask is dilated by grow_mask_by pixels after being applied to the pixel space image. This helps reduce seams caused by inpainting. Higher value means less seams. --- nodes.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 53e0f74bf..4f0b7bfe8 100644 --- a/nodes.py +++ b/nodes.py @@ -5,6 +5,7 @@ import sys import json import hashlib import traceback +import math from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -223,13 +224,13 @@ class VAEEncodeForInpaint: @classmethod def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}} + return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "latent/inpaint" - def encode(self, vae, pixels, mask): + def encode(self, vae, pixels, mask, grow_mask_by=6): x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") @@ -240,8 +241,14 @@ class VAEEncodeForInpaint: mask = mask[:,:,:x,:y] #grow mask by a few pixels to keep things seamless in latent space - kernel_tensor = torch.ones((1, 1, 6, 6)) - mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1) + if grow_mask_by == 0: + mask_erosion = mask + else: + kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by)) + padding = math.ceil((grow_mask_by - 1) / 2) + + mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1) + m = (1.0 - mask.round()).squeeze(1) for i in range(3): pixels[:,:,:,i] -= 0.5 From a307c3f12c7816885802ae4ad2ffc6a14e550540 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 09:40:57 -0400 Subject: [PATCH 064/208] Update nightly pytorch standalone to python 3.11.3 cu121. --- .../update_comfyui_and_python_dependencies.bat | 2 +- .github/workflows/windows_release_nightly_pytorch.yml | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat index b4989534f..94f5d1023 100755 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 +..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 pause diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index f23cae6d5..b6a18ec0a 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -19,21 +19,21 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v4 with: - python-version: '3.10.9' + python-version: '3.11.3' - shell: bash run: | cd .. cp -r ComfyUI ComfyUI_copy - curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip + curl https://www.python.org/ftp/python/3.11.3/python-3.11.3-embed-amd64.zip -o python_embeded.zip unzip python_embeded.zip -d python_embeded cd python_embeded - echo 'import site' >> ./python310._pth + echo 'import site' >> ./python311._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* - sed -i '1i../ComfyUI' ./python310._pth + sed -i '1i../ComfyUI' ./python311._pth cd .. From 66c8aa5c3ee601dbca396f66fe86703977b908b5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 13:31:43 -0400 Subject: [PATCH 065/208] Make unet work with any input shape. --- .../modules/diffusionmodules/openaimodel.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c69c8567..0393dc013 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -76,12 +76,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): support it as an extra input. """ - def forward(self, x, emb, context=None, transformer_options={}): + def forward(self, x, emb, context=None, transformer_options={}, output_shape=None): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) + elif isinstance(layer, Upsample): + x = layer(x, output_shape=output_shape) else: x = layer(x) return x @@ -105,14 +107,21 @@ class Upsample(nn.Module): if use_conv: self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) - def forward(self, x): + def forward(self, x, output_shape=None): + print("upsample", output_shape) assert x.shape[1] == self.channels if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2] + if output_shape is not None: + shape[1] = output_shape[3] + shape[2] = output_shape[4] else: - x = F.interpolate(x, scale_factor=2, mode="nearest") + shape = [x.shape[2] * 2, x.shape[3] * 2] + if output_shape is not None: + shape[0] = output_shape[2] + shape[1] = output_shape[3] + + x = F.interpolate(x, size=shape, mode="nearest") if self.use_conv: x = self.conv(x) return x @@ -813,9 +822,14 @@ class UNetModel(nn.Module): ctrl = control['output'].pop() if ctrl is not None: hsp += ctrl + h = th.cat([h, hsp], dim=1) del hsp - h = module(h, emb, context, transformer_options) + if len(hs) > 0: + output_shape = hs[-1].shape + else: + output_shape = None + h = module(h, emb, context, transformer_options, output_shape) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) From ba8a4c3667eda95649d8bfa906186d42e9ac6835 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 14:16:27 -0400 Subject: [PATCH 066/208] Change latent resolution step to 8. --- .../modules/diffusionmodules/openaimodel.py | 1 - nodes.py | 72 +++++++++---------- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 0393dc013..25309dbd7 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -108,7 +108,6 @@ class Upsample(nn.Module): self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) def forward(self, x, output_shape=None): - print("upsample", output_shape) assert x.shape[1] == self.channels if self.dims == 3: shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2] diff --git a/nodes.py b/nodes.py index 4f0b7bfe8..80d508854 100644 --- a/nodes.py +++ b/nodes.py @@ -94,10 +94,10 @@ class ConditioningSetArea: @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), - "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), + "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) @@ -188,16 +188,21 @@ class VAEEncode: CATEGORY = "latent" - def encode(self, vae, pixels): - x = (pixels.shape[1] // 64) * 64 - y = (pixels.shape[2] // 64) * 64 + @staticmethod + def vae_encode_crop_pixels(pixels): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 if pixels.shape[1] != x or pixels.shape[2] != y: - pixels = pixels[:,:x,:y,:] + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + return pixels + + def encode(self, vae, pixels): + pixels = self.vae_encode_crop_pixels(pixels) t = vae.encode(pixels[:,:,:,:3]) - return ({"samples":t}, ) - class VAEEncodeTiled: def __init__(self, device="cpu"): self.device = device @@ -211,13 +216,10 @@ class VAEEncodeTiled: CATEGORY = "_for_testing" def encode(self, vae, pixels): - x = (pixels.shape[1] // 64) * 64 - y = (pixels.shape[2] // 64) * 64 - if pixels.shape[1] != x or pixels.shape[2] != y: - pixels = pixels[:,:x,:y,:] + pixels = VAEEncode.vae_encode_crop_pixels(pixels) t = vae.encode_tiled(pixels[:,:,:,:3]) - return ({"samples":t}, ) + class VAEEncodeForInpaint: def __init__(self, device="cpu"): self.device = device @@ -231,14 +233,16 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // 64) * 64 - y = (pixels.shape[2] // 64) * 64 + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - pixels = pixels[:,:x,:y,:] - mask = mask[:,:,:x,:y] + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] #grow mask by a few pixels to keep things seamless in latent space if grow_mask_by == 0: @@ -610,8 +614,8 @@ class EmptyLatentImage: @classmethod def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + return {"required": { "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}} RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -649,8 +653,8 @@ class LatentUpscale: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), - "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), "crop": (s.crop_methods,)}} RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" @@ -752,8 +756,8 @@ class LatentCrop: @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), - "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), + "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), }} @@ -778,16 +782,6 @@ class LatentCrop: new_width = width // 8 to_x = new_width + x to_y = new_height + y - def enforce_image_dim(d, to_d, max_d): - if to_d > max_d: - leftover = (to_d - max_d) % 8 - to_d = max_d - d -= leftover - return (d, to_d) - - #make sure size is always multiple of 64 - x, to_x = enforce_image_dim(x, to_x, samples.shape[3]) - y, to_y = enforce_image_dim(y, to_y, samples.shape[2]) s['samples'] = samples[:,:,y:to_y, x:to_x] return (s,) @@ -1105,10 +1099,10 @@ class ImagePadForOutpaint: return { "required": { "image": ("IMAGE",), - "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), - "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), - "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), - "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 64}), + "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}), } } From 06ad35b4932fe6cc4382d8b1dfa79fef8284362a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 2 May 2023 19:18:07 +0100 Subject: [PATCH 067/208] added progress to encode + upscale --- comfy/sd.py | 12 +++++++++--- comfy_extras/nodes_upscale_model.py | 8 +++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 2aadefadc..06d6c1a56 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -491,9 +491,15 @@ class VAE: model_management.unload_model() self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4) - samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4) + + it_1 = -(pixel_samples.shape[2] // -(tile_y * 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x // 2 - overlap)) + it_2 = -(pixel_samples.shape[2] // -(tile_y // 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x * 2 - overlap)) + it_3 = -(pixel_samples.shape[2] // -(tile_y - overlap)) * -(pixel_samples.shape[3] // -(tile_x - overlap)) + pbar = tqdm(total=(it_1 + it_2 + it_3)) + + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples /= 3.0 self.first_stage_model = self.first_stage_model.cpu() samples = samples.cpu() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d8754698c..4fc7dcd77 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -4,6 +4,7 @@ from comfy import model_management import torch import comfy.utils import folder_paths +from tqdm.auto import tqdm class UpscaleModelLoader: @classmethod @@ -37,7 +38,12 @@ class ImageUpscaleWithModel: device = model_management.get_torch_device() upscale_model.to(device) in_img = image.movedim(-1,-3).to(device) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=128 + 64, tile_y=128 + 64, overlap = 8, upscale_amount=upscale_model.scale) + + tile = 128 + 64 + overlap = 8 + its = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) + pbar = tqdm(total=its) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return (s,) From 93c64afaa92b425fc863b80ee0b7c618705d7d60 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 May 2023 23:00:49 -0400 Subject: [PATCH 068/208] Use sampler callback instead of tqdm hook for progress bar. --- comfy/utils.py | 23 +++++++++++++++++++++++ main.py | 12 ++++-------- nodes.py | 6 +++++- 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 68f93403c..7f3c3978c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -86,3 +86,26 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am output[b:b+1] = out/out_div return output + + +PROGRESS_BAR_HOOK = None +def set_progress_bar_global_hook(function): + global PROGRESS_BAR_HOOK + PROGRESS_BAR_HOOK = function + +class ProgressBar: + def __init__(self, total): + global PROGRESS_BAR_HOOK + self.total = total + self.current = 0 + self.hook = PROGRESS_BAR_HOOK + + def update_absolute(self, value): + if value > self.total: + value = self.total + self.current = value + if self.hook is not None: + self.hook(self.current, self.total) + + def update(self, value): + self.update_absolute(self.current + value) diff --git a/main.py b/main.py index 02c700ebc..f369b82f3 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import shutil import threading from comfy.cli_args import args +import comfy.utils if os.name == "nt": import logging @@ -39,14 +40,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) def hijack_progress(server): - from tqdm.auto import tqdm - orig_func = getattr(tqdm, "update") - def wrapped_func(*args, **kwargs): - pbar = args[0] - v = orig_func(*args, **kwargs) - server.send_sync("progress", { "value": pbar.n, "max": pbar.total}, server.client_id) - return v - setattr(tqdm, "update", wrapped_func) + def hook(value, total): + server.send_sync("progress", { "value": value, "max": total}, server.client_id) + comfy.utils.set_progress_bar_global_hook(hook) def cleanup_temp(): temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") diff --git a/nodes.py b/nodes.py index 80d508854..90c943fe3 100644 --- a/nodes.py +++ b/nodes.py @@ -815,9 +815,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] + pbar = comfy.utils.ProgressBar(steps) + def callback(step, x0, x): + pbar.update_absolute(step + 1) + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, - force_full_denoise=force_full_denoise, noise_mask=noise_mask) + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback) out = latent.copy() out["samples"] = samples return (out, ) From 27df74101e6e5bb761364b718d57313388b49182 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 3 May 2023 17:33:19 +0100 Subject: [PATCH 069/208] reduce duplication --- comfy/sd.py | 14 +++++--------- comfy/utils.py | 6 ++++++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 06d6c1a56..87b380b1c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -438,10 +438,8 @@ class VAE: self.device = device def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): - it_1 = -(samples.shape[2] // -(tile_y * 2 - overlap)) * -(samples.shape[3] // -(tile_x // 2 - overlap)) - it_2 = -(samples.shape[2] // -(tile_y // 2 - overlap)) * -(samples.shape[3] // -(tile_x * 2 - overlap)) - it_3 = -(samples.shape[2] // -(tile_y - overlap)) * -(samples.shape[3] // -(tile_x - overlap)) - pbar = tqdm(total=samples.shape[0] * (it_1 + it_2 + it_3)) + steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + pbar = tqdm(total=steps) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = torch.clamp(( @@ -492,11 +490,9 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - it_1 = -(pixel_samples.shape[2] // -(tile_y * 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x // 2 - overlap)) - it_2 = -(pixel_samples.shape[2] // -(tile_y // 2 - overlap)) * -(pixel_samples.shape[3] // -(tile_x * 2 - overlap)) - it_3 = -(pixel_samples.shape[2] // -(tile_y - overlap)) * -(pixel_samples.shape[3] // -(tile_x - overlap)) - pbar = tqdm(total=(it_1 + it_2 + it_3)) - + steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + pbar = tqdm(total=steps) + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) diff --git a/comfy/utils.py b/comfy/utils.py index c7c6a08c5..82d3aa0d8 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -62,6 +62,12 @@ def common_upscale(samples, width, height, upscale_method, crop): s = samples return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) +def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): + it_1 = -(height // -(tile_y * 2 - overlap)) * -(width // -(tile_x // 2 - overlap)) + it_2 = -(height // -(tile_y // 2 - overlap)) * -(width // -(tile_x * 2 - overlap)) + it_3 = -(height // -(tile_y - overlap)) * -(width // -(tile_x - overlap)) + return it_1 + it_2 + it_3 + @torch.inference_mode() def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") From 908dc1d5a8717073f44d136d6d2b4f983ea07d40 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 May 2023 12:58:10 -0400 Subject: [PATCH 070/208] Add a total_steps value to sampler callback. --- comfy/extra_samplers/uni_pc.py | 2 +- comfy/samplers.py | 8 +++++--- comfy/utils.py | 4 +++- nodes.py | 4 ++-- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 78bab5936..2ff10caf1 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -767,7 +767,7 @@ class UniPC: model_x = self.model_fn(x, vec_t) model_prev_list[-1] = model_x if callback is not None: - callback(step_index, model_prev_list[-1], x) + callback(step_index, model_prev_list[-1], x, steps) else: raise NotImplementedError() if denoise_to_zero: diff --git a/comfy/samplers.py b/comfy/samplers.py index b30fc3d9b..dcf93cca2 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -623,7 +623,8 @@ class KSampler: ddim_callback = None if callback is not None: - ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None) + total_steps = len(timesteps) - 1 + ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) sampler = DDIMSampler(self.model, device=self.device) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) @@ -654,13 +655,14 @@ class KSampler: noise = noise * sigmas[0] k_callback = None + total_steps = len(sigmas) - 1 if callback is not None: - k_callback = lambda x: callback(x["i"], x["denoised"], x["x"]) + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) if latent_image is not None: noise += latent_image if self.sampler == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) elif self.sampler == "dpm_adaptive": samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) else: diff --git a/comfy/utils.py b/comfy/utils.py index 7f3c3978c..f1ff97792 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -100,7 +100,9 @@ class ProgressBar: self.current = 0 self.hook = PROGRESS_BAR_HOOK - def update_absolute(self, value): + def update_absolute(self, value, total=None): + if total is not None: + self.total = total if value > self.total: value = self.total self.current = value diff --git a/nodes.py b/nodes.py index 90c943fe3..c2bc36855 100644 --- a/nodes.py +++ b/nodes.py @@ -816,8 +816,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = latent["noise_mask"] pbar = comfy.utils.ProgressBar(steps) - def callback(step, x0, x): - pbar.update_absolute(step + 1) + def callback(step, x0, x, total_steps): + pbar.update_absolute(step + 1, total_steps) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, From 8912623ea9929848b813f1aeafee0fa9e1281817 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 3 May 2023 18:19:22 +0100 Subject: [PATCH 071/208] use comfy progress bar --- comfy/sd.py | 6 +++--- comfy_extras/nodes_upscale_model.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 32499f600..e4c5282d7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -516,7 +516,7 @@ class VAE: def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) - pbar = tqdm(total=steps) + pbar = utils.ProgressBar(steps) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) output = torch.clamp(( @@ -568,8 +568,8 @@ class VAE: pixel_samples = pixel_samples.movedim(-1,1).to(self.device) steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) - pbar = tqdm(total=steps) - + pbar = utils.ProgressBar(steps) + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 4fc7dcd77..dfd1994a6 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -41,8 +41,8 @@ class ImageUpscaleWithModel: tile = 128 + 64 overlap = 8 - its = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) - pbar = tqdm(total=its) + steps = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) + pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) From 5eeecf3fd5adedfa5a92d3549f77a78be714c2a3 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 3 May 2023 18:21:23 +0100 Subject: [PATCH 072/208] remove unused import --- comfy/sd.py | 1 - comfy_extras/nodes_upscale_model.py | 1 - 2 files changed, 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index e4c5282d7..d60b908b8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,7 +1,6 @@ import torch import contextlib import copy -from tqdm.auto import tqdm import sd1_clip import sd2_clip diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index dfd1994a6..f774b4b77 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -4,7 +4,6 @@ from comfy import model_management import torch import comfy.utils import folder_paths -from tqdm.auto import tqdm class UpscaleModelLoader: @classmethod From fcf513e0b6b599e23b7d6f9bde315be6f991652b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 May 2023 17:48:35 -0400 Subject: [PATCH 073/208] Refactor. --- comfy/sd.py | 6 +++++- comfy/utils.py | 6 ++---- comfy_extras/nodes_upscale_model.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index d60b908b8..174ed35e5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -515,6 +515,8 @@ class VAE: def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0) @@ -566,7 +568,9 @@ class VAE: self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1).to(self.device) - steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = utils.ProgressBar(steps) samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) diff --git a/comfy/utils.py b/comfy/utils.py index 5c7143fd9..09e05d4ed 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,4 +1,5 @@ import torch +import math def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -63,10 +64,7 @@ def common_upscale(samples, width, height, upscale_method, crop): return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): - it_1 = -(height // -(tile_y * 2 - overlap)) * -(width // -(tile_x // 2 - overlap)) - it_2 = -(height // -(tile_y // 2 - overlap)) * -(width // -(tile_x * 2 - overlap)) - it_3 = -(height // -(tile_y - overlap)) * -(width // -(tile_x - overlap)) - return it_1 + it_2 + it_3 + return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index f774b4b77..ab5b0ccfc 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -40,7 +40,7 @@ class ImageUpscaleWithModel: tile = 128 + 64 overlap = 8 - steps = -(in_img.shape[2] // -(tile - overlap)) * -(in_img.shape[3] // -(tile - overlap)) + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) upscale_model.cpu() From 7e51bbd07f809555cc50c4fdae3ef84720e5c86f Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 4 May 2023 19:42:07 +0100 Subject: [PATCH 074/208] automatic calculation of image pos from widgets --- web/scripts/app.js | 39 ++++++++++++++++++++++++++++++--------- web/scripts/widgets.js | 9 +-------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index ada1708dc..f0c0f9de4 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -263,6 +263,34 @@ export class ComfyApp { */ #addDrawBackgroundHandler(node) { const app = this; + + function getImageTop(node) { + let shiftY; + if (node.imageOffset != null) { + shiftY = node.imageOffset; + } else { + if (node.widgets?.length) { + const w = node.widgets[node.widgets.length - 1]; + shiftY = w.last_y; + if (w.computeSize) { + shiftY += w.computeSize()[1] + 4; + } else { + shiftY += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } else { + shiftY = node.computeSize()[1]; + } + } + return shiftY; + } + + node.prototype.setSizeForImage = function () { + const minHeight = getImageTop(this) + 220; + if (this.size[1] < minHeight) { + this.setSize([this.size[0], minHeight]); + } + }; + node.prototype.onDrawBackground = function (ctx) { if (!this.flags.collapsed) { const output = app.nodeOutputs[this.id + ""]; @@ -283,9 +311,7 @@ export class ComfyApp { ).then((imgs) => { if (this.images === output.images) { this.imgs = imgs.filter(Boolean); - if (this.size[1] < 100) { - this.size[1] = 250; - } + this.setSizeForImage?.(); app.graph.setDirtyCanvas(true); } }); @@ -310,12 +336,7 @@ export class ComfyApp { this.imageIndex = imageIndex = 0; } - let shiftY; - if (this.imageOffset != null) { - shiftY = this.imageOffset; - } else { - shiftY = this.computeSize()[1]; - } + const shiftY = getImageTop(this); let dw = this.size[0]; let dh = this.size[1]; diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index c0e73ffa1..cd471bc93 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -261,20 +261,13 @@ export const ComfyWidgets = { let uploadWidget; function showImage(name) { - // Position the image somewhere sensible - if (!node.imageOffset) { - node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 25 : 75; - } - const img = new Image(); img.onload = () => { node.imgs = [img]; app.graph.setDirtyCanvas(true); }; img.src = `/view?filename=${name}&type=input`; - if ((node.size[1] - node.imageOffset) < 100) { - node.size[1] = 250 + node.imageOffset; - } + node.setSizeForImage?.(); } // Add our own callback to the combo widget to render an image when it changes From bae4fb4a9dc944c10cca922dc4442eef57bbf583 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 4 May 2023 18:07:41 -0400 Subject: [PATCH 075/208] Fix imports. --- comfy/cldm/cldm.py | 10 +++---- comfy/gligen.py | 2 +- comfy/ldm/models/autoencoder.py | 8 +++--- comfy/ldm/models/diffusion/ddim.py | 2 +- comfy/ldm/models/diffusion/ddpm.py | 12 ++++----- comfy/ldm/modules/attention.py | 4 +-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- .../modules/diffusionmodules/openaimodel.py | 6 ++--- .../ldm/modules/diffusionmodules/upscaling.py | 4 +-- comfy/ldm/modules/diffusionmodules/util.py | 2 +- .../ldm/modules/encoders/noise_aug_modules.py | 4 +-- comfy/model_management.py | 2 +- comfy/sd.py | 26 +++++++++---------- 13 files changed, 42 insertions(+), 42 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index c60abf80b..cb660ee77 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -5,17 +5,17 @@ import torch import torch as th import torch.nn as nn -from ldm.modules.diffusionmodules.util import ( +from ..ldm.modules.diffusionmodules.util import ( conv_nd, linear, zero_module, timestep_embedding, ) -from ldm.modules.attention import SpatialTransformer -from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock -from ldm.models.diffusion.ddpm import LatentDiffusion -from ldm.util import log_txt_as_img, exists, instantiate_from_config +from ..ldm.modules.attention import SpatialTransformer +from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock +from ..ldm.models.diffusion.ddpm import LatentDiffusion +from ..ldm.util import log_txt_as_img, exists, instantiate_from_config class ControlledUnetModel(UNetModel): diff --git a/comfy/gligen.py b/comfy/gligen.py index 8770383e5..45b674503 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -1,6 +1,6 @@ import torch from torch import nn, einsum -from ldm.modules.attention import CrossAttention +from .ldm.modules.attention import CrossAttention from inspect import isfunction diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index bd698621c..1fb7ed879 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -3,11 +3,11 @@ import torch import torch.nn.functional as F from contextlib import contextmanager -from ldm.modules.diffusionmodules.model import Encoder, Decoder -from ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from comfy.ldm.modules.diffusionmodules.model import Encoder, Decoder +from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution -from ldm.util import instantiate_from_config -from ldm.modules.ema import LitEma +from comfy.ldm.util import instantiate_from_config +from comfy.ldm.modules.ema import LitEma # class AutoencoderKL(pl.LightningModule): class AutoencoderKL(torch.nn.Module): diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index deab76f21..c279f2c18 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -4,7 +4,7 @@ import torch import numpy as np from tqdm import tqdm -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor +from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor class DDIMSampler(object): diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index d3f0eb2b2..0f484a7f1 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -19,12 +19,12 @@ from tqdm import tqdm from torchvision.utils import make_grid # from pytorch_lightning.utilities.distributed import rank_zero_only -from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from ldm.models.diffusion.ddim import DDIMSampler +from comfy.ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from comfy.ldm.modules.ema import LitEma +from comfy.ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ..autoencoder import IdentityFirstStage, AutoencoderKL +from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from .ddim import DDIMSampler __conditioning_keys__ = {'concat': 'c_concat', diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index ce7180d91..5eabecd65 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -6,7 +6,7 @@ from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any -from ldm.modules.diffusionmodules.util import checkpoint +from .diffusionmodules.util import checkpoint from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management @@ -21,7 +21,7 @@ if model_management.xformers_enabled(): import os _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") -from cli_args import args +from comfy.cli_args import args def exists(val): return val is not None diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 1599d386e..5e4d2b60f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -6,7 +6,7 @@ import numpy as np from einops import rearrange from typing import Optional, Any -from ldm.modules.attention import MemoryEfficientCrossAttention +from ..attention import MemoryEfficientCrossAttention from comfy import model_management if model_management.xformers_enabled_vae(): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 25309dbd7..4352b756d 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -6,7 +6,7 @@ import torch as th import torch.nn as nn import torch.nn.functional as F -from ldm.modules.diffusionmodules.util import ( +from .util import ( checkpoint, conv_nd, linear, @@ -15,8 +15,8 @@ from ldm.modules.diffusionmodules.util import ( normalization, timestep_embedding, ) -from ldm.modules.attention import SpatialTransformer -from ldm.util import exists +from ..attention import SpatialTransformer +from comfy.ldm.util import exists # dummy replace diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index 038166620..709a7f52e 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -3,8 +3,8 @@ import torch.nn as nn import numpy as np from functools import partial -from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule -from ldm.util import default +from .util import extract_into_tensor, make_beta_schedule +from comfy.ldm.util import default class AbstractLowScaleModel(nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index daf35da7b..82ea3f0a6 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -15,7 +15,7 @@ import torch.nn as nn import numpy as np from einops import repeat -from ldm.util import instantiate_from_config +from comfy.ldm.util import instantiate_from_config def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index f99e7920a..b59bf204b 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -1,5 +1,5 @@ -from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation -from ldm.modules.diffusionmodules.openaimodel import Timestep +from ..diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +from ..diffusionmodules.openaimodel import Timestep import torch class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): diff --git a/comfy/model_management.py b/comfy/model_management.py index db5d368e1..e89f80d69 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,6 @@ import psutil from enum import Enum -from cli_args import args +from .cli_args import args class VRAMState(Enum): CPU = 0 diff --git a/comfy/sd.py b/comfy/sd.py index 174ed35e5..3543bdb77 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -2,8 +2,8 @@ import torch import contextlib import copy -import sd1_clip -import sd2_clip +from . import sd1_clip +from . import sd2_clip from comfy import model_management from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL @@ -446,10 +446,10 @@ class CLIP: else: params = {} - if self.target_clip == "ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder": + if self.target_clip.endswith("FrozenOpenCLIPEmbedder"): clip = sd2_clip.SD2ClipModel tokenizer = sd2_clip.SD2Tokenizer - elif self.target_clip == "ldm.modules.encoders.modules.FrozenCLIPEmbedder": + elif self.target_clip.endswith("FrozenCLIPEmbedder"): clip = sd1_clip.SD1ClipModel tokenizer = sd1_clip.SD1Tokenizer @@ -896,9 +896,9 @@ def load_clip(ckpt_path, embedding_directory=None): clip_data = utils.load_torch_file(ckpt_path) config = {} if "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data: - config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' + config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' else: - config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' + config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder' clip = CLIP(config=config, embedding_directory=embedding_directory) clip.load_from_state_dict(clip_data) return clip @@ -974,9 +974,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if output_clip: clip_config = {} if "cond_stage_model.model.transformer.resblocks.22.attn.out_proj.weight" in sd_keys: - clip_config['target'] = 'ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' + clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder' else: - clip_config['target'] = 'ldm.modules.encoders.modules.FrozenCLIPEmbedder' + clip_config['target'] = 'comfy.ldm.modules.encoders.modules.FrozenCLIPEmbedder' clip = CLIP(config=clip_config, embedding_directory=embedding_directory) w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] @@ -997,7 +997,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o noise_schedule_config["timesteps"] = sd[noise_aug_key].shape[0] noise_schedule_config["beta_schedule"] = "squaredcos_cap_v2" params["noise_schedule_config"] = noise_schedule_config - noise_aug_config['target'] = "ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" + noise_aug_config['target'] = "comfy.ldm.modules.encoders.noise_aug_modules.CLIPEmbeddingNoiseAugmentation" if size == 1280: #h params["timestep_dim"] = 1024 elif size == 1024: #l @@ -1049,19 +1049,19 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] - sd_config["unet_config"] = {"target": "ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} - model_config = {"target": "ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} + sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} + model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} if noise_aug_config is not None: #SD2.x unclip model sd_config["noise_aug_config"] = noise_aug_config sd_config["image_size"] = 96 sd_config["embedding_dropout"] = 0.25 sd_config["conditioning_key"] = 'crossattn-adm' - model_config["target"] = "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" + model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" elif unet_config["in_channels"] > 4: #inpainting model sd_config["conditioning_key"] = "hybrid" sd_config["finetune_keys"] = None - model_config["target"] = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion" else: sd_config["conditioning_key"] = "crossattn" From 1a31020081b22cb55e573f65a11bd4c2c96f17f1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 00:16:57 -0400 Subject: [PATCH 076/208] Support softsign hypernetwork. --- comfy_extras/nodes_hypernetwork.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index 0c7250e43..c19b5e4c7 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -18,6 +18,7 @@ def load_hypernetwork_patch(path, strength): "swish": torch.nn.Hardswish, "tanh": torch.nn.Tanh, "sigmoid": torch.nn.Sigmoid, + "softsign": torch.nn.Softsign, } if activation_func not in valid_activation: From 6ee11d7bc00bdbc109e3b84231aa74fc1799d543 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 00:19:35 -0400 Subject: [PATCH 077/208] Fix import. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index e89f80d69..3aea7ea8e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,6 @@ import psutil from enum import Enum -from .cli_args import args +from comfy.cli_args import args class VRAMState(Enum): CPU = 0 From af9cc1fb6a88e604700d3f57638ab23b9f607e9e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 01:28:48 -0400 Subject: [PATCH 078/208] Search recursively in subfolders for embeddings. --- comfy/sd1_clip.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 7f1217c3d..b1a392736 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -191,11 +191,20 @@ def safe_load_embed_zip(embed_path): del embed return out +def expand_directory_list(directories): + dirs = set() + for x in directories: + dirs.add(x) + for root, subdir, file in os.walk(x, followlinks=True): + dirs.add(root) + return list(dirs) def load_embed(embedding_name, embedding_directory): if isinstance(embedding_directory, str): embedding_directory = [embedding_directory] + embedding_directory = expand_directory_list(embedding_directory) + valid_file = None for embed_dir in embedding_directory: embed_path = os.path.join(embed_dir, embedding_name) From f31e31ee0a3d7da01f2b1f3b68047445c16e494a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 5 May 2023 10:12:06 +0100 Subject: [PATCH 079/208] Fix box shape Match card to litegraph selection --- web/scripts/app.js | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index ada1708dc..68eeb6329 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -703,7 +703,7 @@ export class ComfyApp { ctx.globalAlpha = 0.8; ctx.beginPath(); if (shape == LiteGraph.BOX_SHAPE) - ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT); + ctx.rect(-6, -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT); else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed)) ctx.roundRect( -6, @@ -715,12 +715,11 @@ export class ComfyApp { else if (shape == LiteGraph.CARD_SHAPE) ctx.roundRect( -6, - -6 + LiteGraph.NODE_TITLE_HEIGHT, + -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, - this.round_radius * 2, - 2 - ); + [this.round_radius * 2,2,this.round_radius * 2,2] + ); else if (shape == LiteGraph.CIRCLE_SHAPE) ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2); ctx.strokeStyle = color; From de4623a8a4b8282f2d29d5a3ecbcb9840c3dc7ac Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 5 May 2023 10:34:09 +0100 Subject: [PATCH 080/208] actually fix card --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 68eeb6329..98c0e0799 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -718,7 +718,7 @@ export class ComfyApp { -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, - [this.round_radius * 2,2,this.round_radius * 2,2] + [this.round_radius * 2, this.round_radius * 2, 2, 2] ); else if (shape == LiteGraph.CIRCLE_SHAPE) ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2); From cb1551b819ecaa7d9044c13d0c8e8cfa4ff72830 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 5 May 2023 18:01:21 -0400 Subject: [PATCH 081/208] Lowvram mode for gligen and fix some lowvram issues. --- comfy/gligen.py | 27 +++++++++++++++---- comfy/ldm/modules/attention.py | 3 --- .../modules/diffusionmodules/openaimodel.py | 19 ++++++++++--- comfy/model_management.py | 3 +++ 4 files changed, 41 insertions(+), 11 deletions(-) diff --git a/comfy/gligen.py b/comfy/gligen.py index 45b674503..8c7cb432e 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -242,14 +242,28 @@ class Gligen(nn.Module): self.position_net = position_net self.key_dim = key_dim self.max_objs = 30 + self.lowvram = False def _set_position(self, boxes, masks, positive_embeddings): + if self.lowvram == True: + self.position_net.to(boxes.device) + objs = self.position_net(boxes, masks, positive_embeddings) - def func(key, x): - module = self.module_list[key] - return module(x, objs) - return func + if self.lowvram == True: + self.position_net.cpu() + def func_lowvram(key, x): + module = self.module_list[key] + module.to(x.device) + r = module(x, objs) + module.cpu() + return r + return func_lowvram + else: + def func(key, x): + module = self.module_list[key] + return module(x, objs) + return func def set_position(self, latent_image_shape, position_params, device): batch, c, h, w = latent_image_shape @@ -294,8 +308,11 @@ class Gligen(nn.Module): masks.to(device), conds.to(device)) + def set_lowvram(self, value=True): + self.lowvram = value + def cleanup(self): - pass + self.lowvram = False def get_models(self): return [self] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 5eabecd65..573f4e1c6 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -572,9 +572,6 @@ class BasicTransformerBlock(nn.Module): x += n x = self.ff(self.norm3(x)) + x - - if current_index is not None: - transformer_options["current_index"] += 1 return x diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4352b756d..5aef23f33 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -88,6 +88,19 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): x = layer(x) return x +#This is needed because accelerate makes a copy of transformer_options which breaks "current_index" +def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): + for layer in ts: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context, transformer_options) + transformer_options["current_index"] += 1 + elif isinstance(layer, Upsample): + x = layer(x, output_shape=output_shape) + else: + x = layer(x) + return x class Upsample(nn.Module): """ @@ -805,13 +818,13 @@ class UNetModel(nn.Module): h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): - h = module(h, emb, context, transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options) if control is not None and 'input' in control and len(control['input']) > 0: ctrl = control['input'].pop() if ctrl is not None: h += ctrl hs.append(h) - h = self.middle_block(h, emb, context, transformer_options) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) if control is not None and 'middle' in control and len(control['middle']) > 0: h += control['middle'].pop() @@ -828,7 +841,7 @@ class UNetModel(nn.Module): output_shape = hs[-1].shape else: output_shape = None - h = module(h, emb, context, transformer_options, output_shape) + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3aea7ea8e..7070912df 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -201,6 +201,9 @@ def load_controlnet_gpu(control_models): return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: + for m in control_models: + if hasattr(m, 'set_lowvram'): + m.set_lowvram(True) #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after return From 7a9268185cb6456890f6fe61bcc380b5cb21f614 Mon Sep 17 00:00:00 2001 From: WAS Date: Fri, 5 May 2023 18:06:54 -0700 Subject: [PATCH 082/208] Update README.md Add quick search explanation --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 3b3824714..bfa8904df 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git | Q | Toggle visibility of the queue | | H | Toggle visibility of history | | R | Refresh graph | +| Double-Click LMB | Open node quick search palette | Ctrl can also be replaced with Cmd instead for MacOS users From 8e03c789a25470a88aa05bcc73b1fe226334926b Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 6 May 2023 16:59:40 -0400 Subject: [PATCH 083/208] auto-launch cli arg --- comfy/cli_args.py | 4 ++++ main.py | 13 +++---------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 764427165..cc4709f70 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -7,6 +7,7 @@ parser.add_argument("--port", type=int, default=8188, help="Set the listen port. parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") +parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") @@ -30,3 +31,6 @@ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") args = parser.parse_args() + +if args.windows_standalone_build: + args.auto_launch = True diff --git a/main.py b/main.py index f369b82f3..eb97a2fb8 100644 --- a/main.py +++ b/main.py @@ -91,23 +91,16 @@ if __name__ == "__main__": threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() - address = args.listen - - dont_print = args.dont_print_server - - if args.output_directory: output_dir = os.path.abspath(args.output_directory) print(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) - port = args.port - if args.quick_test_for_ci: exit(0) call_on_start = None - if args.windows_standalone_build: + if args.auto_launch: def startup_server(address, port): import webbrowser webbrowser.open("http://{}:{}".format(address, port)) @@ -115,10 +108,10 @@ if __name__ == "__main__": if os.name == "nt": try: - loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start)) + loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) except KeyboardInterrupt: pass else: - loop.run_until_complete(run(server, address=address, port=port, verbose=not dont_print, call_on_start=call_on_start)) + loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) cleanup_temp() From 678f933d382641933920e84414fe36f89d1da5a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 May 2023 19:00:49 -0400 Subject: [PATCH 084/208] maximum_batch_area for xformers. Remove useless code. --- comfy/model_management.py | 7 ++++++- nodes.py | 4 +--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7070912df..b0640d674 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -312,7 +312,12 @@ def maximum_batch_area(): return 0 memory_free = get_free_memory() / (1024 * 1024) - area = ((memory_free - 1024) * 0.9) / (0.6) + if xformers_enabled(): + #TODO: this needs to be tweaked + area = 50 * memory_free + else: + #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future + area = ((memory_free - 1024) * 0.9) / (0.6) return int(max(area, 0)) def cpu_mode(): diff --git a/nodes.py b/nodes.py index c2bc36855..ca0769ba7 100644 --- a/nodes.py +++ b/nodes.py @@ -105,15 +105,13 @@ class ConditioningSetArea: CATEGORY = "conditioning" - def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, width, height, x, y, strength): c = [] for t in conditioning: n = [t[0], t[1].copy()] n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) n[1]['strength'] = strength n[1]['set_area_to_bounds'] = False - n[1]['min_sigma'] = min_sigma - n[1]['max_sigma'] = max_sigma c.append(n) return (c, ) From 6fc4917634d457c07eb8b676da4fa88e0ef4704b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 May 2023 19:58:54 -0400 Subject: [PATCH 085/208] Make maximum_batch_area take into account python2.0 attention function. More conservative xformers maximum_batch_area. --- comfy/model_management.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index b0640d674..39df8d9a7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -275,8 +275,17 @@ def xformers_enabled_vae(): return XFORMERS_ENABLED_VAE def pytorch_attention_enabled(): + global ENABLE_PYTORCH_ATTENTION return ENABLE_PYTORCH_ATTENTION +def pytorch_attention_flash_attention(): + global ENABLE_PYTORCH_ATTENTION + if ENABLE_PYTORCH_ATTENTION: + #TODO: more reliable way of checking for flash attention? + if torch.version.cuda: #pytorch flash attention only works on Nvidia + return True + return False + def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled @@ -312,9 +321,9 @@ def maximum_batch_area(): return 0 memory_free = get_free_memory() / (1024 * 1024) - if xformers_enabled(): + if xformers_enabled() or pytorch_attention_flash_attention(): #TODO: this needs to be tweaked - area = 50 * memory_free + area = 20 * memory_free else: #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future area = ((memory_free - 1024) * 0.9) / (0.6) From ae08fdb9990956f671d658aaf72a1eaf982b5b33 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 9 May 2023 03:37:36 +0900 Subject: [PATCH 086/208] Clipspace Menu and MaskEditor application. (#548) * Add clipspace feature. * feat: copy content to clipspace * feat: paste content from clipspace Extend validation to allow for validating annotated_path in addition to other parameters. Add support for annotated_filepath in folder_paths function. Generalize the '/upload/image' API to allow for uploading images to the 'input', 'temp', or 'output' directories. * rename contentClipboard -> clipspace * Do deep copy for imgs on copy to clipspace. * mask painting on clipspace * add original_imgs into clipspace * Preserve the original image when 'imgs' are modified * robust patch & refactoring folder_paths about annotated_filepath * wip * Only show the Paste menu if the ComfyApp.clipspace is not empty * clipspace feature added maskeditor feature added * instant refresh on paste force triggering 'changed' on paste action * enhance mask painting smooth drawing add brush_size +/- button * robust patch use mouseup event * robust patch again... * subfolder fix on paste logic attach subfolder if subfolder isn't empty * event listener patch add ], [ key event for brush size remove listener on close * Fix button positioning issue related to window height. Change brush size from button to slider. * clean commit * clean code * various bug fixes * paste action - prevent opening upload popup - ensure rendering after widget_value update * view api update - support annotated_filepath * maskeditor layout - prevent covering button by hidden div * remove dbg message * Add cursor functionality to display brush size * refactor: Replace brush preview feature with missionfloyd implementation * missionfloyd implementation * hiding brush preview off the canvas * change brush size on wheel event * keyup -> keydown event * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * Add support for channel-specific image data retrieval in /view API to fix alpha mask loading issue When loading an image with an alpha mask in JavaScript canvas, there is an issue where the alpha and RGB channels are premultiplied. To avoid reliance on JavaScript canvas, I added support for channel-specific image data retrieval in the "/view" API. This allows us to retrieve data for each channel separately and fix the alpha mask loading issue. The changes have been committed to the repository. * Enable brush preview for key and slider events * optimize * preview fix * robust patch * fix copy (clipspace) action imgs[0] copy -> whole imgs copy * support batch images on clipspace, maskeditor * copy/paste bug fixes for batch images enhance selector preview on clipspace menu add img_paste_mode option into clipspace menu * crash fix * print message if clipspace content cannot editable * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * make default img_paste_mode to 'selected' refactor space -> tab * save clipspace files to input/clipspace instead of temp * show "clipspace/filename.png" instead of 'filename.png [clipspace]' in LoadImage/LoadImageMask * refresh fix related to FILE_COMBO * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * adjust margin based on missionfloyd impelements * mouse event -> pointer event * pen, touch, mouse drawing patched and tested * Update web/extensions/core/maskeditor.js Co-authored-by: missionfloyd * add comment about touch event. --------- Co-authored-by: Lt.Dr.Data Co-authored-by: missionfloyd --- folder_paths.py | 9 + nodes.py | 8 +- server.py | 122 ++++++- web/extensions/core/clipspace.js | 166 +++++++++ web/extensions/core/maskeditor.js | 589 ++++++++++++++++++++++++++++++ web/scripts/app.js | 114 ++++-- web/scripts/ui.js | 1 + web/scripts/widgets.js | 14 + 8 files changed, 976 insertions(+), 47 deletions(-) create mode 100644 web/extensions/core/clipspace.js create mode 100644 web/extensions/core/maskeditor.js diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..0acd22674 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -57,6 +57,10 @@ def get_input_directory(): global input_directory return input_directory +def get_clipspace_directory(): + global input_directory + return input_directory+"/clipspace" + #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name): @@ -66,6 +70,8 @@ def get_directory_by_type(type_name): return get_temp_directory() if type_name == "input": return get_input_directory() + if type_name == "clipspace": + return get_clipspace_directory() return None @@ -81,6 +87,9 @@ def annotated_filepath(name): elif name.endswith("[temp]"): base_dir = get_temp_directory() name = name[:-7] + elif name.endswith("[clipspace]"): + base_dir = get_clipspace_directory() + name = name[:-12] else: return name, None diff --git a/nodes.py b/nodes.py index ca0769ba7..1d9a5c872 100644 --- a/nodes.py +++ b/nodes.py @@ -973,8 +973,9 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), )}, + {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, )}, } CATEGORY = "image" @@ -1014,9 +1015,10 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() + input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(os.listdir(input_dir)), ), - "channel": (s._color_channels, ),} + {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, ), + "channel": (s._color_channels, ), } } CATEGORY = "mask" diff --git a/server.py b/server.py index 1c5c17916..48644d83a 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,9 @@ import execution import uuid import json import glob +from PIL import Image +from io import BytesIO + try: import aiohttp from aiohttp import web @@ -110,19 +113,26 @@ class PromptServer(): files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) + def get_dir_by_type(dir_type): + if dir_type is None: + type_dir = folder_paths.get_input_directory() + elif dir_type == "input": + type_dir = folder_paths.get_input_directory() + elif dir_type == "clipspace": + type_dir = folder_paths.get_clipspace_directory() + elif dir_type == "temp": + type_dir = folder_paths.get_temp_directory() + elif dir_type == "output": + type_dir = folder_paths.get_output_directory() + + return type_dir + @routes.post("/upload/image") async def upload_image(request): post = await request.post() image = post.get("image") - if post.get("type") is None: - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "input": - upload_dir = folder_paths.get_input_directory() - elif post.get("type") == "temp": - upload_dir = folder_paths.get_temp_directory() - elif post.get("type") == "output": - upload_dir = folder_paths.get_output_directory() + upload_dir = get_dir_by_type(post.get("type")) if not os.path.exists(upload_dir): os.makedirs(upload_dir) @@ -147,12 +157,62 @@ class PromptServer(): else: return web.Response(status=400) + @routes.post("/upload/mask") + async def upload_mask(request): + post = await request.post() + image = post.get("image") + original_image = post.get("original_image") + + upload_dir = get_dir_by_type(post.get("type")) + + if not os.path.exists(upload_dir): + os.makedirs(upload_dir) + + if image and image.file: + filename = image.filename + if not filename: + return web.Response(status=400) + + split = os.path.splitext(filename) + i = 1 + while os.path.exists(os.path.join(upload_dir, filename)): + filename = f"{split[0]} ({i}){split[1]}" + i += 1 + + filepath = os.path.join(upload_dir, filename) + + original_pil = Image.open(original_image.file).convert('RGBA') + mask_pil = Image.open(image.file).convert('RGBA') + + # alpha copy + new_alpha = mask_pil.getchannel('A') + original_pil.putalpha(new_alpha) + + original_pil.save(filepath) + + return web.json_response({"name": filename}) + else: + return web.Response(status=400) + @routes.get("/view") async def view_image(request): if "filename" in request.rel_url.query: - type = request.rel_url.query.get("type", "output") - output_dir = folder_paths.get_directory_by_type(type) + filename = request.rel_url.query["filename"] + filename,output_dir = folder_paths.annotated_filepath(filename) + + if request.rel_url.query.get("type", "input") and filename.startswith("clipspace/"): + output_dir = folder_paths.get_clipspace_directory() + filename = filename[10:] + + # validation for security: prevent accessing arbitrary path + if filename[0] == '/' or '..' in filename: + return web.Response(status=400) + + if output_dir is None: + type = request.rel_url.query.get("type", "output") + output_dir = folder_paths.get_directory_by_type(type) + if output_dir is None: return web.Response(status=400) @@ -162,13 +222,49 @@ class PromptServer(): return web.Response(status=403) output_dir = full_output_dir - filename = request.rel_url.query["filename"] filename = os.path.basename(filename) file = os.path.join(output_dir, filename) if os.path.isfile(file): - return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) - + if 'channel' not in request.rel_url.query: + channel = 'rgba' + else: + channel = request.rel_url.query["channel"] + + if channel == 'rgb': + with Image.open(file) as img: + if img.mode == "RGBA": + r, g, b, a = img.split() + new_img = Image.merge('RGB', (r, g, b)) + else: + new_img = img.convert("RGB") + + buffer = BytesIO() + new_img.save(buffer, format='PNG') + buffer.seek(0) + + return web.Response(body=buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + + elif channel == 'a': + with Image.open(file) as img: + if img.mode == "RGBA": + _, _, _, a = img.split() + else: + a = Image.new('L', img.size, 255) + + # alpha img + alpha_img = Image.new('RGBA', img.size) + alpha_img.putalpha(a) + alpha_buffer = BytesIO() + alpha_img.save(alpha_buffer, format='PNG') + alpha_buffer.seek(0) + + return web.Response(body=alpha_buffer.read(), content_type='image/png', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + else: + return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""}) + return web.Response(status=404) @routes.get("/prompt") diff --git a/web/extensions/core/clipspace.js b/web/extensions/core/clipspace.js new file mode 100644 index 000000000..adb5877ea --- /dev/null +++ b/web/extensions/core/clipspace.js @@ -0,0 +1,166 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; + +export class ClipspaceDialog extends ComfyDialog { + static items = []; + static instance = null; + + static registerButton(name, contextPredicate, callback) { + const item = + $el("button", { + type: "button", + textContent: name, + contextPredicate: contextPredicate, + onclick: callback + }) + + ClipspaceDialog.items.push(item); + } + + static invalidatePreview() { + if(ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0) { + const img_preview = document.getElementById("clipspace_preview"); + if(img_preview) { + img_preview.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + img_preview.style.maxHeight = "100%"; + img_preview.style.maxWidth = "100%"; + } + } + } + + static invalidate() { + if(ClipspaceDialog.instance) { + const self = ClipspaceDialog.instance; + // allow reconstruct controls when copying from non-image to image content. + const children = $el("div.comfy-modal-content", [ self.createImgSettings(), ...self.createButtons() ]); + + if(self.element) { + // update + self.element.removeChild(self.element.firstChild); + self.element.appendChild(children); + } + else { + // new + self.element = $el("div.comfy-modal", { parent: document.body }, [children,]); + } + + if(self.element.children[0].children.length <= 1) { + self.element.children[0].appendChild($el("p", {}, ["Unable to find the features to edit content of a format stored in the current Clipspace."])); + } + + ClipspaceDialog.invalidatePreview(); + } + } + + constructor() { + super(); + } + + createButtons(self) { + const buttons = []; + + for(let idx in ClipspaceDialog.items) { + const item = ClipspaceDialog.items[idx]; + if(!item.contextPredicate || item.contextPredicate()) + buttons.push(ClipspaceDialog.items[idx]); + } + + buttons.push( + $el("button", { + type: "button", + textContent: "Close", + onclick: () => { this.close(); } + }) + ); + + return buttons; + } + + createImgSettings() { + if(ComfyApp.clipspace.imgs) { + const combo_items = []; + const imgs = ComfyApp.clipspace.imgs; + + for(let i=0; i < imgs.length; i++) { + combo_items.push($el("option", {value:i}, [`${i}`])); + } + + const combo1 = $el("select", + {id:"clipspace_img_selector", onchange:(event) => { + ComfyApp.clipspace['selectedIndex'] = event.target.selectedIndex; + ClipspaceDialog.invalidatePreview(); + } }, combo_items); + + const row1 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Select Image"])]), + $el("td", {}, [combo1]) + ]); + + + const combo2 = $el("select", + {id:"clipspace_img_paste_mode", onchange:(event) => { + ComfyApp.clipspace['img_paste_mode'] = event.target.value; + } }, + [ + $el("option", {value:'selected'}, 'selected'), + $el("option", {value:'all'}, 'all') + ]); + combo2.value = ComfyApp.clipspace['img_paste_mode']; + + const row2 = + $el("tr", {}, + [ + $el("td", {}, [$el("font", {color:"white"}, ["Paste Mode"])]), + $el("td", {}, [combo2]) + ]); + + const td = $el("td", {align:'center', width:'100px', height:'100px', colSpan:'2'}, + [ $el("img",{id:"clipspace_preview", ondragstart:() => false},[]) ]); + + const row3 = + $el("tr", {}, [td]); + + return $el("table", {}, [row1, row2, row3]); + } + else { + return []; + } + } + + createImgPreview() { + if(ComfyApp.clipspace.imgs) { + return $el("img",{id:"clipspace_preview", ondragstart:() => false}); + } + else + return []; + } + + show() { + const img_preview = document.getElementById("clipspace_preview"); + ClipspaceDialog.invalidate(); + + this.element.style.display = "block"; + } +} + +app.registerExtension({ + name: "Comfy.Clipspace", + init(app) { + app.openClipspace = + function () { + if(!ClipspaceDialog.instance) { + ClipspaceDialog.instance = new ClipspaceDialog(app); + ComfyApp.clipspace_invalidate_handler = ClipspaceDialog.invalidate; + } + + if(ComfyApp.clipspace) { + ClipspaceDialog.instance.show(); + } + else + app.ui.dialog.show("Clipspace is Empty!"); + }; + } +}); \ No newline at end of file diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js new file mode 100644 index 000000000..c55f841b6 --- /dev/null +++ b/web/extensions/core/maskeditor.js @@ -0,0 +1,589 @@ +import { app } from "/scripts/app.js"; +import { ComfyDialog, $el } from "/scripts/ui.js"; +import { ComfyApp } from "/scripts/app.js"; +import { ClipspaceDialog } from "/extensions/core/clipspace.js"; + +// Helper function to convert a data URL to a Blob object +function dataURLToBlob(dataURL) { + const parts = dataURL.split(';base64,'); + const contentType = parts[0].split(':')[1]; + const byteString = atob(parts[1]); + const arrayBuffer = new ArrayBuffer(byteString.length); + const uint8Array = new Uint8Array(arrayBuffer); + for (let i = 0; i < byteString.length; i++) { + uint8Array[i] = byteString.charCodeAt(i); + } + return new Blob([arrayBuffer], { type: contentType }); +} + +function loadedImageToBlob(image) { + const canvas = document.createElement('canvas'); + + canvas.width = image.width; + canvas.height = image.height; + + const ctx = canvas.getContext('2d'); + + ctx.drawImage(image, 0, 0); + + const dataURL = canvas.toDataURL('image/png', 1); + const blob = dataURLToBlob(dataURL); + + return blob; +} + +async function uploadMask(filepath, formData) { + await fetch('/upload/mask', { + method: 'POST', + body: formData + }).then(response => {}).catch(error => { + console.error('Error:', error); + }); + + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = `view?filename=${filepath.filename}&type=${filepath.type}`; + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; + + ClipspaceDialog.invalidatePreview(); +} + +function prepareRGB(image, backupCanvas, backupCtx) { + // paste mask data into alpha channel + backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height); + const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); +} + +class MaskEditorDialog extends ComfyDialog { + static instance = null; + constructor() { + super(); + this.element = $el("div.comfy-modal", { parent: document.body }, + [ $el("div.comfy-modal-content", + [...this.createButtons()]), + ]); + MaskEditorDialog.instance = this; + } + + createButtons() { + return []; + } + + clearMask(self) { + } + + createButton(name, callback) { + var button = document.createElement("button"); + button.innerText = name; + button.addEventListener("click", callback); + return button; + } + createLeftButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "left"; + button.style.marginRight = "4px"; + return button; + } + createRightButton(name, callback) { + var button = this.createButton(name, callback); + button.style.cssFloat = "right"; + button.style.marginLeft = "4px"; + return button; + } + createLeftSlider(self, name, callback) { + const divElement = document.createElement('div'); + divElement.id = "maskeditor-slider"; + divElement.style.cssFloat = "left"; + divElement.style.fontFamily = "sans-serif"; + divElement.style.marginRight = "4px"; + divElement.style.color = "var(--input-text)"; + divElement.style.backgroundColor = "var(--comfy-input-bg)"; + divElement.style.borderRadius = "8px"; + divElement.style.borderColor = "var(--border-color)"; + divElement.style.borderStyle = "solid"; + divElement.style.fontSize = "15px"; + divElement.style.height = "21px"; + divElement.style.padding = "1px 6px"; + divElement.style.display = "flex"; + divElement.style.position = "relative"; + divElement.style.top = "2px"; + self.brush_slider_input = document.createElement('input'); + self.brush_slider_input.setAttribute('type', 'range'); + self.brush_slider_input.setAttribute('min', '1'); + self.brush_slider_input.setAttribute('max', '100'); + self.brush_slider_input.setAttribute('value', '10'); + const labelElement = document.createElement("label"); + labelElement.textContent = name; + + divElement.appendChild(labelElement); + divElement.appendChild(self.brush_slider_input); + + self.brush_slider_input.addEventListener("change", callback); + + return divElement; + } + + setlayout(imgCanvas, maskCanvas) { + const self = this; + + // If it is specified as relative, using it only as a hidden placeholder for padding is recommended + // to prevent anomalies where it exceeds a certain size and goes outside of the window. + var placeholder = document.createElement("div"); + placeholder.style.position = "relative"; + placeholder.style.height = "50px"; + + var bottom_panel = document.createElement("div"); + bottom_panel.style.position = "absolute"; + bottom_panel.style.bottom = "0px"; + bottom_panel.style.left = "20px"; + bottom_panel.style.right = "20px"; + bottom_panel.style.height = "50px"; + + var brush = document.createElement("div"); + brush.id = "brush"; + brush.style.backgroundColor = "transparent"; + brush.style.outline = "1px dashed black"; + brush.style.boxShadow = "0 0 0 1px white"; + brush.style.borderRadius = "50%"; + brush.style.MozBorderRadius = "50%"; + brush.style.WebkitBorderRadius = "50%"; + brush.style.position = "absolute"; + brush.style.zIndex = 100; + brush.style.pointerEvents = "none"; + this.brush = brush; + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + document.body.appendChild(brush); + + var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { + self.brush_size = event.target.value; + self.updateBrushPreview(self, null, null); + }); + var clearButton = this.createLeftButton("Clear", + () => { + self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); + self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height); + }); + var cancelButton = this.createRightButton("Cancel", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.close(); + }); + var saveButton = this.createRightButton("Save", () => { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); + self.save(); + }); + + this.element.appendChild(imgCanvas); + this.element.appendChild(maskCanvas); + this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button + this.element.appendChild(bottom_panel); + + bottom_panel.appendChild(clearButton); + bottom_panel.appendChild(saveButton); + bottom_panel.appendChild(cancelButton); + bottom_panel.appendChild(brush_size_slider); + + this.element.style.display = "block"; + imgCanvas.style.position = "relative"; + imgCanvas.style.top = "200"; + imgCanvas.style.left = "0"; + + maskCanvas.style.position = "absolute"; + } + + show() { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); + + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; + + this.setlayout(imgCanvas, maskCanvas); + + // prepare content + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); + + this.setImages(imgCanvas, backupCanvas); + this.setEventHandler(maskCanvas); + } + + setImages(imgCanvas, backupCanvas) { + const imgCtx = imgCanvas.getContext('2d'); + const backupCtx = backupCanvas.getContext('2d'); + const maskCtx = this.maskCtx; + const maskCanvas = this.maskCanvas; + + // image load + const orig_image = new Image(); + window.addEventListener("resize", () => { + // repositioning + imgCanvas.width = window.innerWidth - 250; + imgCanvas.height = window.innerHeight - 200; + + // redraw image + let drawWidth = orig_image.width; + let drawHeight = orig_image.height; + if (orig_image.width > imgCanvas.width) { + drawWidth = imgCanvas.width; + drawHeight = (drawWidth / orig_image.width) * orig_image.height; + } + + if (drawHeight > imgCanvas.height) { + drawHeight = imgCanvas.height; + drawWidth = (drawHeight / orig_image.height) * orig_image.width; + } + + imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); + + // update mask + backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); + maskCanvas.width = drawWidth; + maskCanvas.height = drawHeight; + maskCanvas.style.top = imgCanvas.offsetTop + "px"; + maskCanvas.style.left = imgCanvas.offsetLeft + "px"; + maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); + }); + + const filepath = ComfyApp.clipspace.images; + + const touched_image = new Image(); + + touched_image.onload = function() { + backupCanvas.width = touched_image.width; + backupCanvas.height = touched_image.height; + + prepareRGB(touched_image, backupCanvas, backupCtx); + }; + + const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) + alpha_url.searchParams.delete('channel'); + alpha_url.searchParams.set('channel', 'a'); + touched_image.src = alpha_url; + + // original image load + orig_image.onload = function() { + window.dispatchEvent(new Event('resize')); + }; + + const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); + rgb_url.searchParams.delete('channel'); + rgb_url.searchParams.set('channel', 'rgb'); + orig_image.src = rgb_url; + this.image = orig_image; + }g + + + setEventHandler(maskCanvas) { + maskCanvas.addEventListener("contextmenu", (event) => { + event.preventDefault(); + }); + + const self = this; + maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); + maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); + document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); + maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); + maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); + document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); + } + + brush_size = 10; + drawing_mode = false; + lastx = -1; + lasty = -1; + lasttime = 0; + + static handleKeyDown(event) { + const self = MaskEditorDialog.instance; + if (event.key === ']') { + self.brush_size = Math.min(self.brush_size+2, 100); + } else if (event.key === '[') { + self.brush_size = Math.max(self.brush_size-2, 1); + } + + self.updateBrushPreview(self); + } + + static handlePointerUp(event) { + event.preventDefault(); + MaskEditorDialog.instance.drawing_mode = false; + } + + updateBrushPreview(self) { + const brush = self.brush; + + var centerX = self.cursorX; + var centerY = self.cursorY; + + brush.style.width = self.brush_size * 2 + "px"; + brush.style.height = self.brush_size * 2 + "px"; + brush.style.left = (centerX - self.brush_size) + "px"; + brush.style.top = (centerY - self.brush_size) + "px"; + } + + handleWheelEvent(self, event) { + if(event.deltaY < 0) + self.brush_size = Math.min(self.brush_size+2, 100); + else + self.brush_size = Math.max(self.brush_size-2, 1); + + self.brush_slider_input.value = self.brush_size; + + self.updateBrushPreview(self); + } + + draw_move(self, event) { + event.preventDefault(); + + this.cursorX = event.pageX; + this.cursorY = event.pageY; + + self.updateBrushPreview(self); + + if (event instanceof TouchEvent || event.buttons == 1) { + var diff = performance.now() - self.lasttime; + + const maskRect = self.maskCanvas.getBoundingClientRect(); + + var x = event.offsetX; + var y = event.offsetY + + if(event.offsetX == null) { + x = event.targetTouches[0].clientX - maskRect.left; + } + + if(event.offsetY == null) { + y = event.targetTouches[0].clientY - maskRect.top; + } + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(event instanceof TouchEvent && diff < 20){ + // The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents. + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !this.drawing_mode) + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + else if(event instanceof TouchEvent && diff < 20){ + brush_size *= this.last_pressure; + } + else { + brush_size = this.brush_size; + } + + if(diff > 20 && !drawing_mode) // cannot tracking drawing_mode for touch event + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + }); + else + requestAnimationFrame(() => { + self.maskCtx.beginPath(); + self.maskCtx.globalCompositeOperation = "destination-out"; + + var dx = x - self.lastx; + var dy = y - self.lasty; + + var distance = Math.sqrt(dx * dx + dy * dy); + var directionX = dx / distance; + var directionY = dy / distance; + + for (var i = 0; i < distance; i+=5) { + var px = self.lastx + (directionX * i); + var py = self.lasty + (directionY * i); + self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + } + self.lastx = x; + self.lasty = y; + }); + + self.lasttime = performance.now(); + } + } + + handlePointerDown(self, event) { + var brush_size = this.brush_size; + if(event instanceof PointerEvent && event.pointerType == 'pen') { + brush_size *= event.pressure; + this.last_pressure = event.pressure; + } + + if ([0, 2, 5].includes(event.button)) { + self.drawing_mode = true; + + event.preventDefault(); + const maskRect = self.maskCanvas.getBoundingClientRect(); + const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; + const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + + self.maskCtx.beginPath(); + if (event.button == 0) { + self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.globalCompositeOperation = "source-over"; + } else { + self.maskCtx.globalCompositeOperation = "destination-out"; + } + self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); + self.maskCtx.fill(); + self.lastx = x; + self.lasty = y; + self.lasttime = performance.now(); + } + } + + save() { + const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); + + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + backupCtx.drawImage(this.maskCanvas, + 0, 0, this.maskCanvas.width, this.maskCanvas.height, + 0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // paste mask data into alpha channel + const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height); + + // refine mask image + for (let i = 0; i < backupData.data.length; i += 4) { + if(backupData.data[i+3] == 255) + backupData.data[i+3] = 0; + else + backupData.data[i+3] = 255; + + backupData.data[i] = 0; + backupData.data[i+1] = 0; + backupData.data[i+2] = 0; + } + + backupCtx.globalCompositeOperation = 'source-over'; + backupCtx.putImageData(backupData, 0, 0); + + const formData = new FormData(); + const filename = "clipspace-mask-" + performance.now() + ".png"; + + const item = + { + "filename": filename, + "subfolder": "", + "type": "clipspace", + }; + + if(ComfyApp.clipspace.images) + ComfyApp.clipspace.images[0] = item; + + if(ComfyApp.clipspace.widgets) { + const index = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); + + if(index >= 0) + ComfyApp.clipspace.widgets[index].value = item; + } + + const dataURL = this.backupCanvas.toDataURL(); + const blob = dataURLToBlob(dataURL); + + const original_blob = loadedImageToBlob(this.image); + + formData.append('image', blob, filename); + formData.append('original_image', original_blob); + formData.append('type', "clipspace"); + + uploadMask(item, formData); + this.close(); + } +} + +app.registerExtension({ + name: "Comfy.MaskEditor", + init(app) { + const callback = + function () { + let dlg = new MaskEditorDialog(app); + dlg.show(); + }; + + const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 + ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); + } +}); \ No newline at end of file diff --git a/web/scripts/app.js b/web/scripts/app.js index 245605484..f4f7272db 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -25,6 +25,7 @@ export class ComfyApp { * @type {serialized node object} */ static clipspace = null; + static clipspace_invalidate_handler = null; constructor() { this.ui = new ComfyUI(this); @@ -143,22 +144,34 @@ export class ComfyApp { callback: (obj) => { var widgets = null; if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); + widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); } - let img = new Image(); var imgs = undefined; + var orig_imgs = undefined; if(this.imgs != undefined) { - img.src = this.imgs[0].src; - imgs = [img]; + imgs = []; + orig_imgs = []; + + for (let i = 0; i < this.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = this.imgs[i].src; + orig_imgs[i] = imgs[i]; + } } ComfyApp.clipspace = { 'widgets': widgets, 'imgs': imgs, - 'original_imgs': imgs, - 'images': this.images + 'original_imgs': orig_imgs, + 'images': this.images, + 'selectedIndex': 0, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action }; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } } }); @@ -167,48 +180,82 @@ export class ComfyApp { { content: "Paste (Clipspace)", callback: () => { - if(ComfyApp.clipspace != null) { - if(ComfyApp.clipspace.widgets != null && this.widgets != null) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop) { - prop.callback(value); - } - }); - } - + if(ComfyApp.clipspace) { // image paste - if(ComfyApp.clipspace.imgs != undefined && this.imgs != undefined && this.widgets != null) { + if(ComfyApp.clipspace.imgs && this.imgs) { var filename = ""; if(this.images && ComfyApp.clipspace.images) { - this.images = ComfyApp.clipspace.images; + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + + } + else + app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images; } - if(ComfyApp.clipspace.images != undefined) { - const clip_image = ComfyApp.clipspace.images[0]; + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + this.imgs = [img]; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); if(index_in_clip >= 0) { - filename = `${ComfyApp.clipspace.widgets[index_in_clip].value}`; + const item = ComfyApp.clipspace.widgets[index_in_clip].value; + if(item.type) + filename = `${item.filename} [${item.type}]`; + else + filename = item.filename; } } - const index = this.widgets.findIndex(obj => obj.name === 'image'); - if(index >= 0 && filename != "" && ComfyApp.clipspace.imgs != undefined) { - this.imgs = ComfyApp.clipspace.imgs; + // for Load Image node. + if(this.widgets) { + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0 && filename != "") { + const postfix = ' [clipspace]'; + if(filename.endsWith(postfix) && this.widgets[index].options.base_dir == 'input') { + filename = "clipspace/" + filename.slice(0, filename.indexOf(postfix)); + } - this.widgets[index].value = filename; - if(this.widgets_values != undefined) { - this.widgets_values[index] = filename; + this.widgets[index].value = filename; + if(this.widgets_values != undefined) { + this.widgets_values[index] = filename; + } } } } - this.trigger('changed'); + + // ensure render after update widget_value + if(ComfyApp.clipspace.widgets && this.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.callback(value); + } + }); + } } + + app.graph.setDirtyCanvas(true); } } ); @@ -1275,12 +1322,17 @@ export class ComfyApp { for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] - if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { - widget.options.values = def["input"]["required"][widget.name][0]; + if(def["input"]["required"][widget.name][0] == "FILE_COMBO") { + console.log(widget.options.values = def["input"]["required"][widget.name][1].files); + widget.options.values = def["input"]["required"][widget.name][1].files; + } + else + widget.options.values = def["input"]["required"][widget.name][0]; if(!widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; + widget.callback(widget.value); } } } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 5accc9d86..77517aec1 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -581,6 +581,7 @@ export class ComfyUI { }), $el("button", { id: "comfy-load-button", textContent: "Load", onclick: () => fileInput.click() }), $el("button", { id: "comfy-refresh-button", textContent: "Refresh", onclick: () => app.refreshComboInNodes() }), + $el("button", { id: "comfy-clipspace-button", textContent: "Clipspace", onclick: () => app.openClipspace() }), $el("button", { id: "comfy-clear-button", textContent: "Clear", onclick: () => { if (!confirmClear.value || confirm("Clear workflow?")) { app.clean(); diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index cd471bc93..4a72246db 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -256,6 +256,20 @@ export const ComfyWidgets = { } return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; }, + FILE_COMBO(node, inputName, inputData) { + const base_dir = inputData[1].base_dir; + let defaultValue = inputData[1].files[0]; + + const files = [] + for(let i in inputData[1].files) { + files[i] = inputData[1].files[i]; + const postfix = ' [clipspace]'; + if(base_dir == 'input' && files[i].endsWith(postfix)) + files[i] = "clipspace/" + files[i].slice(0, files[i].indexOf(postfix)); + } + + return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { base_dir:base_dir, values: files }) }; + }, IMAGEUPLOAD(node, inputName, inputData, app) { const imageWidget = node.widgets.find((w) => w.name === "image"); let uploadWidget; From 850daf0416367ba39d10195540f5b735952f0ee7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 14:13:06 -0400 Subject: [PATCH 087/208] Masked editor changes. Add a way to upload to subfolders. Clean up code. Fix some issues. --- folder_paths.py | 9 ---- nodes.py | 8 ++-- server.py | 74 ++++++++++++------------------- web/extensions/core/maskeditor.js | 9 ++-- web/scripts/app.js | 66 ++++++++------------------- web/scripts/widgets.js | 52 +++++++++++++++------- 6 files changed, 93 insertions(+), 125 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 0acd22674..e5b89492c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -57,10 +57,6 @@ def get_input_directory(): global input_directory return input_directory -def get_clipspace_directory(): - global input_directory - return input_directory+"/clipspace" - #NOTE: used in http server so don't put folders that should not be accessed remotely def get_directory_by_type(type_name): @@ -70,8 +66,6 @@ def get_directory_by_type(type_name): return get_temp_directory() if type_name == "input": return get_input_directory() - if type_name == "clipspace": - return get_clipspace_directory() return None @@ -87,9 +81,6 @@ def annotated_filepath(name): elif name.endswith("[temp]"): base_dir = get_temp_directory() name = name[:-7] - elif name.endswith("[clipspace]"): - base_dir = get_clipspace_directory() - name = name[:-12] else: return name, None diff --git a/nodes.py b/nodes.py index 1d9a5c872..699e60ae8 100644 --- a/nodes.py +++ b/nodes.py @@ -973,9 +973,9 @@ class LoadImage: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, )}, + {"image": (sorted(files), )}, } CATEGORY = "image" @@ -1015,9 +1015,9 @@ class LoadImageMask: @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() - input_dir = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": ("FILE_COMBO", {"base_dir": "input", "files": sorted(input_dir)}, ), + {"image": (sorted(files), ), "channel": (s._color_channels, ), } } diff --git a/server.py b/server.py index 48644d83a..3d02b2f7a 100644 --- a/server.py +++ b/server.py @@ -118,8 +118,6 @@ class PromptServer(): type_dir = folder_paths.get_input_directory() elif dir_type == "input": type_dir = folder_paths.get_input_directory() - elif dir_type == "clipspace": - type_dir = folder_paths.get_clipspace_directory() elif dir_type == "temp": type_dir = folder_paths.get_temp_directory() elif dir_type == "output": @@ -127,73 +125,63 @@ class PromptServer(): return type_dir - @routes.post("/upload/image") - async def upload_image(request): - post = await request.post() + def image_upload(post, image_save_function=None): image = post.get("image") - upload_dir = get_dir_by_type(post.get("type")) - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) + image_upload_type = post.get("type") + upload_dir = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename if not filename: return web.Response(status=400) + subfolder = post.get("subfolder", "") + full_output_folder = os.path.join(upload_dir, os.path.normpath(subfolder)) + + if os.path.commonpath((upload_dir, os.path.abspath(full_output_folder))) != upload_dir: + return web.Response(status=400) + + if not os.path.exists(full_output_folder): + os.makedirs(full_output_folder) + split = os.path.splitext(filename) + filepath = os.path.join(full_output_folder, filename) + i = 1 - while os.path.exists(os.path.join(upload_dir, filename)): + while os.path.exists(filepath): filename = f"{split[0]} ({i}){split[1]}" i += 1 - filepath = os.path.join(upload_dir, filename) + if image_save_function is not None: + image_save_function(image, post, filepath) + else: + with open(filepath, "wb") as f: + f.write(image.file.read()) - with open(filepath, "wb") as f: - f.write(image.file.read()) - - return web.json_response({"name" : filename}) + return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) else: return web.Response(status=400) + @routes.post("/upload/image") + async def upload_image(request): + post = await request.post() + return image_upload(post) + @routes.post("/upload/mask") async def upload_mask(request): post = await request.post() - image = post.get("image") - original_image = post.get("original_image") - upload_dir = get_dir_by_type(post.get("type")) - - if not os.path.exists(upload_dir): - os.makedirs(upload_dir) - - if image and image.file: - filename = image.filename - if not filename: - return web.Response(status=400) - - split = os.path.splitext(filename) - i = 1 - while os.path.exists(os.path.join(upload_dir, filename)): - filename = f"{split[0]} ({i}){split[1]}" - i += 1 - - filepath = os.path.join(upload_dir, filename) - - original_pil = Image.open(original_image.file).convert('RGBA') + def image_save_function(image, post, filepath): + original_pil = Image.open(post.get("original_image").file).convert('RGBA') mask_pil = Image.open(image.file).convert('RGBA') # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath) - return web.json_response({"name": filename}) - else: - return web.Response(status=400) - + return image_upload(post, image_save_function) @routes.get("/view") async def view_image(request): @@ -201,10 +189,6 @@ class PromptServer(): filename = request.rel_url.query["filename"] filename,output_dir = folder_paths.annotated_filepath(filename) - if request.rel_url.query.get("type", "input") and filename.startswith("clipspace/"): - output_dir = folder_paths.get_clipspace_directory() - filename = filename[10:] - # validation for security: prevent accessing arbitrary path if filename[0] == '/' or '..' in filename: return web.Response(status=400) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index c55f841b6..0ffa50c69 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -41,7 +41,7 @@ async function uploadMask(filepath, formData) { }); ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); - ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = `view?filename=${filepath.filename}&type=${filepath.type}`; + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString(); if(ComfyApp.clipspace.images) ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; @@ -546,8 +546,8 @@ class MaskEditorDialog extends ComfyDialog { const item = { "filename": filename, - "subfolder": "", - "type": "clipspace", + "subfolder": "clipspace", + "type": "input", }; if(ComfyApp.clipspace.images) @@ -567,7 +567,8 @@ class MaskEditorDialog extends ComfyDialog { formData.append('image', blob, filename); formData.append('original_image', original_blob); - formData.append('type', "clipspace"); + formData.append('type', "input"); + formData.append('subfolder', "clipspace"); uploadMask(item, formData); this.close(); diff --git a/web/scripts/app.js b/web/scripts/app.js index f4f7272db..c6c29e45b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -183,7 +183,6 @@ export class ComfyApp { if(ComfyApp.clipspace) { // image paste if(ComfyApp.clipspace.imgs && this.imgs) { - var filename = ""; if(this.images && ComfyApp.clipspace.images) { if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; @@ -209,49 +208,25 @@ export class ComfyApp { } } } - - if(ComfyApp.clipspace.images) { - const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]; - if(clip_image.subfolder != '') - filename = `${clip_image.subfolder}/`; - filename += `${clip_image.filename} [${clip_image.type}]`; - } - else if(ComfyApp.clipspace.widgets) { - const index_in_clip = ComfyApp.clipspace.widgets.findIndex(obj => obj.name === 'image'); - if(index_in_clip >= 0) { - const item = ComfyApp.clipspace.widgets[index_in_clip].value; - if(item.type) - filename = `${item.filename} [${item.type}]`; - else - filename = item.filename; - } - } - - // for Load Image node. - if(this.widgets) { - const index = this.widgets.findIndex(obj => obj.name === 'image'); - if(index >= 0 && filename != "") { - const postfix = ' [clipspace]'; - if(filename.endsWith(postfix) && this.widgets[index].options.base_dir == 'input') { - filename = "clipspace/" + filename.slice(0, filename.indexOf(postfix)); - } - - this.widgets[index].value = filename; - if(this.widgets_values != undefined) { - this.widgets_values[index] = filename; - } - } - } } - // ensure render after update widget_value - if(ComfyApp.clipspace.widgets && this.widgets) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { - prop.callback(value); - } - }); + if(this.widgets) { + if(ComfyApp.clipspace.images) { + const clip_image = ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]; + const index = this.widgets.findIndex(obj => obj.name === 'image'); + if(index >= 0) { + this.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } } } @@ -1323,12 +1298,7 @@ export class ComfyApp { for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { - if(def["input"]["required"][widget.name][0] == "FILE_COMBO") { - console.log(widget.options.values = def["input"]["required"][widget.name][1].files); - widget.options.values = def["input"]["required"][widget.name][1].files; - } - else - widget.options.values = def["input"]["required"][widget.name][0]; + widget.options.values = def["input"]["required"][widget.name][0]; if(!widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 4a72246db..65edc0392 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -256,20 +256,6 @@ export const ComfyWidgets = { } return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; }, - FILE_COMBO(node, inputName, inputData) { - const base_dir = inputData[1].base_dir; - let defaultValue = inputData[1].files[0]; - - const files = [] - for(let i in inputData[1].files) { - files[i] = inputData[1].files[i]; - const postfix = ' [clipspace]'; - if(base_dir == 'input' && files[i].endsWith(postfix)) - files[i] = "clipspace/" + files[i].slice(0, files[i].indexOf(postfix)); - } - - return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { base_dir:base_dir, values: files }) }; - }, IMAGEUPLOAD(node, inputName, inputData, app) { const imageWidget = node.widgets.find((w) => w.name === "image"); let uploadWidget; @@ -280,10 +266,46 @@ export const ComfyWidgets = { node.imgs = [img]; app.graph.setDirtyCanvas(true); }; - img.src = `/view?filename=${name}&type=input`; + let folder_separator = name.lastIndexOf("/"); + let subfolder = ""; + if (folder_separator > -1) { + subfolder = name.substring(0, folder_separator); + name = name.substring(folder_separator + 1); + } + img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`; node.setSizeForImage?.(); } + var default_value = imageWidget.value; + Object.defineProperty(imageWidget, "value", { + set : function(value) { + this._real_value = value; + }, + + get : function() { + let value = ""; + if (this._real_value) { + value = this._real_value; + } else { + return default_value; + } + + if (value.filename) { + let real_value = value; + value = ""; + if (real_value.subfolder) { + value = real_value.subfolder + "/"; + } + + value += real_value.filename; + + if(real_value.type && real_value.type !== "input") + value += ` [${real_value.type}]`; + } + return value; + } + }); + // Add our own callback to the combo widget to render an image when it changes const cb = node.callback; imageWidget.callback = function () { From a7ebd5aa1278a63f2f14852dce59b43834f6b9d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 15:52:33 -0400 Subject: [PATCH 088/208] Fix masked editor issue with firefox on windows. --- web/extensions/core/maskeditor.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 0ffa50c69..552059e86 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -368,7 +368,7 @@ class MaskEditorDialog extends ComfyDialog { self.updateBrushPreview(self); - if (event instanceof TouchEvent || event.buttons == 1) { + if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) { var diff = performance.now() - self.lasttime; const maskRect = self.maskCanvas.getBoundingClientRect(); @@ -389,7 +389,7 @@ class MaskEditorDialog extends ComfyDialog { brush_size *= event.pressure; this.last_pressure = event.pressure; } - else if(event instanceof TouchEvent && diff < 20){ + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ // The firing interval of PointerEvents in Pen is unreliable, so it is supplemented by TouchEvents. brush_size *= this.last_pressure; } @@ -442,7 +442,7 @@ class MaskEditorDialog extends ComfyDialog { brush_size *= event.pressure; this.last_pressure = event.pressure; } - else if(event instanceof TouchEvent && diff < 20){ + else if(window.TouchEvent && event instanceof TouchEvent && diff < 20){ brush_size *= this.last_pressure; } else { From a8705dbfe20ba86eaac5a669c61453775c796441 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 17:05:28 -0400 Subject: [PATCH 089/208] Speed up the mask save and fix refresh replacing copied image. --- server.py | 2 +- web/scripts/app.js | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 3d02b2f7a..c1226f304 100644 --- a/server.py +++ b/server.py @@ -179,7 +179,7 @@ class PromptServer(): # alpha copy new_alpha = mask_pil.getchannel('A') original_pil.putalpha(new_alpha) - original_pil.save(filepath) + original_pil.save(filepath, compress_level=4) return image_upload(post, image_save_function) diff --git a/web/scripts/app.js b/web/scripts/app.js index c6c29e45b..2da1b5581 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1300,7 +1300,7 @@ export class ComfyApp { if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { widget.options.values = def["input"]["required"][widget.name][0]; - if(!widget.options.values.includes(widget.value)) { + if(widget.name != 'image' && !widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; widget.callback(widget.value); } From c6e34963e412e1960f73ad357d10c2b7bd1464e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 8 May 2023 18:15:19 -0400 Subject: [PATCH 090/208] Make t2i adapter work with any latent resolution. --- comfy/t2i_adapter/adapter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 0221fff83..87e3d859e 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -56,7 +56,12 @@ class Downsample(nn.Module): def forward(self, x): assert x.shape[1] == self.channels - return self.op(x) + if not self.use_conv: + padding = [x.shape[2] % 2, x.shape[3] % 2] + self.op.padding = padding + + x = self.op(x) + return x class ResnetBlock(nn.Module): From d43e45ce624b82dadbe98646329d2b0fbc17edcf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 10:29:58 -0400 Subject: [PATCH 091/208] Remove print. --- nodes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nodes.py b/nodes.py index 699e60ae8..760db24e1 100644 --- a/nodes.py +++ b/nodes.py @@ -443,7 +443,6 @@ class ControlNetApply: def apply_controlnet(self, conditioning, control_net, image, strength): c = [] control_hint = image.movedim(-1,1) - print(control_hint.shape) for t in conditioning: n = [t[0], t[1].copy()] c_net = control_net.copy().set_cond_hint(control_hint, strength) From 314e526c5ce428a3717207c5c36a42a5c895b6a5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 12:18:18 -0400 Subject: [PATCH 092/208] Not needed anymore because sampling works with any latent size. --- comfy/samplers.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index dcf93cca2..6417f2ed4 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -362,19 +362,8 @@ def resolve_cond_masks(conditions, h, w, device): else: box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) - # Make sure the height and width are divisible by 8 - if X % 8 != 0: - newx = X // 8 * 8 - W = W + (X - newx) - X = newx - if Y % 8 != 0: - newy = Y // 8 * 8 - H = H + (Y - newy) - Y = newy - if H % 8 != 0: - H = H + (8 - (H % 8)) - if W % 8 != 0: - W = W + (8 - (W % 8)) + H = max(8, H) + W = max(8, W) area = (int(H), int(W), int(Y), int(X)) modified['area'] = area From 02ca1c67f87e46e926aba325e73b2845d5244874 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 9 May 2023 23:51:52 -0400 Subject: [PATCH 093/208] Don't print traceback when processing interrupted. --- execution.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index c19c10bc6..edf884611 100644 --- a/execution.py +++ b/execution.py @@ -194,7 +194,10 @@ class PromptExecutor: if valid: recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: - print(traceback.format_exc()) + if isinstance(e, comfy.model_management.InterruptProcessingException): + print("Processing interrupted") + else: + print(traceback.format_exc()) to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): From d6dee8af1df5e7dc80463b9e45bdce76767e4119 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 00:29:31 -0400 Subject: [PATCH 094/208] Only validate each input once. --- execution.py | 40 ++++++++++++++++++---------------------- main.py | 2 +- server.py | 2 +- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/execution.py b/execution.py index edf884611..3953fde3a 100644 --- a/execution.py +++ b/execution.py @@ -147,7 +147,7 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}): + def execute(self, prompt, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -172,27 +172,15 @@ class PromptExecutor: executed = set() try: to_execute = [] - for x in prompt: - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - to_execute += [(0, x)] + for x in list(execute_outputs): + to_execute += [(0, x)] while len(to_execute) > 0: #always execute the output that depends on the least amount of unexecuted nodes first to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - class_ = nodes.NODE_CLASS_MAPPINGS[prompt[x]['class_type']] - if hasattr(class_, 'OUTPUT_NODE'): - if class_.OUTPUT_NODE == True: - valid = False - try: - m = validate_inputs(prompt, x) - valid = m[0] - except: - valid = False - if valid: - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") @@ -219,8 +207,11 @@ class PromptExecutor: comfy.model_management.soft_empty_cache() -def validate_inputs(prompt, item): +def validate_inputs(prompt, item, validated): unique_id = item + if unique_id in validated: + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -241,8 +232,9 @@ def validate_inputs(prompt, item): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) - r = validate_inputs(prompt, o_id) + r = validate_inputs(prompt, o_id, validated) if r[0] == False: + validated[o_id] = r return r else: if type_input == "INT": @@ -270,7 +262,10 @@ def validate_inputs(prompt, item): if isinstance(type_input, list): if val not in type_input: return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) - return (True, "") + + ret = (True, "") + validated[unique_id] = ret + return ret def validate_prompt(prompt): outputs = set() @@ -284,11 +279,12 @@ def validate_prompt(prompt): good_outputs = set() errors = [] + validated = {} for o in outputs: valid = False reason = "" try: - m = validate_inputs(prompt, o) + m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] except Exception as e: @@ -297,7 +293,7 @@ def validate_prompt(prompt): reason = "Parsing error" if valid == True: - good_outputs.add(x) + good_outputs.add(o) else: print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") @@ -307,7 +303,7 @@ def validate_prompt(prompt): errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) - return (True, "") + return (True, "", list(good_outputs)) class PromptQueue: diff --git a/main.py b/main.py index eb97a2fb8..d385df70a 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-2], item[-1]) + e.execute(item[-3], item[-2], item[-1]) q.task_done(item_id, e.outputs) async def run(server, address='', port=8188, verbose=True, call_on_start=None): diff --git a/server.py b/server.py index c1226f304..b6ac7d483 100644 --- a/server.py +++ b/server.py @@ -312,7 +312,7 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data)) + self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2])) else: resp_code = 400 out_string = valid[1] From 8e3d1cbf3b8488b319675f952e1a868aa78f1161 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 01:45:27 -0400 Subject: [PATCH 095/208] Fix bug when uploading image with the same name. --- server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/server.py b/server.py index b6ac7d483..911f6a614 100644 --- a/server.py +++ b/server.py @@ -151,6 +151,7 @@ class PromptServer(): i = 1 while os.path.exists(filepath): filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) i += 1 if image_save_function is not None: From 51583164ef08d2173eb93eefa36bc50429cfe7c6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 10:03:30 -0400 Subject: [PATCH 096/208] Make MaskToImage support masks with a batch size. --- comfy_extras/nodes_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 131cd6a9f..9916f3b21 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -72,7 +72,7 @@ class MaskToImage: FUNCTION = "mask_to_image" def mask_to_image(self, mask): - result = mask[None, :, :, None].expand(-1, -1, -1, 3) + result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) return (result,) class ImageToMask: From f7c0f75d1fb1c6e3657f69247eace796882c62da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 13:58:19 -0400 Subject: [PATCH 097/208] Auto batching improvements. Try batching when cond sizes don't match with smart padding. --- comfy/samplers.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6417f2ed4..aa44fa82d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,6 +6,10 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps +import math + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers #Returns predicted noise @@ -90,8 +94,16 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con if c1.keys() != c2.keys(): return False if 'c_crossattn' in c1: - if c1['c_crossattn'].shape != c2['c_crossattn'].shape: - return False + s1 = c1['c_crossattn'].shape + s2 = c2['c_crossattn'].shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False if 'c_concat' in c1: if c1['c_concat'].shape != c2['c_concat'].shape: return False @@ -124,16 +136,28 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c_crossattn = [] c_concat = [] c_adm = [] + crossattn_max_len = 0 for x in c_list: if 'c_crossattn' in x: - c_crossattn.append(x['c_crossattn']) + c = x['c_crossattn'] + if crossattn_max_len == 0: + crossattn_max_len = c.shape[1] + else: + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + c_crossattn.append(c) if 'c_concat' in x: c_concat.append(x['c_concat']) if 'c_adm' in x: c_adm.append(x['c_adm']) out = {} - if len(c_crossattn) > 0: - out['c_crossattn'] = [torch.cat(c_crossattn)] + c_crossattn_out = [] + for c in c_crossattn: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + c_crossattn_out.append(c) + + if len(c_crossattn_out) > 0: + out['c_crossattn'] = [torch.cat(c_crossattn_out)] if len(c_concat) > 0: out['c_concat'] = [torch.cat(c_concat)] if len(c_adm) > 0: From 602095f614276dd52fad718c223e0be17d12b11e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 15:49:49 -0400 Subject: [PATCH 098/208] Send execution_error message on websocket on execution exception. --- execution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 3953fde3a..7ee038975 100644 --- a/execution.py +++ b/execution.py @@ -185,7 +185,11 @@ class PromptExecutor: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") else: - print(traceback.format_exc()) + message = str(traceback.format_exc()) + print(message) + if self.server.client_id is not None: + self.server.send_sync("execution_error", { "message": message }, self.server.client_id) + to_delete = [] for o in self.outputs: if (o not in current_outputs) and (o not in executed): From 3a7c3acc72435f312a8f050d8ad3a1c902d9cff4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 15:59:24 -0400 Subject: [PATCH 099/208] Send websocket message with list of cached nodes right before execution. --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index 7ee038975..7d18d3b65 100644 --- a/execution.py +++ b/execution.py @@ -169,6 +169,8 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) + if self.server.client_id is not None: + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) }, self.server.client_id) executed = set() try: to_execute = [] From 974958ff81d9af92b01490bcc99dfc93f8bb5d30 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 10 May 2023 16:41:43 -0400 Subject: [PATCH 100/208] Make the prompt_id a uuid and return it when queueing the prompt. --- server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 911f6a614..6965ff3c1 100644 --- a/server.py +++ b/server.py @@ -81,7 +81,7 @@ class PromptServer(): # Reusing existing session, remove old self.sockets.pop(sid, None) else: - sid = uuid.uuid4().hex + sid = uuid.uuid4().hex self.sockets[sid] = ws @@ -313,7 +313,9 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] if valid[0]: - self.prompt_queue.put((number, id(prompt), prompt, extra_data, valid[2])) + prompt_id = str(uuid.uuid4()) + self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) + return web.json_response({"prompt_id": prompt_id}) else: resp_code = 400 out_string = valid[1] From dfc74c19d944b4a4503e22297592fa3a537d3092 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 May 2023 01:22:40 -0400 Subject: [PATCH 101/208] Add the prompt_id to some websocket messages. --- execution.py | 8 ++++---- main.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/execution.py b/execution.py index 7d18d3b65..0ac4d462c 100644 --- a/execution.py +++ b/execution.py @@ -147,7 +147,7 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def execute(self, prompt, extra_data={}, execute_outputs=[]): + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -170,7 +170,7 @@ class PromptExecutor: current_outputs = set(self.outputs.keys()) if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) }, self.server.client_id) + self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() try: to_execute = [] @@ -190,7 +190,7 @@ class PromptExecutor: message = str(traceback.format_exc()) print(message) if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message }, self.server.client_id) + self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) to_delete = [] for o in self.outputs: @@ -207,7 +207,7 @@ class PromptExecutor: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None }, self.server.client_id) + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) gc.collect() comfy.model_management.soft_empty_cache() diff --git a/main.py b/main.py index d385df70a..00cbf3c4a 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ def prompt_worker(q, server): e = execution.PromptExecutor(server) while True: item, item_id = q.get() - e.execute(item[-3], item[-2], item[-1]) + e.execute(item[2], item[1], item[3], item[4]) q.task_done(item_id, e.outputs) async def run(server, address='', port=8188, verbose=True, call_on_start=None): From 8ea165dd1ef877f58f3710f31ce43f27e0f739ab Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 11 May 2023 14:15:13 -0400 Subject: [PATCH 102/208] Add a way to overwrite images when uploading. --- server.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/server.py b/server.py index 6965ff3c1..a2bb26ad9 100644 --- a/server.py +++ b/server.py @@ -127,6 +127,7 @@ class PromptServer(): def image_upload(post, image_save_function=None): image = post.get("image") + overwrite = post.get("overwrite") image_upload_type = post.get("type") upload_dir = get_dir_by_type(image_upload_type) @@ -148,11 +149,14 @@ class PromptServer(): split = os.path.splitext(filename) filepath = os.path.join(full_output_folder, filename) - i = 1 - while os.path.exists(filepath): - filename = f"{split[0]} ({i}){split[1]}" - filepath = os.path.join(full_output_folder, filename) - i += 1 + if overwrite is not None and (overwrite == "true" or overwrite == "1"): + pass + else: + i = 1 + while os.path.exists(filepath): + filename = f"{split[0]} ({i}){split[1]}" + filepath = os.path.join(full_output_folder, filename) + i += 1 if image_save_function is not None: image_save_function(image, post, filepath) From 8a4ff5e34cc53252a9ff23e796904100d75bea55 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 12 May 2023 20:58:29 +0100 Subject: [PATCH 103/208] allow static files to be symlinks --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index a2bb26ad9..ef858a98a 100644 --- a/server.py +++ b/server.py @@ -362,7 +362,7 @@ class PromptServer(): def add_routes(self): self.app.add_routes(self.routes) self.app.add_routes([ - web.static('/', self.web_root), + web.static('/', self.web_root, follow_symlinks=True), ]) def get_queue_info(self): From d9e088ddfd97663abbb933c77f79d2a6c6127851 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 12 May 2023 23:49:09 +0200 Subject: [PATCH 104/208] minor changes for tiled sampler --- comfy/ldm/modules/tomesd.py | 2 +- comfy/sd.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/tomesd.py b/comfy/ldm/modules/tomesd.py index 6a13b80c9..bb971e88f 100644 --- a/comfy/ldm/modules/tomesd.py +++ b/comfy/ldm/modules/tomesd.py @@ -36,7 +36,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, """ B, N, _ = metric.shape - if r <= 0: + if r <= 0 or w == 1 or h == 1: return do_nothing, do_nothing gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather diff --git a/comfy/sd.py b/comfy/sd.py index 3543bdb77..0200f7742 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -581,10 +581,7 @@ class VAE: samples = samples.cpu() return samples -def resize_image_to(tensor, target_latent_tensor, batched_number): - tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center") - target_batch_size = target_latent_tensor.shape[0] - +def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] print(current_batch_size, target_batch_size) if current_batch_size == 1: @@ -623,7 +620,9 @@ class ControlNet: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) if self.control_model.dtype == torch.float16: precision_scope = torch.autocast @@ -794,10 +793,14 @@ class T2IAdapter: if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint + self.control_input = None self.cond_hint = None - self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device) + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) + if x_noisy.shape[0] != self.cond_hint.shape[0]: + self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) + if self.control_input is None: self.t2i_model.to(self.device) self.control_input = self.t2i_model(self.cond_hint) self.t2i_model.cpu() From 19c014f4292863444a3d677d504ad58623395a58 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Fri, 12 May 2023 23:57:40 +0200 Subject: [PATCH 105/208] comment out annoying print statement --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 0200f7742..c6be900ad 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -583,7 +583,7 @@ class VAE: def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - print(current_batch_size, target_batch_size) + #print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor From c5c0ea666f8456b5a788092bad88528bbf34f559 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 12 May 2023 20:34:48 -0400 Subject: [PATCH 106/208] noise_mask in latent should be in a single format. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 760db24e1..c2201dafc 100644 --- a/nodes.py +++ b/nodes.py @@ -795,7 +795,7 @@ class SetLatentNoiseMask: def set_mask(self, samples, mask): s = samples.copy() - s["noise_mask"] = mask + s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): From 997dd1b1312a00cbedeafaf916e49f294a73a431 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 02:07:49 -0400 Subject: [PATCH 107/208] Fix queue delete. --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index a2bb26ad9..8435d091b 100644 --- a/server.py +++ b/server.py @@ -336,9 +336,9 @@ class PromptServer(): if "delete" in json_data: to_delete = json_data['delete'] for id_to_delete in to_delete: - delete_func = lambda a: a[1] == int(id_to_delete) + delete_func = lambda a: a[1] == id_to_delete self.prompt_queue.delete_queue_item(delete_func) - + return web.Response(status=200) @routes.post("/interrupt") From 1201d2eae5820bb8124beb22b712d743415fd47d Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 13 May 2023 17:15:45 +0200 Subject: [PATCH 108/208] Make nodes map over input lists (#579) * allow nodes to map over lists * make work with IS_CHANGED and VALIDATE_INPUTS * give list outputs distinct socket shape * add rebatch node * add batch index logic * add repeat latent batch * deal with noise mask edge cases in latentfrombatch --- comfy/sample.py | 17 ++++-- comfy_extras/nodes_rebatch.py | 108 ++++++++++++++++++++++++++++++++++ execution.py | 90 +++++++++++++++++++++++----- nodes.py | 57 +++++++++++++++--- server.py | 1 + web/scripts/app.js | 3 +- 6 files changed, 250 insertions(+), 26 deletions(-) create mode 100644 comfy_extras/nodes_rebatch.py diff --git a/comfy/sample.py b/comfy/sample.py index bd38585ac..284efca61 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -2,17 +2,26 @@ import torch import comfy.model_management import comfy.samplers import math +import numpy as np -def prepare_noise(latent_image, seed, skip=0): +def prepare_noise(latent_image, seed, noise_inds=None): """ creates random noise given a latent image and a seed. optional arg skip can be used to skip and discard x number of noise generations for a given seed """ generator = torch.manual_seed(seed) - for _ in range(skip): + if noise_inds is None: + return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + noises = [] + for i in range(unique_inds[-1]+1): noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - return noise + if i in unique_inds: + noises.append(noise) + noises = [noises[i] for i in inverse] + noises = torch.cat(noises, axis=0) + return noises def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py new file mode 100644 index 000000000..0a9daf272 --- /dev/null +++ b/comfy_extras/nodes_rebatch.py @@ -0,0 +1,108 @@ +import torch + +class LatentRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "latents": ("LATENT",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "latent/batch" + + @staticmethod + def get_batch(latents, list_ind, offset): + '''prepare a batch out of the list of latents''' + samples = latents[list_ind]['samples'] + shape = samples.shape + mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: + torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") + if mask.shape[0] < samples.shape[0]: + mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] + if 'batch_index' in latents[list_ind]: + batch_inds = latents[list_ind]['batch_index'] + else: + batch_inds = [x+offset for x in range(shape[0])] + return samples, mask, batch_inds + + @staticmethod + def get_slices(indexable, num, batch_size): + '''divides an indexable object into num slices of length batch_size, and a remainder''' + slices = [] + for i in range(num): + slices.append(indexable[i*batch_size:(i+1)*batch_size]) + if num * batch_size < len(indexable): + return slices, indexable[num * batch_size:] + else: + return slices, None + + @staticmethod + def slice_batch(batch, num, batch_size): + result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] + return list(zip(*result)) + + @staticmethod + def cat_batch(batch1, batch2): + if batch1[0] is None: + return batch2 + result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] + return result + + def rebatch(self, latents, batch_size): + batch_size = batch_size[0] + + output_list = [] + current_batch = (None, None, None) + processed = 0 + + for i in range(len(latents)): + # fetch new entry of list + #samples, masks, indices = self.get_batch(latents, i) + next_batch = self.get_batch(latents, i, processed) + processed += len(next_batch[2]) + # set to current if current is None + if current_batch[0] is None: + current_batch = next_batch + # add previous to list if dimensions do not match + elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + current_batch = next_batch + # cat if everything checks out + else: + current_batch = self.cat_batch(current_batch, next_batch) + + # add to list if dimensions gone above target batch size + if current_batch[0].shape[0] > batch_size: + num = current_batch[0].shape[0] // batch_size + sliced, remainder = self.slice_batch(current_batch, num, batch_size) + + for i in range(num): + output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) + + current_batch = remainder + + #add remainder + if current_batch[0] is not None: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + + #get rid of empty masks + for s in output_list: + if s['noise_mask'].mean() == 1.0: + del s['noise_mask'] + + return (output_list,) + +NODE_CLASS_MAPPINGS = { + "RebatchLatents": LatentRebatch, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RebatchLatents": "Rebatch Latents", +} \ No newline at end of file diff --git a/execution.py b/execution.py index 0ac4d462c..cf2e5ea71 100644 --- a/execution.py +++ b/execution.py @@ -26,20 +26,81 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = obj else: if ("required" in valid_inputs and x in valid_inputs["required"]) or ("optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = input_data + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = prompt + input_data_all[x] = [prompt] if h[x] == "EXTRA_PNGINFO": if "extra_pnginfo" in extra_data: - input_data_all[x] = extra_data['extra_pnginfo'] + input_data_all[x] = [extra_data['extra_pnginfo']] if h[x] == "UNIQUE_ID": - input_data_all[x] = unique_id + input_data_all[x] = [unique_id] return input_data_all +def map_node_over_list(obj, input_data_all, func, allow_interrupt=False): + # check if node wants the lists + intput_is_list = False + if hasattr(obj, "INPUT_IS_LIST"): + intput_is_list = obj.INPUT_IS_LIST + + max_len_input = max([len(x) for x in input_data_all.values()]) + + # get a slice of inputs, repeat last input when list isn't long enough + def slice_dict(d, i): + d_new = dict() + for k,v in d.items(): + d_new[k] = v[i if len(v) > i else -1] + return d_new + + results = [] + if intput_is_list: + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**input_data_all)) + else: + for i in range(max_len_input): + if allow_interrupt: + nodes.before_node_execution() + results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + return results + +def get_output_data(obj, input_data_all): + + results = [] + uis = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) + + for r in return_values: + if isinstance(r, dict): + if 'ui' in r: + uis.append(r['ui']) + if 'result' in r: + results.append(r['result']) + else: + results.append(r) + + output = [] + if len(results) > 0: + # check which outputs need concatenating + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + + ui = dict() + if len(uis) > 0: + ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} + return output, ui + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -63,13 +124,11 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute server.send_sync("executing", { "node": unique_id }, server.client_id) obj = class_def() - nodes.before_node_execution() - outputs[unique_id] = getattr(obj, obj.FUNCTION)(**input_data_all) - if "ui" in outputs[unique_id]: + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) - if "result" in outputs[unique_id]: - outputs[unique_id] = outputs[unique_id]["result"] + server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id) executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -105,7 +164,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item input_data_all = get_input_data(inputs, class_def, unique_id, outputs) if input_data_all is not None: try: - is_changed = class_def.IS_CHANGED(**input_data_all) + #is_changed = class_def.IS_CHANGED(**input_data_all) + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") prompt[unique_id]['is_changed'] = is_changed except: to_delete = True @@ -261,9 +321,11 @@ def validate_inputs(prompt, item, validated): if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) - ret = obj_class.VALIDATE_INPUTS(**input_data_all) - if ret != True: - return (False, "{}, {}".format(class_type, ret)) + #ret = obj_class.VALIDATE_INPUTS(**input_data_all) + ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") + for r in ret: + if r != True: + return (False, "{}, {}".format(class_type, r)) else: if isinstance(type_input, list): if val not in type_input: diff --git a/nodes.py b/nodes.py index c2201dafc..509dc0697 100644 --- a/nodes.py +++ b/nodes.py @@ -629,18 +629,57 @@ class LatentFromBatch: def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), }} RETURN_TYPES = ("LATENT",) - FUNCTION = "rotate" + FUNCTION = "frombatch" - CATEGORY = "latent" + CATEGORY = "latent/batch" - def rotate(self, samples, batch_index): + def frombatch(self, samples, batch_index, length): s = samples.copy() s_in = samples["samples"] batch_index = min(s_in.shape[0] - 1, batch_index) - s["samples"] = s_in[batch_index:batch_index + 1].clone() - s["batch_index"] = batch_index + length = min(s_in.shape[0] - batch_index, length) + s["samples"] = s_in[batch_index:batch_index + length].clone() + if "noise_mask" in samples: + masks = samples["noise_mask"] + if masks.shape[0] == 1: + s["noise_mask"] = masks.clone() + else: + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = masks[batch_index:batch_index + length].clone() + if "batch_index" not in s: + s["batch_index"] = [x for x in range(batch_index, batch_index+length)] + else: + s["batch_index"] = samples["batch_index"][batch_index:batch_index + length] + return (s,) + +class RepeatLatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "repeat" + + CATEGORY = "latent/batch" + + def repeat(self, samples, amount): + s = samples.copy() + s_in = samples["samples"] + + s["samples"] = s_in.repeat((amount, 1,1,1)) + if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1: + masks = samples["noise_mask"] + if masks.shape[0] < s_in.shape[0]: + masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]] + s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1)) + if "batch_index" in s: + offset = max(s["batch_index"]) - min(s["batch_index"]) + 1 + s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]] return (s,) class LatentUpscale: @@ -805,8 +844,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: - skip = latent["batch_index"] if "batch_index" in latent else 0 - noise = comfy.sample.prepare_noise(latent_image, seed, skip) + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) noise_mask = None if "noise_mask" in latent: @@ -1170,6 +1209,7 @@ NODE_CLASS_MAPPINGS = { "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "LatentFromBatch": LatentFromBatch, + "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, "PreviewImage": PreviewImage, "LoadImage": LoadImage, @@ -1244,6 +1284,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", "LatentComposite": "Latent Composite", + "LatentFromBatch" : "Latent From Batch", + "RepeatLatentBatch": "Repeat Latent Batch", # Image "SaveImage": "Save Image", "PreviewImage": "Preview Image", @@ -1299,3 +1341,4 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) diff --git a/server.py b/server.py index 8435d091b..cb66cc618 100644 --- a/server.py +++ b/server.py @@ -268,6 +268,7 @@ class PromptServer(): info = {} info['input'] = obj_class.INPUT_TYPES() info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] info['name'] = x info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x diff --git a/web/scripts/app.js b/web/scripts/app.js index 2da1b5581..1a4a18b94 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -976,7 +976,8 @@ export class ComfyApp { for (const o in nodeData["output"]) { const output = nodeData["output"][o]; const outputName = nodeData["output_name"][o] || output; - this.addOutput(outputName, output); + const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; + this.addOutput(outputName, output, { shape: outputShape }); } const s = this.computeSize(); From 44f9f9baf170ddf27891b240002300d8aa09fb2a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 11:17:16 -0400 Subject: [PATCH 109/208] Add the prompt id to some websocket messages. --- execution.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/execution.py b/execution.py index cf2e5ea71..b9548229c 100644 --- a/execution.py +++ b/execution.py @@ -101,7 +101,7 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -116,19 +116,19 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id }, server.client_id) + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) obj = class_def() output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui }, server.client_id) + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): @@ -215,6 +215,9 @@ class PromptExecutor: else: self.server.client_id = None + if self.server.client_id is not None: + self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + with torch.inference_mode(): #delete cached outputs if nodes don't exist for them to_delete = [] @@ -242,7 +245,7 @@ class PromptExecutor: to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") From cb4b8223981ec9e090ebf44205f5ce16d72f01cb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 11:54:45 -0400 Subject: [PATCH 110/208] Print custom nodes that take too much time to import. --- nodes.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nodes.py b/nodes.py index 509dc0697..bc7968308 100644 --- a/nodes.py +++ b/nodes.py @@ -6,6 +6,7 @@ import json import hashlib import traceback import math +import time from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -1325,6 +1326,7 @@ def load_custom_node(module_path): def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") + node_import_times = [] for custom_node_path in node_paths: possible_modules = os.listdir(custom_node_path) if "__pycache__" in possible_modules: @@ -1333,7 +1335,16 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + time_before = time.time() load_custom_node(module_path) + node_import_times.append((time.time() - time_before, module_path)) + + slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) + if len(slow_nodes) > 0: + print("\nDetected some custom nodes that were slow to import, if this is one of yours please improve it if you can:") + for n in sorted(slow_nodes): + print("{:6.1f} seconds to import:".format(n[0]), n[1]) + print() def init_custom_nodes(): load_custom_nodes() From cf439709b6b3ffae5ad15a9f7e59fedc214d5f1c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 12:50:21 -0400 Subject: [PATCH 111/208] Load nodes in comfy_extras before custom nodes. Change the slow import message. --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index bc7968308..956b739d9 100644 --- a/nodes.py +++ b/nodes.py @@ -1341,15 +1341,15 @@ def load_custom_nodes(): slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) if len(slow_nodes) > 0: - print("\nDetected some custom nodes that were slow to import, if this is one of yours please improve it if you can:") + print("\nDetected some custom nodes that were slow to import:") for n in sorted(slow_nodes): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() def init_custom_nodes(): - load_custom_nodes() load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_nodes() From 92bf1cb61efcab45961d1119cb7ec7a076caf24e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:05:52 -0400 Subject: [PATCH 112/208] Change message. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 956b739d9..28215127c 100644 --- a/nodes.py +++ b/nodes.py @@ -1341,7 +1341,7 @@ def load_custom_nodes(): slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) if len(slow_nodes) > 0: - print("\nDetected some custom nodes that were slow to import:") + print("\nImport times for custom nodes:") for n in sorted(slow_nodes): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() From 2ac744f6628d107b3534177eeca5ef06f6668609 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:15:31 -0400 Subject: [PATCH 113/208] Print all custom node import times. --- nodes.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 28215127c..f3b7da1a9 100644 --- a/nodes.py +++ b/nodes.py @@ -1339,10 +1339,9 @@ def load_custom_nodes(): load_custom_node(module_path) node_import_times.append((time.time() - time_before, module_path)) - slow_nodes = list(filter(lambda a: a[0] > 1.0, node_import_times)) - if len(slow_nodes) > 0: + if len(node_import_times) > 0: print("\nImport times for custom nodes:") - for n in sorted(slow_nodes): + for n in sorted(node_import_times): print("{:6.1f} seconds to import:".format(n[0]), n[1]) print() From db4d3a8494a4a7dbb6f911ae126a92abec6bf91b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 13:23:42 -0400 Subject: [PATCH 114/208] Print if custom nodes imported successfully or not. --- nodes.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index f3b7da1a9..63d9adc3d 100644 --- a/nodes.py +++ b/nodes.py @@ -1318,11 +1318,14 @@ def load_custom_node(module_path): NODE_CLASS_MAPPINGS.update(module.NODE_CLASS_MAPPINGS) if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None: NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS) + return True else: print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.") + return False except Exception as e: print(traceback.format_exc()) print(f"Cannot import {module_path} module for custom nodes:", e) + return False def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") @@ -1336,13 +1339,17 @@ def load_custom_nodes(): module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue time_before = time.time() - load_custom_node(module_path) - node_import_times.append((time.time() - time_before, module_path)) + success = load_custom_node(module_path) + node_import_times.append((time.time() - time_before, module_path, success)) if len(node_import_times) > 0: print("\nImport times for custom nodes:") for n in sorted(node_import_times): - print("{:6.1f} seconds to import:".format(n[0]), n[1]) + if n[2]: + import_message = "" + else: + import_message = " (IMPORT FAILED)" + print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) print() def init_custom_nodes(): From b0505eb7ab8af1986dabd97c23fae83a0539303d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 15:31:22 -0400 Subject: [PATCH 115/208] Return right type when none specified in upload route. Switch time.time to time.perf_counter for custom node import times. --- nodes.py | 4 ++-- server.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/nodes.py b/nodes.py index 63d9adc3d..c4aff1012 100644 --- a/nodes.py +++ b/nodes.py @@ -1338,9 +1338,9 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - time_before = time.time() + time_before = time.perf_counter() success = load_custom_node(module_path) - node_import_times.append((time.time() - time_before, module_path, success)) + node_import_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_import_times) > 0: print("\nImport times for custom nodes:") diff --git a/server.py b/server.py index d1079dd83..ba4dcba03 100644 --- a/server.py +++ b/server.py @@ -115,22 +115,23 @@ class PromptServer(): def get_dir_by_type(dir_type): if dir_type is None: - type_dir = folder_paths.get_input_directory() - elif dir_type == "input": + dir_type = "input" + + if dir_type == "input": type_dir = folder_paths.get_input_directory() elif dir_type == "temp": type_dir = folder_paths.get_temp_directory() elif dir_type == "output": type_dir = folder_paths.get_output_directory() - return type_dir + return type_dir, dir_type def image_upload(post, image_save_function=None): image = post.get("image") overwrite = post.get("overwrite") image_upload_type = post.get("type") - upload_dir = get_dir_by_type(image_upload_type) + upload_dir, image_upload_type = get_dir_by_type(image_upload_type) if image and image.file: filename = image.filename From 3a1f47764d76bb9878b55e82657044b3faceda9c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 13 May 2023 17:11:27 -0400 Subject: [PATCH 116/208] Print the torch device that is used on startup. --- comfy/model_management.py | 42 ++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 39df8d9a7..c15323219 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -127,6 +127,32 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.cuda.current_device() + +def get_torch_device_name(device): + if hasattr(device, 'type'): + return "{}".format(device.type) + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + +try: + print("Using device:", get_torch_device_name(get_torch_device())) +except: + print("Could not pick default device.") + current_loaded_model = None current_gpu_controlnets = [] @@ -233,22 +259,6 @@ def unload_if_low_vram(model): return model.cpu() return model -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type From e7b9d2c02cffd59fecca4ee617137ea38641078a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 01:30:58 -0400 Subject: [PATCH 117/208] /prompt endpoint error is now in json format. --- server.py | 7 +++---- web/scripts/api.js | 2 +- web/scripts/app.js | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/server.py b/server.py index ba4dcba03..f52117f10 100644 --- a/server.py +++ b/server.py @@ -323,12 +323,11 @@ class PromptServer(): self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) return web.json_response({"prompt_id": prompt_id}) else: - resp_code = 400 - out_string = valid[1] print("invalid prompt:", valid[1]) + return web.json_response({"error": valid[1]}, status=400) + else: + return web.json_response({"error": "no prompt"}, status=400) - return web.Response(body=out_string, status=resp_code) - @routes.post("/queue") async def post_queue(request): json_data = await request.json() diff --git a/web/scripts/api.js b/web/scripts/api.js index d29faa5ba..4f061c358 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -163,7 +163,7 @@ class ComfyApi extends EventTarget { if (res.status !== 200) { throw { - response: await res.text(), + response: await res.json(), }; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index 1a4a18b94..00d3c9746 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1222,7 +1222,7 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response || error.toString()); + this.ui.dialog.show(error.response.error || error.toString()); break; } From 9bf67c4c5a5c8b8d1efc2d4ce7e7ab1eccce1fa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 01:34:25 -0400 Subject: [PATCH 118/208] Print prompt execution time. --- execution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/execution.py b/execution.py index b9548229c..dd88029bc 100644 --- a/execution.py +++ b/execution.py @@ -6,6 +6,7 @@ import threading import heapq import traceback import gc +import time import torch import nodes @@ -215,6 +216,7 @@ class PromptExecutor: else: self.server.client_id = None + execution_start_time = time.perf_counter() if self.server.client_id is not None: self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) @@ -272,6 +274,7 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() comfy.model_management.soft_empty_cache() From d926f65f56217e7828ad27ec5b646c74398593c4 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sun, 14 May 2023 23:21:22 +0900 Subject: [PATCH 119/208] Feature/maskeditor context menu (#649) * add "Open in MaskEditor" to context menu * change save button name to 'Save to node' if open in node. clear clipspace_return_node after auto paste * * leak patch: prevent infinite duplication of MaskEditorDialog instance on every dialog open * prevent conflict of multiple opening of MaskEditorDialog * name of save button fix * patch: brushPreview hiding by dialog * consider close by 'esc' key on maskeditor. * bugfix about last patch * patch: invalid close detection * 'enter' key as save action * * batch support enhance - pick index based on imageIndex on copy action * paste fix on batch image node * typo --------- Co-authored-by: Lt.Dr.Data --- web/extensions/core/maskeditor.js | 120 ++++++++++++---- web/scripts/app.js | 226 +++++++++++++++++------------- 2 files changed, 221 insertions(+), 125 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 552059e86..4b0c12747 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -72,40 +72,50 @@ function prepareRGB(image, backupCanvas, backupCtx) { class MaskEditorDialog extends ComfyDialog { static instance = null; + + static getInstance() { + if(!MaskEditorDialog.instance) { + MaskEditorDialog.instance = new MaskEditorDialog(app); + } + + return MaskEditorDialog.instance; + } + + is_layout_created = false; + constructor() { super(); this.element = $el("div.comfy-modal", { parent: document.body }, [ $el("div.comfy-modal-content", [...this.createButtons()]), ]); - MaskEditorDialog.instance = this; } createButtons() { return []; } - clearMask(self) { - } - createButton(name, callback) { var button = document.createElement("button"); button.innerText = name; button.addEventListener("click", callback); return button; } + createLeftButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "left"; button.style.marginRight = "4px"; return button; } + createRightButton(name, callback) { var button = this.createButton(name, callback); button.style.cssFloat = "right"; button.style.marginLeft = "4px"; return button; } + createLeftSlider(self, name, callback) { const divElement = document.createElement('div'); divElement.id = "maskeditor-slider"; @@ -164,7 +174,7 @@ class MaskEditorDialog extends ComfyDialog { brush.style.MozBorderRadius = "50%"; brush.style.WebkitBorderRadius = "50%"; brush.style.position = "absolute"; - brush.style.zIndex = 100; + brush.style.zIndex = 8889; brush.style.pointerEvents = "none"; this.brush = brush; this.element.appendChild(imgCanvas); @@ -187,7 +197,8 @@ class MaskEditorDialog extends ComfyDialog { document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.close(); }); - var saveButton = this.createRightButton("Save", () => { + + this.saveButton = this.createRightButton("Save", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); self.save(); @@ -199,11 +210,10 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(bottom_panel); bottom_panel.appendChild(clearButton); - bottom_panel.appendChild(saveButton); + bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(brush_size_slider); - this.element.style.display = "block"; imgCanvas.style.position = "relative"; imgCanvas.style.top = "200"; imgCanvas.style.left = "0"; @@ -212,25 +222,63 @@ class MaskEditorDialog extends ComfyDialog { } show() { - // layout - const imgCanvas = document.createElement('canvas'); - const maskCanvas = document.createElement('canvas'); - const backupCanvas = document.createElement('canvas'); + if(!this.is_layout_created) { + // layout + const imgCanvas = document.createElement('canvas'); + const maskCanvas = document.createElement('canvas'); + const backupCanvas = document.createElement('canvas'); - imgCanvas.id = "imageCanvas"; - maskCanvas.id = "maskCanvas"; - backupCanvas.id = "backupCanvas"; + imgCanvas.id = "imageCanvas"; + maskCanvas.id = "maskCanvas"; + backupCanvas.id = "backupCanvas"; - this.setlayout(imgCanvas, maskCanvas); + this.setlayout(imgCanvas, maskCanvas); - // prepare content - this.maskCanvas = maskCanvas; - this.backupCanvas = backupCanvas; - this.maskCtx = maskCanvas.getContext('2d'); - this.backupCtx = backupCanvas.getContext('2d'); + // prepare content + this.imgCanvas = imgCanvas; + this.maskCanvas = maskCanvas; + this.backupCanvas = backupCanvas; + this.maskCtx = maskCanvas.getContext('2d'); + this.backupCtx = backupCanvas.getContext('2d'); - this.setImages(imgCanvas, backupCanvas); - this.setEventHandler(maskCanvas); + this.setEventHandler(maskCanvas); + + this.is_layout_created = true; + + // replacement of onClose hook since close is not real close + const self = this; + const observer = new MutationObserver(function(mutations) { + mutations.forEach(function(mutation) { + if (mutation.type === 'attributes' && mutation.attributeName === 'style') { + if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { + ComfyApp.onClipspaceEditorClosed(); + } + + self.last_display_style = self.element.style.display; + } + }); + }); + + const config = { attributes: true }; + observer.observe(this.element, config); + } + + this.setImages(this.imgCanvas, this.backupCanvas); + + if(ComfyApp.clipspace_return_node) { + this.saveButton.innerText = "Save to node"; + } + else { + this.saveButton.innerText = "Save"; + } + this.saveButton.disabled = false; + + this.element.style.display = "block"; + this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. + } + + isOpened() { + return this.element.style.display == "block"; } setImages(imgCanvas, backupCanvas) { @@ -239,6 +287,10 @@ class MaskEditorDialog extends ComfyDialog { const maskCtx = this.maskCtx; const maskCanvas = this.maskCanvas; + backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); + maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); + // image load const orig_image = new Image(); window.addEventListener("resize", () => { @@ -296,8 +348,7 @@ class MaskEditorDialog extends ComfyDialog { rgb_url.searchParams.set('channel', 'rgb'); orig_image.src = rgb_url; this.image = orig_image; - }g - + } setEventHandler(maskCanvas) { maskCanvas.addEventListener("contextmenu", (event) => { @@ -327,6 +378,8 @@ class MaskEditorDialog extends ComfyDialog { self.brush_size = Math.min(self.brush_size+2, 100); } else if (event.key === '[') { self.brush_size = Math.max(self.brush_size-2, 1); + } else if(event.key === 'Enter') { + self.save(); } self.updateBrushPreview(self); @@ -514,7 +567,7 @@ class MaskEditorDialog extends ComfyDialog { } } - save() { + async save() { const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); @@ -570,7 +623,10 @@ class MaskEditorDialog extends ComfyDialog { formData.append('type', "input"); formData.append('subfolder', "clipspace"); - uploadMask(item, formData); + this.saveButton.innerText = "Saving..."; + this.saveButton.disabled = true; + await uploadMask(item, formData); + ComfyApp.onClipspaceEditorSave(); this.close(); } } @@ -578,13 +634,15 @@ class MaskEditorDialog extends ComfyDialog { app.registerExtension({ name: "Comfy.MaskEditor", init(app) { - const callback = + ComfyApp.open_maskeditor = function () { - let dlg = new MaskEditorDialog(app); - dlg.show(); + const dlg = MaskEditorDialog.getInstance(); + if(!dlg.isOpened()) { + dlg.show(); + } }; const context_predicate = () => ComfyApp.clipspace && ComfyApp.clipspace.imgs && ComfyApp.clipspace.imgs.length > 0 - ClipspaceDialog.registerButton("MaskEditor", context_predicate, callback); + ClipspaceDialog.registerButton("MaskEditor", context_predicate, ComfyApp.open_maskeditor); } }); \ No newline at end of file diff --git a/web/scripts/app.js b/web/scripts/app.js index 00d3c9746..87c5e30ca 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -26,6 +26,8 @@ export class ComfyApp { */ static clipspace = null; static clipspace_invalidate_handler = null; + static open_maskeditor = null; + static clipspace_return_node = null; constructor() { this.ui = new ComfyUI(this); @@ -49,6 +51,114 @@ export class ComfyApp { this.shiftDown = false; } + static isImageNode(node) { + return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0); + } + + static onClipspaceEditorSave() { + if(ComfyApp.clipspace_return_node) { + ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node); + } + } + + static onClipspaceEditorClosed() { + ComfyApp.clipspace_return_node = null; + } + + static copyToClipspace(node) { + var widgets = null; + if(node.widgets) { + widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value })); + } + + var imgs = undefined; + var orig_imgs = undefined; + if(node.imgs != undefined) { + imgs = []; + orig_imgs = []; + + for (let i = 0; i < node.imgs.length; i++) { + imgs[i] = new Image(); + imgs[i].src = node.imgs[i].src; + orig_imgs[i] = imgs[i]; + } + } + + var selectedIndex = 0; + if(node.imageIndex) { + selectedIndex = node.imageIndex; + } + + ComfyApp.clipspace = { + 'widgets': widgets, + 'imgs': imgs, + 'original_imgs': orig_imgs, + 'images': node.images, + 'selectedIndex': selectedIndex, + 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action + }; + + ComfyApp.clipspace_return_node = null; + + if(ComfyApp.clipspace_invalidate_handler) { + ComfyApp.clipspace_invalidate_handler(); + } + } + + static pasteFromClipspace(node) { + if(ComfyApp.clipspace) { + // image paste + if(ComfyApp.clipspace.imgs && node.imgs) { + if(node.images && ComfyApp.clipspace.images) { + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + } + else + app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images; + } + + if(ComfyApp.clipspace.imgs) { + // deep-copy to cut link with clipspace + if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + const img = new Image(); + img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; + node.imgs = [img]; + node.imageIndex = 0; + } + else { + const imgs = []; + for(let i=0; i obj.name === 'image'); + if(index >= 0) { + node.widgets[index].value = clip_image; + } + } + if(ComfyApp.clipspace.widgets) { + ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { + const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); + if (prop && prop.type != 'button') { + prop.value = value; + prop.callback(value); + } + }); + } + } + + app.graph.setDirtyCanvas(true); + } + } + /** * Invoke an extension callback * @param {keyof ComfyExtension} method The extension callback to execute @@ -138,102 +248,30 @@ export class ComfyApp { } } - options.push( - { - content: "Copy (Clipspace)", - callback: (obj) => { - var widgets = null; - if(this.widgets) { - widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); - } - - var imgs = undefined; - var orig_imgs = undefined; - if(this.imgs != undefined) { - imgs = []; - orig_imgs = []; + // prevent conflict of clipspace content + if(!ComfyApp.clipspace_return_node) { + options.push({ + content: "Copy (Clipspace)", + callback: (obj) => { ComfyApp.copyToClipspace(this); } + }); - for (let i = 0; i < this.imgs.length; i++) { - imgs[i] = new Image(); - imgs[i].src = this.imgs[i].src; - orig_imgs[i] = imgs[i]; + if(ComfyApp.clipspace != null) { + options.push({ + content: "Paste (Clipspace)", + callback: () => { ComfyApp.pasteFromClipspace(this); } + }); + } + + if(ComfyApp.isImageNode(this)) { + options.push({ + content: "Open in MaskEditor", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + ComfyApp.clipspace_return_node = this; + ComfyApp.open_maskeditor(); } - } - - ComfyApp.clipspace = { - 'widgets': widgets, - 'imgs': imgs, - 'original_imgs': orig_imgs, - 'images': this.images, - 'selectedIndex': 0, - 'img_paste_mode': 'selected' // reset to default im_paste_mode state on copy action - }; - - if(ComfyApp.clipspace_invalidate_handler) { - ComfyApp.clipspace_invalidate_handler(); - } - } - }); - - if(ComfyApp.clipspace != null) { - options.push( - { - content: "Paste (Clipspace)", - callback: () => { - if(ComfyApp.clipspace) { - // image paste - if(ComfyApp.clipspace.imgs && this.imgs) { - if(this.images && ComfyApp.clipspace.images) { - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - app.nodeOutputs[this.id + ""].images = this.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; - - } - else - app.nodeOutputs[this.id + ""].images = this.images = ComfyApp.clipspace.images; - } - - if(ComfyApp.clipspace.imgs) { - // deep-copy to cut link with clipspace - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - const img = new Image(); - img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; - this.imgs = [img]; - } - else { - const imgs = []; - for(let i=0; i obj.name === 'image'); - if(index >= 0) { - this.widgets[index].value = clip_image; - } - } - if(ComfyApp.clipspace.widgets) { - ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { - const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { - prop.value = value; - prop.callback(value); - } - }); - } - } - } - - app.graph.setDirtyCanvas(true); - } - } - ); + }); + } } }; } From acff543d669dba9b03fb500a10010f2da8739ff3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 12:50:21 -0400 Subject: [PATCH 120/208] Remove useless code. --- nodes.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/nodes.py b/nodes.py index c4aff1012..bc23e5c17 100644 --- a/nodes.py +++ b/nodes.py @@ -146,9 +146,6 @@ class ConditioningSetMask: return (c, ) class VAEDecode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -161,9 +158,6 @@ class VAEDecode: return (vae.decode(samples["samples"]), ) class VAEDecodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} @@ -176,9 +170,6 @@ class VAEDecodeTiled: return (vae.decode_tiled(samples["samples"]), ) class VAEEncode: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -203,9 +194,6 @@ class VAEEncode: return ({"samples":t}, ) class VAEEncodeTiled: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} @@ -220,9 +208,6 @@ class VAEEncodeTiled: return ({"samples":t}, ) class VAEEncodeForInpaint: - def __init__(self, device="cpu"): - self.device = device - @classmethod def INPUT_TYPES(s): return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}} From 587f89fe5a8e2bcb389fb4919dc33c330320fa41 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 15:10:40 -0400 Subject: [PATCH 121/208] Enable safe loading for upscale models. --- comfy_extras/nodes_upscale_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index ab5b0ccfc..f9252ea0b 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -17,7 +17,7 @@ class UpscaleModelLoader: def load_model(self, model_name): model_path = folder_paths.get_full_path("upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) out = model_loading.load_state_dict(sd).eval() return (out, ) From 84ea21c815d426000c233e0c7b8c542764335cc8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 14 May 2023 17:02:40 -0400 Subject: [PATCH 122/208] Update litegraph from upstream. --- web/lib/litegraph.core.js | 145 +++++++++++++++++++++++++++++++++++--- 1 file changed, 137 insertions(+), 8 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 2bc6af0c3..6c81c3ffd 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -5880,13 +5880,13 @@ LGraphNode.prototype.executeAction = function(action) //when clicked on top of a node //and it is not interactive - if (node && this.allow_interaction && !skip_action && !this.read_only) { + if (node && (this.allow_interaction || node.flags.allow_interaction) && !skip_action && !this.read_only) { if (!this.live_mode && !node.flags.pinned) { this.bringToFront(node); } //if it wasn't selected? //not dragging mouse to connect two slots - if ( !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { + if ( this.allow_interaction && !this.connecting_node && !node.flags.collapsed && !this.live_mode ) { //Search for corner for resize if ( !skip_action && node.resizable !== false && node.inResizeCorner(e.canvasX, e.canvasY) @@ -6033,7 +6033,7 @@ LGraphNode.prototype.executeAction = function(action) } //double clicking - if (is_double_click && this.selected_nodes[node.id]) { + if (this.allow_interaction && is_double_click && this.selected_nodes[node.id]) { //double click node if (node.onDblClick) { node.onDblClick( e, pos, this ); @@ -6307,6 +6307,9 @@ LGraphNode.prototype.executeAction = function(action) this.dirty_canvas = true; } + //get node over + var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); + if (this.dragging_rectangle) { this.dragging_rectangle[2] = e.canvasX - this.dragging_rectangle[0]; @@ -6336,14 +6339,11 @@ LGraphNode.prototype.executeAction = function(action) this.ds.offset[1] += delta[1] / this.ds.scale; this.dirty_canvas = true; this.dirty_bgcanvas = true; - } else if (this.allow_interaction && !this.read_only) { + } else if ((this.allow_interaction || (node && node.flags.allow_interaction)) && !this.read_only) { if (this.connecting_node) { this.dirty_canvas = true; } - //get node over - var node = this.graph.getNodeOnPos(e.canvasX,e.canvasY,this.visible_nodes); - //remove mouseover flag for (var i = 0, l = this.graph._nodes.length; i < l; ++i) { if (this.graph._nodes[i].mouseOver && node != this.graph._nodes[i] ) { @@ -9911,7 +9911,7 @@ LGraphNode.prototype.executeAction = function(action) event, active_widget ) { - if (!node.widgets || !node.widgets.length) { + if (!node.widgets || !node.widgets.length || (!this.allow_interaction && !node.flags.allow_interaction)) { return null; } @@ -10300,6 +10300,119 @@ LGraphNode.prototype.executeAction = function(action) canvas.graph.add(group); }; + /** + * Determines the furthest nodes in each direction + * @param nodes {LGraphNode[]} the nodes to from which boundary nodes will be extracted + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.getBoundaryNodes = function(nodes) { + let top = null; + let right = null; + let bottom = null; + let left = null; + for (const nID in nodes) { + const node = nodes[nID]; + const [x, y] = node.pos; + const [width, height] = node.size; + + if (top === null || y < top.pos[1]) { + top = node; + } + if (right === null || x + width > right.pos[0] + right.size[0]) { + right = node; + } + if (bottom === null || y + height > bottom.pos[1] + bottom.size[1]) { + bottom = node; + } + if (left === null || x < left.pos[0]) { + left = node; + } + } + + return { + "top": top, + "right": right, + "bottom": bottom, + "left": left + }; + } + /** + * Determines the furthest nodes in each direction for the currently selected nodes + * @return {{left: LGraphNode, top: LGraphNode, right: LGraphNode, bottom: LGraphNode}} + */ + LGraphCanvas.prototype.boundaryNodesForSelection = function() { + return LGraphCanvas.getBoundaryNodes(Object.values(this.selected_nodes)); + } + + /** + * + * @param {LGraphNode[]} nodes a list of nodes + * @param {"top"|"bottom"|"left"|"right"} direction Direction to align the nodes + * @param {LGraphNode?} align_to Node to align to (if null, align to the furthest node in the given direction) + */ + LGraphCanvas.alignNodes = function (nodes, direction, align_to) { + if (!nodes) { + return; + } + + const canvas = LGraphCanvas.active_canvas; + let boundaryNodes = [] + if (align_to === undefined) { + boundaryNodes = LGraphCanvas.getBoundaryNodes(nodes) + } else { + boundaryNodes = { + "top": align_to, + "right": align_to, + "bottom": align_to, + "left": align_to + } + } + + for (const [_, node] of Object.entries(canvas.selected_nodes)) { + switch (direction) { + case "right": + node.pos[0] = boundaryNodes["right"].pos[0] + boundaryNodes["right"].size[0] - node.size[0]; + break; + case "left": + node.pos[0] = boundaryNodes["left"].pos[0]; + break; + case "top": + node.pos[1] = boundaryNodes["top"].pos[1]; + break; + case "bottom": + node.pos[1] = boundaryNodes["bottom"].pos[1] + boundaryNodes["bottom"].size[1] - node.size[1]; + break; + } + } + + canvas.dirty_canvas = true; + canvas.dirty_bgcanvas = true; + }; + + LGraphCanvas.onNodeAlign = function(value, options, event, prev_menu, node) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase(), node); + } + } + + LGraphCanvas.onGroupAlign = function(value, options, event, prev_menu) { + new LiteGraph.ContextMenu(["Top", "Bottom", "Left", "Right"], { + event: event, + callback: inner_clicked, + parentMenu: prev_menu, + }); + + function inner_clicked(value) { + LGraphCanvas.alignNodes(LGraphCanvas.active_canvas.selected_nodes, value.toLowerCase()); + } + } + LGraphCanvas.onMenuAdd = function (node, options, e, prev_menu, callback) { var canvas = LGraphCanvas.active_canvas; @@ -12900,6 +13013,14 @@ LGraphNode.prototype.executeAction = function(action) options.push({ content: "Options", callback: that.showShowGraphOptionsPanel }); }*/ + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align", + has_submenu: true, + callback: LGraphCanvas.onGroupAlign, + }) + } + if (this._graph_stack && this._graph_stack.length > 0) { options.push(null, { content: "Close subgraph", @@ -13014,6 +13135,14 @@ LGraphNode.prototype.executeAction = function(action) callback: LGraphCanvas.onMenuNodeToSubgraph }); + if (Object.keys(this.selected_nodes).length > 1) { + options.push({ + content: "Align Selected To", + has_submenu: true, + callback: LGraphCanvas.onNodeAlign, + }) + } + options.push(null, { content: "Remove", disabled: !(node.removable !== false && !node.block_delete ), From 1dd846a7bad8cfab679a0976e201c722871c6917 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 00:27:28 -0400 Subject: [PATCH 123/208] Fix outputs gone from history. --- execution.py | 16 +++++++++++----- main.py | 2 +- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/execution.py b/execution.py index dd88029bc..0e2cc15c1 100644 --- a/execution.py +++ b/execution.py @@ -102,7 +102,7 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui -def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] @@ -117,7 +117,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: @@ -128,6 +128,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute output_data, output_ui = get_output_data(obj, input_data_all) outputs[unique_id] = output_data if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) executed.add(unique_id) @@ -205,6 +206,7 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): self.outputs = {} + self.outputs_ui = {} self.old_prompt = {} self.server = server @@ -234,6 +236,11 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) + for x in list(self.outputs_ui.keys()): + if x not in current_outputs: + d = self.outputs_ui.pop(x) + del d + if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() @@ -247,7 +254,7 @@ class PromptExecutor: to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) x = to_execute.pop(0)[-1] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) except Exception as e: if isinstance(e, comfy.model_management.InterruptProcessingException): print("Processing interrupted") @@ -413,8 +420,7 @@ class PromptQueue: prompt = self.currently_running.pop(item_id) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } for o in outputs: - if "ui" in outputs[o]: - self.history[prompt[1]]["outputs"][o] = outputs[o]["ui"] + self.history[prompt[1]]["outputs"][o] = outputs[o] self.server.queue_updated() def get_current_queue(self): diff --git a/main.py b/main.py index 00cbf3c4a..50d3b9a62 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,7 @@ def prompt_worker(q, server): while True: item, item_id = q.get() e.execute(item[2], item[1], item[3], item[4]) - q.task_done(item_id, e.outputs) + q.task_done(item_id, e.outputs_ui) async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) From ef815ba1e24eef45041adec8a55ecd628b20476f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 00:29:56 -0400 Subject: [PATCH 124/208] Switch default scheduler to normal. --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index aa44fa82d..fccf254ec 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -495,7 +495,7 @@ def encode_adm(noise_augmentor, conds, batch_size, device): class KSampler: - SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] From c02a554bcf6ef50f8e252c89dc0a56c08d4955c0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 03:25:24 -0400 Subject: [PATCH 125/208] Make DiffusersLoader work with subfolders. --- nodes.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index bc23e5c17..797ad6c9c 100644 --- a/nodes.py +++ b/nodes.py @@ -282,7 +282,10 @@ class DiffusersLoader: paths = [] for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths += next(os.walk(search_path))[1] + for root, subdir, files in os.walk(search_path, followlinks=True): + if "model_index.json" in files: + paths.append(os.path.relpath(root, start=search_path)) + return {"required": {"model_path": (paths,), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -292,9 +295,9 @@ class DiffusersLoader: def load_checkpoint(self, model_path, output_vae=True, output_clip=True): for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): - paths = next(os.walk(search_path))[1] - if model_path in paths: - model_path = os.path.join(search_path, model_path) + path = os.path.join(search_path, model_path) + if os.path.exists(path): + model_path = path break return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) From 2ec6d1c6e364ab92e3d8149a83873ac47c797248 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 15 May 2023 03:31:03 -0400 Subject: [PATCH 126/208] Don't import custom nodes when the folder ends with .disabled --- nodes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nodes.py b/nodes.py index 797ad6c9c..e8b36c24a 100644 --- a/nodes.py +++ b/nodes.py @@ -1326,6 +1326,7 @@ def load_custom_nodes(): for possible_module in possible_modules: module_path = os.path.join(custom_node_path, possible_module) if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + if module_path.endswith(".disabled"): continue time_before = time.perf_counter() success = load_custom_node(module_path) node_import_times.append((time.perf_counter() - time_before, module_path, success)) From 5f7968f1fafb2cf5d15fe049fc53265ad0fc6696 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 01:12:44 -0400 Subject: [PATCH 127/208] Print the endpoint ip for localtunnel in the colab notebook. --- notebooks/comfyui_colab.ipynb | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/notebooks/comfyui_colab.ipynb b/notebooks/comfyui_colab.ipynb index fecfa6707..c5a209eec 100644 --- a/notebooks/comfyui_colab.ipynb +++ b/notebooks/comfyui_colab.ipynb @@ -175,6 +175,8 @@ "import threading\n", "import time\n", "import socket\n", + "import urllib.request\n", + "\n", "def iframe_thread(port):\n", " while True:\n", " time.sleep(0.5)\n", @@ -183,7 +185,9 @@ " if result == 0:\n", " break\n", " sock.close()\n", - " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\")\n", + " print(\"\\nComfyUI finished loading, trying to launch localtunnel (if it gets stuck here localtunnel is having issues)\\n\")\n", + "\n", + " print(\"The password/enpoint ip for localtunnel is:\", urllib.request.urlopen('https://ipv4.icanhazip.com').read().decode('utf8').strip(\"\\n\"))\n", " p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n", " for line in p.stdout:\n", " print(line.decode(), end='')\n", From 13d94caf49b21bd129ec867b04641973e3a102da Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 03:18:11 -0400 Subject: [PATCH 128/208] Add control_after_generate to combo primitive. --- web/extensions/core/widgetInputs.js | 2 +- web/scripts/widgets.js | 80 +++++++++++++++++++---------- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index df7d8f071..4fe0a6013 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -300,7 +300,7 @@ app.registerExtension({ } } - if (widget.type === "number") { + if (widget.type === "number" || widget.type === "combo") { addValueControlWidget(this, widget, "fixed"); } diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 65edc0392..3d1acc53e 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,35 +19,61 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - let min = targetWidget.options.min; - let max = targetWidget.options.max; - // limit to something that javascript can handle - max = Math.min(1125899906842624, max); - min = Math.max(-1125899906842624, min); - let range = (max - min) / (targetWidget.options.step / 10); + console.log(targetWidget); + if (targetWidget.type == "combo" && v !== "fixed") { + let current_index = targetWidget.options.values.indexOf(targetWidget.value); + let current_length = targetWidget.options.values.length; - //adjust values based on valueControl Behaviour - switch (v) { - case "fixed": - break; - case "increment": - targetWidget.value += targetWidget.options.step / 10; - break; - case "decrement": - targetWidget.value -= targetWidget.options.step / 10; - break; - case "randomize": - targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; - default: - break; + switch (v) { + case "increment": + current_index += 1; + break; + case "decrement": + current_index -= 1; + break; + case "randomize": + current_index = Math.floor(Math.random() * current_length); + default: + break; + } + current_index = Math.max(0, current_index); + current_index = Math.min(current_length - 1, current_index); + if (current_index >= 0) { + let value = targetWidget.options.values[current_index]; + targetWidget.value = value; + targetWidget.callback(value); + } + } else { //number + let min = targetWidget.options.min; + let max = targetWidget.options.max; + // limit to something that javascript can handle + max = Math.min(1125899906842624, max); + min = Math.max(-1125899906842624, min); + let range = (max - min) / (targetWidget.options.step / 10); + + //adjust values based on valueControl Behaviour + switch (v) { + case "fixed": + break; + case "increment": + targetWidget.value += targetWidget.options.step / 10; + break; + case "decrement": + targetWidget.value -= targetWidget.options.step / 10; + break; + case "randomize": + targetWidget.value = Math.floor(Math.random() * range) * (targetWidget.options.step / 10) + min; + default: + break; + } + /*check if values are over or under their respective + * ranges and set them to min or max.*/ + if (targetWidget.value < min) + targetWidget.value = min; + + if (targetWidget.value > max) + targetWidget.value = max; } - /*check if values are over or under their respective - * ranges and set them to min or max.*/ - if (targetWidget.value < min) - targetWidget.value = min; - - if (targetWidget.value > max) - targetWidget.value = max; } return valueControl; }; From 7ada9e7d85f93495aa5006468a45220932f5e988 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Tue, 16 May 2023 22:55:00 +0900 Subject: [PATCH 129/208] allows touch drag --- web/scripts/app.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 87c5e30ca..ef3b44c83 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -902,7 +902,9 @@ export class ComfyApp { await this.#loadExtensions(); // Create and mount the LiteGraph in the DOM - const canvasEl = (this.canvasEl = Object.assign(document.createElement("canvas"), { id: "graph-canvas" })); + const mainCanvas = document.createElement("canvas") + mainCanvas.style.touchAction = "none" + const canvasEl = (this.canvasEl = Object.assign(mainCanvas, { id: "graph-canvas" })); canvasEl.tabIndex = "1"; document.body.prepend(canvasEl); From 11e7168d56e0987e52d0afb620189f08bda2b454 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 11:55:16 -0400 Subject: [PATCH 130/208] Remove print. --- web/scripts/widgets.js | 1 - 1 file changed, 1 deletion(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 3d1acc53e..94988d0f2 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -19,7 +19,6 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random var v = valueControl.value; - console.log(targetWidget); if (targetWidget.type == "combo" && v !== "fixed") { let current_index = targetWidget.options.values.indexOf(targetWidget.value); let current_length = targetWidget.options.values.length; From 4088e61aa6b8943e28ee243c0b1265c41974ef67 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 16 May 2023 15:35:07 -0400 Subject: [PATCH 131/208] Update litegraph from upstream. --- web/lib/litegraph.core.js | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 6c81c3ffd..95f4a2735 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -9734,7 +9734,7 @@ LGraphNode.prototype.executeAction = function(action) if (show_text) { ctx.textAlign = "center"; ctx.fillStyle = text_color; - ctx.fillText(w.name, widget_width * 0.5, y + H * 0.7); + ctx.fillText(w.label || w.name, widget_width * 0.5, y + H * 0.7); } break; case "toggle": @@ -9755,8 +9755,9 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); if (show_text) { ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = w.value ? text_color : secondary_text_color; ctx.textAlign = "right"; @@ -9791,7 +9792,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.textAlign = "center"; ctx.fillStyle = text_color; ctx.fillText( - w.name + " " + Number(w.value).toFixed(3), + w.label || w.name + " " + Number(w.value).toFixed(3), widget_width * 0.5, y + H * 0.7 ); @@ -9826,7 +9827,7 @@ LGraphNode.prototype.executeAction = function(action) ctx.fill(); } ctx.fillStyle = secondary_text_color; - ctx.fillText(w.name, margin * 2 + 5, y + H * 0.7); + ctx.fillText(w.label || w.name, margin * 2 + 5, y + H * 0.7); ctx.fillStyle = text_color; ctx.textAlign = "right"; if (w.type == "number") { @@ -9878,8 +9879,9 @@ LGraphNode.prototype.executeAction = function(action) //ctx.stroke(); ctx.fillStyle = secondary_text_color; - if (w.name != null) { - ctx.fillText(w.name, margin * 2, y + H * 0.7); + const label = w.label || w.name; + if (label != null) { + ctx.fillText(label, margin * 2, y + H * 0.7); } ctx.fillStyle = text_color; ctx.textAlign = "right"; From e7f2816c6f1da22e2018cf088bd45110ff265c79 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Thu, 18 May 2023 12:40:28 +0900 Subject: [PATCH 132/208] feat:Latent Save/Load (#662) * wip * latent dir * fix * fix * now working * mark todo * remove server.py changes to separate PRt --------- Co-authored-by: Lt.Dr.Data --- input/latents/_input_latents_will_be_put_here | 0 nodes.py | 90 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 input/latents/_input_latents_will_be_put_here diff --git a/input/latents/_input_latents_will_be_put_here b/input/latents/_input_latents_will_be_put_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index e8b36c24a..a2c7713aa 100644 --- a/nodes.py +++ b/nodes.py @@ -29,6 +29,8 @@ import importlib import folder_paths +import safetensors.torch as sft + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -246,6 +248,91 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class SaveLatent: + def __init__(self): + self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents") + self.type = "output" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = () + FUNCTION = "save" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(self.output_dir, subfolder) + + if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: + print("Saving latent outside the 'input/latents' folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + + # support save metadata for latent sharing + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {"workflow": prompt_info} + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + file = f"{filename}_{counter:05}_.latent" + file = os.path.join(full_output_folder, file) + + sft.save_file(samples, file, metadata=metadata) + + return {} + + +class LoadLatent: + input_dir = os.path.join(folder_paths.get_input_directory(), "latents") + + @classmethod + def INPUT_TYPES(s): + files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")] + return {"required": {"latent": [sorted(files), ]}, } + + CATEGORY = "_for_testing" + + RETURN_TYPES = ("LATENT", ) + FUNCTION = "load" + + def load(self, latent): + file = folder_paths.get_annotated_filepath(latent, self.input_dir) + + latent = sft.load_file(file, device="cpu") + + return (latent, ) + + class CheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -1235,6 +1322,9 @@ NODE_CLASS_MAPPINGS = { "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, + + "LoadLatent": LoadLatent, + "SaveLatent": SaveLatent } NODE_DISPLAY_NAME_MAPPINGS = { From a7375103b9c80bb7607f85faa4afbf11ab5a5685 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 May 2023 23:04:40 -0400 Subject: [PATCH 133/208] Some small changes to Load/SaveLatent. --- nodes.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index a2c7713aa..7255621d7 100644 --- a/nodes.py +++ b/nodes.py @@ -11,6 +11,7 @@ import time from PIL import Image from PIL.PngImagePlugin import PngInfo import numpy as np +import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -29,7 +30,6 @@ import importlib import folder_paths -import safetensors.torch as sft def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -307,7 +307,10 @@ class SaveLatent: file = f"{filename}_{counter:05}_.latent" file = os.path.join(full_output_folder, file) - sft.save_file(samples, file, metadata=metadata) + output = {} + output["latent_tensor"] = samples["samples"] + + safetensors.torch.save_file(output, file, metadata=metadata) return {} @@ -328,9 +331,10 @@ class LoadLatent: def load(self, latent): file = folder_paths.get_annotated_filepath(latent, self.input_dir) - latent = sft.load_file(file, device="cpu") + latent = safetensors.torch.load_file(file, device="cpu") + samples = {"samples": latent["latent_tensor"]} - return (latent, ) + return (samples, ) class CheckpointLoader: From faf899ad5ae32f770f0dae6a9df457e81d2b5c38 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 17 May 2023 23:43:59 -0400 Subject: [PATCH 134/208] LoadLatent and SaveLatent should behave like the LoadImage and SaveImage. --- folder_paths.py | 33 +++++++ input/latents/_input_latents_will_be_put_here | 0 nodes.py | 90 +++++-------------- 3 files changed, 55 insertions(+), 68 deletions(-) delete mode 100644 input/latents/_input_latents_will_be_put_here diff --git a/folder_paths.py b/folder_paths.py index e5b89492c..28f117824 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -147,4 +147,37 @@ def get_filename_list(folder_name): output_list.update(filter_files_extensions(recursive_search(x), folders[1])) return sorted(list(output_list)) +def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): + def map_filename(filename): + prefix_len = len(os.path.basename(filename_prefix)) + prefix = filename[:prefix_len + 1] + try: + digits = int(filename[prefix_len + 1:].split('_')[0]) + except: + digits = 0 + return (digits, prefix) + def compute_vars(input, image_width, image_height): + input = input.replace("%width%", str(image_width)) + input = input.replace("%height%", str(image_height)) + return input + + filename_prefix = compute_vars(filename_prefix, image_width, image_height) + + subfolder = os.path.dirname(os.path.normpath(filename_prefix)) + filename = os.path.basename(os.path.normpath(filename_prefix)) + + full_output_folder = os.path.join(output_dir, subfolder) + + if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: + print("Saving image outside the output folder is not allowed.") + return {} + + try: + counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 + except ValueError: + counter = 1 + except FileNotFoundError: + os.makedirs(full_output_folder, exist_ok=True) + counter = 1 + return full_output_folder, filename, counter, subfolder, filename_prefix diff --git a/input/latents/_input_latents_will_be_put_here b/input/latents/_input_latents_will_be_put_here deleted file mode 100644 index e69de29bb..000000000 diff --git a/nodes.py b/nodes.py index 7255621d7..7b450df38 100644 --- a/nodes.py +++ b/nodes.py @@ -251,13 +251,12 @@ class VAEEncodeForInpaint: class SaveLatent: def __init__(self): - self.output_dir = os.path.join(folder_paths.get_input_directory(), "latents") - self.type = "output" + self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"})}, + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } RETURN_TYPES = () @@ -268,31 +267,7 @@ class SaveLatent: CATEGORY = "_for_testing" def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving latent outside the 'input/latents' folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) # support save metadata for latent sharing prompt_info = "" @@ -316,11 +291,10 @@ class SaveLatent: class LoadLatent: - input_dir = os.path.join(folder_paths.get_input_directory(), "latents") - @classmethod def INPUT_TYPES(s): - files = [f for f in os.listdir(s.input_dir) if os.path.isfile(os.path.join(s.input_dir, f)) and f.endswith(".latent")] + input_dir = folder_paths.get_input_directory() + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] return {"required": {"latent": [sorted(files), ]}, } CATEGORY = "_for_testing" @@ -329,13 +303,25 @@ class LoadLatent: FUNCTION = "load" def load(self, latent): - file = folder_paths.get_annotated_filepath(latent, self.input_dir) - - latent = safetensors.torch.load_file(file, device="cpu") + latent_path = folder_paths.get_annotated_filepath(latent) + latent = safetensors.torch.load_file(latent_path, device="cpu") samples = {"samples": latent["latent_tensor"]} - return (samples, ) + @classmethod + def IS_CHANGED(s, latent): + image_path = folder_paths.get_annotated_filepath(latent) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, latent): + if not folder_paths.exists_annotated_filepath(latent): + return "Invalid latent file: {}".format(latent) + return True + class CheckpointLoader: @classmethod @@ -1020,39 +1006,7 @@ class SaveImage: CATEGORY = "image" def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - def compute_vars(input): - input = input.replace("%width%", str(images[0].shape[1])) - input = input.replace("%height%", str(images[0].shape[0])) - return input - - filename_prefix = compute_vars(filename_prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - if os.path.commonpath((self.output_dir, os.path.abspath(full_output_folder))) != self.output_dir: - print("Saving image outside the output folder is not allowed.") - return {} - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 - + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() for image in images: i = 255. * image.cpu().numpy() From 62a371e12b4763bf6f9aeb42ff4928138df6ae26 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 18 May 2023 02:41:21 -0400 Subject: [PATCH 135/208] Load workflow from latent file. --- nodes.py | 2 +- web/scripts/app.js | 7 ++++++- web/scripts/pnginfo.js | 16 ++++++++++++++++ web/scripts/ui.js | 2 +- 4 files changed, 24 insertions(+), 3 deletions(-) diff --git a/nodes.py b/nodes.py index 7b450df38..3c61cd2ec 100644 --- a/nodes.py +++ b/nodes.py @@ -274,7 +274,7 @@ class SaveLatent: if prompt is not None: prompt_info = json.dumps(prompt) - metadata = {"workflow": prompt_info} + metadata = {"prompt": prompt_info} if extra_pnginfo is not None: for x in extra_pnginfo: metadata[x] = json.dumps(extra_pnginfo[x]) diff --git a/web/scripts/app.js b/web/scripts/app.js index ef3b44c83..97b7c8d31 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2,7 +2,7 @@ import { ComfyWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; -import { getPngMetadata, importA1111 } from "./pnginfo.js"; +import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; /** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension @@ -1308,6 +1308,11 @@ export class ComfyApp { this.loadGraphData(JSON.parse(reader.result)); }; reader.readAsText(file); + } else if (file.name?.endsWith(".latent")) { + const info = await getLatentMetadata(file); + if (info.workflow) { + this.loadGraphData(JSON.parse(info.workflow)); + } } } diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 209b562a6..8ddb7a1c5 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -47,6 +47,22 @@ export function getPngMetadata(file) { }); } +export function getLatentMetadata(file) { + return new Promise((r) => { + const reader = new FileReader(); + reader.onload = (event) => { + const safetensorsData = new Uint8Array(event.target.result); + const dataView = new DataView(safetensorsData.buffer); + let header_size = dataView.getUint32(0, true); + let offset = 8; + let header = JSON.parse(String.fromCharCode(...safetensorsData.slice(offset, offset + header_size))); + r(header.__metadata__); + }; + + reader.readAsArrayBuffer(file); + }); +} + export async function importA1111(graph, parameters) { const p = parameters.lastIndexOf("\nSteps:"); if (p > -1) { diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 77517aec1..2c9043d00 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -465,7 +465,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png", + accept: ".json,image/png,.latent", style: { display: "none" }, parent: document.body, onchange: () => { From 8bbd9815a976ef43e2665d45c5afb4a21c06c831 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 May 2023 02:15:32 -0400 Subject: [PATCH 136/208] Support loading fp16 latent files. --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 3c61cd2ec..878e0b955 100644 --- a/nodes.py +++ b/nodes.py @@ -305,7 +305,7 @@ class LoadLatent: def load(self, latent): latent_path = folder_paths.get_annotated_filepath(latent) latent = safetensors.torch.load_file(latent_path, device="cpu") - samples = {"samples": latent["latent_tensor"]} + samples = {"samples": latent["latent_tensor"].float()} return (samples, ) @classmethod From 2998e232cb26b66e7ba42a53ada3a8285fcb2c15 Mon Sep 17 00:00:00 2001 From: malern <701073+malern@users.noreply.github.com> Date: Fri, 19 May 2023 19:57:15 +0100 Subject: [PATCH 137/208] Make multiline widget work with different canvas dimensions. It now scales the textarea positioning using the canvas height/width. --- web/scripts/widgets.js | 20 +++++++++++++------- web/style.css | 2 ++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 94988d0f2..82168b08b 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -155,18 +155,24 @@ function addMultilineWidget(node, name, opts, app) { computeSize(node.size); } const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext"; - const t = ctx.getTransform(); const margin = 10; + const elRect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + Object.assign(this.inputEl.style, { - left: `${t.a * margin + t.e}px`, - top: `${t.d * (y + widgetHeight - margin - 3) + t.f}px`, - width: `${(widgetWidth - margin * 2 - 3) * t.a}px`, - background: (!node.color)?'':node.color, - height: `${(this.parent.inputHeight - margin * 2 - 4) * t.d}px`, + transformOrigin: "0 0", + transform: transform, + left: "0px", + top: "0px", + width: `${widgetWidth - (margin * 2)}px`, + height: `${this.parent.inputHeight - (margin * 2)}px`, position: "absolute", + background: (!node.color)?'':node.color, color: (!node.color)?'':'white', zIndex: app.graph._nodes.indexOf(node), - fontSize: `${t.d * 10.0}px`, }); this.inputEl.hidden = !visible; }, diff --git a/web/style.css b/web/style.css index df220cc02..87f096e14 100644 --- a/web/style.css +++ b/web/style.css @@ -39,6 +39,8 @@ body { padding: 2px; resize: none; border: none; + box-sizing: border-box; + font-size: 10px; } .comfy-modal { From e6e1999f96adbe0f4b041d265837e00bde9283ab Mon Sep 17 00:00:00 2001 From: malern <701073+malern@users.noreply.github.com> Date: Fri, 19 May 2023 20:04:36 +0100 Subject: [PATCH 138/208] Render UI at a higher resolution when viewing with a higher pixel ratio --- web/scripts/app.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 97b7c8d31..514ca3958 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -921,8 +921,9 @@ export class ComfyApp { this.graph.start(); function resizeCanvas() { - canvasEl.width = canvasEl.offsetWidth; - canvasEl.height = canvasEl.offsetHeight; + canvasEl.width = canvasEl.offsetWidth * window.devicePixelRatio; + canvasEl.height = canvasEl.offsetHeight * window.devicePixelRatio; + canvasEl.getContext("2d").scale(window.devicePixelRatio, window.devicePixelRatio); canvas.draw(true, true); } From b9daf4e30f32e76a00628add0849d84b5ec2fe76 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 19 May 2023 22:40:28 -0400 Subject: [PATCH 139/208] Add a /object_info/{node_class} route to get only the info of one node. --- server.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/server.py b/server.py index f52117f10..18ce54306 100644 --- a/server.py +++ b/server.py @@ -261,23 +261,34 @@ class PromptServer(): async def get_prompt(request): return web.json_response(self.get_queue_info()) + def node_info(node_class): + obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] + info = {} + info['input'] = obj_class.INPUT_TYPES() + info['output'] = obj_class.RETURN_TYPES + info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) + info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] + info['name'] = node_class + info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class + info['description'] = '' + info['category'] = 'sd' + if hasattr(obj_class, 'CATEGORY'): + info['category'] = obj_class.CATEGORY + return info + @routes.get("/object_info") async def get_object_info(request): out = {} for x in nodes.NODE_CLASS_MAPPINGS: - obj_class = nodes.NODE_CLASS_MAPPINGS[x] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = x - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[x] if x in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else x - info['description'] = '' - info['category'] = 'sd' - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - out[x] = info + out[x] = node_info(x) + return web.json_response(out) + + @routes.get("/object_info/{node_class}") + async def get_object_info_node(request): + node_class = request.match_info.get("node_class", None) + out = {} + if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): + out[node_class] = node_info(node_class) return web.json_response(out) @routes.get("/history") From 36af98d75580292d4afbeafb5e7ba7f010145436 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Sat, 20 May 2023 15:23:28 +0200 Subject: [PATCH 140/208] improve sharpen and blur nodes --- comfy_extras/nodes_post_processing.py | 35 ++++++++++++++++----------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ba699e2b8..37c824bde 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -59,6 +59,12 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) +def gaussian_kernel(kernel_size: int, sigma: float): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + class Blur: def __init__(self): pass @@ -88,12 +94,6 @@ class Blur: CATEGORY = "image/postprocessing" - def gaussian_kernel(self, kernel_size: int, sigma: float): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") - d = torch.sqrt(x * x + y * y) - g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) - return g / g.sum() - def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): if blur_radius == 0: return (image,) @@ -101,10 +101,11 @@ class Blur: batch_size, height, width, channels = image.shape kernel_size = blur_radius * 2 + 1 - kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) + padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') + blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) return (blurred,) @@ -167,9 +168,15 @@ class Sharpen: "max": 31, "step": 1 }), - "alpha": ("FLOAT", { + "sigma": ("FLOAT", { "default": 1.0, "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + "alpha": ("FLOAT", { + "default": 1.0, + "min": 0.0, "max": 5.0, "step": 0.1 }), @@ -181,21 +188,21 @@ class Sharpen: CATEGORY = "image/postprocessing" - def sharpen(self, image: torch.Tensor, sharpen_radius: int, alpha: float): + def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): if sharpen_radius == 0: return (image,) batch_size, height, width, channels = image.shape kernel_size = sharpen_radius * 2 + 1 - kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 + kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) center = kernel_size // 2 - kernel[center, center] = kernel_size**2 - kernel *= alpha + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) - sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) + tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect') + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] sharpened = sharpened.permute(0, 2, 3, 1) result = torch.clamp(sharpened, 0, 1) From 71666f248f769af073408a3475dd7a82a29d8247 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 10:08:47 -0400 Subject: [PATCH 141/208] Fix padding in Blur. --- comfy_extras/nodes_post_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 37c824bde..3be141dfe 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -105,7 +105,7 @@ class Blur: image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') - blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] + blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] blurred = blurred.permute(0, 2, 3, 1) return (blurred,) From 797c4e8d3b56559bb205ff8aab97d97dca424b9a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 15:07:21 -0400 Subject: [PATCH 142/208] Simplify and improve some vae attention code. --- comfy/ldm/modules/diffusionmodules/model.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5e4d2b60f..05caf7312 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -331,25 +331,13 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): # compute attention B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), + lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = out.transpose(2, 3).reshape(B, C, H, W) out = self.proj_out(out) return x+out From b8636a44aacd83ec6a9a19a6d3d3f5b76fc863c9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 15:43:39 -0400 Subject: [PATCH 143/208] Make scaled_dot_product switch to sliced attention on OOM. --- comfy/ldm/modules/diffusionmodules/model.py | 79 +++++++++++---------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 05caf7312..91e7d60ec 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -146,6 +146,41 @@ class ResnetBlock(nn.Module): return x+h +def slice_attention(q, k, v): + r1 = torch.zeros_like(k, device=q.device) + scale = (int(q.shape[-1])**(-0.5)) + + mem_free_total = model_management.get_free_memory(q.device) + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + while True: + try: + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = torch.bmm(q[:, i:end], k) * scale + + s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) + del s1 + + r1[:, :, i:end] = torch.bmm(v, s2) + del s2 + break + except model_management.OOM_EXCEPTION as e: + steps *= 2 + if steps > 128: + raise e + print("out of memory error, increasing steps and trying again", steps) + + return r1 class AttnBlock(nn.Module): def __init__(self, in_channels): @@ -183,48 +218,15 @@ class AttnBlock(nn.Module): # compute attention b,c,h,w = q.shape - scale = (int(c)**(-0.5)) q = q.reshape(b,c,h*w) q = q.permute(0,2,1) # b,hw,c k = k.reshape(b,c,h*w) # b,c,hw v = v.reshape(b,c,h*w) - r1 = torch.zeros_like(k, device=q.device) - - mem_free_total = model_management.get_free_memory(q.device) - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - while True: - try: - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = torch.bmm(q[:, i:end], k) * scale - - s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1) - del s1 - - r1[:, :, i:end] = torch.bmm(v, s2) - del s2 - break - except model_management.OOM_EXCEPTION as e: - steps *= 2 - if steps > 128: - raise e - print("out of memory error, increasing steps and trying again", steps) - + r1 = slice_attention(q, k, v) h_ = r1.reshape(b,c,h,w) del r1 - h_ = self.proj_out(h_) return x+h_ @@ -335,9 +337,14 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = out.transpose(2, 3).reshape(B, C, H, W) + try: + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = out.transpose(2, 3).reshape(B, C, H, W) + except model_management.OOM_EXCEPTION as e: + print("scaled_dot_product_attention OOMed: switched to slice attention") + out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) + out = self.proj_out(out) return x+out From 3c76f43057f140c583962327c18c7d5257e7495c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 20 May 2023 23:06:33 -0400 Subject: [PATCH 144/208] Cleaner code. --- server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 18ce54306..701c0e7a7 100644 --- a/server.py +++ b/server.py @@ -331,7 +331,8 @@ class PromptServer(): extra_data["client_id"] = json_data["client_id"] if valid[0]: prompt_id = str(uuid.uuid4()) - self.prompt_queue.put((number, prompt_id, prompt, extra_data, valid[2])) + outputs_to_execute = valid[2] + self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) return web.json_response({"prompt_id": prompt_id}) else: print("invalid prompt:", valid[1]) From 516119ad835841bd176055cbc888843b418b8004 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 00:24:28 -0400 Subject: [PATCH 145/208] Print min and max values in validation error message. --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 0e2cc15c1..35f044346 100644 --- a/execution.py +++ b/execution.py @@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value smaller than min. {}, {}".format(class_type, x)) + return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x)) if "max" in info[1] and val > info[1]["max"]: - return (False, "Value bigger than max. {}, {}".format(class_type, x)) + return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x)) if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) From 069657fbf3d8d977ead39ab206d8c917bbcc4997 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 01:35:08 -0400 Subject: [PATCH 146/208] Add DPM-Solver++(2M) SDE and exponential scheduler. exponential scheduler is the one recommended with this sampler. --- comfy/k_diffusion/sampling.py | 43 +++++++++++++++++++++++++++++++++++ comfy/samplers.py | 6 +++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c809d39fb..94d7a5762 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -605,3 +605,46 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d old_denoised = denoised return x + +@torch.no_grad() +def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): + """DPM-Solver++(2M) SDE.""" + + if solver_type not in {'heun', 'midpoint'}: + raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index fccf254ec..1fb928f8d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -495,10 +495,10 @@ def encode_adm(noise_augmentor, conds, batch_size, device): class KSampler: - SCHEDULERS = ["normal", "karras", "simple", "ddim_uniform"] + SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", - "dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] + "dpmpp_2m", "dpmpp_2m_sde", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -532,6 +532,8 @@ class KSampler: if self.scheduler == "karras": sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) + elif self.scheduler == "exponential": + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max) elif self.scheduler == "normal": sigmas = self.model_wrap.get_sigmas(steps) elif self.scheduler == "simple": From 4796e615dd7faad38429fc8e716e3a817a28c526 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 10:34:26 -0400 Subject: [PATCH 147/208] Revert DPI fix since it caused more issues than it solved. --- web/scripts/app.js | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 514ca3958..97b7c8d31 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -921,9 +921,8 @@ export class ComfyApp { this.graph.start(); function resizeCanvas() { - canvasEl.width = canvasEl.offsetWidth * window.devicePixelRatio; - canvasEl.height = canvasEl.offsetHeight * window.devicePixelRatio; - canvasEl.getContext("2d").scale(window.devicePixelRatio, window.devicePixelRatio); + canvasEl.width = canvasEl.offsetWidth; + canvasEl.height = canvasEl.offsetHeight; canvas.draw(true, true); } From dc198650c0d2d281c9d87b23a8917b457a94d837 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 21 May 2023 11:34:29 -0400 Subject: [PATCH 148/208] sample_dpmpp_2m_sde no longer crashes when step == 1. --- comfy/k_diffusion/sampling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 94d7a5762..c540d7411 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -628,6 +628,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if sigmas[i + 1] == 0: # Denoising step x = denoised + h = None else: # DPM-Solver++(2M) SDE t, s = -sigmas[i].log(), -sigmas[i + 1].log() From 6cc450579b3314fe314e64af22c4be81afd1f87d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 00:22:24 -0400 Subject: [PATCH 149/208] Auto transpose images from exif data. --- comfy/k_diffusion/sampling.py | 2 +- nodes.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c540d7411..26930428f 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -620,6 +620,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl old_denoised = None h_last = None + h = None for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -628,7 +629,6 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl if sigmas[i + 1] == 0: # Denoising step x = denoised - h = None else: # DPM-Solver++(2M) SDE t, s = -sigmas[i].log(), -sigmas[i + 1].log() diff --git a/nodes.py b/nodes.py index 878e0b955..bae330bc9 100644 --- a/nodes.py +++ b/nodes.py @@ -8,7 +8,7 @@ import traceback import math import time -from PIL import Image +from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch @@ -1057,6 +1057,7 @@ class LoadImage: def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] @@ -1100,6 +1101,7 @@ class LoadImageMask: def load_image(self, image, channel): image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) + i = ImageOps.exif_transpose(i) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") mask = None From ffc56c53c9cccfcc21c92fe14cb095bb32ea2744 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 13:22:38 -0400 Subject: [PATCH 150/208] Add a node_errors to the /prompt error json response. "node_errors" contains a dict keyed by node ids. The contents are a message and a list of dependent outputs. --- execution.py | 27 ++++++++++++++++----------- server.py | 4 ++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/execution.py b/execution.py index 35f044346..212e789ca 100644 --- a/execution.py +++ b/execution.py @@ -299,18 +299,18 @@ def validate_inputs(prompt, item, validated): required_inputs = class_inputs['required'] for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x)) + return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x)) + return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) + return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) r = validate_inputs(prompt, o_id, validated) if r[0] == False: validated[o_id] = r @@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x)) + return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x)) + return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) @@ -338,13 +338,13 @@ def validate_inputs(prompt, item, validated): ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") for r in ret: if r != True: - return (False, "{}, {}".format(class_type, r)) + return (False, "{}, {}".format(class_type, r), unique_id) else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) - ret = (True, "") + ret = (True, "", unique_id) validated[unique_id] = ret return ret @@ -356,10 +356,11 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs") + return (False, "Prompt has no outputs", [], []) good_outputs = set() errors = [] + node_errors = {} validated = {} for o in outputs: valid = False @@ -368,6 +369,7 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] + node_id = m[2] except Exception as e: print(traceback.format_exc()) valid = False @@ -379,12 +381,15 @@ def validate_prompt(prompt): print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") errors += [(o, reason)] + if node_id not in node_errors: + node_errors[node_id] = {"message": reason, "dependent_outputs": []} + node_errors[node_id]["dependent_outputs"].append(o) if len(good_outputs) == 0: errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) + return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) - return (True, "", list(good_outputs)) + return (True, "", list(good_outputs), node_errors) class PromptQueue: diff --git a/server.py b/server.py index 701c0e7a7..8429a63fb 100644 --- a/server.py +++ b/server.py @@ -336,9 +336,9 @@ class PromptServer(): return web.json_response({"prompt_id": prompt_id}) else: print("invalid prompt:", valid[1]) - return web.json_response({"error": valid[1]}, status=400) + return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: - return web.json_response({"error": "no prompt"}, status=400) + return web.json_response({"error": "no prompt", "node_errors": []}, status=400) @routes.post("/queue") async def post_queue(request): From db27b0405a31983916d6801cf84f7f1fc4503e6a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 13:25:50 -0400 Subject: [PATCH 151/208] object_info now returns if node is an output_node or not. --- server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server.py b/server.py index 8429a63fb..c0f79cbd5 100644 --- a/server.py +++ b/server.py @@ -272,6 +272,11 @@ class PromptServer(): info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class info['description'] = '' info['category'] = 'sd' + if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: + info['output_node'] = True + else: + info['output_node'] = False + if hasattr(obj_class, 'CATEGORY'): info['category'] = obj_class.CATEGORY return info From bfb13f5eee48545f1c4b0b8a377de80be84bb100 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 17:05:23 -0400 Subject: [PATCH 152/208] Remove useless call to /object_info --- web/extensions/core/colorPalette.js | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 2f2238a2b..bfcd847a3 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -174,7 +174,7 @@ const els = {} // const ctxMenu = LiteGraph.ContextMenu; app.registerExtension({ name: id, - init() { + addCustomNodeDefs(node_defs) { const sortObjectKeys = (unordered) => { return Object.keys(unordered).sort().reduce((obj, key) => { obj[key] = unordered[key]; @@ -182,10 +182,10 @@ app.registerExtension({ }, {}); }; - const getSlotTypes = async () => { + function getSlotTypes() { var types = []; - const defs = await api.getNodeDefs(); + const defs = node_defs; for (const nodeId in defs) { const nodeData = defs[nodeId]; @@ -212,8 +212,8 @@ app.registerExtension({ return types; }; - const completeColorPalette = async (colorPalette) => { - var types = await getSlotTypes(); + function completeColorPalette(colorPalette) { + var types = getSlotTypes(); for (const type of types) { if (!colorPalette.colors.node_slot[type]) { From 48fcc5b777b3a1ab5d6dc5fec6adaebeb32c2c93 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 20:51:30 -0400 Subject: [PATCH 153/208] Parsing error crash. --- execution.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/execution.py b/execution.py index 212e789ca..25f2fcacd 100644 --- a/execution.py +++ b/execution.py @@ -374,6 +374,7 @@ def validate_prompt(prompt): print(traceback.format_exc()) valid = False reason = "Parsing error" + node_id = None if valid == True: good_outputs.add(o) @@ -381,9 +382,10 @@ def validate_prompt(prompt): print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") errors += [(o, reason)] - if node_id not in node_errors: - node_errors[node_id] = {"message": reason, "dependent_outputs": []} - node_errors[node_id]["dependent_outputs"].append(o) + if node_id is not None: + if node_id not in node_errors: + node_errors[node_id] = {"message": reason, "dependent_outputs": []} + node_errors[node_id]["dependent_outputs"].append(o) if len(good_outputs) == 0: errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) From 34887b888546716b5c5507606289ca2728bf3123 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 03:12:56 -0400 Subject: [PATCH 154/208] Add experimental bislerp algorithm for latent upscaling. It's like bilinear but with slerp. --- comfy/utils.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++- nodes.py | 2 +- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 09e05d4ed..0f7b34503 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -46,6 +46,65 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +#slow and inefficient, should be optimized +def bislerp(samples, width, height): + shape = list(samples.shape) + width_scale = (shape[3]) / (width ) + height_scale = (shape[2]) / (height ) + + shape[3] = width + shape[2] = height + out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + + def algorithm(in1, w1, in2, w2): + dims = in1.shape + val = w2 + + #flatten to batches + low = in1.reshape(dims[0], -1) + high = in2.reshape(dims[0], -1) + + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + + # in case we divide by zero + low_norm[low_norm != low_norm] = 0.0 + high_norm[high_norm != high_norm] = 0.0 + + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res.reshape(dims) + + for x_dest in range(shape[3]): + for y_dest in range(shape[2]): + y = (y_dest) * height_scale + x = (x_dest) * width_scale + + x1 = max(math.floor(x), 0) + x2 = min(x1 + 1, samples.shape[3] - 1) + y1 = max(math.floor(y), 0) + y2 = min(y1 + 1, samples.shape[2] - 1) + + in1 = samples[:,:,y1,x1] + in2 = samples[:,:,y1,x2] + in3 = samples[:,:,y2,x1] + in4 = samples[:,:,y2,x2] + + if (x1 == x2) and (y1 == y2): + out_value = in1 + elif (x1 == x2): + out_value = algorithm(in1, (y2 - y), in3, (y - y1)) + elif (y1 == y2): + out_value = algorithm(in1, (x2 - x), in2, (x - x1)) + else: + o1 = algorithm(in1, (x2 - x), in2, (x - x1)) + o2 = algorithm(in3, (x2 - x), in4, (x - x1)) + out_value = algorithm(o1, (y2 - y), o2, (y - y1)) + + out1[:,:,y_dest,x_dest] = out_value + return out1 + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] @@ -61,7 +120,11 @@ def common_upscale(samples, width, height, upscale_method, crop): s = samples[:,:,y:old_height-y,x:old_width-x] else: s = samples - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + + if upscale_method == "bislerp": + return bislerp(s, width, height) + else: + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) diff --git a/nodes.py b/nodes.py index bae330bc9..e5cec2632 100644 --- a/nodes.py +++ b/nodes.py @@ -749,7 +749,7 @@ class RepeatLatentBatch: return (s,) class LatentUpscale: - upscale_methods = ["nearest-exact", "bilinear", "area"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] crop_methods = ["disabled", "center"] @classmethod From 451fb4169ad900e5d33b540f039f56ced9a76157 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 11:35:32 -0400 Subject: [PATCH 155/208] Fix 'git pull' not working on the standalones. --- .github/workflows/windows_release_cu118_package.yml | 1 + .github/workflows/windows_release_nightly_pytorch.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/windows_release_cu118_package.yml b/.github/workflows/windows_release_cu118_package.yml index 15322c86a..2d6048a23 100644 --- a/.github/workflows/windows_release_cu118_package.yml +++ b/.github/workflows/windows_release_cu118_package.yml @@ -30,6 +30,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - shell: bash run: | cd .. diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index b6a18ec0a..767a7216b 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -17,6 +17,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - uses: actions/setup-python@v4 with: python-version: '3.11.3' From b8ccbec6d893d34dab90d2418a3fe00969251fa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 11:40:24 -0400 Subject: [PATCH 156/208] Various improvements to bislerp. --- comfy/utils.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 0f7b34503..300eda6aa 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -56,35 +56,42 @@ def bislerp(samples, width, height): shape[2] = height out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) - def algorithm(in1, w1, in2, w2): + def algorithm(in1, in2, t): dims = in1.shape - val = w2 + val = t #flatten to batches low = in1.reshape(dims[0], -1) high = in2.reshape(dims[0], -1) - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) + low_weight = torch.norm(low, dim=1, keepdim=True) + low_weight[low_weight == 0] = 0.0000000001 + low_norm = low/low_weight + high_weight = torch.norm(high, dim=1, keepdim=True) + high_weight[high_weight == 0] = 0.0000000001 + high_norm = high/high_weight - # in case we divide by zero - low_norm[low_norm != low_norm] = 0.0 - high_norm[high_norm != high_norm] = 0.0 - - omega = torch.acos((low_norm*high_norm).sum(1)) + dot_prod = (low_norm*high_norm).sum(1) + dot_prod[dot_prod > 0.9995] = 0.9995 + dot_prod[dot_prod < -0.9995] = -0.9995 + omega = torch.acos(dot_prod) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm + res *= (low_weight * (1.0-val) + high_weight * val) return res.reshape(dims) for x_dest in range(shape[3]): for y_dest in range(shape[2]): - y = (y_dest) * height_scale - x = (x_dest) * width_scale + y = (y_dest + 0.5) * height_scale - 0.5 + x = (x_dest + 0.5) * width_scale - 0.5 x1 = max(math.floor(x), 0) x2 = min(x1 + 1, samples.shape[3] - 1) + wx = x - math.floor(x) + y1 = max(math.floor(y), 0) y2 = min(y1 + 1, samples.shape[2] - 1) + wy = y - math.floor(y) in1 = samples[:,:,y1,x1] in2 = samples[:,:,y1,x2] @@ -94,13 +101,13 @@ def bislerp(samples, width, height): if (x1 == x2) and (y1 == y2): out_value = in1 elif (x1 == x2): - out_value = algorithm(in1, (y2 - y), in3, (y - y1)) + out_value = algorithm(in1, in3, wy) elif (y1 == y2): - out_value = algorithm(in1, (x2 - x), in2, (x - x1)) + out_value = algorithm(in1, in2, wx) else: - o1 = algorithm(in1, (x2 - x), in2, (x - x1)) - o2 = algorithm(in3, (x2 - x), in4, (x - x1)) - out_value = algorithm(o1, (y2 - y), o2, (y - y1)) + o1 = algorithm(in1, in2, wx) + o2 = algorithm(in3, in4, wx) + out_value = algorithm(o1, o2, wy) out1[:,:,y_dest,x_dest] = out_value return out1 From c00bb1a0b78f0d2cf2e4ec2dd9ae7d61cb07a637 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 12:53:38 -0400 Subject: [PATCH 157/208] Add a latent upscale by node. --- nodes.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nodes.py b/nodes.py index e5cec2632..f0a93ebd5 100644 --- a/nodes.py +++ b/nodes.py @@ -768,6 +768,25 @@ class LatentUpscale: s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) +class LatentUpscaleBy: + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "upscale" + + CATEGORY = "latent" + + def upscale(self, samples, upscale_method, scale_by): + s = samples.copy() + width = round(samples["samples"].shape[3] * scale_by) + height = round(samples["samples"].shape[2] * scale_by) + s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled") + return (s,) + class LatentRotate: @classmethod def INPUT_TYPES(s): @@ -1244,6 +1263,7 @@ NODE_CLASS_MAPPINGS = { "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, + "LatentUpscaleBy": LatentUpscaleBy, "LatentFromBatch": LatentFromBatch, "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, @@ -1322,6 +1342,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentCrop": "Crop Latent", "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", + "LatentUpscaleBy": "Upscale Latent By", "LatentComposite": "Latent Composite", "LatentFromBatch" : "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", From 7310290f17aad79480edb92f22cd58f0997db964 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 22:26:50 -0400 Subject: [PATCH 158/208] Pull in latest upscale model code from chainner. --- .../architecture/OmniSR/ChannelAttention.py | 110 ++++ .../architecture/OmniSR/LICENSE | 201 ++++++ .../architecture/OmniSR/OSA.py | 577 ++++++++++++++++++ .../architecture/OmniSR/OSAG.py | 60 ++ .../architecture/OmniSR/OmniSR.py | 133 ++++ .../architecture/OmniSR/esa.py | 294 +++++++++ .../architecture/OmniSR/layernorm.py | 70 +++ .../architecture/OmniSR/pixelshuffle.py | 31 + .../chainner_models/architecture/RRDB.py | 17 +- .../chainner_models/architecture/block.py | 30 + comfy_extras/chainner_models/model_loading.py | 5 + comfy_extras/chainner_models/types.py | 4 +- 12 files changed, 1530 insertions(+), 2 deletions(-) create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/LICENSE create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OSA.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OSAG.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/esa.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/layernorm.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py diff --git a/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py new file mode 100644 index 000000000..f4d52aa1e --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py @@ -0,0 +1,110 @@ +import math + +import torch.nn as nn + + +class CA_layer(nn.Module): + def __init__(self, channel, reduction=16): + super(CA_layer, self).__init__() + # global average pooling + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False), + nn.GELU(), + nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False), + # nn.Sigmoid() + ) + + def forward(self, x): + y = self.fc(self.gap(x)) + return x * y.expand_as(x) + + +class Simple_CA_layer(nn.Module): + def __init__(self, channel): + super(Simple_CA_layer, self).__init__() + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=channel, + out_channels=channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return x * self.fc(self.gap(x)) + + +class ECA_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) + + +class ECA_MaxPool_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_MaxPool_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.max_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/LICENSE b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSA.py b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py new file mode 100644 index 000000000..d7a129696 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSA.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:07:42 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce +from torch import einsum, nn + +from .layernorm import LayerNorm2d + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, length=1): + return val if isinstance(val, tuple) else ((val,) * length) + + +# helper classes + + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class Conv_PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = LayerNorm2d(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Conv2d(dim, inner_dim, 1, 1, 0), + nn.GELU(), + nn.Dropout(dropout), + nn.Conv2d(inner_dim, dim, 1, 1, 0), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Gated_Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=1, bias=False, dropout=0.0): + super().__init__() + + hidden_features = int(dim * mult) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + bias=bias, + ) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +# MBConv + + +class SqueezeExcitation(nn.Module): + def __init__(self, dim, shrinkage_rate=0.25): + super().__init__() + hidden_dim = int(dim * shrinkage_rate) + + self.gate = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(dim, hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim, dim, bias=False), + nn.Sigmoid(), + Rearrange("b c -> b c 1 1"), + ) + + def forward(self, x): + return x * self.gate(x) + + +class MBConvResidual(nn.Module): + def __init__(self, fn, dropout=0.0): + super().__init__() + self.fn = fn + self.dropsample = Dropsample(dropout) + + def forward(self, x): + out = self.fn(x) + out = self.dropsample(out) + return out + x + + +class Dropsample(nn.Module): + def __init__(self, prob=0): + super().__init__() + self.prob = prob + + def forward(self, x): + device = x.device + + if self.prob == 0.0 or (not self.training): + return x + + keep_mask = ( + torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() + > self.prob + ) + return x * keep_mask / (1 - self.prob) + + +def MBConv( + dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0 +): + hidden_dim = int(expansion_rate * dim_out) + stride = 2 if downsample else 1 + + net = nn.Sequential( + nn.Conv2d(dim_in, hidden_dim, 1), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + nn.Conv2d( + hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim + ), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate), + nn.Conv2d(hidden_dim, dim_out, 1), + # nn.BatchNorm2d(dim_out) + ) + + if dim_in == dim_out and not downsample: + net = MBConvResidual(net, dropout=dropout) + + return net + + +# attention related classes +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Sequential( + nn.Linear(dim, dim, bias=False), nn.Dropout(dropout) + ) + + # relative positional bias + if self.with_pe: + self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) + + pos = torch.arange(window_size) + grid = torch.stack(torch.meshgrid(pos, pos)) + grid = rearrange(grid, "c i j -> (i j) c") + rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange( + grid, "j ... -> 1 j ..." + ) + rel_pos += window_size - 1 + rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum( + dim=-1 + ) + + self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False) + + def forward(self, x): + batch, height, width, window_height, window_width, _, device, h = ( + *x.shape, + x.device, + self.heads, + ) + + # flatten + + x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d") + + # project for queries, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + # split heads + + q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v)) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # add positional bias + if self.with_pe: + bias = self.rel_pos_bias(self.rel_pos_indices) + sim = sim + rearrange(bias, "i j h -> h i j") + + # attention + + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + + out = rearrange( + out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width + ) + + # combine heads out + + out = self.to_out(out) + return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width) + + +class Block_Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + bias=False, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.ps = window_size + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + # project for queries, keys, values + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + # split heads + + q, k, v = map( + lambda t: rearrange( + t, + "b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d", + h=self.heads, + w1=self.ps, + w2=self.ps, + ), + (q, k, v), + ) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # attention + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + out = rearrange( + out, + "(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)", + x=h // self.ps, + y=w // self.ps, + head=self.heads, + w1=self.ps, + w2=self.ps, + ) + + out = self.to_out(out) + return out + + +class Channel_Attention(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class Channel_Attention_grid(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention_grid, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class OSA_Block(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + ffn_bias=True, + window_size=8, + with_pe=False, + dropout=0.0, + ): + super(OSA_Block, self).__init__() + + w = window_size + + self.layer = nn.Sequential( + MBConv( + channel_num, + channel_num, + downsample=False, + expansion_rate=1, + shrinkage_rate=0.25, + ), + Rearrange( + "b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w + ), # block-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + Rearrange( + "b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w + ), # grid-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention_grid( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + ) + + def forward(self, x): + out = self.layer(x) + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py new file mode 100644 index 000000000..477e81f9d --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSAG.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:08:49 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + + +import torch.nn as nn + +from .esa import ESA +from .OSA import OSA_Block + + +class OSAG(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + block_num=4, + ffn_bias=False, + window_size=0, + pe=False, + ): + super(OSAG, self).__init__() + + # print("window_size: %d" % (window_size)) + # print("with_pe", pe) + # print("ffn_bias: %d" % (ffn_bias)) + + # block_script_name = kwargs.get("block_script_name", "OSA") + # block_class_name = kwargs.get("block_class_name", "OSA_Block") + + # script_name = "." + block_script_name + # package = __import__(script_name, fromlist=True) + block_class = OSA_Block # getattr(package, block_class_name) + group_list = [] + for _ in range(block_num): + temp_res = block_class( + channel_num, + bias, + ffn_bias=ffn_bias, + window_size=window_size, + with_pe=pe, + ) + group_list.append(temp_res) + group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias)) + self.residual_layer = nn.Sequential(*group_list) + esa_channel = max(channel_num // 4, 16) + self.esa = ESA(esa_channel, channel_num) + + def forward(self, x): + out = self.residual_layer(x) + out = out + x + return self.esa(out) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py new file mode 100644 index 000000000..dec169520 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OmniSR.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:06:36 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .OSAG import OSAG +from .pixelshuffle import pixelshuffle_block + + +class OmniSR(nn.Module): + def __init__( + self, + state_dict, + **kwargs, + ): + super(OmniSR, self).__init__() + self.state = state_dict + + bias = True # Fine to assume this for now + block_num = 1 # Fine to assume this for now + ffn_bias = True + pe = True + + num_feat = state_dict["input.weight"].shape[0] or 64 + num_in_ch = state_dict["input.weight"].shape[1] or 3 + num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh + + pixelshuffle_shape = state_dict["up.0.weight"].shape[0] + up_scale = math.sqrt(pixelshuffle_shape / num_out_ch) + if up_scale - int(up_scale) > 0: + print( + "out_nc is probably different than in_nc, scale calculation might be wrong" + ) + up_scale = int(up_scale) + res_num = 0 + for key in state_dict.keys(): + if "residual_layer" in key: + temp_res_num = int(key.split(".")[1]) + if temp_res_num > res_num: + res_num = temp_res_num + res_num = res_num + 1 # zero-indexed + + residual_layer = [] + self.res_num = res_num + + self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer) + self.up_scale = up_scale + + for _ in range(res_num): + temp_res = OSAG( + channel_num=num_feat, + bias=bias, + block_num=block_num, + ffn_bias=ffn_bias, + window_size=self.window_size, + pe=pe, + ) + residual_layer.append(temp_res) + self.residual_layer = nn.Sequential(*residual_layer) + self.input = nn.Conv2d( + in_channels=num_in_ch, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.output = nn.Conv2d( + in_channels=num_feat, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias) + + # self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, sqrt(2. / n)) + + # chaiNNer specific stuff + self.model_arch = "OmniSR" + self.sub_type = "SR" + self.in_nc = num_in_ch + self.out_nc = num_out_ch + self.num_feat = num_feat + self.scale = up_scale + + self.supports_fp16 = True # TODO: Test this + self.supports_bfp16 = True + self.min_size_restriction = 16 + + self.load_state_dict(state_dict, strict=False) + + def check_image_size(self, x): + _, _, h, w = x.size() + # import pdb; pdb.set_trace() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + # x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0) + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + residual = self.input(x) + out = self.residual_layer(residual) + + # origin + out = torch.add(self.output(out), residual) + out = self.up(out) + + out = out[:, :, : H * self.up_scale, : W * self.up_scale] + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/esa.py b/comfy_extras/chainner_models/architecture/OmniSR/esa.py new file mode 100644 index 000000000..f9ce7f7a6 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/esa.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: esa.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:06 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .layernorm import LayerNorm2d + + +def moment(x, dim=(2, 3), k=2): + assert len(x.size()) == 4 + mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1) + mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim) + return mk + + +class ESA(nn.Module): + """ + Modification of Enhanced Spatial Attention (ESA), which is proposed by + `Residual Feature Aggregation Network for Image Super-Resolution` + Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes + are deleted. + """ + + def __init__(self, esa_channels, n_feats, conv=nn.Conv2d): + super(ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) + self.conv3 = conv(f, f, kernel_size=3, padding=1) + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + c1 = self.conv2(c1_) + v_max = F.max_pool2d(c1, kernel_size=7, stride=3) + c3 = self.conv3(v_max) + c3 = F.interpolate( + c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False + ) + cf = self.conv_f(c1_) + c4 = self.conv4(c3 + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA_LN(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA_LN, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.norm = LayerNorm2d(n_feats) + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.norm(x) + c1_ = self.conv1(c1_) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class AdaGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaGuidedFilter, self).__init__() + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=n_feats, + out_channels=1, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + self.r = 5 + + def box_filter(self, x, r): + channel = x.shape[1] + kernel_size = 2 * r + 1 + weight = 1.0 / (kernel_size**2) + box_kernel = weight * torch.ones( + (channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device + ) + output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) + return output + + def forward(self, x): + _, _, H, W = x.shape + N = self.box_filter( + torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r + ) + + # epsilon = self.fc(self.gap(x)) + # epsilon = torch.pow(epsilon, 2) + epsilon = 1e-2 + + mean_x = self.box_filter(x, self.r) / N + var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x + + A = var_x / (var_x + epsilon) + b = (1 - A) * mean_x + m = A * x + b + + # mean_A = self.box_filter(A, self.r) / N + # mean_b = self.box_filter(b, self.r) / N + # m = mean_A * x + mean_b + return x * m + + +class AdaConvGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaConvGuidedFilter, self).__init__() + f = esa_channels + + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=f, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=f, + bias=bias, + ) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + y = self.vec_conv(x) + y = self.hor_conv(y) + + sigma = torch.pow(y, 2) + epsilon = self.fc(self.gap(y)) + + weight = sigma / (sigma + epsilon) + + m = weight * x + (1 - weight) + + return x * m diff --git a/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py new file mode 100644 index 000000000..731a25f75 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: layernorm.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:20 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn + + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class GRN(nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) + Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x diff --git a/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py new file mode 100644 index 000000000..4260fb7c9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: pixelshuffle.py +# Created Date: Friday July 1st 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 1st July 2022 10:18:39 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import torch.nn as nn + + +def pixelshuffle_block( + in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False +): + """ + Upsample features according to `upscale_factor`. + """ + padding = kernel_size // 2 + conv = nn.Conv2d( + in_channels, + out_channels * (upscale_factor**2), + kernel_size, + padding=1, + bias=bias, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + return nn.Sequential(*[conv, pixel_shuffle]) diff --git a/comfy_extras/chainner_models/architecture/RRDB.py b/comfy_extras/chainner_models/architecture/RRDB.py index 4d52f05dd..b50db7c24 100644 --- a/comfy_extras/chainner_models/architecture/RRDB.py +++ b/comfy_extras/chainner_models/architecture/RRDB.py @@ -79,6 +79,12 @@ class RRDBNet(nn.Module): self.scale: int = self.get_scale() self.num_filters: int = self.state[self.key_arr[0]].shape[0] + c2x2 = False + if self.state["model.0.weight"].shape[-2] == 2: + c2x2 = True + self.scale = round(math.sqrt(self.scale / 4)) + self.model_arch = "ESRGAN-2c2" + self.supports_fp16 = True self.supports_bfp16 = True self.min_size_restriction = None @@ -105,11 +111,15 @@ class RRDBNet(nn.Module): out_nc=self.num_filters, upscale_factor=3, act_type=self.act, + c2x2=c2x2, ) else: upsample_blocks = [ upsample_block( - in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act + in_nc=self.num_filters, + out_nc=self.num_filters, + act_type=self.act, + c2x2=c2x2, ) for _ in range(int(math.log(self.scale, 2))) ] @@ -122,6 +132,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), B.ShortcutBlock( B.sequential( @@ -138,6 +149,7 @@ class RRDBNet(nn.Module): act_type=self.act, mode="CNA", plus=self.plus, + c2x2=c2x2, ) for _ in range(self.num_blocks) ], @@ -149,6 +161,7 @@ class RRDBNet(nn.Module): norm_type=self.norm, act_type=None, mode=self.mode, + c2x2=c2x2, ), ) ), @@ -160,6 +173,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=self.act, + c2x2=c2x2, ), # hr_conv1 B.conv_block( @@ -168,6 +182,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), ) diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py index 214642cc4..d7bc5d227 100644 --- a/comfy_extras/chainner_models/architecture/block.py +++ b/comfy_extras/chainner_models/architecture/block.py @@ -141,6 +141,19 @@ def sequential(*args): ConvMode = Literal["CNA", "NAC", "CNAC"] +# 2x2x2 Conv Block +def conv_block_2c2( + in_nc, + out_nc, + act_type="relu", +): + return sequential( + nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), + nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), + act(act_type) if act_type else None, + ) + + def conv_block( in_nc: int, out_nc: int, @@ -153,12 +166,17 @@ def conv_block( norm_type: str | None = None, act_type: str | None = "relu", mode: ConvMode = "CNA", + c2x2=False, ): """ Conv layer with padding, normalization, activation mode: CNA --> Conv -> Norm -> Act NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) """ + + if c2x2: + return conv_block_2c2(in_nc, out_nc, act_type=act_type) + assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) padding = get_valid_padding(kernel_size, dilation) p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None @@ -285,6 +303,7 @@ class RRDB(nn.Module): _convtype="Conv2D", _spectral_norm=False, plus=False, + c2x2=False, ): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C( @@ -298,6 +317,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB2 = ResidualDenseBlock_5C( nf, @@ -310,6 +330,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB3 = ResidualDenseBlock_5C( nf, @@ -322,6 +343,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) def forward(self, x): @@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module): act_type="leakyrelu", mode: ConvMode = "CNA", plus=False, + c2x2=False, ): super(ResidualDenseBlock_5C, self).__init__() @@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv2 = conv_block( nf + gc, @@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv3 = conv_block( nf + 2 * gc, @@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv4 = conv_block( nf + 3 * gc, @@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) if mode == "CNA": last_act = None @@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=last_act, mode=mode, + c2x2=c2x2, ) def forward(self, x): @@ -499,6 +527,7 @@ def upconv_block( norm_type: str | None = None, act_type="relu", mode="nearest", + c2x2=False, ): # Up conv # described in https://distill.pub/2016/deconv-checkerboard/ @@ -512,5 +541,6 @@ def upconv_block( pad_type=pad_type, norm_type=norm_type, act_type=act_type, + c2x2=c2x2, ) return sequential(upsample, conv) diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py index 8234ac5d1..2e66e6247 100644 --- a/comfy_extras/chainner_models/model_loading.py +++ b/comfy_extras/chainner_models/model_loading.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel: state_dict = state_dict["params"] state_dict_keys = list(state_dict.keys()) + # SRVGGNet Real-ESRGAN (v2) if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys: model = RealESRGANv2(state_dict) @@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel: # MAT elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys: model = MAT(state_dict) + # Omni-SR + elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys: + model = OmniSR(state_dict) # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1 else: try: diff --git a/comfy_extras/chainner_models/types.py b/comfy_extras/chainner_models/types.py index 8e2bef47a..1906c0c7f 100644 --- a/comfy_extras/chainner_models/types.py +++ b/comfy_extras/chainner_models/types.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN from .architecture.Swin2SR import Swin2SR from .architecture.SwinIR import SwinIR -PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT) +PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR) PyTorchSRModel = Union[ RealESRGANv2, SPSR, @@ -22,6 +23,7 @@ PyTorchSRModel = Union[ SwinIR, Swin2SR, HAT, + OmniSR, ] From 9b1396e93a19748dd4c4bb35637638bb0f91b5f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 24 May 2023 14:01:11 -0400 Subject: [PATCH 159/208] Fix issue importing other ui prompts. --- web/scripts/pnginfo.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 8ddb7a1c5..977b5ac2f 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -69,6 +69,7 @@ export async function importA1111(graph, parameters) { const embeddings = await api.getEmbeddings(); const opts = parameters .substr(p) + .split("\n")[1] .split(",") .reduce((p, n) => { const s = n.split(":"); From 8b4b0c3188110e1faa8865570637172ab4b60ba1 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Thu, 25 May 2023 19:23:47 +0200 Subject: [PATCH 160/208] vecorized bislerp --- comfy/utils.py | 117 +++++++++++++++++++++++++++---------------------- 1 file changed, 64 insertions(+), 53 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 300eda6aa..cc0e5069a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,6 @@ import torch import math +import einops def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -46,71 +47,81 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd -#slow and inefficient, should be optimized def bislerp(samples, width, height): - shape = list(samples.shape) - width_scale = (shape[3]) / (width ) - height_scale = (shape[2]) / (height ) + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] - shape[3] = width - shape[2] = height - out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + #norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) - def algorithm(in1, in2, t): - dims = in1.shape - val = t + #normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms - #flatten to batches - low = in1.reshape(dims[0], -1) - high = in2.reshape(dims[0], -1) + #zero when norms are zero + b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 - low_weight = torch.norm(low, dim=1, keepdim=True) - low_weight[low_weight == 0] = 0.0000000001 - low_norm = low/low_weight - high_weight = torch.norm(high, dim=1, keepdim=True) - high_weight[high_weight == 0] = 0.0000000001 - high_norm = high/high_weight - - dot_prod = (low_norm*high_norm).sum(1) - dot_prod[dot_prod > 0.9995] = 0.9995 - dot_prod[dot_prod < -0.9995] = -0.9995 - omega = torch.acos(dot_prod) + #slerp + dot = (b1_normalized*b2_normalized).sum(1) + omega = torch.acos(dot) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm - res *= (low_weight * (1.0-val) + high_weight * val) - return res.reshape(dims) - for x_dest in range(shape[3]): - for y_dest in range(shape[2]): - y = (y_dest + 0.5) * height_scale - 0.5 - x = (x_dest + 0.5) * width_scale - 0.5 + #technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) - x1 = max(math.floor(x), 0) - x2 = min(x1 + 1, samples.shape[3] - 1) - wx = x - math.floor(x) + #edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + return res + + def generate_bilinear_data(length_old, length_new): + coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") + ratios = coords_1 - coords_1.floor() + coords_1 = coords_1.to(torch.int64) + + coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2[:,:,:,-1] -= 1 + coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") + coords_2 = coords_2.to(torch.int64) + return ratios, coords_1, coords_2 + + n,c,h,w = samples.shape + h_new, w_new = (height, width) + + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) - y1 = max(math.floor(y), 0) - y2 = min(y1 + 1, samples.shape[2] - 1) - wy = y - math.floor(y) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w)) - in1 = samples[:,:,y1,x1] - in2 = samples[:,:,y1,x2] - in3 = samples[:,:,y2,x1] - in4 = samples[:,:,y2,x2] + pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') - if (x1 == x2) and (y1 == y2): - out_value = in1 - elif (x1 == x2): - out_value = algorithm(in1, in3, wy) - elif (y1 == y2): - out_value = algorithm(in1, in2, wx) - else: - o1 = algorithm(in1, in2, wx) - o2 = algorithm(in3, in4, wx) - out_value = algorithm(o1, o2, wy) + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w) - out1[:,:,y_dest,x_dest] = out_value - return out1 + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + + coords_1 = coords_1.expand((n, c, h_new, -1)) + coords_2 = coords_2.expand((n, c, h_new, -1)) + ratios = ratios.expand((n, 1, h_new, -1)) + + pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new) + return result def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": From e1278fa925cf59350bae76dc3d0c59a0e9564789 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 13:30:59 -0400 Subject: [PATCH 161/208] Support old pytorch versions that don't have weights_only. --- comfy/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 300eda6aa..d58320b4a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -6,6 +6,10 @@ def load_torch_file(ckpt, safe_load=False): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: + if safe_load: + if not 'weights_only' in torch.load.__code__.co_varnames: + print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) else: From 87ab25fac77ff1d558fea3c02733a463cb1fa013 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 18:31:27 -0400 Subject: [PATCH 162/208] Do operations in same order as the one it replaces. --- comfy/utils.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 33c1c3dd7..f139fbb27 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -98,29 +98,27 @@ def bislerp(samples, width, height): n,c,h,w = samples.shape h_new, w_new = (height, width) - #linear h - ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + coords_1 = coords_1.expand((n, c, h, -1)) + coords_2 = coords_2.expand((n, c, h, -1)) + ratios = ratios.expand((n, 1, h, -1)) - coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w)) - coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w)) - ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w)) - - pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c') + pass_1 = einops.rearrange(samples.gather(-1,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(samples.gather(-1,coords_2), 'n c h w -> (n h w) c') ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h, w=w_new) - #linear w - ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) - coords_1 = coords_1.expand((n, c, h_new, -1)) - coords_2 = coords_2.expand((n, c, h_new, -1)) - ratios = ratios.expand((n, 1, h_new, -1)) - - pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c') + pass_1 = einops.rearrange(result.gather(-2,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(result.gather(-2,coords_2), 'n c h w -> (n h w) c') ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') result = slerp(pass_1, pass_2, ratios) From eb4bd7711acec9a2a2d4f1d4dcc1d32e1236c976 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 18:42:56 -0400 Subject: [PATCH 163/208] Remove einops. --- comfy/utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index f139fbb27..5ed9aaa02 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,6 +1,5 @@ import torch import math -import einops def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -104,12 +103,12 @@ def bislerp(samples, width, height): coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) - pass_1 = einops.rearrange(samples.gather(-1,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(samples.gather(-1,coords_2), 'n c h w -> (n h w) c') - ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h, w=w_new) + result = result.reshape(n, h, w_new, c).movedim(-1, 1) #linear h ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) @@ -117,12 +116,12 @@ def bislerp(samples, width, height): coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) - pass_1 = einops.rearrange(result.gather(-2,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(result.gather(-2,coords_2), 'n c h w -> (n h w) c') - ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new) + result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result def common_upscale(samples, width, height, upscale_method, crop): From 4d1ed829d9a934d9a303a725e325f90934854ac8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 26 May 2023 19:33:30 -0500 Subject: [PATCH 164/208] Don't load some model types if weight is zero --- nodes.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nodes.py b/nodes.py index f0a93ebd5..68010f040 100644 --- a/nodes.py +++ b/nodes.py @@ -426,6 +426,9 @@ class LoraLoader: CATEGORY = "loaders" def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + lora_path = folder_paths.get_full_path("loras", lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) @@ -507,6 +510,9 @@ class ControlNetApply: CATEGORY = "conditioning" def apply_controlnet(self, conditioning, control_net, image, strength): + if strength == 0: + return (conditioning, ) + c = [] control_hint = image.movedim(-1,1) for t in conditioning: @@ -613,6 +619,9 @@ class unCLIPConditioning: CATEGORY = "conditioning" def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): + if strength == 0: + return (conditioning, ) + c = [] for t in conditioning: o = t[1].copy() From 679bd2845af8e22b2802cf326b99b40a26ba7811 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 26 May 2023 21:46:11 -0400 Subject: [PATCH 165/208] Safetensors isn't optional anymore. --- folder_paths.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 28f117824..20b461c94 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,7 @@ import os -supported_ckpt_extensions = set(['.ckpt', '.pth']) -supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth']) -try: - import safetensors.torch - supported_ckpt_extensions.add('.safetensors') - supported_pt_extensions.add('.safetensors') -except: - print("Could not import safetensors, safetensors support disabled.") - +supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) +supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} From 73e85fb3f4b104053fb1ac5d0aea456e373ea8c8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:00:47 -0500 Subject: [PATCH 166/208] Improve error output for failed nodes --- execution.py | 237 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 204 insertions(+), 33 deletions(-) diff --git a/execution.py b/execution.py index 25f2fcacd..691beb102 100644 --- a/execution.py +++ b/execution.py @@ -297,24 +297,80 @@ def validate_inputs(prompt, item, validated): class_inputs = obj_class.INPUT_TYPES() required_inputs = class_inputs['required'] + + errors = [] + valid = True + for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } + } + errors.append(error) + continue + val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) + error = { + "type": "bad_linked_input", + "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val + } + } + errors.append(error) + continue + o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) - r = validate_inputs(prompt, o_id, validated) - if r[0] == False: - validated[o_id] = r - return r + received_type = r[val[1]] + details = f"{x}, {received_type} != {type_input}" + error = { + "type": "return_type_mismatch", + "message": "Return type mismatch between linked nodes", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_type": received_type + } + } + errors.append(error) + continue + try: + r = validate_inputs(prompt, o_id, validated) + if r[0] is False: + # `r` will be set in `validated[o_id]` already + valid = False + continue + except Exception as ex: + typ, _, tb = sys.exc_info() + valid = False + error_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "error_type": error_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o_id] = (False, reasons, o_id) + continue else: if type_input == "INT": val = int(val) @@ -328,26 +384,97 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for r in ret: - if r != True: - return (False, "{}, {}".format(class_type, r), unique_id) + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f": {str(r)}" + else: + details += "." + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) + input_config = info + list_info = "" + + # Don't send back gigantic lists like if they're lots of + # scanned model filepaths + if len(type_input) > 20: + list_info = f"(list of length {len(type_input)})" + input_config = None + else: + list_info = str(type_input) + + error = { + "type": "value_not_in_list", + "message": "Value not in list", + "details": f"{x}: '{val}' not in {list_info}", + "extra_info": { + "input_name": x, + "input_config": input_config, + "received_value": val, + } + } + errors.append(error) + continue + + if len(errors) > 0 or valid is not True: + ret = (False, errors, unique_id) + else: + ret = (True, [], unique_id) - ret = (True, "", unique_id) validated[unique_id] = ret return ret +def full_type_name(klass): + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ + return module + '.' + klass.__qualname__ + def validate_prompt(prompt): outputs = set() for x in prompt: @@ -356,7 +483,13 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs", [], []) + error = { + "type": "prompt_no_outputs", + "message": "Prompt has no outputs", + "details": "", + "extra_info": {} + } + return (False, error, [], []) good_outputs = set() errors = [] @@ -364,34 +497,72 @@ def validate_prompt(prompt): validated = {} for o in outputs: valid = False - reason = "" + reasons = [] try: m = validate_inputs(prompt, o, validated) valid = m[0] - reason = m[1] - node_id = m[2] - except Exception as e: - print(traceback.format_exc()) + reasons = m[1] + except Exception as ex: + typ, _, tb = sys.exc_info() valid = False - reason = "Parsing error" - node_id = None + error_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "error_type": error_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o] = (False, reasons, o) - if valid == True: + if valid is True: good_outputs.add(o) else: - print("Failed to validate prompt for output {} {}".format(o, reason)) - print("output will be ignored") - errors += [(o, reason)] - if node_id is not None: - if node_id not in node_errors: - node_errors[node_id] = {"message": reason, "dependent_outputs": []} - node_errors[node_id]["dependent_outputs"].append(o) + print(f"Failed to validate prompt for output {o}:") + if len(reasons) > 0: + print("* (prompt):") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + errors += [(o, reasons)] + for node_id, result in validated.items(): + valid = result[0] + reasons = result[1] + # If a node upstream has errors, the nodes downstream will also + # be reported as invalid, but there will be no errors attached. + # So don't return those nodes as having errors in the response. + if valid is not True and len(reasons) > 0: + if node_id not in node_errors: + class_type = prompt[node_id]['class_type'] + node_errors[node_id] = { + "errors": reasons, + "dependent_outputs": [], + "class_type": class_type + } + print(f"* {class_type} {node_id}:") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + node_errors[node_id]["dependent_outputs"].append(o) + print("Output will be ignored") if len(good_outputs) == 0: - errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) + errors_list = [] + for o, errors in errors: + for error in errors: + errors_list.append(f"{error['message']}: {error['details']}") + errors_list = "\n".join(errors_list) - return (True, "", list(good_outputs), node_errors) + error = { + "type": "prompt_no_good_outputs", + "message": "Prompt has no properly connected outputs", + "details": errors_list, + "extra_info": {} + } + + return (False, error, list(good_outputs), node_errors) + + return (True, None, list(good_outputs), node_errors) class PromptQueue: From cc4d3435d3590288e21f3adfd42f044a7e45fae4 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:48:55 -0500 Subject: [PATCH 167/208] Highlight failing nodes/inputs in frontend --- web/scripts/app.js | 74 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 7 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 97b7c8d31..21fe94802 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -771,16 +771,25 @@ export class ComfyApp { LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) { const res = origDrawNodeShape.apply(this, arguments); + const nodeErrors = self.lastPromptError?.node_errors[node.id]; + let color = null; + let lineWidth = 1; if (node.id === +self.runningNodeId) { color = "#0f0"; } else if (self.dragOverNode && node.id === self.dragOverNode.id) { color = "dodgerblue"; } + else if (self.lastPromptError != null && nodeErrors?.errors) { + color = "red"; + lineWidth = 2; + } + + self.graphTime = Date.now() if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; - ctx.lineWidth = 1; + ctx.lineWidth = lineWidth; ctx.globalAlpha = 0.8; ctx.beginPath(); if (shape == LiteGraph.BOX_SHAPE) @@ -807,11 +816,28 @@ export class ComfyApp { ctx.stroke(); ctx.strokeStyle = fgcolor; ctx.globalAlpha = 1; + } - if (self.progress) { - ctx.fillStyle = "green"; - ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); - ctx.fillStyle = bgcolor; + if (self.progress && node.id === +self.runningNodeId) { + ctx.fillStyle = "green"; + ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); + ctx.fillStyle = bgcolor; + } + + // Highlight inputs that failed validation + if (nodeErrors) { + ctx.lineWidth = 2; + ctx.strokeStyle = "red"; + for (const error of nodeErrors.errors) { + if (error.extra_info && error.extra_info.input_name) { + const inputIndex = node.findInputSlot(error.extra_info.input_name) + if (inputIndex !== -1) { + let pos = node.getConnectionPos(true, inputIndex); + ctx.beginPath(); + ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false) + ctx.stroke(); + } + } } } @@ -1243,6 +1269,31 @@ export class ComfyApp { return { workflow, output }; } + #formatError(error) { + if (error == null) { + return "(unknown error)" + } + else if (typeof error === "string") { + return error; + } + else if (error.stack && error.message) { + return error.toString() + } + else if (error.response) { + let message = error.response.error.message; + if (error.response.error.details) + message += ": " + error.response.error.details; + for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) { + message += "\n" + nodeError.class_type + ":" + for (const errorReason of nodeError.errors) { + message += "\n - " + errorReason.message + ": " + errorReason.details + } + } + return message + } + return "(unknown error)" + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1250,8 +1301,10 @@ export class ComfyApp { if (this.#processingQueue) { return; } - + this.#processingQueue = true; + this.lastPromptError = null; + try { while (this.#queueItems.length) { ({ number, batchCount } = this.#queueItems.pop()); @@ -1262,7 +1315,12 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response.error || error.toString()); + const formattedError = this.#formatError(error) + this.ui.dialog.show(formattedError); + if (error.response) { + this.lastPromptError = error.response; + this.canvas.draw(true, true); + } break; } @@ -1360,6 +1418,8 @@ export class ComfyApp { */ clean() { this.nodeOutputs = {}; + this.lastPromptError = null; + this.graphTime = null } } From c33b7c5549b7b277011e2c3f50215ba466afb205 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:54:13 -0500 Subject: [PATCH 168/208] Improve invalid prompt error message --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 691beb102..66753ff90 100644 --- a/execution.py +++ b/execution.py @@ -554,8 +554,8 @@ def validate_prompt(prompt): errors_list = "\n".join(errors_list) error = { - "type": "prompt_no_good_outputs", - "message": "Prompt has no properly connected outputs", + "type": "prompt_outputs_failed_validation", + "message": "Prompt outputs failed validation", "details": errors_list, "extra_info": {} } From 0d834e3a2ba6272b8cee6503f574c0f06002ddc3 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:59:30 -0500 Subject: [PATCH 169/208] Add missing input name/config --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index 66753ff90..632aaa843 100644 --- a/execution.py +++ b/execution.py @@ -365,6 +365,8 @@ def validate_inputs(prompt, item, validated): "message": "Exception when validating node", "details": str(ex), "extra_info": { + "input_name": x, + "input_config": info, "error_type": error_type, "traceback": traceback.format_tb(tb) } From ffec815257ddf2371b880eafd575838210fcea07 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 12:48:06 -0500 Subject: [PATCH 170/208] Send back more information about exceptions that happen during execution --- execution.py | 173 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 50 deletions(-) diff --git a/execution.py b/execution.py index 632aaa843..5ed9ff348 100644 --- a/execution.py +++ b/execution.py @@ -102,13 +102,19 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui +def format_value(x): + if isinstance(x, (int, float, bool, str)): + return x + else: + return str(x) + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return + return (True, None, None) for x in inputs: input_data = inputs[x] @@ -117,22 +123,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + if result[0] is not True: + # Another node failed further upstream + return result - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - obj = class_def() - - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + input_data_all = None + try: + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.last_node_id = unique_id + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + obj = class_def() + + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui + if server.client_id is not None: + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + except comfy.model_management.InterruptProcessingException as iex: + print("Processing interrupted") + + # skip formatting inputs/outputs + error_details = { + "node_id": unique_id, + } + + return (False, error_details, iex) + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputs in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputs] + + output_data_formatted = {} + for node_id, node_outputs in outputs.items(): + output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + + print("!!! Exception during processing !!!") + print(traceback.format_exc()) + + error_details = { + "node_id": unique_id, + "message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "current_inputs": input_data_formatted, + "current_outputs": output_data_formatted + } + return (False, error_details, ex) + executed.add(unique_id) + return (True, None, None) + def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -210,6 +258,44 @@ class PromptExecutor: self.old_prompt = {} self.server = server + def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): + # First, send back the status to the frontend depending + # on the exception type + if isinstance(ex, comfy.model_management.InterruptProcessingException): + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "node_id": error["node_id"], + } + self.server.send_sync("execution_interrupted", mes, self.server.client_id) + else: + if self.server.client_id is not None: + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "message": error["message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "node_id": error["node_id"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.server.send_sync("execution_error", mes, self.server.client_id) + + # Next, remove the subsequent outputs since they will not be executed + to_delete = [] + for o in self.outputs: + if (o not in current_outputs) and (o not in executed): + to_delete += [o] + if o in self.old_prompt: + d = self.old_prompt.pop(o) + del d + for o in to_delete: + d = self.outputs.pop(o) + del d + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -244,42 +330,29 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() - try: - to_execute = [] - for x in list(execute_outputs): - to_execute += [(0, x)] + output_node_id = None + to_execute = [] - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - x = to_execute.pop(0)[-1] + for node_id in list(execute_outputs): + to_execute += [(0, node_id)] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) - except Exception as e: - if isinstance(e, comfy.model_management.InterruptProcessingException): - print("Processing interrupted") - else: - message = str(traceback.format_exc()) - print(message) - if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) + while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + output_node_id = to_execute.pop(0)[-1] - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - finally: - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) - self.server.last_node_id = None - if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + # This call shouldn't raise anything if there's an error deep in + # the actual SD code, instead it will report the node where the + # error was raised + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) + if success is not True: + self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) + + for x in executed: + self.old_prompt[x] = copy.deepcopy(prompt[x]) + self.server.last_node_id = None + if self.server.client_id is not None: + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() @@ -359,7 +432,7 @@ def validate_inputs(prompt, item, validated): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", @@ -367,7 +440,7 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] @@ -507,13 +580,13 @@ def validate_prompt(prompt): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", "details": str(ex), "extra_info": { - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] From 6b2a8a3845972bcff02184aaa8ded6eace8300ad Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:03:41 -0500 Subject: [PATCH 171/208] Show message in the frontend if prompt execution raises an exception --- execution.py | 14 +++++++++----- web/scripts/api.js | 6 ++++++ web/scripts/app.js | 35 ++++++++++++++++++++++++++++++----- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/execution.py b/execution.py index 5ed9ff348..f79c3d351 100644 --- a/execution.py +++ b/execution.py @@ -258,27 +258,31 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + node_id = error["node_id"] + class_type = prompt[node_id]["class_type"] + # First, send back the status to the frontend depending # on the exception type if isinstance(ex, comfy.model_management.InterruptProcessingException): mes = { "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, "executed": list(executed), - - "node_id": error["node_id"], } self.server.send_sync("execution_interrupted", mes, self.server.client_id) else: if self.server.client_id is not None: mes = { "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, "executed": list(executed), "message": error["message"], "exception_type": error["exception_type"], "traceback": error["traceback"], - "node_id": error["node_id"], "current_inputs": error["current_inputs"], "current_outputs": error["current_outputs"], } @@ -346,7 +350,7 @@ class PromptExecutor: # error was raised success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) if success is not True: - self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) + self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) diff --git a/web/scripts/api.js b/web/scripts/api.js index 4f061c358..378165b3a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -88,6 +88,12 @@ class ComfyApi extends EventTarget { case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); break; + case "execution_start": + this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data })); + break; + case "execution_error": + this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); + break; default: if (this.#registered.has(msg.type)) { this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index 21fe94802..e8ab32cf9 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -784,8 +784,10 @@ export class ComfyApp { color = "red"; lineWidth = 2; } - - self.graphTime = Date.now() + else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) { + color = "#f0f"; + lineWidth = 2; + } if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; @@ -895,6 +897,17 @@ export class ComfyApp { } }); + api.addEventListener("execution_start", ({ detail }) => { + this.lastExecutionError = null + }); + + api.addEventListener("execution_error", ({ detail }) => { + this.lastExecutionError = detail; + const formattedError = this.#formatExecutionError(detail); + this.ui.dialog.show(formattedError); + this.canvas.draw(true, true); + }); + api.init(); } @@ -1269,7 +1282,7 @@ export class ComfyApp { return { workflow, output }; } - #formatError(error) { + #formatPromptError(error) { if (error == null) { return "(unknown error)" } @@ -1294,6 +1307,18 @@ export class ComfyApp { return "(unknown error)" } + #formatExecutionError(error) { + if (error == null) { + return "(unknown error)" + } + + const traceback = error.traceback.join("") + const nodeId = error.node_id + const nodeType = error.node_type + + return `Error occurred when executing ${nodeType}:\n\n${error.message}\n\n${traceback}` + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1315,7 +1340,7 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - const formattedError = this.#formatError(error) + const formattedError = this.#formatPromptError(error) this.ui.dialog.show(formattedError); if (error.response) { this.lastPromptError = error.response; @@ -1419,7 +1444,7 @@ export class ComfyApp { clean() { this.nodeOutputs = {}; this.lastPromptError = null; - this.graphTime = null + this.lastExecutionError = null; } } From e2d080b6941783e50155f694c11ab0da1b1ae240 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:07:51 -0500 Subject: [PATCH 172/208] Return null for value format --- execution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index f79c3d351..9cebce928 100644 --- a/execution.py +++ b/execution.py @@ -103,7 +103,9 @@ def get_output_data(obj, input_data_all): return output, ui def format_value(x): - if isinstance(x, (int, float, bool, str)): + if x is None: + return None + elif isinstance(x, (int, float, bool, str)): return x else: return str(x) From a9e7e237248296c8fe0d79991e0f8c2c0f2cf530 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:11:34 -0500 Subject: [PATCH 173/208] Fix --- execution.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/execution.py b/execution.py index 9cebce928..ffea00a8c 100644 --- a/execution.py +++ b/execution.py @@ -499,9 +499,7 @@ def validate_inputs(prompt, item, validated): if r is not True: details = f"{x}" if r is not False: - details += f": {str(r)}" - else: - details += "." + details += f" - {str(r)}" error = { "type": "custom_validation_failed", From 62bdd9d26aba086ffbeedd118140e2806e6f4345 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 26 May 2023 16:35:54 -0500 Subject: [PATCH 174/208] Catch typecast errors --- execution.py | 43 ++++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/execution.py b/execution.py index ffea00a8c..6af58a673 100644 --- a/execution.py +++ b/execution.py @@ -424,7 +424,8 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, - "received_type": received_type + "received_type": received_type, + "linked_node": val } } errors.append(error) @@ -440,28 +441,44 @@ def validate_inputs(prompt, item, validated): valid = False exception_type = full_type_name(typ) reasons = [{ - "type": "exception_during_validation", - "message": "Exception when validating node", + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", "details": str(ex), "extra_info": { "input_name": x, "input_config": info, "exception_type": exception_type, - "traceback": traceback.format_tb(tb) + "traceback": traceback.format_tb(tb), + "linked_node": val } }] validated[o_id] = (False, reasons, o_id) continue else: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val + try: + if type_input == "INT": + val = int(val) + inputs[x] = val + if type_input == "FLOAT": + val = float(val) + inputs[x] = val + if type_input == "STRING": + val = str(val) + inputs[x] = val + except Exception as ex: + error = { + "type": "invalid_input_type", + "message": f"Failed to convert an input value to a {type_input} value", + "details": f"{x}, {val}, {ex}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + "exception_message": str(ex) + } + } + errors.append(error) + continue if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: From 52c9590b7b65dba86e8622f6ad38974bc4045f31 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 01:51:39 -0500 Subject: [PATCH 175/208] Exception message --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 6af58a673..52c264b0f 100644 --- a/execution.py +++ b/execution.py @@ -447,6 +447,7 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, + "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "linked_node": val From 03f2d0a764726641e848ba4e069c8809a502afdf Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 02:02:11 -0500 Subject: [PATCH 176/208] Rename exception message field --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 52c264b0f..1a9a1ff73 100644 --- a/execution.py +++ b/execution.py @@ -171,7 +171,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute error_details = { "node_id": unique_id, - "message": str(ex), + "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "current_inputs": input_data_formatted, @@ -282,7 +282,7 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), - "message": error["message"], + "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], "current_inputs": error["current_inputs"], From 00646b0813e4f395725f3013f18b13a46f4d619d Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 21:48:49 -0500 Subject: [PATCH 177/208] Bitwise operations for masks --- comfy_extras/nodes_mask.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9916f3b21..9134c24da 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -167,7 +167,7 @@ class MaskComposite: "source": ("MASK",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract"],), + "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), } } @@ -193,6 +193,12 @@ class MaskComposite: output[top:bottom, left:right] = destination_portion + source_portion elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion + elif operation == "and": + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.bool(), source_portion.bool()).float() + elif operation == "or": + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.bool(), source_portion.bool()).float() + elif operation == "xor": + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.bool(), source_portion.bool()).float() output = torch.clamp(output, 0.0, 1.0) From ad81fd682a5e5e7c1f258d7c11a000c0dfd07be3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 00:32:26 -0400 Subject: [PATCH 178/208] Fix issue with cancelling prompt. --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 1a9a1ff73..218a84c36 100644 --- a/execution.py +++ b/execution.py @@ -353,6 +353,7 @@ class PromptExecutor: success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) if success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + break for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) From f3ac938b4a5c031adb9ee2951f26360d6a2b36de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 00:42:53 -0400 Subject: [PATCH 179/208] Round the mask values for bitwise operations. --- comfy_extras/nodes_mask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9134c24da..15377af14 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -194,11 +194,11 @@ class MaskComposite: elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion elif operation == "and": - output[top:bottom, left:right] = torch.bitwise_and(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "or": - output[top:bottom, left:right] = torch.bitwise_or(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "xor": - output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() output = torch.clamp(output, 0.0, 1.0) From 0fc483dcfdef457b50d3a67e66b4f463e6ef9d62 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 01:52:09 -0400 Subject: [PATCH 180/208] Refactor diffusers model convert code to be able to reuse it. --- comfy/diffusers_convert.py | 107 ----------------------------------- comfy/diffusers_load.py | 111 +++++++++++++++++++++++++++++++++++++ nodes.py | 4 +- 3 files changed, 113 insertions(+), 109 deletions(-) create mode 100644 comfy/diffusers_load.py diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index ceca80305..1eab54d4b 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -1,14 +1,5 @@ -import json -import os -import yaml - -import folder_paths -from comfy.ldm.util import instantiate_from_config -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE -import os.path as osp import re import torch -from safetensors.torch import load_file, save_file # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict -def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): - diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) - diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) - - # magic - v2 = diffusers_unet_conf["sample_size"] == 96 - if 'prediction_type' in diffusers_scheduler_conf: - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' - - if v2: - if v_pred: - config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') - - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) - - model_config_params = config['model']['params'] - clip_config = model_config_params['cond_stage_config'] - scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - vae_config['scale_factor'] = scale_factor - model_config_params["unet_config"]["params"]["use_fp16"] = fp16 - - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") - text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") - - # Load models from safetensors if it exists, if it doesn't pytorch - if osp.exists(unet_path): - unet_state_dict = load_file(unet_path, device="cpu") - else: - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") - unet_state_dict = torch.load(unet_path, map_location="cpu") - - if osp.exists(vae_path): - vae_state_dict = load_file(vae_path, device="cpu") - else: - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") - vae_state_dict = torch.load(vae_path, map_location="cpu") - - if osp.exists(text_enc_path): - text_enc_dict = load_file(text_enc_path, device="cpu") - else: - text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict(unet_state_dict) - unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} - - # Convert the VAE model - vae_state_dict = convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper - is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict - - if is_v20_model: - # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm - text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} - text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) - text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} - else: - text_enc_dict = convert_text_enc_state_dict(text_enc_dict) - text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} - - # Put together new checkpoint - sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - w = WeightsLoader() - load_state_dict_to = [] - if output_vae: - vae = VAE(scale_factor=scale_factor, config=vae_config) - w.first_stage_model = vae.first_stage_model - load_state_dict_to = [w] - - if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_state_dict_to = [w] - - model = instantiate_from_config(config["model"]) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - - if fp16: - model = model.half() - - return ModelPatcher(model), clip, vae diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py new file mode 100644 index 000000000..43877fb83 --- /dev/null +++ b/comfy/diffusers_load.py @@ -0,0 +1,111 @@ +import json +import os +import yaml + +import folder_paths +from comfy.ldm.util import instantiate_from_config +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +import os.path as osp +import re +import torch +from safetensors.torch import load_file, save_file +import diffusers_convert + +def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): + diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) + diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) + + # magic + v2 = diffusers_unet_conf["sample_size"] == 96 + if 'prediction_type' in diffusers_scheduler_conf: + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + + if v2: + if v_pred: + config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) + + model_config_params = config['model']['params'] + clip_config = model_config_params['cond_stage_config'] + scale_factor = model_config_params['scale_factor'] + vae_config = model_config_params['first_stage_config'] + vae_config['scale_factor'] = scale_factor + model_config_params["unet_config"]["params"]["use_fp16"] = fp16 + + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict) + text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # Put together new checkpoint + sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + + clip = None + vae = None + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + load_state_dict_to = [] + if output_vae: + vae = VAE(scale_factor=scale_factor, config=vae_config) + w.first_stage_model = vae.first_stage_model + load_state_dict_to = [w] + + if output_clip: + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_state_dict_to = [w] + + model = instantiate_from_config(config["model"]) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + + if fp16: + model = model.half() + + return ModelPatcher(model), clip, vae diff --git a/nodes.py b/nodes.py index 68010f040..90444a92c 100644 --- a/nodes.py +++ b/nodes.py @@ -17,7 +17,7 @@ import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) -import comfy.diffusers_convert +import comfy.diffusers_load import comfy.samplers import comfy.sample import comfy.sd @@ -377,7 +377,7 @@ class DiffusersLoader: model_path = path break - return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: From a532888846809de7b8890e8beb10ea87edf39d7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 02:02:09 -0400 Subject: [PATCH 181/208] Support VAEs in diffusers format. --- comfy/sd.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c6be900ad..4df149fe1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ from .t2i_adapter import adapter from . import utils from . import clip_vision from . import gligen +from . import diffusers_convert def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) @@ -504,10 +505,16 @@ class VAE: if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") else: - self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + if ckpt_path is not None: + sd = utils.load_torch_file(ckpt_path) + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) + self.first_stage_model.load_state_dict(sd, strict=False) + self.scale_factor = scale_factor if device is None: device = model_management.get_torch_device() From 23ffafeb5d4a25bb5e41c34c9f04a0733643892c Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sun, 28 May 2023 23:31:40 +0900 Subject: [PATCH 182/208] typo fix: field name in error message --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index e8ab32cf9..26670239b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1316,7 +1316,7 @@ export class ComfyApp { const nodeId = error.node_id const nodeType = error.node_type - return `Error occurred when executing ${nodeType}:\n\n${error.message}\n\n${traceback}` + return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}` } async queuePrompt(number, batchCount = 1) { From b9818eb910b6ce683c38602c9b8fbd3979d97aaf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 02:48:50 -0400 Subject: [PATCH 183/208] Add route to get safetensors metadata: /view_metadata/loras?filename=lora.safetensors --- comfy/utils.py | 9 +++++++++ folder_paths.py | 2 ++ server.py | 25 ++++++++++++++++++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 5ed9aaa02..4e84e870b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,6 @@ import torch import math +import struct def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -50,6 +51,14 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +def safetensors_header(safetensors_path, max_size=100*1024*1024): + with open(safetensors_path, "rb") as f: + header = f.read(8) + length_of_header = struct.unpack(' max_size: + return None + return f.read(length_of_header) + def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' diff --git a/folder_paths.py b/folder_paths.py index 20b461c94..19245a617 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -126,11 +126,13 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths folders = folder_names_and_paths[folder_name] + filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: full_path = os.path.join(x, filename) if os.path.isfile(full_path): return full_path + return None def get_filename_list(folder_name): global folder_names_and_paths diff --git a/server.py b/server.py index c0f79cbd5..72c565a63 100644 --- a/server.py +++ b/server.py @@ -22,7 +22,7 @@ except ImportError: import mimetypes from comfy.cli_args import args - +import comfy.utils @web.middleware async def cache_control(request: web.Request, handler): @@ -257,6 +257,29 @@ class PromptServer(): return web.Response(status=404) + @routes.get("/view_metadata/{folder_name}") + async def view_metadata(request): + folder_name = request.match_info.get("folder_name", None) + if folder_name is None: + return web.Response(status=404) + if not "filename" in request.rel_url.query: + return web.Response(status=404) + + filename = request.rel_url.query["filename"] + if not filename.endswith(".safetensors"): + return web.Response(status=404) + + safetensors_path = folder_paths.get_full_path(folder_name, filename) + if safetensors_path is None: + return web.Response(status=404) + out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024) + if out is None: + return web.Response(status=404) + dt = json.loads(out) + if not "__metadata__" in dt: + return web.Response(status=404) + return web.json_response(dt["__metadata__"]) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) From 560e9f7a43242c51da2589a33f659ecd41914b20 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 11:29:00 -0400 Subject: [PATCH 184/208] Disable repo owner validation in update.py --- .ci/update_windows/update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index c09f29a80..ef9374c44 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'): else: raise AssertionError('Unknown merge analysis result') - +pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) repo = pygit2.Repository(str(sys.argv[1])) ident = pygit2.Signature('comfyui', 'comfy@ui') try: From 08abd838b82ea8d08a7e6f1484140d1694180381 Mon Sep 17 00:00:00 2001 From: "Lt.Dr.Data" Date: Tue, 30 May 2023 15:26:45 +0900 Subject: [PATCH 185/208] HOTFIX: Patched the conflict issue between the Combo Refresh feature and PrimitiveNodes. --- web/scripts/app.js | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 26670239b..64adc3e6a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1424,6 +1424,11 @@ export class ComfyApp { const def = defs[node.type]; + // HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes, + // and additional work is needed to consider the primitive logic in the refresh logic. + if(!def) + continue; + for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { From eb448dd8e18125b569bea9002f909769678a6c43 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 May 2023 12:36:41 -0400 Subject: [PATCH 186/208] Auto load model in lowvram if not enough memory. --- comfy/model_management.py | 46 ++++++++++++++++++++++++--------------- comfy/sd.py | 18 +++++++++++++-- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c15323219..10a706793 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,9 +15,8 @@ vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM total_vram = 0 -total_vram_available_mb = -1 -accelerate_enabled = False +lowvram_available = True xpu_available = False directml_enabled = False @@ -31,11 +30,12 @@ if args.directml is not None: directml_device = torch_directml.device(device_index) print("Using directml with device:", torch_directml.device_name(device_index)) # torch_directml.disable_tiled_resources(True) + lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: import torch if directml_enabled: - total_vram = 4097 #TODO + pass #TODO else: try: import intel_extension_for_pytorch as ipex @@ -46,7 +46,7 @@ try: total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) if not args.normalvram and not args.cpu: - if total_vram <= 4096: + if lowvram_available and total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") set_vram_to = VRAMState.LOW_VRAM elif total_vram > total_ram * 1.1 and total_vram > 14336: @@ -92,6 +92,7 @@ if ENABLE_PYTORCH_ATTENTION: if args.lowvram: set_vram_to = VRAMState.LOW_VRAM + lowvram_available = True elif args.novram: set_vram_to = VRAMState.NO_VRAM elif args.highvram: @@ -103,18 +104,18 @@ if args.force_fp32: FORCE_FP32 = True -if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + +if lowvram_available: try: import accelerate - accelerate_enabled = True - vram_state = set_vram_to + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to except Exception as e: import traceback print(traceback.format_exc()) - print("ERROR: COULD NOT ENABLE LOW VRAM MODE.") + print("ERROR: LOW VRAM MODE NEEDS accelerate.") + lowvram_available = False - total_vram_available_mb = (total_vram - 1024) // 2 - total_vram_available_mb = int(max(256, total_vram_available_mb)) try: if torch.backends.mps.is_available(): @@ -199,22 +200,33 @@ def load_model_gpu(model): model.unpatch_model() raise e - model.model_patches_to(get_torch_device()) + torch_dev = get_torch_device() + model.model_patches_to(torch_dev) + + vram_set_state = vram_state + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = model.model_size() + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.2 )) + if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary + vram_set_state = VRAMState.LOW_VRAM + current_loaded_model = model - if vram_state == VRAMState.CPU: + + if vram_set_state == VRAMState.CPU: pass - elif vram_state == VRAMState.MPS: + elif vram_set_state == VRAMState.MPS: mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM: model_accelerated = False real_model.to(get_torch_device()) else: - if vram_state == VRAMState.NO_VRAM: + if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) + elif vram_set_state == VRAMState.LOW_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) model_accelerated = True diff --git a/comfy/sd.py b/comfy/sd.py index 4df149fe1..ce17994f7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -286,15 +286,29 @@ def model_lora_keys(model, key_map={}): return key_map + class ModelPatcher: - def __init__(self, model): + def __init__(self, model, size=0): + self.size = size self.model = model self.patches = [] self.backup = {} self.model_options = {"transformer_options":{}} + self.model_size() + + def model_size(self): + if self.size > 0: + return self.size + model_sd = self.model.state_dict() + size = 0 + for k in model_sd: + t = model_sd[k] + size += t.nelement() * t.element_size() + self.size = size + return size def clone(self): - n = ModelPatcher(self.model) + n = ModelPatcher(self.model, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) return n From 2260802d90c41f1475a7bf2960aa018dc25f1001 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 May 2023 16:44:09 -0400 Subject: [PATCH 187/208] Check if folder_name is valid instead of just throwing exception. --- folder_paths.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/folder_paths.py b/folder_paths.py index 19245a617..fc37e52c7 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -125,6 +125,8 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths + if folder_name not in folder_names_and_paths: + return None folders = folder_names_and_paths[folder_name] filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: From 04f4fba013da1f556fc310235d5a30c2bfe682e8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 16:01:49 -0500 Subject: [PATCH 188/208] Fix litegraph dialog CSS --- web/index.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/index.html b/web/index.html index bb79433ce..da0adb6c2 100644 --- a/web/index.html +++ b/web/index.html @@ -14,5 +14,5 @@ window.graph = app.graph; - + From 468c27afea29928d7d9fcd208e1137a36118ad13 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 16:06:17 -0500 Subject: [PATCH 189/208] Fix litegraph dialog z-index/font --- web/style.css | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/web/style.css b/web/style.css index 87f096e14..db82887c3 100644 --- a/web/style.css +++ b/web/style.css @@ -289,6 +289,11 @@ button.comfy-queue-btn { /* Context menu */ +.litegraph .dialog { + z-index: 1; + font-family: Arial; +} + .litegraph .litemenu-entry.has_submenu { position: relative; padding-right: 20px; From 8ef197f02852b65509d6ebe06df8794b96a07f2f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 11:26:57 -0400 Subject: [PATCH 190/208] Keep list of filenames and only refresh it when something changes. --- folder_paths.py | 47 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index fc37e52c7..f3d1b8773 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -31,6 +31,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +filename_list_cache = {} + if not os.path.exists(input_directory): os.makedirs(input_directory) @@ -111,12 +113,18 @@ def get_folder_paths(folder_name): return folder_names_and_paths[folder_name][0][:] def recursive_search(directory): + if not os.path.isdir(directory): + return [], {} result = [] + dirs = {directory: os.path.getmtime(directory)} for root, subdir, file in os.walk(directory, followlinks=True): for filepath in file: #we os.path,join directory with a blank string to generate a path separator at the end. result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) - return result + for d in subdir: + path = os.path.join(root, d) + dirs[path] = os.path.getmtime(path) + return result, dirs def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) @@ -136,13 +144,44 @@ def get_full_path(folder_name, filename): return None -def get_filename_list(folder_name): +def get_filename_list_(folder_name): global folder_names_and_paths output_list = set() folders = folder_names_and_paths[folder_name] + output_folders = {} for x in folders[0]: - output_list.update(filter_files_extensions(recursive_search(x), folders[1])) - return sorted(list(output_list)) + files, folders_all = recursive_search(x) + output_list.update(filter_files_extensions(files, folders[1])) + output_folders = {**output_folders, **folders_all} + + return (sorted(list(output_list)), output_folders) + +def cached_filename_list_(folder_name): + global filename_list_cache + global folder_names_and_paths + if folder_name not in filename_list_cache: + return None + out = filename_list_cache[folder_name] + for x in out[1]: + time_modified = out[1][x] + folder = x + if os.path.getmtime(folder) != time_modified: + return None + + folders = folder_names_and_paths[folder_name] + for x in folders[0]: + if x not in out[1]: + return None + + return out + +def get_filename_list(folder_name): + out = cached_filename_list_(folder_name) + if out is None: + out = get_filename_list_(folder_name) + global filename_list_cache + filename_list_cache[folder_name] = out + return out[0] def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): From 1f34bf08f06550fb2f041188b5a01d395240be17 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Wed, 31 May 2023 22:01:25 +0900 Subject: [PATCH 191/208] To support dynamic custom loading, separate the node registration process based on the defs in the registerNodes function. --- web/scripts/app.js | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 64adc3e6a..9ecad8489 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1010,6 +1010,11 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); + this.registerNodesFromDefs(defs); + await this.#invokeExtensionsAsync("registerCustomNodes"); + } + + async registerNodesFromDefs(defs) { await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); // Generate list of known widgets @@ -1082,8 +1087,6 @@ export class ComfyApp { LiteGraph.registerNodeType(nodeId, node); node.category = nodeData.category; } - - await this.#invokeExtensionsAsync("registerCustomNodes"); } /** From 8e8d6070f2e80aff0200bb3ad0f31716a98d5739 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Wed, 31 May 2023 23:26:56 +0900 Subject: [PATCH 192/208] race condition patch --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 9ecad8489..8a9c7ca49 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1010,7 +1010,7 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); - this.registerNodesFromDefs(defs); + await this.registerNodesFromDefs(defs); await this.#invokeExtensionsAsync("registerCustomNodes"); } From 03da8a34265bb333d03a51d7503697b5ede9b335 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 31 May 2023 13:03:24 -0400 Subject: [PATCH 193/208] This is useless for inference. --- comfy/sd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index ce17994f7..fa7bd8d32 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -743,7 +743,7 @@ def load_controlnet(ckpt_path, model=None): use_spatial_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) else: @@ -760,7 +760,7 @@ def load_controlnet(ckpt_path, model=None): use_linear_in_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) if pth: @@ -1045,7 +1045,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o } unet_config = { - "use_checkpoint": True, + "use_checkpoint": False, "image_size": 32, "out_channels": 4, "attention_resolutions": [ From d200fa131420a8871633b7321664db419aab2712 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 31 May 2023 19:00:01 -0500 Subject: [PATCH 194/208] Prevent callers from mutating folder lists --- folder_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index f3d1b8773..e179a28d4 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -181,7 +181,7 @@ def get_filename_list(folder_name): out = get_filename_list_(folder_name) global filename_list_cache filename_list_cache[folder_name] = out - return out[0] + return list(out[0]) def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): From 94680732d32b4b540251c122aee36df8d37266e1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 1 Jun 2023 03:52:51 -0400 Subject: [PATCH 195/208] Empty cache on mps. --- comfy/model_management.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 10a706793..60bcd786b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -389,7 +389,10 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - if xpu_available: + global vram_state + if vram_state == VRAMState.MPS: + torch.mps.empty_cache() + elif xpu_available: torch.xpu.empty_cache() elif torch.cuda.is_available(): if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda From 5c38958e49efd11b5234cb5ff472d752698c5090 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 1 Jun 2023 04:04:35 -0400 Subject: [PATCH 196/208] Tweak lowvram model memory so it's closer to what it was before. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 60bcd786b..e9af7f3a7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -207,7 +207,7 @@ def load_model_gpu(model): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = model.model_size() current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.2 )) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM From 1bbd3f7fe16e6637bba232059d004a5fe7804a30 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Jun 2023 22:15:06 -0500 Subject: [PATCH 197/208] Send back prompt number from prompt/ endpoint --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 72c565a63..0b64df147 100644 --- a/server.py +++ b/server.py @@ -361,7 +361,7 @@ class PromptServer(): prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) - return web.json_response({"prompt_id": prompt_id}) + return web.json_response({"prompt_id": prompt_id, "number": number}) else: print("invalid prompt:", valid[1]) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) From b5dd15c67ad3f4dbdc23811f40a4c121e318bfe9 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Jun 2023 23:26:23 -0500 Subject: [PATCH 198/208] System stats endpoint --- comfy/model_management.py | 27 +++++++++++++++++++++++++++ server.py | 24 ++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e9af7f3a7..3b7b1dbf1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -308,6 +308,33 @@ def pytorch_attention_flash_attention(): return True return False +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + mem_total_torch + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled diff --git a/server.py b/server.py index 0b64df147..acbc88f66 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,7 @@ import execution import uuid import json import glob +import torch from PIL import Image from io import BytesIO @@ -23,6 +24,7 @@ except ImportError: import mimetypes from comfy.cli_args import args import comfy.utils +import comfy.model_management @web.middleware async def cache_control(request: web.Request, handler): @@ -280,6 +282,28 @@ class PromptServer(): return web.Response(status=404) return web.json_response(dt["__metadata__"]) + @routes.get("/system_stats") + async def get_queue(request): + device_index = comfy.model_management.get_torch_device() + device = torch.device(device_index) + device_name = comfy.model_management.get_torch_device_name(device_index) + vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) + system_stats = { + "devices": [ + { + "name": device_name, + "type": device.type, + "index": device.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + } + ] + } + return web.json_response(system_stats) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) From 499641ebf1be190e20624ee352e9dc88884e3df1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 2 Jun 2023 00:14:41 -0500 Subject: [PATCH 199/208] More accurate total --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3b7b1dbf1..0ea0c71e5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -328,7 +328,7 @@ def get_total_memory(dev=None, torch_total_too=False): mem_reserved = stats['reserved_bytes.all.current'] _, mem_total_cuda = torch.cuda.mem_get_info(dev) mem_total_torch = mem_reserved - mem_total = mem_total_cuda + mem_total_torch + mem_total = mem_total_cuda if torch_total_too: return (mem_total, mem_total_torch) From 67892b5ac584ff8def01a5852246c364f8408d95 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Jun 2023 15:05:25 -0400 Subject: [PATCH 200/208] Refactor and improve model_management code related to free memory. --- comfy/model_management.py | 131 +++++++++++++++++++------------------- server.py | 6 +- 2 files changed, 68 insertions(+), 69 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0ea0c71e5..9c3147d76 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,7 @@ import psutil from enum import Enum from comfy.cli_args import args +import torch class VRAMState(Enum): CPU = 0 @@ -33,28 +34,67 @@ if args.directml is not None: lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import torch - if directml_enabled: - pass #TODO - else: - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) - total_ram = psutil.virtual_memory().total / (1024 * 1024) - if not args.normalvram and not args.cpu: - if lowvram_available and total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = VRAMState.LOW_VRAM - elif total_vram > total_ram * 1.1 and total_vram > 14336: - print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = VRAMState.HIGH_VRAM + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True except: pass +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.device(torch.cuda.current_device()) + +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + mem_total_torch = mem_total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + +total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) +total_ram = psutil.virtual_memory().total / (1024 * 1024) +print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +if not args.normalvram and not args.cpu: + if lowvram_available and total_vram <= 4096: + print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + set_vram_to = VRAMState.LOW_VRAM + elif total_vram > total_ram * 1.1 and total_vram > 14336: + print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") + vram_state = VRAMState.HIGH_VRAM + try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError except: @@ -128,29 +168,17 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_torch_device_name(device): if hasattr(device, 'type'): - return "{}".format(device.type) - return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + if device.type == "cuda": + return "{} {}".format(device, torch.cuda.get_device_name(device)) + else: + return "{}".format(device.type) + else: + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Using device:", get_torch_device_name(get_torch_device())) + print("Device:", get_torch_device_name(get_torch_device())) except: print("Could not pick default device.") @@ -308,33 +336,6 @@ def pytorch_attention_flash_attention(): return True return False -def get_total_memory(dev=None, torch_total_too=False): - global xpu_available - global directml_enabled - if dev is None: - dev = get_torch_device() - - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_total = psutil.virtual_memory().total - else: - if directml_enabled: - mem_total = 1024 * 1024 * 1024 #TODO - mem_total_torch = mem_total - elif xpu_available: - mem_total = torch.xpu.get_device_properties(dev).total_memory - mem_total_torch = mem_total - else: - stats = torch.cuda.memory_stats(dev) - mem_reserved = stats['reserved_bytes.all.current'] - _, mem_total_cuda = torch.cuda.mem_get_info(dev) - mem_total_torch = mem_reserved - mem_total = mem_total_cuda - - if torch_total_too: - return (mem_total, mem_total_torch) - else: - return mem_total - def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled diff --git a/server.py b/server.py index acbc88f66..5be822a6f 100644 --- a/server.py +++ b/server.py @@ -7,7 +7,6 @@ import execution import uuid import json import glob -import torch from PIL import Image from io import BytesIO @@ -284,9 +283,8 @@ class PromptServer(): @routes.get("/system_stats") async def get_queue(request): - device_index = comfy.model_management.get_torch_device() - device = torch.device(device_index) - device_name = comfy.model_management.get_torch_device_name(device_index) + device = comfy.model_management.get_torch_device() + device_name = comfy.model_management.get_torch_device_name(device) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) system_stats = { From 871a86593ae7eb96518d326c83cfded5d41c6fa6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Jun 2023 16:34:47 -0400 Subject: [PATCH 201/208] Smarter filename list caching. --- folder_paths.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index e179a28d4..8cee6afde 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,4 +1,5 @@ import os +import time supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) @@ -154,7 +155,7 @@ def get_filename_list_(folder_name): output_list.update(filter_files_extensions(files, folders[1])) output_folders = {**output_folders, **folders_all} - return (sorted(list(output_list)), output_folders) + return (sorted(list(output_list)), output_folders, time.perf_counter()) def cached_filename_list_(folder_name): global filename_list_cache @@ -162,6 +163,8 @@ def cached_filename_list_(folder_name): if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] + if time.perf_counter() < (out[2] + 0.5): + return out for x in out[1]: time_modified = out[1][x] folder = x From 66e588d837275b26b428f737692357090ad41426 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Jun 2023 16:48:56 -0400 Subject: [PATCH 202/208] Ignore folder path directories that don't exist. --- folder_paths.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 8cee6afde..a1bf1444d 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -173,8 +173,9 @@ def cached_filename_list_(folder_name): folders = folder_names_and_paths[folder_name] for x in folders[0]: - if x not in out[1]: - return None + if os.path.isdir(x): + if x not in out[1]: + return None return out From 700491d81a9faf5370a0c54d869e902bbfc839ec Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 01:47:21 -0400 Subject: [PATCH 203/208] Implement global average pooling for controlnet. --- comfy/sd.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index fa7bd8d32..336fee4a6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): return torch.cat([tensor] * batched_number, dim=0) class ControlNet: - def __init__(self, control_model, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None): self.control_model = control_model self.cond_hint_original = None self.cond_hint = None @@ -630,6 +630,7 @@ class ControlNet: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond_txt, batched_number): control_prev = None @@ -665,6 +666,9 @@ class ControlNet: key = 'output' index = i x = control[i] + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + x *= self.strength if x.dtype != output_dtype and not autocast_enabled: x = x.to(output_dtype) @@ -695,7 +699,7 @@ class ControlNet: self.cond_hint = None def copy(self): - c = ControlNet(self.control_model) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c.cond_hint_original = self.cond_hint_original c.strength = self.strength return c @@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None): if use_fp16: control_model = control_model.half() - control = ControlNet(control_model) + global_average_pooling = False + if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling) return control class T2IAdapter: From 0a5fefd6213e3116359e0738533a9e3b733506c5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:05:37 -0400 Subject: [PATCH 204/208] Cleanups and fixes for model_management.py Hopefully fix regression on MPS and CPU. --- comfy/model_management.py | 63 ++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9c3147d76..a492ca6b9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -4,16 +4,22 @@ from comfy.cli_args import args import torch class VRAMState(Enum): - CPU = 0 + DISABLED = 0 NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 - MPS = 5 + SHARED = 5 + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM +cpu_state = CPUState.GPU total_vram = 0 @@ -40,15 +46,25 @@ try: except: pass +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS +except: + pass + +if args.cpu: + cpu_state = CPUState.CPU + def get_torch_device(): global xpu_available global directml_enabled + global cpu_state if directml_enabled: global directml_device return directml_device - if vram_state == VRAMState.MPS: + if cpu_state == CPUState.MPS: return torch.device("mps") - if vram_state == VRAMState.CPU: + if cpu_state == CPUState.CPU: return torch.device("cpu") else: if xpu_available: @@ -143,8 +159,6 @@ if args.force_fp32: print("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True - - if lowvram_available: try: import accelerate @@ -157,17 +171,15 @@ if lowvram_available: lowvram_available = False -try: - if torch.backends.mps.is_available(): - vram_state = VRAMState.MPS -except: - pass +if cpu_state != CPUState.GPU: + vram_state = VRAMState.DISABLED -if args.cpu: - vram_state = VRAMState.CPU +if cpu_state == CPUState.MPS: + vram_state = VRAMState.SHARED print(f"Set vram state to: {vram_state.name}") + def get_torch_device_name(device): if hasattr(device, 'type'): if device.type == "cuda": @@ -241,13 +253,9 @@ def load_model_gpu(model): current_loaded_model = model - if vram_set_state == VRAMState.CPU: + if vram_set_state == VRAMState.DISABLED: pass - elif vram_set_state == VRAMState.MPS: - mps_device = torch.device("mps") - real_model.to(mps_device) - pass - elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False real_model.to(get_torch_device()) else: @@ -263,7 +271,7 @@ def load_model_gpu(model): def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state - if vram_state == VRAMState.CPU: + if vram_state == VRAMState.DISABLED: return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: @@ -308,7 +316,8 @@ def get_autocast_device(dev): def xformers_enabled(): global xpu_available global directml_enabled - if vram_state == VRAMState.CPU: + global cpu_state + if cpu_state != CPUState.GPU: return False if xpu_available: return False @@ -380,12 +389,12 @@ def maximum_batch_area(): return int(max(area, 0)) def cpu_mode(): - global vram_state - return vram_state == VRAMState.CPU + global cpu_state + return cpu_state == CPUState.CPU def mps_mode(): - global vram_state - return vram_state == VRAMState.MPS + global cpu_state + return cpu_state == CPUState.MPS def should_use_fp16(): global xpu_available @@ -417,8 +426,8 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - global vram_state - if vram_state == VRAMState.MPS: + global cpu_state + if cpu_state == CPUState.MPS: torch.mps.empty_cache() elif xpu_available: torch.xpu.empty_cache() From 32f282c861eabcee42fdec702b96ebc8924c9834 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:19:10 -0400 Subject: [PATCH 205/208] Search box style fix. --- web/style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/web/style.css b/web/style.css index db82887c3..47571a16e 100644 --- a/web/style.css +++ b/web/style.css @@ -336,6 +336,7 @@ button.comfy-queue-btn { z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; overflow: hidden; + display: block; } .litegraph.litesearchbox input, From c092ffcc18f0a44c062fe914ebda05b29bdcfbc0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:46:52 -0400 Subject: [PATCH 206/208] Latest litegraph from upstream. --- web/lib/litegraph.core.js | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 95f4a2735..908ed5f16 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -8099,11 +8099,15 @@ LGraphNode.prototype.executeAction = function(action) bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR; hovercolor = hovercolor || "#555"; textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR; - var yFix = y + LiteGraph.NODE_TITLE_HEIGHT + 2; // fix the height with the title - var pos = this.mouse; - var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); - pos = this.last_click_position; - var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); + var pos = this.ds.convertOffsetToCanvas(this.graph_mouse); + var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h ); + pos = this.last_click_position ? [this.last_click_position[0], this.last_click_position[1]] : null; + if(pos) { + var rect = this.canvas.getBoundingClientRect(); + pos[0] -= rect.left; + pos[1] -= rect.top; + } + var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h ); ctx.fillStyle = hover ? hovercolor : bgcolor; if(clicked) From 0764bb5218ea49fdeeaebbfc10c6f5b87a8bc879 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:47:20 -0400 Subject: [PATCH 207/208] Move node properties panel from double click to menu option. --- web/lib/litegraph.core.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 908ed5f16..a60848d77 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -7294,10 +7294,6 @@ LGraphNode.prototype.executeAction = function(action) if (this.onShowNodePanel) { this.onShowNodePanel(n); } - else - { - this.showShowNodePanel(n); - } if (this.onNodeDblClicked) { this.onNodeDblClicked(n); @@ -13071,6 +13067,10 @@ LGraphNode.prototype.executeAction = function(action) has_submenu: true, callback: LGraphCanvas.onShowMenuNodeProperties }, + { + content: "Properties Panel", + callback: function(item, options, e, menu, node) { LGraphCanvas.active_canvas.showShowNodePanel(node) } + }, null, { content: "Title", From 126b4050dc34daabca51c236bfb5cc31dd48056d Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sun, 4 Jun 2023 01:25:49 +0900 Subject: [PATCH 208/208] Crash fix for intermittent crashes that occur when opening MaskEditor. (#732) --- web/extensions/core/maskeditor.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 4b0c12747..6cb3a5385 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -314,11 +314,11 @@ class MaskEditorDialog extends ComfyDialog { imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); // update mask - backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); maskCanvas.width = drawWidth; maskCanvas.height = drawHeight; maskCanvas.style.top = imgCanvas.offsetTop + "px"; maskCanvas.style.left = imgCanvas.offsetLeft + "px"; + backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); });