From af365e4dd152b23cd6cf993ddf9ed7c7e7088b39 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 03:12:18 -0500 Subject: [PATCH 01/98] All the unet ops with weights are now handled by comfy.ops --- comfy/controlnet.py | 10 ++++++++++ comfy/ldm/modules/attention.py | 18 ++++-------------- .../modules/diffusionmodules/openaimodel.py | 13 ++++++------- comfy/ops.py | 8 ++++++++ 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 433381df6..6dd99afdc 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -5,6 +5,7 @@ import comfy.utils import comfy.model_management import comfy.model_detection import comfy.model_patcher +import comfy.ops import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -248,6 +249,15 @@ class ControlLoraOps: else: raise ValueError(f"unsupported dimensions: {dims}") + class Conv3d(comfy.ops.Conv3d): + pass + + class GroupNorm(comfy.ops.GroupNorm): + pass + + class LayerNorm(comfy.ops.LayerNorm): + pass + class ControlLora(ControlNet): def __init__(self, control_weights, global_average_pooling=False, device=None): diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index f68452382..c2b85a691 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -83,16 +83,6 @@ class FeedForward(nn.Module): def forward(self, x): return self.net(x) - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) @@ -414,10 +404,10 @@ class BasicTransformerBlock(nn.Module): self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none - self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) - self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head @@ -559,7 +549,7 @@ class SpatialTransformer(nn.Module): context_dim = [context_dim] * depth self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) if not use_linear: self.proj_in = operations.Conv2d(in_channels, inner_dim, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 48264892c..855c3d1f4 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -177,7 +177,7 @@ class ResBlock(TimestepBlock): padding = kernel_size // 2 self.in_layers = nn.Sequential( - nn.GroupNorm(32, channels, dtype=dtype, device=device), + operations.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), ) @@ -206,12 +206,11 @@ class ResBlock(TimestepBlock): ), ) self.out_layers = nn.Sequential( - nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), + operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) - ), + operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) + , ) if self.out_channels == channels: @@ -810,13 +809,13 @@ class UNetModel(nn.Module): self._feature_size += ch self.out = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), nn.SiLU(), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - nn.GroupNorm(32, ch, dtype=self.dtype, device=device), + operations.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) diff --git a/comfy/ops.py b/comfy/ops.py index 0bfb698aa..deb849d63 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -13,6 +13,14 @@ class Conv3d(torch.nn.Conv3d): def reset_parameters(self): return None +class GroupNorm(torch.nn.GroupNorm): + def reset_parameters(self): + return None + +class LayerNorm(torch.nn.LayerNorm): + def reset_parameters(self): + return None + def conv_nd(dims, *args, **kwargs): if dims == 2: return Conv2d(*args, **kwargs) From 31b0f6f3d8034371e95024d6bba5c193db79bd9d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 11:10:00 -0500 Subject: [PATCH 02/98] UNET weights can now be stored in fp8. --fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats supported by pytorch. --- comfy/cldm/cldm.py | 4 ++-- comfy/cli_args.py | 5 ++++- comfy/controlnet.py | 16 ++++++++++++---- .../ldm/modules/diffusionmodules/openaimodel.py | 4 ++-- comfy/model_base.py | 13 ++++++++++++- comfy/model_management.py | 15 +++++++++++++++ 6 files changed, 47 insertions(+), 10 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 76a525b37..bbe5891e6 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -283,7 +283,7 @@ class ControlNet(nn.Module): return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) def forward(self, x, hint, timesteps, context, y=None, **kwargs): - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) guided_hint = self.input_hint_block(hint, emb, context) @@ -295,7 +295,7 @@ class ControlNet(nn.Module): assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for module, zero_conv in zip(self.input_blocks, self.zero_convs): if guided_hint is not None: h = module(h, emb, context) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 72fce1087..58d034802 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -55,7 +55,10 @@ fp_group = parser.add_mutually_exclusive_group() fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") -parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group = parser.add_mutually_exclusive_group() +fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") +fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") fpvae_group = parser.add_mutually_exclusive_group() fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.") diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6dd99afdc..5921e6b1d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,6 +1,7 @@ import torch import math import os +import contextlib import comfy.utils import comfy.model_management import comfy.model_detection @@ -147,24 +148,31 @@ class ControlNet(ControlBase): else: return None + dtype = self.control_model.dtype + if comfy.model_management.supports_dtype(self.device, dtype): + precision_scope = lambda a: contextlib.nullcontext(a) + else: + precision_scope = torch.autocast + dtype = torch.float32 + output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] y = cond.get('y', None) if y is not None: - y = y.to(self.control_model.dtype) + y = y.to(dtype) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) + with precision_scope(comfy.model_management.get_autocast_device(self.device)): + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 855c3d1f4..12efd833c 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -841,14 +841,14 @@ class UNetModel(nn.Module): self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) - h = x.type(self.dtype) + h = x for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) diff --git a/comfy/model_base.py b/comfy/model_base.py index 253ea6667..5bfcc391d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -5,6 +5,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import comfy.model_management import comfy.conds from enum import Enum +import contextlib from . import utils class ModelType(Enum): @@ -61,6 +62,13 @@ class BaseModel(torch.nn.Module): context = c_crossattn dtype = self.get_dtype() + + if comfy.model_management.supports_dtype(xc.device, dtype): + precision_scope = lambda a: contextlib.nullcontext(a) + else: + precision_scope = torch.autocast + dtype = torch.float32 + xc = xc.to(dtype) t = self.model_sampling.timestep(t).float() context = context.to(dtype) @@ -70,7 +78,10 @@ class BaseModel(torch.nn.Module): if hasattr(extra, "to"): extra = extra.to(dtype) extra_conds[o] = extra - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + + with precision_scope(comfy.model_management.get_autocast_device(xc.device)): + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) def get_dtype(self): diff --git a/comfy/model_management.py b/comfy/model_management.py index d4acd8950..18d15f9d0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -459,6 +459,10 @@ def unet_inital_load_device(parameters, dtype): def unet_dtype(device=None, model_params=0): if args.bf16_unet: return torch.bfloat16 + if args.fp8_e4m3fn_unet: + return torch.float8_e4m3fn + if args.fp8_e5m2_unet: + return torch.float8_e5m2 if should_use_fp16(device=device, model_params=model_params): return torch.float16 return torch.float32 @@ -515,6 +519,17 @@ def get_autocast_device(dev): return dev.type return "cuda" +def supports_dtype(device, dtype): #TODO + if dtype == torch.float32: + return True + if torch.device("cpu") == device: + return False + if dtype == torch.float16: + return True + if dtype == torch.bfloat16: + return True + return False + def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: From ca82ade7652c80727b402f51a115feb5df4ad27a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 11:52:06 -0500 Subject: [PATCH 03/98] Use .itemsize to get dtype size for fp8. --- comfy/model_management.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 18d15f9d0..94d596969 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -430,6 +430,13 @@ def dtype_size(dtype): dtype_size = 4 if dtype == torch.float16 or dtype == torch.bfloat16: dtype_size = 2 + elif dtype == torch.float32: + dtype_size = 4 + else: + try: + dtype_size = dtype.itemsize + except: #Old pytorch doesn't have .itemsize + pass return dtype_size def unet_offload_device(): From be3468ddd5db871e3943003e0fd7a2219c7d02e6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 12:49:00 -0500 Subject: [PATCH 04/98] Less useless downcasting. --- comfy/sd1_clip.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 58acb97fc..4e9f6bffe 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -84,12 +84,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.inner_name = inner_name if dtype is not None: - self.transformer.to(dtype) inner_model = getattr(self.transformer, self.inner_name) if hasattr(inner_model, "embeddings"): - inner_model.embeddings.to(torch.float32) + embeddings_bak = inner_model.embeddings.to(torch.float32) + inner_model.embeddings = None + self.transformer.to(dtype) + inner_model.embeddings = embeddings_bak else: - self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) + previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True) + self.transformer.to(dtype) + self.transformer.set_input_embeddings(previous_inputs) self.max_length = max_length if freeze: From 26b1c0a77150be2253f88e4cd106a11112d96d59 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 13:47:41 -0500 Subject: [PATCH 05/98] Fix control lora on fp8. --- comfy/controlnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 5921e6b1d..6d37aa74f 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -208,7 +208,7 @@ class ControlLoraOps: def forward(self, input): if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) else: return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) @@ -247,7 +247,7 @@ class ControlLoraOps: def forward(self, input): if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) else: return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) From 9b655d4fd72903d33af101177b0cb9576c5babd3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 4 Dec 2023 21:55:19 -0500 Subject: [PATCH 06/98] Fix memory issue with control loras. --- comfy/sample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 034db97ee..bcbed3343 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -101,7 +101,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative samples = samples.cpu() cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) + cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): @@ -113,6 +113,6 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.cpu() cleanup_additional_models(models) - cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))) + cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples From 1bbd65ab307a5510af1b2e6145fad0b6c583fe6c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 5 Dec 2023 12:48:41 -0500 Subject: [PATCH 07/98] Missed this one. --- comfy/ldm/modules/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c2b85a691..d3348c472 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -384,7 +384,7 @@ class BasicTransformerBlock(nn.Module): self.is_res = inner_dim == dim if self.ff_in: - self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) + self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device) self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.disable_self_attn = disable_self_attn From 44265e081031a4647b295b32e7f6b77ab71c80c9 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 5 Dec 2023 20:27:13 +0000 Subject: [PATCH 08/98] Allow connecting primitivenode to reroutes --- web/extensions/core/rerouteNode.js | 56 ++++++++++++++++--- web/extensions/core/widgetInputs.js | 86 +++++++++++++++++++++-------- 2 files changed, 112 insertions(+), 30 deletions(-) diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index 499a171da..cfa952f3c 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -1,10 +1,11 @@ import { app } from "../../scripts/app.js"; +import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs.js"; // Node that allows you to redirect connections for cleaner graphs app.registerExtension({ name: "Comfy.RerouteNode", - registerCustomNodes() { + registerCustomNodes(app) { class RerouteNode { constructor() { if (!this.properties) { @@ -16,6 +17,12 @@ app.registerExtension({ this.addInput("", "*"); this.addOutput(this.properties.showOutputText ? "*" : "", "*"); + this.onAfterGraphConfigured = function () { + requestAnimationFrame(() => { + this.onConnectionsChange(LiteGraph.INPUT, null, true, null); + }); + }; + this.onConnectionsChange = function (type, index, connected, link_info) { this.applyOrientation(); @@ -54,8 +61,7 @@ app.registerExtension({ // We've found a circle currentNode.disconnectInput(link.target_slot); currentNode = null; - } - else { + } else { // Move the previous node currentNode = node; } @@ -94,8 +100,11 @@ app.registerExtension({ updateNodes.push(node); } else { // We've found an output - const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null; - if (inputType && nodeOutType !== inputType) { + const nodeOutType = + node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type + ? node.inputs[link.target_slot].type + : null; + if (inputType && inputType !== "*" && nodeOutType !== inputType) { // The output doesnt match our input so disconnect it node.disconnectInput(link.target_slot); } else { @@ -111,6 +120,9 @@ app.registerExtension({ const displayType = inputType || outputType || "*"; const color = LGraphCanvas.link_type_colors[displayType]; + let widgetConfig; + let targetWidget; + let widgetType; // Update the types of each node for (const node of updateNodes) { // If we dont have an input type we are always wildcard but we'll show the output type @@ -125,10 +137,38 @@ app.registerExtension({ const link = app.graph.links[l]; if (link) { link.color = color; + + if (app.configuringGraph) continue; + const targetNode = app.graph.getNodeById(link.target_id); + const targetInput = targetNode.inputs?.[link.target_slot]; + if (targetInput?.widget) { + const config = getWidgetConfig(targetInput); + if (!widgetConfig) { + widgetConfig = config[1] ?? {}; + widgetType = config[0]; + } + if (!targetWidget) { + targetWidget = targetNode.widgets?.find((w) => w.name === targetInput.widget.name); + } + + const merged = mergeIfValid(targetInput, [config[0], widgetConfig]); + if (merged.customConfig) { + widgetConfig = merged.customConfig; + } + } } } } + for (const node of updateNodes) { + if (widgetConfig && outputType) { + node.inputs[0].widget = { name: "value" }; + setWidgetConfig(node.inputs[0], [widgetType ?? displayType, widgetConfig], targetWidget); + } else { + setWidgetConfig(node.inputs[0], null); + } + } + if (inputNode) { const link = app.graph.links[inputNode.inputs[0].link]; if (link) { @@ -173,8 +213,8 @@ app.registerExtension({ }, { // naming is inverted with respect to LiteGraphNode.horizontal - // LiteGraphNode.horizontal == true means that - // each slot in the inputs and outputs are layed out horizontally, + // LiteGraphNode.horizontal == true means that + // each slot in the inputs and outputs are layed out horizontally, // which is the opposite of the visual orientation of the inputs and outputs as a node content: "Set " + (this.properties.horizontal ? "Horizontal" : "Vertical"), callback: () => { @@ -187,7 +227,7 @@ app.registerExtension({ applyOrientation() { this.horizontal = this.properties.horizontal; if (this.horizontal) { - // we correct the input position, because LiteGraphNode.horizontal + // we correct the input position, because LiteGraphNode.horizontal // doesn't account for title presence // which reroute nodes don't have this.inputs[0].pos = [this.size[0] / 2, 0]; diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index b6fa411f7..c33f7346a 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -5,6 +5,11 @@ const CONVERTED_TYPE = "converted-widget"; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; const CONFIG = Symbol(); const GET_CONFIG = Symbol(); +const TARGET = Symbol(); // Used for reroutes to specify the real target widget + +export function getWidgetConfig(slot) { + return slot.widget[CONFIG] ?? slot.widget[GET_CONFIG](); +} function getConfig(widgetName) { const { nodeData } = this.constructor; @@ -100,7 +105,6 @@ function getWidgetType(config) { return { type }; } - function isValidCombo(combo, obj) { // New input isnt a combo if (!(obj instanceof Array)) { @@ -121,6 +125,31 @@ function isValidCombo(combo, obj) { return true; } +export function setWidgetConfig(slot, config, target) { + if (!slot.widget) return; + if (config) { + slot.widget[GET_CONFIG] = () => config; + slot.widget[TARGET] = target; + } else { + delete slot.widget; + } + + if (slot.link) { + const link = app.graph.links[slot.link]; + if (link) { + const originNode = app.graph.getNodeById(link.origin_id); + if (originNode.type === "PrimitiveNode") { + if (config) { + originNode.recreateWidget(); + } else if(!app.configuringGraph) { + originNode.disconnectOutput(0); + originNode.onLastDisconnect(); + } + } + } + } +} + export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) { if (!config1) { config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG](); @@ -434,14 +463,20 @@ app.registerExtension({ for (const linkInfo of links) { const node = this.graph.getNodeById(linkInfo.target_id); const input = node.inputs[linkInfo.target_slot]; - const widgetName = input.widget.name; - if (widgetName) { - const widget = node.widgets.find((w) => w.name === widgetName); - if (widget) { - widget.value = this.widgets[0].value; - if (widget.callback) { - widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); - } + let widget; + if (input.widget[TARGET]) { + widget = input.widget[TARGET]; + } else { + const widgetName = input.widget.name; + if (widgetName) { + widget = node.widgets.find((w) => w.name === widgetName); + } + } + + if (widget) { + widget.value = this.widgets[0].value; + if (widget.callback) { + widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); } } } @@ -494,14 +529,13 @@ app.registerExtension({ this.#mergeWidgetConfig(); if (!links?.length) { - this.#onLastDisconnect(); + this.onLastDisconnect(); } } } onConnectOutput(slot, type, input, target_node, target_slot) { // Fires before the link is made allowing us to reject it if it isn't valid - // No widget, we cant connect if (!input.widget) { if (!(input.type in ComfyWidgets)) return false; @@ -519,6 +553,10 @@ app.registerExtension({ #onFirstConnection(recreating) { // First connection can fire before the graph is ready on initial load so random things can be missing + if (!this.outputs[0].links) { + this.onLastDisconnect(); + return; + } const linkId = this.outputs[0].links[0]; const link = this.graph.links[linkId]; if (!link) return; @@ -546,10 +584,10 @@ app.registerExtension({ this.outputs[0].name = type; this.outputs[0].widget = widget; - this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating); + this.#createWidget(widget[CONFIG] ?? config, theirNode, widget.name, recreating, widget[TARGET]); } - #createWidget(inputData, node, widgetName, recreating) { + #createWidget(inputData, node, widgetName, recreating, targetWidget) { let type = inputData[0]; if (type instanceof Array) { @@ -563,7 +601,9 @@ app.registerExtension({ widget = this.addWidget(type, "value", null, () => {}, {}); } - if (node?.widgets && widget) { + if (targetWidget) { + widget.value = targetWidget.value; + } else if (node?.widgets && widget) { const theirWidget = node.widgets.find((w) => w.name === widgetName); if (theirWidget) { widget.value = theirWidget.value; @@ -577,7 +617,7 @@ app.registerExtension({ } addValueControlWidgets(this, widget, control_value, undefined, inputData); let filter = this.widgets_values?.[2]; - if(filter && this.widgets.length === 3) { + if (filter && this.widgets.length === 3) { this.widgets[2].value = filter; } } @@ -610,12 +650,14 @@ app.registerExtension({ } } - #recreateWidget() { - const values = this.widgets.map((w) => w.value); + recreateWidget() { + const values = this.widgets?.map((w) => w.value); this.#removeWidgets(); this.#onFirstConnection(true); - for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; - return this.widgets[0]; + if (values?.length) { + for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; + } + return this.widgets?.[0]; } #mergeWidgetConfig() { @@ -631,7 +673,7 @@ app.registerExtension({ if (links?.length < 2 && hasConfig) { // Copy the widget options from the source if (links.length) { - this.#recreateWidget(); + this.recreateWidget(); } return; @@ -657,7 +699,7 @@ app.registerExtension({ // Only allow connections where the configs match const output = this.outputs[0]; const config2 = input.widget[GET_CONFIG](); - return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget); + return !!mergeIfValid.call(this, output, config2, forceUpdate, this.recreateWidget); } #removeWidgets() { @@ -672,7 +714,7 @@ app.registerExtension({ } } - #onLastDisconnect() { + onLastDisconnect() { // We cant remove + re-add the output here as if you drag a link over the same link // it removes, then re-adds, causing it to break this.outputs[0].type = "*"; From a99da6667fadf4683ec24e44546cd5ce8f9e7aff Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 5 Dec 2023 20:28:05 +0000 Subject: [PATCH 09/98] reroute + primitive tests --- tests-ui/tests/widgetInputs.test.js | 174 +++++++++++++++++++++++++++- tests-ui/utils/ezgraph.js | 5 +- 2 files changed, 171 insertions(+), 8 deletions(-) diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index 8e191adf0..67e3fa341 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -1,7 +1,13 @@ // @ts-check /// -const { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined } = require("../utils"); +const { + start, + makeNodeDef, + checkBeforeAndAfterReload, + assertNotNullOrUndefined, + createDefaultWorkflow, +} = require("../utils"); const lg = require("../utils/litegraph"); /** @@ -36,7 +42,7 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWi if (controlWidgetCount) { const controlWidget = primitive.widgets.control_after_generate; expect(controlWidget.widget.type).toBe("combo"); - if(widgetType === "combo") { + if (widgetType === "combo") { const filterWidget = primitive.widgets.control_filter_list; expect(filterWidget.widget.type).toBe("string"); } @@ -308,8 +314,8 @@ describe("widget inputs", () => { const { ez } = await start({ mockNodeDefs: { ...makeNodeDef("TestNode1", {}, [["A", "B"]]), - ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true}] }), - ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true}] }), + ...makeNodeDef("TestNode2", { example: [["A", "B"], { forceInput: true }] }), + ...makeNodeDef("TestNode3", { example: [["A", "B", "C"], { forceInput: true }] }), }, }); @@ -330,7 +336,7 @@ describe("widget inputs", () => { const n1 = ez.TestNode1(); n1.widgets.example.convertToInput(); - const p = ez.PrimitiveNode() + const p = ez.PrimitiveNode(); p.outputs[0].connectTo(n1.inputs[0]); const value = p.widgets.value; @@ -380,7 +386,7 @@ describe("widget inputs", () => { // Check random control.value = "randomize"; filter.value = "/D/"; - for(let i = 0; i < 100; i++) { + for (let i = 0; i < 100; i++) { control["afterQueued"](); expect(value.value === "D" || value.value === "DD").toBeTruthy(); } @@ -392,4 +398,160 @@ describe("widget inputs", () => { control["afterQueued"](); expect(value.value).toBe("B"); }); + + describe("reroutes", () => { + async function checkOutput(graph, values) { + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { ckpt_name: "model1.safetensors" }, class_type: "CheckpointLoaderSimple" }, + 2: { inputs: { text: "positive", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 3: { inputs: { text: "negative", clip: ["1", 1] }, class_type: "CLIPTextEncode" }, + 4: { + inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 }, + class_type: "EmptyLatentImage", + }, + 5: { + inputs: { + seed: 0, + steps: 20, + cfg: 8, + sampler_name: "euler", + scheduler: values?.scheduler ?? "normal", + denoise: 1, + model: ["1", 0], + positive: ["2", 0], + negative: ["3", 0], + latent_image: ["4", 0], + }, + class_type: "KSampler", + }, + 6: { inputs: { samples: ["5", 0], vae: ["1", 2] }, class_type: "VAEDecode" }, + 7: { + inputs: { filename_prefix: values.filename_prefix ?? "ComfyUI", images: ["6", 0] }, + class_type: "SaveImage", + }, + }); + } + + async function waitForWidget(node) { + // widgets are created slightly after the graph is ready + // hard to find an exact hook to get these so just wait for them to be ready + for (let i = 0; i < 10; i++) { + await new Promise((r) => setTimeout(r, 10)); + if (node.widgets?.value) { + return; + } + } + } + + it("can connect primitive via a reroute path to a widget input", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.sampler.widgets.scheduler.convertToInput(); + nodes.save.widgets.filename_prefix.convertToInput(); + + let widthReroute = ez.Reroute(); + let schedulerReroute = ez.Reroute(); + let fileReroute = ez.Reroute(); + + let widthNext = widthReroute; + let schedulerNext = schedulerReroute; + let fileNext = fileReroute; + + for (let i = 0; i < 5; i++) { + let next = ez.Reroute(); + widthNext.outputs[0].connectTo(next.inputs[0]); + widthNext = next; + + next = ez.Reroute(); + schedulerNext.outputs[0].connectTo(next.inputs[0]); + schedulerNext = next; + + next = ez.Reroute(); + fileNext.outputs[0].connectTo(next.inputs[0]); + fileNext = next; + } + + widthNext.outputs[0].connectTo(nodes.empty.inputs.width); + schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler); + fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix); + + let widthPrimitive = ez.PrimitiveNode(); + let schedulerPrimitive = ez.PrimitiveNode(); + let filePrimitive = ez.PrimitiveNode(); + + widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]); + schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]); + filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]); + expect(widthPrimitive.widgets.value.value).toBe(512); + widthPrimitive.widgets.value.value = 1024; + expect(schedulerPrimitive.widgets.value.value).toBe("normal"); + schedulerPrimitive.widgets.value.value = "simple"; + expect(filePrimitive.widgets.value.value).toBe("ComfyUI"); + filePrimitive.widgets.value.value = "ComfyTest"; + + await checkBeforeAndAfterReload(graph, async () => { + widthPrimitive = graph.find(widthPrimitive); + schedulerPrimitive = graph.find(schedulerPrimitive); + filePrimitive = graph.find(filePrimitive); + await waitForWidget(filePrimitive); + expect(widthPrimitive.widgets.length).toBe(2); + expect(schedulerPrimitive.widgets.length).toBe(3); + expect(filePrimitive.widgets.length).toBe(1); + + await checkOutput(graph, { + width: 1024, + scheduler: "simple", + filename_prefix: "ComfyTest", + }); + }); + }); + it("can connect primitive via a reroute path to multiple widget inputs", async () => { + const { ez, graph } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.empty.widgets.width.convertToInput(); + nodes.empty.widgets.height.convertToInput(); + nodes.empty.widgets.batch_size.convertToInput(); + + let reroute = ez.Reroute(); + let prevReroute = reroute; + for (let i = 0; i < 5; i++) { + const next = ez.Reroute(); + prevReroute.outputs[0].connectTo(next.inputs[0]); + prevReroute = next; + } + + const r1 = ez.Reroute(prevReroute.outputs[0]); + const r2 = ez.Reroute(prevReroute.outputs[0]); + const r3 = ez.Reroute(r2.outputs[0]); + const r4 = ez.Reroute(r2.outputs[0]); + + r1.outputs[0].connectTo(nodes.empty.inputs.width); + r3.outputs[0].connectTo(nodes.empty.inputs.height); + r4.outputs[0].connectTo(nodes.empty.inputs.batch_size); + + let primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(reroute.inputs[0]); + expect(primitive.widgets.value.value).toBe(1); + primitive.widgets.value.value = 64; + + await checkBeforeAndAfterReload(graph, async (r) => { + primitive = graph.find(primitive); + await waitForWidget(primitive); + + // Ensure widget configs are merged + expect(primitive.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(primitive.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(primitive.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + await checkOutput(graph, { + width: 64, + height: 64, + batch_size: 64, + }); + }); + }); + }); }); diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 898b82db0..3101aa292 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -117,7 +117,7 @@ export class EzOutput extends EzSlot { const inp = input.input; const inName = inp.name || inp.label || inp.type; throw new Error( - `Connecting from ${input.node.node.type}[${inName}#${input.index}] -> ${this.node.node.type}[${ + `Connecting from ${input.node.node.type}#${input.node.id}[${inName}#${input.index}] -> ${this.node.node.type}#${this.node.id}[${ this.output.name ?? this.output.type }#${this.index}] failed.` ); @@ -179,6 +179,7 @@ export class EzWidget { set value(v) { this.widget.value = v; + this.widget.callback?.call?.(this.widget, v) } get isConvertedToInput() { @@ -319,7 +320,7 @@ export class EzGraph { } stringify() { - return JSON.stringify(this.app.graph.serialize(), undefined, "\t"); + return JSON.stringify(this.app.graph.serialize(), undefined); } /** From bcc469a2c95d40e0d64152d1531bc95d84fa98c5 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 5 Dec 2023 20:28:52 +0000 Subject: [PATCH 10/98] try to stop test failing --- tests-ui/tests/extensions.test.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests-ui/tests/extensions.test.js b/tests-ui/tests/extensions.test.js index b82e55c32..159e5113a 100644 --- a/tests-ui/tests/extensions.test.js +++ b/tests-ui/tests/extensions.test.js @@ -52,7 +52,7 @@ describe("extensions", () => { const nodeNames = Object.keys(defs); const nodeCount = nodeNames.length; expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); - for (let i = 0; i < nodeCount; i++) { + for (let i = 0; i < 10; i++) { // It should be send the JS class and the original JSON definition const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]; const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]; @@ -133,7 +133,7 @@ describe("extensions", () => { expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2); expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1); expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2); - }); + }, 15000); it("allows custom nodeDefs and widgets to be registered", async () => { const widgetMock = jest.fn((node, inputName, inputData, app) => { From 8de6f94f5cc5ee4ae690876108c1dd7705e59595 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 5 Dec 2023 21:02:10 +0000 Subject: [PATCH 11/98] Allow widget placeholder replacement on primitives --- web/extensions/core/saveImageExtraOutput.js | 73 ++------------------- web/extensions/core/widgetInputs.js | 13 +++- web/scripts/utils.js | 67 +++++++++++++++++++ 3 files changed, 83 insertions(+), 70 deletions(-) create mode 100644 web/scripts/utils.js diff --git a/web/extensions/core/saveImageExtraOutput.js b/web/extensions/core/saveImageExtraOutput.js index 99e2213bf..a0506b43b 100644 --- a/web/extensions/core/saveImageExtraOutput.js +++ b/web/extensions/core/saveImageExtraOutput.js @@ -1,5 +1,5 @@ import { app } from "../../scripts/app.js"; - +import { applyTextReplacements } from "../../scripts/utils.js"; // Use widget values and dates in output filenames app.registerExtension({ @@ -7,84 +7,19 @@ app.registerExtension({ async beforeRegisterNodeDef(nodeType, nodeData, app) { if (nodeData.name === "SaveImage") { const onNodeCreated = nodeType.prototype.onNodeCreated; - - // Simple date formatter - const parts = { - d: (d) => d.getDate(), - M: (d) => d.getMonth() + 1, - h: (d) => d.getHours(), - m: (d) => d.getMinutes(), - s: (d) => d.getSeconds(), - }; - const format = - Object.keys(parts) - .map((k) => k + k + "?") - .join("|") + "|yyy?y?"; - - function formatDate(text, date) { - return text.replace(new RegExp(format, "g"), function (text) { - if (text === "yy") return (date.getFullYear() + "").substring(2); - if (text === "yyyy") return date.getFullYear(); - if (text[0] in parts) { - const p = parts[text[0]](date); - return (p + "").padStart(text.length, "0"); - } - return text; - }); - } - - // When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R + // When the SaveImage node is created we want to override the serialization of the output name widget to run our S&R nodeType.prototype.onNodeCreated = function () { const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; const widget = this.widgets.find((w) => w.name === "filename_prefix"); widget.serializeValue = () => { - return widget.value.replace(/%([^%]+)%/g, function (match, text) { - const split = text.split("."); - if (split.length !== 2) { - // Special handling for dates - if (split[0].startsWith("date:")) { - return formatDate(split[0].substring(5), new Date()); - } - - if (text !== "width" && text !== "height") { - // Dont warn on standard replacements - console.warn("Invalid replacement pattern", text); - } - return match; - } - - // Find node with matching S&R property name - let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]); - // If we cant, see if there is a node with that title - if (!nodes.length) { - nodes = app.graph._nodes.filter((n) => n.title === split[0]); - } - if (!nodes.length) { - console.warn("Unable to find node", split[0]); - return match; - } - - if (nodes.length > 1) { - console.warn("Multiple nodes matched", split[0], "using first match"); - } - - const node = nodes[0]; - - const widget = node.widgets?.find((w) => w.name === split[1]); - if (!widget) { - console.warn("Unable to find widget", split[1], "on node", split[0], node); - return match; - } - - return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_"); - }); + return applyTextReplacements(app, widget.value); }; return r; }; } else { - // When any other node is created add a property to alias the node + // When any other node is created add a property to alias the node const onNodeCreated = nodeType.prototype.onNodeCreated; nodeType.prototype.onNodeCreated = function () { const r = onNodeCreated ? onNodeCreated.apply(this, arguments) : undefined; diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index b6fa411f7..b8dd47d0e 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -1,5 +1,6 @@ import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js"; import { app } from "../../scripts/app.js"; +import { applyTextReplacements } from "../../scripts/utils.js"; const CONVERTED_TYPE = "converted-widget"; const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; @@ -405,11 +406,16 @@ app.registerExtension({ }; }, registerCustomNodes() { + const replacePropertyName = "Run widget replace on values"; class PrimitiveNode { constructor() { this.addOutput("connect to widget input", "*"); this.serialize_widgets = true; this.isVirtualNode = true; + + if (!this.properties || !(replacePropertyName in this.properties)) { + this.addProperty(replacePropertyName, false, "boolean"); + } } applyToGraph(extraLinks = []) { @@ -430,6 +436,11 @@ app.registerExtension({ } let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks]; + let v = this.widgets?.[0].value; + if(v && this.properties[replacePropertyName]) { + v = applyTextReplacements(app, v); + } + // For each output link copy our value over the original widget value for (const linkInfo of links) { const node = this.graph.getNodeById(linkInfo.target_id); @@ -438,7 +449,7 @@ app.registerExtension({ if (widgetName) { const widget = node.widgets.find((w) => w.name === widgetName); if (widget) { - widget.value = this.widgets[0].value; + widget.value = v; if (widget.callback) { widget.callback(widget.value, app.canvas, node, app.canvas.graph_mouse, {}); } diff --git a/web/scripts/utils.js b/web/scripts/utils.js new file mode 100644 index 000000000..401aca9e4 --- /dev/null +++ b/web/scripts/utils.js @@ -0,0 +1,67 @@ +// Simple date formatter +const parts = { + d: (d) => d.getDate(), + M: (d) => d.getMonth() + 1, + h: (d) => d.getHours(), + m: (d) => d.getMinutes(), + s: (d) => d.getSeconds(), +}; +const format = + Object.keys(parts) + .map((k) => k + k + "?") + .join("|") + "|yyy?y?"; + +function formatDate(text, date) { + return text.replace(new RegExp(format, "g"), function (text) { + if (text === "yy") return (date.getFullYear() + "").substring(2); + if (text === "yyyy") return date.getFullYear(); + if (text[0] in parts) { + const p = parts[text[0]](date); + return (p + "").padStart(text.length, "0"); + } + return text; + }); +} + +export function applyTextReplacements(app, value) { + return value.replace(/%([^%]+)%/g, function (match, text) { + const split = text.split("."); + if (split.length !== 2) { + // Special handling for dates + if (split[0].startsWith("date:")) { + return formatDate(split[0].substring(5), new Date()); + } + + if (text !== "width" && text !== "height") { + // Dont warn on standard replacements + console.warn("Invalid replacement pattern", text); + } + return match; + } + + // Find node with matching S&R property name + let nodes = app.graph._nodes.filter((n) => n.properties?.["Node name for S&R"] === split[0]); + // If we cant, see if there is a node with that title + if (!nodes.length) { + nodes = app.graph._nodes.filter((n) => n.title === split[0]); + } + if (!nodes.length) { + console.warn("Unable to find node", split[0]); + return match; + } + + if (nodes.length > 1) { + console.warn("Multiple nodes matched", split[0], "using first match"); + } + + const node = nodes[0]; + + const widget = node.widgets?.find((w) => w.name === split[1]); + if (!widget) { + console.warn("Unable to find widget", split[1], "on node", split[0], node); + return match; + } + + return ((widget.value ?? "") + "").replaceAll(/\/|\\/g, "_"); + }); +} From 8112a0d9fcb80c341afa53798f62acdf618cee2b Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:56:03 +0900 Subject: [PATCH 12/98] improve: Mask Editor (#2171) * renewal mask editor * fix: ignoring keydown when 2nd open --- web/extensions/core/maskeditor.js | 356 +++++++++++++++++++++--------- 1 file changed, 251 insertions(+), 105 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 8ace79562..1ea4dbcaa 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -33,6 +33,18 @@ function loadedImageToBlob(image) { return blob; } +function loadImage(imagePath) { + return new Promise((resolve, reject) => { + const image = new Image(); + + image.onload = function() { + resolve(image); + }; + + image.src = imagePath; + }); +} + async function uploadMask(filepath, formData) { await api.fetchApi('/upload/mask', { method: 'POST', @@ -50,25 +62,25 @@ async function uploadMask(filepath, formData) { ClipspaceDialog.invalidatePreview(); } -function prepareRGB(image, backupCanvas, backupCtx) { +function prepare_mask(image, maskCanvas, maskCtx) { // paste mask data into alpha channel - backupCtx.drawImage(image, 0, 0, backupCanvas.width, backupCanvas.height); - const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); + maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height); + const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height); - // refine mask image - for (let i = 0; i < backupData.data.length; i += 4) { - if(backupData.data[i+3] == 255) - backupData.data[i+3] = 0; + // invert mask + for (let i = 0; i < maskData.data.length; i += 4) { + if(maskData.data[i+3] == 255) + maskData.data[i+3] = 0; else - backupData.data[i+3] = 255; + maskData.data[i+3] = 255; - backupData.data[i] = 0; - backupData.data[i+1] = 0; - backupData.data[i+2] = 0; + maskData.data[i] = 0; + maskData.data[i+1] = 0; + maskData.data[i+2] = 0; } - backupCtx.globalCompositeOperation = 'source-over'; - backupCtx.putImageData(backupData, 0, 0); + maskCtx.globalCompositeOperation = 'source-over'; + maskCtx.putImageData(maskData, 0, 0); } class MaskEditorDialog extends ComfyDialog { @@ -184,14 +196,13 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(bottom_panel); document.body.appendChild(brush); - var brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { + this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { self.brush_size = event.target.value; self.updateBrushPreview(self, null, null); }); var clearButton = this.createLeftButton("Clear", () => { self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); - self.backupCtx.clearRect(0, 0, self.backupCanvas.width, self.backupCanvas.height); }); var cancelButton = this.createRightButton("Cancel", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); @@ -213,34 +224,37 @@ class MaskEditorDialog extends ComfyDialog { bottom_panel.appendChild(clearButton); bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); - bottom_panel.appendChild(brush_size_slider); + bottom_panel.appendChild(this.brush_size_slider); + + imgCanvas.style.position = "absolute"; + maskCanvas.style.position = "absolute"; - imgCanvas.style.position = "relative"; imgCanvas.style.top = "200"; imgCanvas.style.left = "0"; - maskCanvas.style.position = "absolute"; + maskCanvas.style.top = imgCanvas.style.top; + maskCanvas.style.left = imgCanvas.style.left; } - show() { + async show() { + this.zoom_ratio = 1.0; + this.pan_x = 0; + this.pan_y = 0; + if(!this.is_layout_created) { // layout const imgCanvas = document.createElement('canvas'); const maskCanvas = document.createElement('canvas'); - const backupCanvas = document.createElement('canvas'); imgCanvas.id = "imageCanvas"; maskCanvas.id = "maskCanvas"; - backupCanvas.id = "backupCanvas"; this.setlayout(imgCanvas, maskCanvas); // prepare content this.imgCanvas = imgCanvas; this.maskCanvas = maskCanvas; - this.backupCanvas = backupCanvas; - this.maskCtx = maskCanvas.getContext('2d'); - this.backupCtx = backupCanvas.getContext('2d'); + this.maskCtx = maskCanvas.getContext('2d', {willReadFrequently: true }); this.setEventHandler(maskCanvas); @@ -252,6 +266,8 @@ class MaskEditorDialog extends ComfyDialog { mutations.forEach(function(mutation) { if (mutation.type === 'attributes' && mutation.attributeName === 'style') { if(self.last_display_style && self.last_display_style != 'none' && self.element.style.display == 'none') { + document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); + self.brush.style.display = "none"; ComfyApp.onClipspaceEditorClosed(); } @@ -264,7 +280,8 @@ class MaskEditorDialog extends ComfyDialog { observer.observe(this.element, config); } - this.setImages(this.imgCanvas, this.backupCanvas); + // The keydown event needs to be reconfigured when closing the dialog as it gets removed. + document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); if(ComfyApp.clipspace_return_node) { this.saveButton.innerText = "Save to node"; @@ -275,97 +292,157 @@ class MaskEditorDialog extends ComfyDialog { this.saveButton.disabled = false; this.element.style.display = "block"; + this.element.style.width = "85%"; + this.element.style.margin = "0 7.5%"; + this.element.style.height = "100vh"; + this.element.style.top = "50%"; + this.element.style.left = "42%"; this.element.style.zIndex = 8888; // NOTE: alert dialog must be high priority. + + await this.setImages(this.imgCanvas); + + this.is_visible = true; } isOpened() { return this.element.style.display == "block"; } - setImages(imgCanvas, backupCanvas) { - const imgCtx = imgCanvas.getContext('2d'); - const backupCtx = backupCanvas.getContext('2d'); + invalidateCanvas(orig_image, mask_image) { + this.imgCanvas.width = orig_image.width; + this.imgCanvas.height = orig_image.height; + + this.maskCanvas.width = orig_image.width; + this.maskCanvas.height = orig_image.height; + + let imgCtx = this.imgCanvas.getContext('2d', {willReadFrequently: true }); + let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true }); + + imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height); + prepare_mask(mask_image, this.maskCanvas, maskCtx); + } + + async setImages(imgCanvas) { + let self = this; + + const imgCtx = imgCanvas.getContext('2d', {willReadFrequently: true }); const maskCtx = this.maskCtx; const maskCanvas = this.maskCanvas; - backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); imgCtx.clearRect(0,0,this.imgCanvas.width,this.imgCanvas.height); maskCtx.clearRect(0,0,this.maskCanvas.width,this.maskCanvas.height); // image load - const orig_image = new Image(); - window.addEventListener("resize", () => { - // repositioning - imgCanvas.width = window.innerWidth - 250; - imgCanvas.height = window.innerHeight - 200; - - // redraw image - let drawWidth = orig_image.width; - let drawHeight = orig_image.height; - if (orig_image.width > imgCanvas.width) { - drawWidth = imgCanvas.width; - drawHeight = (drawWidth / orig_image.width) * orig_image.height; - } - - if (drawHeight > imgCanvas.height) { - drawHeight = imgCanvas.height; - drawWidth = (drawHeight / orig_image.height) * orig_image.width; - } - - imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); - - // update mask - maskCanvas.width = drawWidth; - maskCanvas.height = drawHeight; - maskCanvas.style.top = imgCanvas.offsetTop + "px"; - maskCanvas.style.left = imgCanvas.offsetLeft + "px"; - backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); - maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); - }); - const filepath = ComfyApp.clipspace.images; - const touched_image = new Image(); - - touched_image.onload = function() { - backupCanvas.width = touched_image.width; - backupCanvas.height = touched_image.height; - - prepareRGB(touched_image, backupCanvas, backupCtx); - }; - const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) alpha_url.searchParams.delete('channel'); alpha_url.searchParams.delete('preview'); alpha_url.searchParams.set('channel', 'a'); - touched_image.src = alpha_url; + let mask_image = await loadImage(alpha_url); // original image load - orig_image.onload = function() { - window.dispatchEvent(new Event('resize')); - }; - const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); rgb_url.searchParams.delete('channel'); rgb_url.searchParams.set('channel', 'rgb'); - orig_image.src = rgb_url; - this.image = orig_image; + this.image = new Image(); + this.image.onload = function() { + maskCanvas.width = self.image.width; + maskCanvas.height = self.image.height; + + self.invalidateCanvas(self.image, mask_image); + self.initializeCanvasPanZoom(); + }; + this.image.src = rgb_url; } - setEventHandler(maskCanvas) { - maskCanvas.addEventListener("contextmenu", (event) => { - event.preventDefault(); - }); + initializeCanvasPanZoom() { + // set initialize + let drawWidth = this.image.width; + let drawHeight = this.image.height; + let width = this.element.clientWidth; + let height = this.element.clientHeight; + + if (this.image.width > width) { + drawWidth = width; + drawHeight = (drawWidth / this.image.width) * this.image.height; + } + + if (drawHeight > height) { + drawHeight = height; + drawWidth = (drawHeight / this.image.height) * this.image.width; + } + + this.zoom_ratio = drawWidth/this.image.width; + + const canvasX = (width - drawWidth) / 2; + const canvasY = (height - drawHeight) / 2; + this.pan_x = canvasX; + this.pan_y = canvasY; + + this.invalidatePanZoom(); + } + + + invalidatePanZoom() { + let raw_width = this.image.width * this.zoom_ratio; + let raw_height = this.image.height * this.zoom_ratio; + + if(this.pan_x + raw_width < 10) { + this.pan_x = 10 - raw_width; + } + + if(this.pan_y + raw_height < 10) { + this.pan_y = 10 - raw_height; + } + + let width = `${raw_width}px`; + let height = `${raw_height}px`; + + let left = `${this.pan_x}px`; + let top = `${this.pan_y}px`; + + this.maskCanvas.style.width = width; + this.maskCanvas.style.height = height; + this.maskCanvas.style.left = left; + this.maskCanvas.style.top = top; + + this.imgCanvas.style.width = width; + this.imgCanvas.style.height = height; + this.imgCanvas.style.left = left; + this.imgCanvas.style.top = top; + } + + + setEventHandler(maskCanvas) { const self = this; - maskCanvas.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); - maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); - document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); - maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); - maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); - maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); - maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); - document.addEventListener('keydown', MaskEditorDialog.handleKeyDown); + + if(!this.handler_registered) { + maskCanvas.addEventListener("contextmenu", (event) => { + event.preventDefault(); + }); + + this.element.addEventListener('wheel', (event) => this.handleWheelEvent(self,event)); + this.element.addEventListener('pointermove', (event) => this.pointMoveEvent(self,event)); + this.element.addEventListener('touchmove', (event) => this.pointMoveEvent(self,event)); + + this.element.addEventListener('dragstart', (event) => { + if(event.ctrlKey) { + event.preventDefault(); + } + }); + + maskCanvas.addEventListener('pointerdown', (event) => this.handlePointerDown(self,event)); + maskCanvas.addEventListener('pointermove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('touchmove', (event) => this.draw_move(self,event)); + maskCanvas.addEventListener('pointerover', (event) => { this.brush.style.display = "block"; }); + maskCanvas.addEventListener('pointerleave', (event) => { this.brush.style.display = "none"; }); + + document.addEventListener('pointerup', MaskEditorDialog.handlePointerUp); + + this.handler_registered = true; + } } brush_size = 10; @@ -378,8 +455,10 @@ class MaskEditorDialog extends ComfyDialog { const self = MaskEditorDialog.instance; if (event.key === ']') { self.brush_size = Math.min(self.brush_size+2, 100); + self.brush_slider_input.value = self.brush_size; } else if (event.key === '[') { self.brush_size = Math.max(self.brush_size-2, 1); + self.brush_slider_input.value = self.brush_size; } else if(event.key === 'Enter') { self.save(); } @@ -389,6 +468,10 @@ class MaskEditorDialog extends ComfyDialog { static handlePointerUp(event) { event.preventDefault(); + + this.mousedown_x = null; + this.mousedown_y = null; + MaskEditorDialog.instance.drawing_mode = false; } @@ -398,24 +481,70 @@ class MaskEditorDialog extends ComfyDialog { var centerX = self.cursorX; var centerY = self.cursorY; - brush.style.width = self.brush_size * 2 + "px"; - brush.style.height = self.brush_size * 2 + "px"; - brush.style.left = (centerX - self.brush_size) + "px"; - brush.style.top = (centerY - self.brush_size) + "px"; + brush.style.width = self.brush_size * 2 * this.zoom_ratio + "px"; + brush.style.height = self.brush_size * 2 * this.zoom_ratio + "px"; + brush.style.left = (centerX - self.brush_size * this.zoom_ratio) + "px"; + brush.style.top = (centerY - self.brush_size * this.zoom_ratio) + "px"; } handleWheelEvent(self, event) { - if(event.deltaY < 0) - self.brush_size = Math.min(self.brush_size+2, 100); - else - self.brush_size = Math.max(self.brush_size-2, 1); + event.preventDefault(); - self.brush_slider_input.value = self.brush_size; + if(event.ctrlKey) { + // zoom canvas + if(event.deltaY < 0) { + this.zoom_ratio = Math.min(10.0, this.zoom_ratio+0.2); + } + else { + this.zoom_ratio = Math.max(0.2, this.zoom_ratio-0.2); + } + + this.invalidatePanZoom(); + } + else { + // adjust brush size + if(event.deltaY < 0) + this.brush_size = Math.min(this.brush_size+2, 100); + else + this.brush_size = Math.max(this.brush_size-2, 1); + + this.brush_slider_input.value = this.brush_size; + + this.updateBrushPreview(this); + } + } + + pointMoveEvent(self, event) { + this.cursorX = event.pageX; + this.cursorY = event.pageY; self.updateBrushPreview(self); + + if(event.ctrlKey) { + event.preventDefault(); + self.pan_move(self, event); + } + } + + pan_move(self, event) { + if(event.buttons == 1) { + if(this.mousedown_x) { + let deltaX = this.mousedown_x - event.clientX; + let deltaY = this.mousedown_y - event.clientY; + + self.pan_x = this.mousedown_pan_x - deltaX; + self.pan_y = this.mousedown_pan_y - deltaY; + + self.invalidatePanZoom(); + } + } } draw_move(self, event) { + if(event.ctrlKey) { + return; + } + event.preventDefault(); this.cursorX = event.pageX; @@ -439,6 +568,9 @@ class MaskEditorDialog extends ComfyDialog { y = event.targetTouches[0].clientY - maskRect.top; } + x /= self.zoom_ratio; + y /= self.zoom_ratio; + var brush_size = this.brush_size; if(event instanceof PointerEvent && event.pointerType == 'pen') { brush_size *= event.pressure; @@ -489,8 +621,8 @@ class MaskEditorDialog extends ComfyDialog { } else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { const maskRect = self.maskCanvas.getBoundingClientRect(); - const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; - const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; + const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; var brush_size = this.brush_size; if(event instanceof PointerEvent && event.pointerType == 'pen') { @@ -540,6 +672,17 @@ class MaskEditorDialog extends ComfyDialog { } handlePointerDown(self, event) { + if(event.ctrlKey) { + if (event.buttons == 1) { + this.mousedown_x = event.clientX; + this.mousedown_y = event.clientY; + + this.mousedown_pan_x = this.pan_x; + this.mousedown_pan_y = this.pan_y; + } + return; + } + var brush_size = this.brush_size; if(event instanceof PointerEvent && event.pointerType == 'pen') { brush_size *= event.pressure; @@ -551,8 +694,8 @@ class MaskEditorDialog extends ComfyDialog { event.preventDefault(); const maskRect = self.maskCanvas.getBoundingClientRect(); - const x = event.offsetX || event.targetTouches[0].clientX - maskRect.left; - const y = event.offsetY || event.targetTouches[0].clientY - maskRect.top; + const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; + const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; self.maskCtx.beginPath(); if (event.button == 0) { @@ -570,15 +713,18 @@ class MaskEditorDialog extends ComfyDialog { } async save() { - const backupCtx = this.backupCanvas.getContext('2d', {willReadFrequently:true}); + const backupCanvas = document.createElement('canvas'); + const backupCtx = backupCanvas.getContext('2d', {willReadFrequently:true}); + backupCanvas.width = this.image.width; + backupCanvas.height = this.image.height; - backupCtx.clearRect(0,0,this.backupCanvas.width,this.backupCanvas.height); + backupCtx.clearRect(0,0, backupCanvas.width, backupCanvas.height); backupCtx.drawImage(this.maskCanvas, 0, 0, this.maskCanvas.width, this.maskCanvas.height, - 0, 0, this.backupCanvas.width, this.backupCanvas.height); + 0, 0, backupCanvas.width, backupCanvas.height); // paste mask data into alpha channel - const backupData = backupCtx.getImageData(0, 0, this.backupCanvas.width, this.backupCanvas.height); + const backupData = backupCtx.getImageData(0, 0, backupCanvas.width, backupCanvas.height); // refine mask image for (let i = 0; i < backupData.data.length; i += 4) { @@ -615,7 +761,7 @@ class MaskEditorDialog extends ComfyDialog { ComfyApp.clipspace.widgets[index].value = item; } - const dataURL = this.backupCanvas.toDataURL(); + const dataURL = backupCanvas.toDataURL(); const blob = dataURLToBlob(dataURL); let original_url = new URL(this.image.src); From 2db86b4676ed2b5c8551beea25dd2ef3fe3c4f66 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 6 Dec 2023 05:13:14 -0500 Subject: [PATCH 13/98] Slightly faster lora applying. --- comfy/model_management.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 94d596969..3588d3503 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -550,12 +550,12 @@ def cast_to_device(tensor, device, dtype, copy=False): if device_supports_cast: if copy: if tensor.device == device: - return tensor.to(dtype, copy=copy) - return tensor.to(device, copy=copy).to(dtype) + return tensor.to(dtype, copy=copy, non_blocking=True) + return tensor.to(device, copy=copy, non_blocking=True).to(dtype, non_blocking=True) else: - return tensor.to(device).to(dtype) + return tensor.to(device, non_blocking=True).to(dtype, non_blocking=True) else: - return tensor.to(dtype).to(device, copy=copy) + return tensor.to(device, dtype, copy=copy, non_blocking=True) def xformers_enabled(): global directml_enabled From 03eadbb53c82954ae5e42efa44903ed1319ff3d6 Mon Sep 17 00:00:00 2001 From: asagi4 <130366179+asagi4@users.noreply.github.com> Date: Wed, 6 Dec 2023 21:12:49 +0200 Subject: [PATCH 14/98] Make HyperTile deterministic --- comfy_extras/nodes_hypertile.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index 0d7d4c954..15736b835 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -2,9 +2,10 @@ import math from einops import rearrange -import random +# Use torch rng for consistency across generations +from torch import randint -def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int: +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: min_value = min(min_value, value) # All big divisors of value (inclusive) @@ -12,8 +13,7 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter ns = [value // i for i in divisors[:max_options]] # has at least 1 element - random.seed(counter) - idx = random.randint(0, len(ns) - 1) + idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() return ns[idx] @@ -42,7 +42,6 @@ class HyperTile: latent_tile_size = max(32, tile_size) // 8 self.temp = None - self.counter = 1 def hypertile_in(q, k, v, extra_options): if q.shape[-1] in apply_to: @@ -53,10 +52,8 @@ class HyperTile: h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 - nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter) - self.counter += 1 - nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter) - self.counter += 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size) + nw = random_divisor(w, latent_tile_size * factor, swap_size) if nh * nw > 1: q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) From fbdb14d4c4c3d2e783d585506c6b598487ec7a9d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 6 Dec 2023 15:55:09 -0500 Subject: [PATCH 15/98] Cleaner CLIP text encoder implementation. Use a simple CLIP model implementation instead of the one from transformers. This will allow some interesting things that would too hackish to implement using the transformers implementation. --- comfy/clip_model.py | 126 +++++++++++++++++++++++++++++++++ comfy/ldm/modules/attention.py | 23 ++++-- comfy/sd1_clip.py | 60 ++++++---------- comfy/sd2_clip.py | 6 +- comfy/sdxl_clip.py | 6 +- 5 files changed, 172 insertions(+), 49 deletions(-) create mode 100644 comfy/clip_model.py diff --git a/comfy/clip_model.py b/comfy/clip_model.py new file mode 100644 index 000000000..e6a7bfa66 --- /dev/null +++ b/comfy/clip_model.py @@ -0,0 +1,126 @@ +import torch +from comfy.ldm.modules.attention import optimized_attention_for_device + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device, operations): + super().__init__() + + self.heads = heads + self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x, mask=None, optimized_attention=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + + out = optimized_attention(q, k, v, self.heads, mask) + return self.out_proj(out) + +ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + +class CLIPMLP(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations): + super().__init__() + self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device) + self.activation = ACTIVATIONS[activation] + self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + +class CLIPLayer(torch.nn.Module): + def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations): + super().__init__() + self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations) + self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device) + self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations) + + def forward(self, x, mask=None, optimized_attention=None): + x += self.self_attn(self.layer_norm1(x), mask, optimized_attention) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations): + super().__init__() + self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) + + def forward(self, x, mask=None, intermediate_output=None): + optimized_attention = optimized_attention_for_device(x.device, mask=True) + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask, optimized_attention) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + +class CLIPEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + super().__init__() + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) + self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + #TODO: attention_mask + x, i = self.encoder(x, intermediate_output=intermediate_output) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + + pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),] + return x, i, pooled_output + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + return self.text_model(*args, **kwargs) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index d3348c472..8299b1d94 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -112,10 +112,13 @@ def attention_basic(q, k, v, heads, mask=None): del q, k if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + if mask.dtype == torch.bool: + mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + else: + sim += mask # attention, what we cannot get enough of sim = sim.softmax(dim=-1) @@ -340,6 +343,18 @@ else: if model_management.pytorch_attention_enabled(): optimized_attention_masked = attention_pytorch +def optimized_attention_for_device(device, mask=False): + if device == torch.device("cpu"): #TODO + if model_management.pytorch_attention_enabled(): + return attention_pytorch + else: + return attention_basic + if mask: + return optimized_attention_masked + + return optimized_attention + + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): super().__init__() diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4e9f6bffe..1acd972c4 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -1,12 +1,14 @@ import os -from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils +from transformers import CLIPTokenizer import comfy.ops import torch import traceback import zipfile from . import model_management import contextlib +import comfy.clip_model +import json def gen_empty_tokens(special_tokens, length): start_token = special_tokens.get("start", None) @@ -65,35 +67,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): "hidden" ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, - freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None, - special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig, - model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32 + freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS - self.num_layers = 12 - if textmodel_path is not None: - self.transformer = model_class.from_pretrained(textmodel_path) - else: - if textmodel_json_config is None: - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") - config = config_class.from_json_file(textmodel_json_config) - self.num_layers = config.num_hidden_layers - with comfy.ops.use_comfy_ops(device, dtype): - with modeling_utils.no_init_weights(): - self.transformer = model_class(config) - self.inner_name = inner_name - if dtype is not None: - inner_model = getattr(self.transformer, self.inner_name) - if hasattr(inner_model, "embeddings"): - embeddings_bak = inner_model.embeddings.to(torch.float32) - inner_model.embeddings = None - self.transformer.to(dtype) - inner_model.embeddings = embeddings_bak - else: - previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True) - self.transformer.to(dtype) - self.transformer.set_input_embeddings(previous_inputs) + if textmodel_json_config is None: + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") + + with open(textmodel_json_config) as f: + config = json.load(f) + + self.transformer = model_class(config, dtype, device, comfy.ops) + self.num_layers = self.transformer.num_layers self.max_length = max_length if freeze: @@ -108,7 +94,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = layer_norm_hidden_state if layer == "hidden": assert layer_idx is not None - assert abs(layer_idx) <= self.num_layers + assert abs(layer_idx) < self.num_layers self.clip_layer(layer_idx) self.layer_default = (self.layer, self.layer_idx) @@ -119,7 +105,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): param.requires_grad = False def clip_layer(self, layer_idx): - if abs(layer_idx) >= self.num_layers: + if abs(layer_idx) > self.num_layers: self.layer = "last" else: self.layer = "hidden" @@ -174,7 +160,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: + if self.transformer.dtype != torch.float32: precision_scope = torch.autocast else: precision_scope = lambda a, dtype: contextlib.nullcontext(a) @@ -190,20 +176,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if tokens[x, y] == max_token: break - outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden") + outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] + z = outputs[0] else: - z = outputs.hidden_states[self.layer_idx] - if self.layer_norm_hidden_state: - z = getattr(self.transformer, self.inner_name).final_layer_norm(z) + z = outputs[1] - if hasattr(outputs, "pooler_output"): - pooled_output = outputs.pooler_output.float() + if outputs[2] is not None: + pooled_output = outputs[2].float() else: pooled_output = None diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 2ee0ca055..9c878d54a 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -3,13 +3,13 @@ import torch import os class SD2ClipHModel(sd1_clip.SDClipModel): - def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): + def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): if layer == "penultimate": layer="hidden" - layer_idx=23 + layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index 673399e22..b35056bb9 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -3,13 +3,13 @@ import torch import os class SDXLClipG(sd1_clip.SDClipModel): - def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None): if layer == "penultimate": layer="hidden" layer_idx=-2 textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") - super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) def load_sd(self, sd): @@ -37,7 +37,7 @@ class SDXLTokenizer: class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False) self.clip_g = SDXLClipG(device=device, dtype=dtype) def clip_layer(self, layer_idx): From efb704c758f916bdf3b8fcaa3c2ade69d03a27f8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Dec 2023 02:51:02 -0500 Subject: [PATCH 16/98] Support attention masking in CLIP implementation. --- comfy/clip_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index e6a7bfa66..c61353dcf 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -100,8 +100,12 @@ class CLIPTextModel_(torch.nn.Module): def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): x = self.embeddings(input_tokens) - #TODO: attention_mask - x, i = self.encoder(x, intermediate_output=intermediate_output) + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + + x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output) x = self.final_layer_norm(x) if i is not None and final_layer_norm_intermediate: i = self.final_layer_norm(i) From cdff08102346f34b6d5bbe65f036a6731e685285 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 7 Dec 2023 15:22:35 -0500 Subject: [PATCH 17/98] Fix hypertile. --- comfy_extras/nodes_hypertile.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index 15736b835..e7446b2e5 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -13,7 +13,10 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: ns = [value // i for i in divisors[:max_options]] # has at least 1 element - idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() + if len(ns) - 1 > 0: + idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() + else: + idx = 0 return ns[idx] From 9ac0b487acf569ebe8a2d87ed750fed58b59262d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 8 Dec 2023 02:35:45 -0500 Subject: [PATCH 18/98] Make --gpu-only put intermediate values in GPU memory instead of cpu. --- comfy/clip_vision.py | 4 ++-- comfy/model_management.py | 6 ++++++ comfy/sample.py | 4 ++-- comfy/sd.py | 23 ++++++++++++----------- comfy/sd1_clip.py | 6 +++--- comfy/utils.py | 12 ++++++------ comfy_extras/nodes_canny.py | 2 +- comfy_extras/nodes_post_processing.py | 2 +- nodes.py | 6 +++--- 9 files changed, 36 insertions(+), 29 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 9e2e03d72..449be8e44 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -54,10 +54,10 @@ class ClipVisionModel(): t = outputs[k] if t is not None: if k == 'hidden_states': - outputs["penultimate_hidden_states"] = t[-2].cpu() + outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device()) outputs["hidden_states"] = None else: - outputs[k] = t.cpu() + outputs[k] = t.to(comfy.model_management.intermediate_device()) return outputs diff --git a/comfy/model_management.py b/comfy/model_management.py index 3588d3503..ef9bec545 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -508,6 +508,12 @@ def text_encoder_dtype(device=None): else: return torch.float32 +def intermediate_device(): + if args.gpu_only: + return get_torch_device() + else: + return torch.device("cpu") + def vae_device(): return get_torch_device() diff --git a/comfy/sample.py b/comfy/sample.py index bcbed3343..eadd6dcc8 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -98,7 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.cpu() + samples = samples.to(comfy.model_management.intermediate_device()) cleanup_additional_models(models) cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) @@ -111,7 +111,7 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent sigmas = sigmas.to(model.load_device) samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.cpu() + samples = samples.to(comfy.model_management.intermediate_device()) cleanup_additional_models(models) cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control"))) return samples diff --git a/comfy/sd.py b/comfy/sd.py index f4f84d0a0..43e201d36 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -190,6 +190,7 @@ class VAE: offload_device = model_management.vae_offload_device() self.vae_dtype = model_management.vae_dtype() self.first_stage_model.to(self.vae_dtype) + self.output_device = model_management.intermediate_device() self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) @@ -201,9 +202,9 @@ class VAE: decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( - (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar)) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output @@ -214,9 +215,9 @@ class VAE: pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() - samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar) + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) samples /= 3.0 return samples @@ -228,15 +229,15 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu") + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) + pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) - pixel_samples = pixel_samples.cpu().movedim(1,-1) + pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): @@ -252,10 +253,10 @@ class VAE: free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu") + samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) - samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() + samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 1acd972c4..4530168ab 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -39,7 +39,7 @@ class ClipTokenWeightEncoder: out, pooled = self.encode(to_encode) if pooled is not None: - first_pooled = pooled[0:1].cpu() + first_pooled = pooled[0:1].to(model_management.intermediate_device()) else: first_pooled = pooled @@ -56,8 +56,8 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - return out[-1:].cpu(), first_pooled - return torch.cat(output, dim=-2).cpu(), first_pooled + return out[-1:].to(model_management.intermediate_device()), first_pooled + return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" diff --git a/comfy/utils.py b/comfy/utils.py index 505577047..f8026ddab 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -376,7 +376,7 @@ def lanczos(samples, width, height): images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) - return result + return result.to(samples.device, samples.dtype) def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": @@ -405,17 +405,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) @torch.inference_mode() -def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None): - output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu") +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): + output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device) for b in range(samples.shape[0]): s = samples[b:b+1] - out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") - out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") + out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) + out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): s_in = s[:,:,y:y+tile_y,x:x+tile_x] - ps = function(s_in).cpu() + ps = function(s_in).to(output_device) mask = torch.ones_like(ps) feather = round(overlap * upscale_amount) for t in range(feather): diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 94d453f2c..730dded5f 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -291,7 +291,7 @@ class Canny: def detect_edge(self, image, low_threshold, high_threshold): output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1) + img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) return (img_out,) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 12704f545..71660f8a5 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -226,7 +226,7 @@ class Sharpen: batch_size, height, width, channels = image.shape kernel_size = sharpen_radius * 2 + 1 - kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10) + kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) center = kernel_size // 2 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) diff --git a/nodes.py b/nodes.py index 24e591fdd..db96e0e2d 100644 --- a/nodes.py +++ b/nodes.py @@ -947,8 +947,8 @@ class GLIGENTextBoxApply: return (c, ) class EmptyLatentImage: - def __init__(self, device="cpu"): - self.device = device + def __init__(self): + self.device = comfy.model_management.intermediate_device() @classmethod def INPUT_TYPES(s): @@ -961,7 +961,7 @@ class EmptyLatentImage: CATEGORY = "latent" def generate(self, width, height, batch_size=1): - latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) return ({"samples":latent}, ) From a4ec54a40d978c4249dc6a7e2d5133657d1fd109 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 8 Dec 2023 02:49:30 -0500 Subject: [PATCH 19/98] Add linear_start and linear_end to model_config.sampling_settings --- comfy/model_sampling.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 69c8b1f01..cc8745c10 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -22,10 +22,17 @@ class V_PREDICTION(EPS): class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None): super().__init__() - beta_schedule = "linear" + if model_config is not None: - beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule) - self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + beta_schedule = sampling_settings.get("beta_schedule", "linear") + linear_start = sampling_settings.get("linear_start", 0.00085) + linear_end = sampling_settings.get("linear_end", 0.012) + + self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) self.sigma_data = 1.0 def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, From 97015b6b383718bdc65cb617e3050069a156679d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 8 Dec 2023 16:02:08 -0500 Subject: [PATCH 20/98] Cleanup. --- comfy/samplers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 1d012a514..ffc1fe3ac 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -276,10 +276,7 @@ class KSamplerX0Inpaint(torch.nn.Module): x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed) if denoise_mask is not None: - out *= denoise_mask - - if denoise_mask is not None: - out += self.latent_image * latent_mask + out = out * denoise_mask + self.latent_image * latent_mask return out def simple_scheduler(model, steps): From 9aaf368a415d23cabd80ae30ba7e4bb918635b4a Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 9 Dec 2023 13:04:35 +0000 Subject: [PATCH 21/98] Fix internal reroutes connected to other groups --- tests-ui/tests/groupNode.test.js | 26 ++++++++++++++++++++++++++ web/extensions/core/groupNode.js | 4 ++++ 2 files changed, 30 insertions(+) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index ce54c1154..9bcb19e93 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -383,6 +383,32 @@ describe("group node", () => { getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id]) ); }); + test("groups can connect to each other via internal reroutes", async () => { + const { ez, graph, app } = await start(); + + const latent = ez.EmptyLatentImage(); + const vae = ez.VAELoader(); + const latentReroute = ez.Reroute(); + const vaeReroute = ez.Reroute(); + + latent.outputs[0].connectTo(latentReroute.inputs[0]); + vae.outputs[0].connectTo(vaeReroute.inputs[0]); + + const group1 = await convertToGroup(app, graph, "test", [latentReroute, vaeReroute]); + group1.menu.Clone.call(); + expect(app.graph._nodes).toHaveLength(4); + const group2 = graph.find(app.graph._nodes[3]); + expect(group2.node.type).toEqual("workflow/test"); + expect(group2.id).not.toEqual(group1.id); + + group1.outputs.VAE.connectTo(group2.inputs.VAE); + group1.outputs.LATENT.connectTo(group2.inputs.LATENT); + + const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE); + ez.PreviewImage(decode.outputs[0]); + + expect((await graph.toPrompt()).output).toEqual({}); + }); test("displays generated image on group node", async () => { const { ez, graph, app } = await start(); const nodes = createDefaultWorkflow(ez, graph); diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 6766f356d..9a1d9b207 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -602,6 +602,10 @@ export class GroupNodeHandler { innerNode = innerNode.getInputNode(0); } + if (l && GroupNodeHandler.isGroupNode(innerNode)) { + return innerNode.updateLink(l); + } + link.origin_id = innerNode.id; link.origin_slot = l?.origin_slot ?? output.slot; return link; From 080ef75c3148060bfccdf82f5f063e9a0cdacd0d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 9 Dec 2023 13:19:21 +0000 Subject: [PATCH 22/98] fix --- tests-ui/tests/groupNode.test.js | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index 9bcb19e93..dc9d4bd49 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -405,9 +405,14 @@ describe("group node", () => { group1.outputs.LATENT.connectTo(group2.inputs.LATENT); const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE); - ez.PreviewImage(decode.outputs[0]); + const preview = ez.PreviewImage(decode.outputs[0]); - expect((await graph.toPrompt()).output).toEqual({}); + expect((await graph.toPrompt()).output).toEqual({ + [latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" }, + [vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, + [decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" }, + [preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" }, + }); }); test("displays generated image on group node", async () => { const { ez, graph, app } = await start(); From 174eba8e957b4b885d4d510d53dca859226ba9ef Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 9 Dec 2023 11:56:31 -0500 Subject: [PATCH 23/98] Use own clip vision model implementation. --- comfy/clip_model.py | 70 ++++++++++++++++++++++++++++++++++++++++---- comfy/clip_vision.py | 33 +++++++++++---------- 2 files changed, 81 insertions(+), 22 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index c61353dcf..850b5fdbe 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -57,12 +57,7 @@ class CLIPEncoder(torch.nn.Module): self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) def forward(self, x, mask=None, intermediate_output=None): - optimized_attention = optimized_attention_for_device(x.device, mask=True) - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - if mask is not None: - mask += causal_mask - else: - mask = causal_mask + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None) if intermediate_output is not None: if intermediate_output < 0: @@ -105,6 +100,12 @@ class CLIPTextModel_(torch.nn.Module): mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output) x = self.final_layer_norm(x) if i is not None and final_layer_norm_intermediate: @@ -128,3 +129,60 @@ class CLIPTextModel(torch.nn.Module): def forward(self, *args, **kwargs): return self.text_model(*args, **kwargs) + +class CLIPVisionEmbeddings(torch.nn.Module): + def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None): + super().__init__() + self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device)) + + self.patch_embedding = operations.Conv2d( + in_channels=num_channels, + out_channels=embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=False, + dtype=dtype, + device=device + ) + + num_patches = (image_size // patch_size) ** 2 + num_positions = num_patches + 1 + self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + + def forward(self, pixel_values): + embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) + return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight + + +class CLIPVision(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations) + self.pre_layrnorm = operations.LayerNorm(embed_dim) + self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) + self.post_layernorm = operations.LayerNorm(embed_dim) + + def forward(self, pixel_values, attention_mask=None, intermediate_output=None): + x = self.embeddings(pixel_values) + x = self.pre_layrnorm(x) + #TODO: attention_mask? + x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output) + pooled_output = self.post_layernorm(x[:, 0, :]) + return x, i, pooled_output + +class CLIPVisionModelProjection(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.vision_model = CLIPVision(config_dict, dtype, device, operations) + self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False) + + def forward(self, *args, **kwargs): + x = self.vision_model(*args, **kwargs) + out = self.visual_projection(x[2]) + return (x[0], x[1], out) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 449be8e44..ae87c75b4 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,13 +1,20 @@ -from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils from .utils import load_torch_file, transformers_convert, common_upscale import os import torch import contextlib +import json import comfy.ops import comfy.model_patcher import comfy.model_management import comfy.utils +import comfy.clip_model + +class Output: + def __getitem__(self, key): + return getattr(self, key) + def __setitem__(self, key, item): + setattr(self, key, item) def clip_preprocess(image, size=224): mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) @@ -22,17 +29,16 @@ def clip_preprocess(image, size=224): class ClipVisionModel(): def __init__(self, json_config): - config = CLIPVisionConfig.from_json_file(json_config) + with open(json_config) as f: + config = json.load(f) + self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = torch.float32 if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): self.dtype = torch.float16 - with comfy.ops.use_comfy_ops(offload_device, self.dtype): - with modeling_utils.no_init_weights(): - self.model = CLIPVisionModelWithProjection(config) - self.model.to(self.dtype) + self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): @@ -48,17 +54,12 @@ class ClipVisionModel(): precision_scope = lambda a, b: contextlib.nullcontext(a) with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32): - outputs = self.model(pixel_values=pixel_values, output_hidden_states=True) - - for k in outputs: - t = outputs[k] - if t is not None: - if k == 'hidden_states': - outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device()) - outputs["hidden_states"] = None - else: - outputs[k] = t.to(comfy.model_management.intermediate_device()) + out = self.model(pixel_values=pixel_values, intermediate_output=-2) + outputs = Output() + outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) + outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) + outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) return outputs def convert_to_transformers(sd, prefix): From da74e3bbe3705d6c3141db8c19f5217f51c6d4a7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 9 Dec 2023 12:01:17 -0500 Subject: [PATCH 24/98] Update pytorch nightly packaging workflow. --- ...update_comfyui_and_python_dependencies.bat | 3 -- .../windows_base_files/run_nvidia_gpu.bat | 2 -- .../windows_release_nightly_pytorch.yml | 33 +++++++++++++++---- 3 files changed, 26 insertions(+), 12 deletions(-) delete mode 100755 .ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat delete mode 100755 .ci/nightly/windows_base_files/run_nvidia_gpu.bat diff --git a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat b/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat deleted file mode 100755 index 94f5d1023..000000000 --- a/.ci/nightly/update_windows/update_comfyui_and_python_dependencies.bat +++ /dev/null @@ -1,3 +0,0 @@ -..\python_embeded\python.exe .\update.py ..\ComfyUI\ -..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -pause diff --git a/.ci/nightly/windows_base_files/run_nvidia_gpu.bat b/.ci/nightly/windows_base_files/run_nvidia_gpu.bat deleted file mode 100755 index 8ee2f3402..000000000 --- a/.ci/nightly/windows_base_files/run_nvidia_gpu.bat +++ /dev/null @@ -1,2 +0,0 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --use-pytorch-cross-attention -pause diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index b793f7fe2..90e09d27a 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -2,6 +2,24 @@ name: "Windows Release Nightly pytorch" on: workflow_dispatch: + inputs: + cu: + description: 'cuda version' + required: true + type: string + default: "121" + + python_minor: + description: 'python minor version' + required: true + type: string + default: "12" + + python_patch: + description: 'python patch version' + required: true + type: string + default: "1" # push: # branches: # - master @@ -20,21 +38,21 @@ jobs: persist-credentials: false - uses: actions/setup-python@v4 with: - python-version: '3.11.6' + python-version: 3.${{ inputs.python_minor }}.${{ inputs.python_patch }} - shell: bash run: | cd .. cp -r ComfyUI ComfyUI_copy - curl https://www.python.org/ftp/python/3.11.6/python-3.11.6-embed-amd64.zip -o python_embeded.zip + curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip unzip python_embeded.zip -d python_embeded cd python_embeded - echo 'import site' >> ./python311._pth + echo 'import site' >> ./python3${{ inputs.python_minor }}._pth curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py - python -m pip wheel torch torchvision torchaudio aiohttp==3.8.5 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir + python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir ls ../temp_wheel_dir ./python.exe -s -m pip install --pre ../temp_wheel_dir/* - sed -i '1i../ComfyUI' ./python311._pth + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth cd .. git clone https://github.com/comfyanonymous/taesd @@ -49,9 +67,10 @@ jobs: mkdir update cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/windows_base_files/* ./ - cp -r ComfyUI/.ci/nightly/update_windows/* ./update/ - cp -r ComfyUI/.ci/nightly/windows_base_files/* ./ + echo "..\python_embeded\python.exe .\update.py ..\ComfyUI\\ + ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 + pause" > ./update/update_comfyui_and_python_dependencies.bat cd .. "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma -mx=8 -mfb=64 -md=32m -ms=on -mf=BCJ2 ComfyUI_windows_portable_nightly_pytorch.7z ComfyUI_windows_portable_nightly_pytorch From 9e411073e901f766118a7b82f613872fd745ecc2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 9 Dec 2023 13:41:30 -0500 Subject: [PATCH 25/98] Add instructions for those that have python 3.12 --- README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 450a012bb..167214c05 100644 --- a/README.md +++ b/README.md @@ -93,23 +93,27 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your VAE in: models/vae -Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier. +Note: pytorch stable does not support python 3.12 yet. If you have python 3.12 you will have to use the nightly version of pytorch. If you run into issues you should try python 3.11 instead. ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6``` -This is the command to install the nightly with ROCm 5.7 that might have some performance improvements: +This is the command to install the nightly with ROCm 5.7 which has a python 3.12 package and might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7``` ### NVIDIA -Nvidia users should install pytorch using this command: +Nvidia users should install stable pytorch using this command: ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121``` +This is the command to install pytorch nightly instead which has a python 3.12 package and might have performance improvements: + +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121``` + #### Troubleshooting If you get the "Torch not compiled with CUDA enabled" error, uninstall torch with: From cb63e230b41193601e48778111eff045391cfbe2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 9 Dec 2023 14:15:09 -0500 Subject: [PATCH 26/98] Make lora code a bit cleaner. --- comfy/lora.py | 14 +++++++------- comfy/model_patcher.py | 14 +++++++++++--- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 29c59d893..ecd518084 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -43,7 +43,7 @@ def load_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -64,7 +64,7 @@ def load_lora(lora, to_load): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2) + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -116,7 +116,7 @@ def load_lora(lora, to_load): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2) + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) w_norm_name = "{}.w_norm".format(x) @@ -126,21 +126,21 @@ def load_lora(lora, to_load): if w_norm is not None: loaded_keys.add(w_norm_name) - patch_dict[to_load[x]] = (w_norm,) + patch_dict[to_load[x]] = ("diff", (w_norm,)) if b_norm is not None: loaded_keys.add(b_norm_name) - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,)) diff_name = "{}.diff".format(x) diff_weight = lora.get(diff_name, None) if diff_weight is not None: - patch_dict[to_load[x]] = (diff_weight,) + patch_dict[to_load[x]] = ("diff", (diff_weight,)) loaded_keys.add(diff_name) diff_bias_name = "{}.diff_b".format(x) diff_bias = lora.get(diff_bias_name, None) if diff_bias is not None: - patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) + patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,)) loaded_keys.add(diff_bias_name) for x in lora.keys(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a3cffc3be..d78cdfd4d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -217,13 +217,19 @@ class ModelPatcher: v = (self.calculate_weight(v[1:], v[0].clone(), key), ) if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": w1 = v[0] if alpha != 0.0: if w1.shape != weight.shape: print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype) - elif len(v) == 4: #lora/locon + elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) if v[2] is not None: @@ -237,7 +243,7 @@ class ModelPatcher: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) - elif len(v) == 8: #lokr + elif patch_type == "lokr": w1 = v[0] w2 = v[1] w1_a = v[3] @@ -276,7 +282,7 @@ class ModelPatcher: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) - else: #loha + elif patch_type == "loha": w1a = v[0] w1b = v[1] if v[2] is not None: @@ -305,6 +311,8 @@ class ModelPatcher: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) + else: + print("patch type not recognized", patch_type, key) return weight From 614b7e731f7f9fdcf11eeb46e0623b0977a7e634 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 9 Dec 2023 18:15:26 -0500 Subject: [PATCH 27/98] Implement GLora. --- comfy/lora.py | 11 +++++++++++ comfy/model_patcher.py | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index ecd518084..5e4009b47 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -118,6 +118,17 @@ def load_lora(lora, to_load): if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) + #glora + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) w_norm_name = "{}.w_norm".format(x) b_norm_name = "{}.b_norm".format(x) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d78cdfd4d..55ca913ec 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -311,6 +311,16 @@ class ModelPatcher: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) except Exception as e: print("ERROR", key, e) + elif patch_type == "glora": + if v[4] is not None: + alpha *= v[4] / v[0].shape[0] + + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32) + + weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) else: print("patch type not recognized", patch_type, key) From 340177e6e85d076ab9e222e4f3c6a22f1fb4031f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 10 Dec 2023 01:30:35 -0500 Subject: [PATCH 28/98] Disable non blocking on mps. --- comfy/model_management.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index ef9bec545..0c51eee51 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -553,15 +553,19 @@ def cast_to_device(tensor, device, dtype, copy=False): elif is_intel_xpu(): device_supports_cast = True + non_blocking = True + if is_device_mps(device): + non_blocking = False #pytorch bug? mps doesn't support non blocking + if device_supports_cast: if copy: if tensor.device == device: - return tensor.to(dtype, copy=copy, non_blocking=True) - return tensor.to(device, copy=copy, non_blocking=True).to(dtype, non_blocking=True) + return tensor.to(dtype, copy=copy, non_blocking=non_blocking) + return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) else: - return tensor.to(device, non_blocking=True).to(dtype, non_blocking=True) + return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) else: - return tensor.to(device, dtype, copy=copy, non_blocking=True) + return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) def xformers_enabled(): global directml_enabled From 69033081c50de94cbc2a4fce12900611da04b1e9 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Mon, 11 Dec 2023 00:24:16 +0900 Subject: [PATCH 29/98] mask editor bugfix - Addressing the issue where an unnecessary hidden panel disrupts the drawing. --- web/extensions/core/maskeditor.js | 6 ------ 1 file changed, 6 deletions(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 1ea4dbcaa..bb2f16d42 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -167,10 +167,6 @@ class MaskEditorDialog extends ComfyDialog { // If it is specified as relative, using it only as a hidden placeholder for padding is recommended // to prevent anomalies where it exceeds a certain size and goes outside of the window. - var placeholder = document.createElement("div"); - placeholder.style.position = "relative"; - placeholder.style.height = "50px"; - var bottom_panel = document.createElement("div"); bottom_panel.style.position = "absolute"; bottom_panel.style.bottom = "0px"; @@ -192,7 +188,6 @@ class MaskEditorDialog extends ComfyDialog { this.brush = brush; this.element.appendChild(imgCanvas); this.element.appendChild(maskCanvas); - this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button this.element.appendChild(bottom_panel); document.body.appendChild(brush); @@ -218,7 +213,6 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(imgCanvas); this.element.appendChild(maskCanvas); - this.element.appendChild(placeholder); // must below z-index than bottom_panel to avoid covering button this.element.appendChild(bottom_panel); bottom_panel.appendChild(clearButton); From 57926635e8d84ae9eea4a0416cc75e363f5ede45 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 10 Dec 2023 23:00:54 -0500 Subject: [PATCH 30/98] Switch text encoder to manual cast. Use fp16 text encoder weights for CPU inference to lower memory usage. --- comfy/model_management.py | 3 +++ comfy/ops.py | 33 +++++++++++++++++++++++++ comfy/sd1_clip.py | 52 +++++++++++++++++---------------------- 3 files changed, 59 insertions(+), 29 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0c51eee51..a6c8fb352 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -503,6 +503,9 @@ def text_encoder_dtype(device=None): elif args.fp32_text_enc: return torch.float32 + if is_device_cpu(device): + return torch.float16 + if should_use_fp16(device, prioritize_performance=False): return torch.float16 else: diff --git a/comfy/ops.py b/comfy/ops.py index deb849d63..e48568409 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs): else: raise ValueError(f"unsupported dimensions: {dims}") +def cast_bias_weight(s, input): + bias = None + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype) + weight = s.weight.to(device=input.device, dtype=input.dtype) + return weight, bias + +class manual_cast: + class Linear(Linear): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + class Conv2d(Conv2d): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + class Conv3d(Conv3d): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + class GroupNorm(GroupNorm): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + class LayerNorm(LayerNorm): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + @contextmanager def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way old_torch_nn_linear = torch.nn.Linear diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4530168ab..6ffef515e 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -78,7 +78,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): with open(textmodel_json_config) as f: config = json.load(f) - self.transformer = model_class(config, dtype, device, comfy.ops) + self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast) self.num_layers = self.transformer.num_layers self.max_length = max_length @@ -160,37 +160,31 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if self.transformer.dtype != torch.float32: - precision_scope = torch.autocast + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + self.transformer.set_input_embeddings(backup_embeds) + + if self.layer == "last": + z = outputs[0] else: - precision_scope = lambda a, dtype: contextlib.nullcontext(a) + z = outputs[1] - with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): - attention_mask = None - if self.enable_attention_masks: - attention_mask = torch.zeros_like(tokens) - max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 - for x in range(attention_mask.shape[0]): - for y in range(attention_mask.shape[1]): - attention_mask[x, y] = 1 - if tokens[x, y] == max_token: - break + if outputs[2] is not None: + pooled_output = outputs[2].float() + else: + pooled_output = None - outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) - self.transformer.set_input_embeddings(backup_embeds) - - if self.layer == "last": - z = outputs[0] - else: - z = outputs[1] - - if outputs[2] is not None: - pooled_output = outputs[2].float() - else: - pooled_output = None - - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + if self.text_projection is not None and pooled_output is not None: + pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() return z.float(), pooled_output def encode(self, tokens): From ab93abd4b2eaf99d4a52f9a036600d9d46355d92 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 11 Dec 2023 17:33:35 +0000 Subject: [PATCH 31/98] Prevent cleaning graph state on undo/redo (#2255) * Prevent cleaning graph state on undo/redo * Remove pause rendering due to LG bug --- web/extensions/core/undoRedo.js | 25 +++++++++++-------------- web/scripts/app.js | 7 +++++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js index c6613b0f0..3cb137520 100644 --- a/web/extensions/core/undoRedo.js +++ b/web/extensions/core/undoRedo.js @@ -71,24 +71,21 @@ function graphEqual(a, b, root = true) { } const undoRedo = async (e) => { + const updateState = async (source, target) => { + const prevState = source.pop(); + if (prevState) { + target.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState, false); + activeState = prevState; + } + } if (e.ctrlKey || e.metaKey) { if (e.key === "y") { - const prevState = redo.pop(); - if (prevState) { - undo.push(activeState); - isOurLoad = true; - await app.loadGraphData(prevState); - activeState = prevState; - } + updateState(redo, undo); return true; } else if (e.key === "z") { - const prevState = undo.pop(); - if (prevState) { - redo.push(activeState); - isOurLoad = true; - await app.loadGraphData(prevState); - activeState = prevState; - } + updateState(undo, redo); return true; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index 5faf41fb3..d2a6f4de4 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1559,9 +1559,12 @@ export class ComfyApp { /** * Populates the graph with the specified workflow data * @param {*} graphData A serialized graph object + * @param { boolean } clean If the graph state, e.g. images, should be cleared */ - async loadGraphData(graphData) { - this.clean(); + async loadGraphData(graphData, clean = true) { + if (clean !== false) { + this.clean(); + } let reset_invalid_values = false; if (!graphData) { From ba07cb748e4793a6393288d621aa8e2f0f282595 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Dec 2023 18:24:44 -0500 Subject: [PATCH 32/98] Use faster manual cast for fp8 in unet. --- comfy/model_base.py | 19 ++++++++++--------- comfy/model_management.py | 16 +++++++++++++++- comfy/ops.py | 9 +++++++++ comfy/sd.py | 12 ++++++++++-- comfy/supported_models_base.py | 4 ++++ 5 files changed, 48 insertions(+), 12 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 5bfcc391d..bab7b9b34 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,6 +4,7 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import comfy.model_management import comfy.conds +import comfy.ops from enum import Enum import contextlib from . import utils @@ -41,9 +42,14 @@ class BaseModel(torch.nn.Module): unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config + self.manual_cast_dtype = model_config.manual_cast_dtype if not unet_config.get("disable_unet_model_creation", False): - self.diffusion_model = UNetModel(**unet_config, device=device) + if self.manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops + self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -63,11 +69,8 @@ class BaseModel(torch.nn.Module): context = c_crossattn dtype = self.get_dtype() - if comfy.model_management.supports_dtype(xc.device, dtype): - precision_scope = lambda a: contextlib.nullcontext(a) - else: - precision_scope = torch.autocast - dtype = torch.float32 + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype xc = xc.to(dtype) t = self.model_sampling.timestep(t).float() @@ -79,9 +82,7 @@ class BaseModel(torch.nn.Module): extra = extra.to(dtype) extra_conds[o] = extra - with precision_scope(comfy.model_management.get_autocast_device(xc.device)): - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() - + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x) def get_dtype(self): diff --git a/comfy/model_management.py b/comfy/model_management.py index a6c8fb352..fe0374a8b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -474,6 +474,20 @@ def unet_dtype(device=None, model_params=0): return torch.float16 return torch.float32 +# None means no manual cast +def unet_manual_cast(weight_dtype, inference_device): + if weight_dtype == torch.float32: + return None + + fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False) + if fp16_supported and weight_dtype == torch.float16: + return None + + if fp16_supported: + return torch.float16 + else: + return torch.float32 + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() @@ -538,7 +552,7 @@ def get_autocast_device(dev): def supports_dtype(device, dtype): #TODO if dtype == torch.float32: return True - if torch.device("cpu") == device: + if is_device_cpu(device): return False if dtype == torch.float16: return True diff --git a/comfy/ops.py b/comfy/ops.py index e48568409..a67bc809f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -62,6 +62,15 @@ class manual_cast: weight, bias = cast_bias_weight(self, input) return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + @classmethod + def conv_nd(s, dims, *args, **kwargs): + if dims == 2: + return s.Conv2d(*args, **kwargs) + elif dims == 3: + return s.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + @contextmanager def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way old_torch_nn_linear = torch.nn.Linear diff --git a/comfy/sd.py b/comfy/sd.py index 43e201d36..8c056e4ea 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -433,11 +433,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) class WeightsLoader(torch.nn.Module): pass model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) + model_config.set_manual_cast(manual_cast_dtype) + if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -470,7 +474,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o print("left over keys:", left_over) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) @@ -481,6 +485,9 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + if "input_blocks.0.0.weight" in sd: #ldm model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) if model_config is None: @@ -501,13 +508,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format else: print(diffusers_keys[k], k) offload_device = model_management.unet_offload_device() + model_config.set_manual_cast(manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: print("left over keys in unet:", left_over) - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) def load_unet(unet_path): sd = comfy.utils.load_torch_file(unet_path) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3412cfea0..49087d23e 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -22,6 +22,8 @@ class BASE: sampling_settings = {} latent_format = latent_formats.LatentFormat + manual_cast_dtype = None + @classmethod def matches(s, unet_config): for k in s.unet_config: @@ -71,3 +73,5 @@ class BASE: replace_prefix = {"": "first_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def set_manual_cast(self, manual_cast_dtype): + self.manual_cast_dtype = manual_cast_dtype From b0aab1e4ea3dfefe09c4f07de0e5237558097e22 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Dec 2023 18:36:29 -0500 Subject: [PATCH 33/98] Add an option --fp16-unet to force using fp16 for the unet. --- comfy/cli_args.py | 1 + comfy/model_management.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 58d034802..d9c8668f4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -57,6 +57,7 @@ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fpunet_group = parser.add_mutually_exclusive_group() fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.") fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") diff --git a/comfy/model_management.py b/comfy/model_management.py index fe0374a8b..b6a9471bf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -466,6 +466,8 @@ def unet_inital_load_device(parameters, dtype): def unet_dtype(device=None, model_params=0): if args.bf16_unet: return torch.bfloat16 + if args.fp16_unet: + return torch.float16 if args.fp8_e4m3fn_unet: return torch.float8_e4m3fn if args.fp8_e5m2_unet: From 77755ab8dbc74f3f231aa817590401d7969f96a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Dec 2023 23:27:13 -0500 Subject: [PATCH 34/98] Refactor comfy.ops comfy.ops -> comfy.ops.disable_weight_init This should make it more clear what they actually do. Some unused code has also been removed. --- comfy/cldm/cldm.py | 2 +- comfy/clip_vision.py | 2 +- comfy/controlnet.py | 27 ++-- comfy/ldm/modules/attention.py | 13 +- comfy/ldm/modules/diffusionmodules/model.py | 39 +++--- .../modules/diffusionmodules/openaimodel.py | 14 +-- comfy/ldm/modules/diffusionmodules/util.py | 41 ------ comfy/ldm/modules/temporal_ae.py | 5 +- comfy/model_base.py | 2 +- comfy/ops.py | 119 +++++++----------- 10 files changed, 94 insertions(+), 170 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index bbe5891e6..00373a790 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -53,7 +53,7 @@ class ControlNet(nn.Module): transformer_depth_middle=None, transformer_depth_output=None, device=None, - operations=comfy.ops, + operations=comfy.ops.disable_weight_init, **kwargs, ): super().__init__() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index ae87c75b4..ba8a3a8d5 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -38,7 +38,7 @@ class ClipVisionModel(): if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): self.dtype = torch.float16 - self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops) + self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.disable_weight_init) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6d37aa74f..3212ac8c4 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -208,9 +208,9 @@ class ControlLoraOps: def forward(self, input): if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) else: - return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) + return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias) class Conv2d(torch.nn.Module): def __init__( @@ -247,24 +247,9 @@ class ControlLoraOps: def forward(self, input): if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) - - def conv_nd(self, dims, *args, **kwargs): - if dims == 2: - return self.Conv2d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - - class Conv3d(comfy.ops.Conv3d): - pass - - class GroupNorm(comfy.ops.GroupNorm): - pass - - class LayerNorm(comfy.ops.LayerNorm): - pass + return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) class ControlLora(ControlNet): @@ -278,7 +263,9 @@ class ControlLora(ControlNet): controlnet_config = model.model_config.unet_config.copy() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - controlnet_config["operations"] = ControlLoraOps() + class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): + pass + controlnet_config["operations"] = control_lora_ops self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) dtype = model.get_dtype() self.control_model.to(dtype) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 8299b1d94..8d86aa53d 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -19,6 +19,7 @@ if model_management.xformers_enabled(): from comfy.cli_args import args import comfy.ops +ops = comfy.ops.disable_weight_init # CrossAttn precision handling if args.dont_upcast_attention: @@ -55,7 +56,7 @@ def init_(tensor): # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops): super().__init__() self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) @@ -65,7 +66,7 @@ class GEGLU(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) @@ -356,7 +357,7 @@ def optimized_attention_for_device(device, mask=False): class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -389,7 +390,7 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, - disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops): + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops): super().__init__() self.ff_in = ff_in or inner_dim is not None @@ -558,7 +559,7 @@ class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, - use_checkpoint=True, dtype=None, device=None, operations=comfy.ops): + use_checkpoint=True, dtype=None, device=None, operations=ops): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] * depth @@ -632,7 +633,7 @@ class SpatialVideoTransformer(SpatialTransformer): disable_self_attn=False, disable_temporal_crossattention=False, max_time_embed_period: int = 10000, - dtype=None, device=None, operations=comfy.ops + dtype=None, device=None, operations=ops ): super().__init__( in_channels, diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index f23417fd2..fce29cb85 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -8,6 +8,7 @@ from typing import Optional, Any from comfy import model_management import comfy.ops +ops = comfy.ops.disable_weight_init if model_management.xformers_enabled_vae(): import xformers @@ -48,7 +49,7 @@ class Upsample(nn.Module): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = comfy.ops.Conv2d(in_channels, + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, @@ -78,7 +79,7 @@ class Downsample(nn.Module): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = comfy.ops.Conv2d(in_channels, + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, @@ -105,30 +106,30 @@ class ResnetBlock(nn.Module): self.swish = torch.nn.SiLU(inplace=True) self.norm1 = Normalize(in_channels) - self.conv1 = comfy.ops.Conv2d(in_channels, + self.conv1 = ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: - self.temb_proj = comfy.ops.Linear(temb_channels, + self.temb_proj = ops.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) - self.conv2 = comfy.ops.Conv2d(out_channels, + self.conv2 = ops.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = comfy.ops.Conv2d(in_channels, + self.conv_shortcut = ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = comfy.ops.Conv2d(in_channels, + self.nin_shortcut = ops.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, @@ -245,22 +246,22 @@ class AttnBlock(nn.Module): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = comfy.ops.Conv2d(in_channels, + self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = comfy.ops.Conv2d(in_channels, + self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = comfy.ops.Conv2d(in_channels, + self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = comfy.ops.Conv2d(in_channels, + self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, @@ -312,14 +313,14 @@ class Model(nn.Module): # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList([ - comfy.ops.Linear(self.ch, + ops.Linear(self.ch, self.temb_ch), - comfy.ops.Linear(self.temb_ch, + ops.Linear(self.temb_ch, self.temb_ch), ]) # downsampling - self.conv_in = comfy.ops.Conv2d(in_channels, + self.conv_in = ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -388,7 +389,7 @@ class Model(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = comfy.ops.Conv2d(block_in, + self.conv_out = ops.Conv2d(block_in, out_ch, kernel_size=3, stride=1, @@ -461,7 +462,7 @@ class Encoder(nn.Module): self.in_channels = in_channels # downsampling - self.conv_in = comfy.ops.Conv2d(in_channels, + self.conv_in = ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -506,7 +507,7 @@ class Encoder(nn.Module): # end self.norm_out = Normalize(block_in) - self.conv_out = comfy.ops.Conv2d(block_in, + self.conv_out = ops.Conv2d(block_in, 2*z_channels if double_z else z_channels, kernel_size=3, stride=1, @@ -541,7 +542,7 @@ class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - conv_out_op=comfy.ops.Conv2d, + conv_out_op=ops.Conv2d, resnet_op=ResnetBlock, attn_op=AttnBlock, **ignorekwargs): @@ -565,7 +566,7 @@ class Decoder(nn.Module): self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = comfy.ops.Conv2d(z_channels, + self.conv_in = ops.Conv2d(z_channels, block_in, kernel_size=3, stride=1, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 12efd833c..057dd16b2 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -12,13 +12,13 @@ from .util import ( checkpoint, avg_pool_nd, zero_module, - normalization, timestep_embedding, AlphaBlender, ) from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists import comfy.ops +ops = comfy.ops.disable_weight_init class TimestepBlock(nn.Module): """ @@ -70,7 +70,7 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -106,7 +106,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -159,7 +159,7 @@ class ResBlock(TimestepBlock): skip_t_emb=False, dtype=None, device=None, - operations=comfy.ops + operations=ops ): super().__init__() self.channels = channels @@ -284,7 +284,7 @@ class VideoResBlock(ResBlock): down: bool = False, dtype=None, device=None, - operations=comfy.ops + operations=ops ): super().__init__( channels, @@ -434,7 +434,7 @@ class UNetModel(nn.Module): disable_temporal_crossattention=False, max_ddpm_temb_period=10000, device=None, - operations=comfy.ops, + operations=ops, ): super().__init__() assert use_spatial_transformer == True, "use_spatial_transformer has to be true" @@ -581,7 +581,7 @@ class UNetModel(nn.Module): up=False, dtype=None, device=None, - operations=comfy.ops + operations=ops ): if self.use_temporal_resblocks: return VideoResBlock( diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 704bbe574..68175b62a 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -16,7 +16,6 @@ import numpy as np from einops import repeat, rearrange from comfy.ldm.util import instantiate_from_config -import comfy.ops class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] @@ -273,46 +272,6 @@ def mean_flat(tensor): return tensor.mean(dim=list(range(1, len(tensor.shape)))) -def normalization(channels, dtype=None): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels, dtype=dtype) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return comfy.ops.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return comfy.ops.Linear(*args, **kwargs) - - def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py index 11ae049f3..7ea68dc9e 100644 --- a/comfy/ldm/modules/temporal_ae.py +++ b/comfy/ldm/modules/temporal_ae.py @@ -5,6 +5,7 @@ import torch from einops import rearrange, repeat import comfy.ops +ops = comfy.ops.disable_weight_init from .diffusionmodules.model import ( AttnBlock, @@ -130,9 +131,9 @@ class AttnVideoBlock(AttnBlock): time_embed_dim = self.in_channels * 4 self.video_time_embed = torch.nn.Sequential( - comfy.ops.Linear(self.in_channels, time_embed_dim), + ops.Linear(self.in_channels, time_embed_dim), torch.nn.SiLU(), - comfy.ops.Linear(time_embed_dim, self.in_channels), + ops.Linear(time_embed_dim, self.in_channels), ) self.merge_strategy = merge_strategy diff --git a/comfy/model_base.py b/comfy/model_base.py index bab7b9b34..412c83792 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -48,7 +48,7 @@ class BaseModel(torch.nn.Module): if self.manual_cast_dtype is not None: operations = comfy.ops.manual_cast else: - operations = comfy.ops + operations = comfy.ops.disable_weight_init self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) diff --git a/comfy/ops.py b/comfy/ops.py index a67bc809f..08c633847 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,66 +1,26 @@ import torch from contextlib import contextmanager -class Linear(torch.nn.Linear): - def reset_parameters(self): - return None +class disable_weight_init: + class Linear(torch.nn.Linear): + def reset_parameters(self): + return None -class Conv2d(torch.nn.Conv2d): - def reset_parameters(self): - return None + class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None -class Conv3d(torch.nn.Conv3d): - def reset_parameters(self): - return None + class Conv3d(torch.nn.Conv3d): + def reset_parameters(self): + return None -class GroupNorm(torch.nn.GroupNorm): - def reset_parameters(self): - return None + class GroupNorm(torch.nn.GroupNorm): + def reset_parameters(self): + return None -class LayerNorm(torch.nn.LayerNorm): - def reset_parameters(self): - return None - -def conv_nd(dims, *args, **kwargs): - if dims == 2: - return Conv2d(*args, **kwargs) - elif dims == 3: - return Conv3d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - -def cast_bias_weight(s, input): - bias = None - if s.bias is not None: - bias = s.bias.to(device=input.device, dtype=input.dtype) - weight = s.weight.to(device=input.device, dtype=input.dtype) - return weight, bias - -class manual_cast: - class Linear(Linear): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) - - class Conv2d(Conv2d): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) - - class Conv3d(Conv3d): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) - - class GroupNorm(GroupNorm): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) - - class LayerNorm(LayerNorm): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + class LayerNorm(torch.nn.LayerNorm): + def reset_parameters(self): + return None @classmethod def conv_nd(s, dims, *args, **kwargs): @@ -71,20 +31,35 @@ class manual_cast: else: raise ValueError(f"unsupported dimensions: {dims}") -@contextmanager -def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way - old_torch_nn_linear = torch.nn.Linear - force_device = device - force_dtype = dtype - def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): - if force_device is not None: - device = force_device - if force_dtype is not None: - dtype = force_dtype - return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) +def cast_bias_weight(s, input): + bias = None + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype) + weight = s.weight.to(device=input.device, dtype=input.dtype) + return weight, bias - torch.nn.Linear = linear_with_dtype - try: - yield - finally: - torch.nn.Linear = old_torch_nn_linear +class manual_cast(disable_weight_init): + class Linear(disable_weight_init.Linear): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + class Conv2d(disable_weight_init.Conv2d): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + class Conv3d(disable_weight_init.Conv3d): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + class GroupNorm(disable_weight_init.GroupNorm): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + class LayerNorm(disable_weight_init.LayerNorm): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) From 3152023fbc4f8ee6598a863314ca98d48ea9c2e6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Dec 2023 23:50:38 -0500 Subject: [PATCH 35/98] Use inference dtype for unet memory usage estimation. --- comfy/model_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 412c83792..a7582b330 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -177,9 +177,12 @@ class BaseModel(torch.nn.Module): def memory_required(self, input_shape): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): + dtype = self.get_dtype() + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype #TODO: this needs to be tweaked area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024) + return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * input_shape[2] * input_shape[3] From 32b7e7e769c206a06bf6e10ad2ddb6af9a378f56 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Dec 2023 03:32:23 -0500 Subject: [PATCH 36/98] Add manual cast to controlnet. --- comfy/cldm/cldm.py | 28 +++++++++++------------ comfy/controlnet.py | 56 ++++++++++++++++++++++++++------------------- 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 00373a790..5eee5a51c 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -141,24 +141,24 @@ class ControlNet(nn.Module): ) ] ) - self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)]) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)]) self.input_hint_block = TimestepEmbedSequential( - operations.conv_nd(dims, hint_channels, 16, 3, padding=1), + operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 16, 3, padding=1), + operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2), + operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 32, 3, padding=1), + operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2), + operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 96, 3, padding=1), + operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2), + operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1)) + operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) self._feature_size = model_channels @@ -206,7 +206,7 @@ class ControlNet(nn.Module): ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: @@ -234,7 +234,7 @@ class ControlNet(nn.Module): ) ch = out_ch input_block_chans.append(ch) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) ds *= 2 self._feature_size += ch @@ -276,11 +276,11 @@ class ControlNet(nn.Module): operations=operations )] self.middle_block = TimestepEmbedSequential(*mid_block) - self.middle_block_out = self.make_zero_conv(ch, operations=operations) + self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self._feature_size += ch - def make_zero_conv(self, channels, operations=None): - return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) + def make_zero_conv(self, channels, operations=None, dtype=None, device=None): + return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) def forward(self, x, hint, timesteps, context, y=None, **kwargs): t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 3212ac8c4..110b5c7c2 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -36,13 +36,13 @@ class ControlBase: self.cond_hint = None self.strength = 1.0 self.timestep_percent_range = (0.0, 1.0) + self.global_average_pooling = False self.timestep_range = None if device is None: device = comfy.model_management.get_torch_device() self.device = device self.previous_controlnet = None - self.global_average_pooling = False def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): self.cond_hint_original = cond_hint @@ -77,6 +77,7 @@ class ControlBase: c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range + c.global_average_pooling = self.global_average_pooling def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -129,12 +130,14 @@ class ControlBase: return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + self.load_device = load_device + self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling self.model_sampling_current = None + self.manual_cast_dtype = manual_cast_dtype def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -149,11 +152,8 @@ class ControlNet(ControlBase): return None dtype = self.control_model.dtype - if comfy.model_management.supports_dtype(self.device, dtype): - precision_scope = lambda a: contextlib.nullcontext(a) - else: - precision_scope = torch.autocast - dtype = torch.float32 + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: @@ -171,12 +171,11 @@ class ControlNet(ControlBase): timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - with precision_scope(comfy.model_management.get_autocast_device(self.device)): - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) self.copy_to(c) return c @@ -207,10 +206,11 @@ class ControlLoraOps: self.bias = None def forward(self, input): + weight, bias = comfy.ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias) + return torch.nn.functional.linear(input, weight, bias) class Conv2d(torch.nn.Module): def __init__( @@ -246,10 +246,11 @@ class ControlLoraOps: def forward(self, input): + weight, bias = comfy.ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) class ControlLora(ControlNet): @@ -263,12 +264,19 @@ class ControlLora(ControlNet): controlnet_config = model.model_config.unet_config.copy() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): - pass - controlnet_config["operations"] = control_lora_ops - self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + self.manual_cast_dtype = model.manual_cast_dtype dtype = model.get_dtype() - self.control_model.to(dtype) + if self.manual_cast_dtype is None: + class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): + pass + else: + class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast): + pass + dtype = self.manual_cast_dtype + + controlnet_config["operations"] = control_lora_ops + controlnet_config["dtype"] = dtype + self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) self.control_model.to(comfy.model_management.get_torch_device()) diffusion_model = model.diffusion_model sd = diffusion_model.state_dict() @@ -372,6 +380,10 @@ def load_controlnet(ckpt_path, model=None): if controlnet_config is None: unet_dtype = comfy.model_management.unet_dtype() controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + load_device = comfy.model_management.get_torch_device() + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + controlnet_config["operations"] = comfy.ops.manual_cast controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) @@ -400,14 +412,12 @@ def load_controlnet(ckpt_path, model=None): missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) print(missing, unexpected) - control_model = control_model.to(unet_dtype) - global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True - control = ControlNet(control_model, global_average_pooling=global_average_pooling) + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control class T2IAdapter(ControlBase): From 824e4935f53fdbda8f4608f511b4c2e8daf79dfa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Dec 2023 12:03:29 -0500 Subject: [PATCH 37/98] Add dtype parameter to VAE object. --- comfy/sd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 8c056e4ea..220637a05 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -151,7 +151,7 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None): + def __init__(self, sd=None, device=None, config=None, dtype=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) @@ -188,7 +188,9 @@ class VAE: device = model_management.vae_device() self.device = device offload_device = model_management.vae_offload_device() - self.vae_dtype = model_management.vae_dtype() + if dtype is None: + dtype = model_management.vae_dtype() + self.vae_dtype = dtype self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() From b454a67bb964fc20bac0354d009c1c811a289d89 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Dec 2023 19:09:53 -0500 Subject: [PATCH 38/98] Support segmind vega model. --- comfy/model_detection.py | 8 +++++++- comfy/supported_models.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index c682c3e1a..e3af422a3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -289,7 +289,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype): 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] + Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} + + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega] for unet_config in supported_models: matches = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 455323b96..2f2dee871 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -217,6 +217,16 @@ class SSD1B(SDXL): "use_temporal_attention": False, } +class Segmind_Vega(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 1, 1, 2, 2], + "context_dim": 2048, + "adm_in_channels": 2816, + "use_temporal_attention": False, + } + class SVD_img2vid(supported_models_base.BASE): unet_config = { "model_channels": 320, @@ -242,5 +252,5 @@ class SVD_img2vid(supported_models_base.BASE): def clip_target(self): return None -models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] +models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega] models += [SVD_img2vid] From 390078904c791c7c66c08478ed3d657b42ba7888 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 13 Dec 2023 05:56:39 +0000 Subject: [PATCH 39/98] Group node fixes (#2259) * Prevent cleaning graph state on undo/redo * Remove pause rendering due to LG bug * Fix crash on disconnected internal reroutes * Fix widget inputs being incorrect order and value * Fix initial primitive values on connect * basic support for basic rerouted converted inputs * Populate primitive to reroute input * dont crash on bad primitive links * Fix convert to group changing control value * reduce restrictions * fix random crash in tests --- tests-ui/tests/groupNode.test.js | 134 ++++++++++++++++++++++++++-- tests-ui/utils/ezgraph.js | 8 ++ tests-ui/utils/index.js | 9 ++ web/extensions/core/groupNode.js | 127 +++++++++++++++++++++----- web/extensions/core/rerouteNode.js | 1 + web/extensions/core/widgetInputs.js | 19 +++- web/scripts/app.js | 4 +- 7 files changed, 275 insertions(+), 27 deletions(-) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index dc9d4bd49..625890a09 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -1,7 +1,7 @@ // @ts-check /// -const { start, createDefaultWorkflow } = require("../utils"); +const { start, createDefaultWorkflow, getNodeDef, checkBeforeAndAfterReload } = require("../utils"); const lg = require("../utils/litegraph"); describe("group node", () => { @@ -273,7 +273,7 @@ describe("group node", () => { let reroutes = []; let prevNode = nodes.ckpt; - for(let i = 0; i < 5; i++) { + for (let i = 0; i < 5; i++) { const reroute = ez.Reroute(); prevNode.outputs[0].connectTo(reroute.inputs[0]); prevNode = reroute; @@ -283,7 +283,7 @@ describe("group node", () => { const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]); expect((await graph.toPrompt()).output).toEqual(getOutput()); - + group.menu["Convert to nodes"].call(); expect((await graph.toPrompt()).output).toEqual(getOutput()); }); @@ -407,12 +407,18 @@ describe("group node", () => { const decode = ez.VAEDecode(group2.outputs.LATENT, group2.outputs.VAE); const preview = ez.PreviewImage(decode.outputs[0]); - expect((await graph.toPrompt()).output).toEqual({ + const output = { [latent.id]: { inputs: { width: 512, height: 512, batch_size: 1 }, class_type: "EmptyLatentImage" }, [vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, [decode.id]: { inputs: { samples: [latent.id + "", 0], vae: [vae.id + "", 0] }, class_type: "VAEDecode" }, [preview.id]: { inputs: { images: [decode.id + "", 0] }, class_type: "PreviewImage" }, - }); + }; + expect((await graph.toPrompt()).output).toEqual(output); + + // Ensure missing connections dont cause errors + group2.inputs.VAE.disconnect(); + delete output[decode.id].inputs.vae; + expect((await graph.toPrompt()).output).toEqual(output); }); test("displays generated image on group node", async () => { const { ez, graph, app } = await start(); @@ -673,6 +679,55 @@ describe("group node", () => { 2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" }, }); }); + test("correctly handles widget inputs", async () => { + const { ez, graph, app } = await start(); + const upscaleMethods = (await getNodeDef("ImageScaleBy")).input.required["upscale_method"][0]; + + const image = ez.LoadImage(); + const scale1 = ez.ImageScaleBy(image.outputs[0]); + const scale2 = ez.ImageScaleBy(image.outputs[0]); + const preview1 = ez.PreviewImage(scale1.outputs[0]); + const preview2 = ez.PreviewImage(scale2.outputs[0]); + scale1.widgets.upscale_method.value = upscaleMethods[1]; + scale1.widgets.upscale_method.convertToInput(); + + const group = await convertToGroup(app, graph, "test", [scale1, scale2]); + expect(group.inputs.length).toBe(3); + expect(group.inputs[0].input.type).toBe("IMAGE"); + expect(group.inputs[1].input.type).toBe("IMAGE"); + expect(group.inputs[2].input.type).toBe("COMBO"); + + // Ensure links are maintained + expect(group.inputs[0].connection?.originNode?.id).toBe(image.id); + expect(group.inputs[1].connection?.originNode?.id).toBe(image.id); + expect(group.inputs[2].connection).toBeFalsy(); + + // Ensure primitive gets correct type + const primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(group.inputs[2]); + expect(primitive.widgets.value.widget.options.values).toBe(upscaleMethods); + expect(primitive.widgets.value.value).toBe(upscaleMethods[1]); // Ensure value is copied + primitive.widgets.value.value = upscaleMethods[1]; + + await checkBeforeAndAfterReload(graph, async (r) => { + const scale1id = r ? `${group.id}:0` : scale1.id; + const scale2id = r ? `${group.id}:1` : scale2.id; + // Ensure widget value is applied to prompt + expect((await graph.toPrompt()).output).toStrictEqual({ + [image.id]: { inputs: { image: "example.png", upload: "image" }, class_type: "LoadImage" }, + [scale1id]: { + inputs: { upscale_method: upscaleMethods[1], scale_by: 1, image: [`${image.id}`, 0] }, + class_type: "ImageScaleBy", + }, + [scale2id]: { + inputs: { upscale_method: "nearest-exact", scale_by: 1, image: [`${image.id}`, 0] }, + class_type: "ImageScaleBy", + }, + [preview1.id]: { inputs: { images: [`${scale1id}`, 0] }, class_type: "PreviewImage" }, + [preview2.id]: { inputs: { images: [`${scale2id}`, 0] }, class_type: "PreviewImage" }, + }); + }); + }); test("adds widgets in node execution order", async () => { const { ez, graph, app } = await start(); const scale = ez.LatentUpscale(); @@ -846,4 +901,73 @@ describe("group node", () => { expect(p2.widgets.control_after_generate.value).toBe("randomize"); expect(p2.widgets.control_filter_list.value).toBe("/.+/"); }); + test("internal reroutes work with converted inputs and merge options", async () => { + const { ez, graph, app } = await start(); + const vae = ez.VAELoader(); + const latent = ez.EmptyLatentImage(); + const decode = ez.VAEDecode(latent.outputs.LATENT, vae.outputs.VAE); + const scale = ez.ImageScale(decode.outputs.IMAGE); + ez.PreviewImage(scale.outputs.IMAGE); + + const r1 = ez.Reroute(); + const r2 = ez.Reroute(); + + latent.widgets.width.value = 64; + latent.widgets.height.value = 128; + + latent.widgets.width.convertToInput(); + latent.widgets.height.convertToInput(); + latent.widgets.batch_size.convertToInput(); + + scale.widgets.width.convertToInput(); + scale.widgets.height.convertToInput(); + + r1.inputs[0].input.label = "hbw"; + r1.outputs[0].connectTo(latent.inputs.height); + r1.outputs[0].connectTo(latent.inputs.batch_size); + r1.outputs[0].connectTo(scale.inputs.width); + + r2.inputs[0].input.label = "wh"; + r2.outputs[0].connectTo(latent.inputs.width); + r2.outputs[0].connectTo(scale.inputs.height); + + const group = await convertToGroup(app, graph, "test", [r1, r2, latent, decode, scale]); + + expect(group.inputs[0].input.type).toBe("VAE"); + expect(group.inputs[1].input.type).toBe("INT"); + expect(group.inputs[2].input.type).toBe("INT"); + + const p1 = ez.PrimitiveNode(); + const p2 = ez.PrimitiveNode(); + p1.outputs[0].connectTo(group.inputs[1]); + p2.outputs[0].connectTo(group.inputs[2]); + + expect(p1.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p1.widgets.value.widget.options?.max).toBe(4096); // batch max + expect(p1.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + expect(p2.widgets.value.widget.options?.min).toBe(16); // width/height min + expect(p2.widgets.value.widget.options?.max).toBe(8192); // width/height max + expect(p2.widgets.value.widget.options?.step).toBe(80); // width/height step * 10 + + expect(p1.widgets.value.value).toBe(128); + expect(p2.widgets.value.value).toBe(64); + + p1.widgets.value.value = 16; + p2.widgets.value.value = 32; + + await checkBeforeAndAfterReload(graph, async (r) => { + const id = (v) => (r ? `${group.id}:` : "") + v; + expect((await graph.toPrompt()).output).toStrictEqual({ + 1: { inputs: { vae_name: "vae1.safetensors" }, class_type: "VAELoader" }, + [id(2)]: { inputs: { width: 32, height: 16, batch_size: 16 }, class_type: "EmptyLatentImage" }, + [id(3)]: { inputs: { samples: [id(2), 0], vae: ["1", 0] }, class_type: "VAEDecode" }, + [id(4)]: { + inputs: { upscale_method: "nearest-exact", width: 16, height: 32, crop: "disabled", image: [id(3), 0] }, + class_type: "ImageScale", + }, + 5: { inputs: { images: [id(4), 0] }, class_type: "PreviewImage" }, + }); + }); + }); }); diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 3101aa292..8a55246ee 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -78,6 +78,14 @@ export class EzInput extends EzSlot { this.input = input; } + get connection() { + const link = this.node.node.inputs?.[this.index]?.link; + if (link == null) { + return null; + } + return new EzConnection(this.node.app, this.node.app.graph.links[link]); + } + disconnect() { this.node.node.disconnectInput(this.index); } diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index 3a018f566..6a08e8594 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -104,3 +104,12 @@ export function createDefaultWorkflow(ez, graph) { return { ckpt, pos, neg, empty, sampler, decode, save }; } + +export async function getNodeDefs() { + const { api } = require("../../web/scripts/api"); + return api.getNodeDefs(); +} + +export async function getNodeDef(nodeId) { + return (await getNodeDefs())[nodeId]; +} \ No newline at end of file diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 9a1d9b207..dc962ac24 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -174,6 +174,11 @@ export class GroupNodeConfig { node.index = i; this.processNode(node, seenInputs, seenOutputs); } + + for (const p of this.#convertedToProcess) { + p(); + } + this.#convertedToProcess = null; await app.registerNodeDef("workflow/" + this.name, this.nodeDef); } @@ -192,7 +197,10 @@ export class GroupNodeConfig { if (!this.linksFrom[sourceNodeId]) { this.linksFrom[sourceNodeId] = {}; } - this.linksFrom[sourceNodeId][sourceNodeSlot] = l; + if (!this.linksFrom[sourceNodeId][sourceNodeSlot]) { + this.linksFrom[sourceNodeId][sourceNodeSlot] = []; + } + this.linksFrom[sourceNodeId][sourceNodeSlot].push(l); if (!this.linksTo[targetNodeId]) { this.linksTo[targetNodeId] = {}; @@ -230,11 +238,11 @@ export class GroupNodeConfig { // Skip as its not linked if (!linksFrom) return; - let type = linksFrom["0"][5]; + let type = linksFrom["0"][0][5]; if (type === "COMBO") { // Use the array items const source = node.outputs[0].widget.name; - const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type; + const fromTypeName = this.nodeData.nodes[linksFrom["0"][0][2]].type; const fromType = globalDefs[fromTypeName]; const input = fromType.input.required[source] ?? fromType.input.optional[source]; type = input[0]; @@ -258,10 +266,33 @@ export class GroupNodeConfig { return null; } + let config = {}; let rerouteType = "*"; if (linksFrom) { - const [, , id, slot] = linksFrom["0"]; - rerouteType = this.nodeData.nodes[id].inputs[slot].type; + for (const [, , id, slot] of linksFrom["0"]) { + const node = this.nodeData.nodes[id]; + const input = node.inputs[slot]; + if (rerouteType === "*") { + rerouteType = input.type; + } + if (input.widget) { + const targetDef = globalDefs[node.type]; + const targetWidget = + targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; + + const widget = [targetWidget[0], config]; + const res = mergeIfValid( + { + widget, + }, + targetWidget, + false, + null, + widget + ); + config = res?.customConfig ?? config; + } + } } else if (linksTo) { const [id, slot] = linksTo["0"]; rerouteType = this.nodeData.nodes[id].outputs[slot].type; @@ -282,10 +313,11 @@ export class GroupNodeConfig { } } + config.forceInput = true; return { input: { required: { - [rerouteType]: [rerouteType, {}], + [rerouteType]: [rerouteType, config], }, }, output: [rerouteType], @@ -420,10 +452,18 @@ export class GroupNodeConfig { defaultInput: true, }); this.nodeDef.input.required[name] = config; + this.newToOldWidgetMap[name] = { node, inputName }; + + if (!this.oldToNewWidgetMap[node.index]) { + this.oldToNewWidgetMap[node.index] = {}; + } + this.oldToNewWidgetMap[node.index][inputName] = name; + inputMap[slots.length + i] = this.inputCount++; } } + #convertedToProcess = []; processNodeInputs(node, seenInputs, inputs) { const inputMapping = []; @@ -434,7 +474,11 @@ export class GroupNodeConfig { const linksTo = this.linksTo[node.index] ?? {}; const inputMap = (this.oldToNewInputMap[node.index] = {}); this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); - this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs); + + // Converted inputs have to be processed after all other nodes as they'll be at the end of the list + this.#convertedToProcess.push(() => + this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs) + ); return inputMapping; } @@ -597,11 +641,15 @@ export class GroupNodeHandler { const output = this.groupData.newToOldOutputMap[link.origin_slot]; let innerNode = this.innerNodes[output.node.index]; let l; - while (innerNode.type === "Reroute") { + while (innerNode?.type === "Reroute") { l = innerNode.getInputLink(0); innerNode = innerNode.getInputNode(0); } + if (!innerNode) { + return null; + } + if (l && GroupNodeHandler.isGroupNode(innerNode)) { return innerNode.updateLink(l); } @@ -669,6 +717,8 @@ export class GroupNodeHandler { top = newNode.pos[1]; } + if (!newNode.widgets) continue; + const map = this.groupData.oldToNewWidgetMap[innerNode.index]; if (map) { const widgets = Object.keys(map); @@ -725,7 +775,7 @@ export class GroupNodeHandler { } }; - const reconnectOutputs = () => { + const reconnectOutputs = (selectedIds) => { for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) { const output = node.outputs[groupOutputId]; if (!output.links) continue; @@ -865,7 +915,7 @@ export class GroupNodeHandler { if (innerNode.type === "PrimitiveNode") { innerNode.primitiveValue = newValue; const primitiveLinked = this.groupData.primitiveToWidget[old.node.index]; - for (const linked of primitiveLinked) { + for (const linked of primitiveLinked ?? []) { const node = this.innerNodes[linked.nodeId]; const widget = node.widgets.find((w) => w.name === linked.inputName); @@ -874,6 +924,18 @@ export class GroupNodeHandler { } } continue; + } else if (innerNode.type === "Reroute") { + const rerouteLinks = this.groupData.linksFrom[old.node.index]; + for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { + const node = this.innerNodes[targetNodeId]; + const input = node.inputs[targetSlot]; + if (input.widget) { + const widget = node.widgets?.find((w) => w.name === input.widget.name); + if (widget) { + widget.value = newValue; + } + } + } } const widget = innerNode.widgets?.find((w) => w.name === old.inputName); @@ -901,33 +963,58 @@ export class GroupNodeHandler { this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value; } } + return true; } + populateReroute(node, nodeId, map) { + if (node.type !== "Reroute") return; + + const link = this.groupData.linksFrom[nodeId]?.[0]?.[0]; + if (!link) return; + const [, , targetNodeId, targetNodeSlot] = link; + const targetNode = this.groupData.nodeData.nodes[targetNodeId]; + const inputs = targetNode.inputs; + const targetWidget = inputs?.[targetNodeSlot].widget; + if (!targetWidget) return; + + const offset = inputs.length - (targetNode.widgets_values?.length ?? 0); + const v = targetNode.widgets_values?.[targetNodeSlot - offset]; + if (v == null) return; + + const widgetName = Object.values(map)[0]; + const widget = this.node.widgets.find(w => w.name === widgetName); + if(widget) { + widget.value = v; + } + } + + populateWidgets() { + if (!this.node.widgets) return; + for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) { const node = this.groupData.nodeData.nodes[nodeId]; - - if (!node.widgets_values?.length) continue; - - const map = this.groupData.oldToNewWidgetMap[nodeId]; + const map = this.groupData.oldToNewWidgetMap[nodeId] ?? {}; const widgets = Object.keys(map); + if (!node.widgets_values?.length) { + // special handling for populating values into reroutes + // this allows primitives connect to them to pick up the correct value + this.populateReroute(node, nodeId, map); + continue; + } + let linkedShift = 0; for (let i = 0; i < widgets.length; i++) { const oldName = widgets[i]; const newName = map[oldName]; const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); const mainWidget = this.node.widgets[widgetIndex]; - if (!newName) { - // New name will be null if its a converted widget - this.populatePrimitive(node, nodeId, oldName, i, linkedShift); - + if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift)) { // Find the inner widget and shift by the number of linked widgets as they will have been removed too const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName); linkedShift += innerWidget.linkedWidgets?.length ?? 0; - continue; } - if (widgetIndex === -1) { continue; } diff --git a/web/extensions/core/rerouteNode.js b/web/extensions/core/rerouteNode.js index cfa952f3c..4feff91e5 100644 --- a/web/extensions/core/rerouteNode.js +++ b/web/extensions/core/rerouteNode.js @@ -54,6 +54,7 @@ app.registerExtension({ const linkId = currentNode.inputs[0].link; if (linkId !== null) { const link = app.graph.links[linkId]; + if (!link) return; const node = app.graph.getNodeById(link.origin_id); const type = node.constructor.type; if (type === "Reroute") { diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 865db4923..3f1c1f8c1 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -180,7 +180,7 @@ export function mergeIfValid(output, config2, forceUpdate, recreateWidget, confi const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; for (const k of keys.values()) { - if (k !== "default" && k !== "forceInput" && k !== "defaultInput") { + if (k !== "default" && k !== "forceInput" && k !== "defaultInput" && k !== "control_after_generate" && k !== "multiline") { let v1 = config1[1][k]; let v2 = config2[1]?.[k]; @@ -633,6 +633,14 @@ app.registerExtension({ } } + // Restore any saved control values + const controlValues = this.controlValues; + if(this.lastType === this.widgets[0].type && controlValues?.length === this.widgets.length - 1) { + for(let i = 0; i < controlValues.length; i++) { + this.widgets[i + 1].value = controlValues[i]; + } + } + // When our value changes, update other widgets to reflect our changes // e.g. so LoadImage shows correct image const callback = widget.callback; @@ -721,6 +729,15 @@ app.registerExtension({ w.onRemove(); } } + + // Temporarily store the current values in case the node is being recreated + // e.g. by group node conversion + this.controlValues = []; + this.lastType = this.widgets[0]?.type; + for(let i = 1; i < this.widgets.length; i++) { + this.controlValues.push(this.widgets[i].value); + } + setTimeout(() => { delete this.lastType; delete this.controlValues }, 15); this.widgets.length = 0; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index d2a6f4de4..62169abfb 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1774,7 +1774,9 @@ export class ComfyApp { if (parent?.updateLink) { link = parent.updateLink(link); } - inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + if (link) { + inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + } } } } From 6761233e9dd875002c4ff9dac4574828ac564156 Mon Sep 17 00:00:00 2001 From: Rafie Walker Date: Wed, 13 Dec 2023 21:52:11 +0100 Subject: [PATCH 40/98] Implement Self-Attention Guidance (#2201) * First SAG test * need to put extra options on the model instead of patcher * no errors and results seem not-broken * Use @ashen-uncensored formula, which works better!!! * Fix a crash when using weird resolutions. Remove an unnecessary UNet call * Improve comments, optimize memory in blur routine * SAG works with sampler_cfg_function --- comfy/samplers.py | 75 ++++++++++++++++++++++--- comfy_extras/nodes_sag.py | 115 ++++++++++++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 182 insertions(+), 9 deletions(-) create mode 100644 comfy_extras/nodes_sag.py diff --git a/comfy/samplers.py b/comfy/samplers.py index ffc1fe3ac..1cdad736d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,6 +1,7 @@ from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc import torch +import torch.nn.functional as F import enum from comfy import model_management import math @@ -60,10 +61,10 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option for t in range(rr): mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - conditionning = {} + conditioning = {} model_conds = conds["model_conds"] for c in model_conds: - conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) control = None if 'control' in conds: @@ -82,7 +83,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option patches['middle_patch'] = [gligen_patch] - return (input_x, mult, conditionning, area, control, patches) + return (input_x, mult, conditioning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -246,15 +247,71 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option return out_cond, out_uncond - if math.isclose(cond_scale, 1.0): + # if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out. + if math.isclose(cond_scale, 1.0) and "sag" not in model_options: uncond = None - cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale if "sampler_cfg_function" in model_options: - args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} - return x - model_options["sampler_cfg_function"](args) - else: - return uncond + (cond - uncond) * cond_scale + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} + cfg_result = x - model_options["sampler_cfg_function"](args) + + if "sag" in model_options: + assert uncond is not None, "SAG requires uncond guidance" + sag_scale = model_options["sag_scale"] + sag_sigma = model_options["sag_sigma"] + sag_threshold = model_options.get("sag_threshold", 1.0) + + # these methods are added by the sag patcher + uncond_attn = model.get_attn_scores() + mid_shape = model.get_mid_block_shape() + # create the adversarially blurred image + degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # call into the UNet + (sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options) + cfg_result += (degraded - sag) * sag_scale + return cfg_result + +def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0): + # reshape and GAP the attention map + _, hw1, hw2 = attn.shape + b, _, lh, lw = x0.shape + attn = attn.reshape(b, -1, hw1, hw2) + # Global Average Pool + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + # Reshape + mask = ( + mask.reshape(b, *mid_shape) + .unsqueeze(1) + .type(attn.dtype) + ) + # Upsample + mask = F.interpolate(mask, (lh, lw)) + + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) + blurred = blurred * mask + x0 * (1 - mask) + return blurred + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + return img class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py new file mode 100644 index 000000000..1ec0c93ac --- /dev/null +++ b/comfy_extras/nodes_sag.py @@ -0,0 +1,115 @@ +import torch +from torch import einsum +from einops import rearrange, repeat +import os +from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION + +# from comfy/ldm/modules/attention.py +# but modified to return attention scores as well as output +def attention_basic_with_sim(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + + h = heads + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * scale + + del q, k + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return (out, sim) + +class SagNode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}), + "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, scale, blur_sigma): + m = model.clone() + # set extra options on the model + m.model_options["sag"] = True + m.model_options["sag_scale"] = scale + m.model_options["sag_sigma"] = blur_sigma + + attn_scores = None + mid_block_shape = None + m.model.get_attn_scores = lambda: attn_scores + m.model.get_mid_block_shape = lambda: mid_block_shape + + # TODO: make this work properly with chunked batches + # currently, we can only save the attn from one UNet call + def attn_and_record(q, k, v, extra_options): + nonlocal attn_scores + # if uncond, save the attention scores + heads = extra_options["n_heads"] + cond_or_uncond = extra_options["cond_or_uncond"] + b = q.shape[0] // len(cond_or_uncond) + if 1 in cond_or_uncond: + uncond_index = cond_or_uncond.index(1) + # do the entire attention operation, but save the attention scores to attn_scores + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads) + # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] + n_slices = heads * b + attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] + return out + else: + return optimized_attention(q, k, v, heads=heads) + + # from diffusers: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch + def set_model_patch_replace(patch, name, key): + to = m.model_options["transformer_options"] + if "patches_replace" not in to: + to["patches_replace"] = {} + if name not in to["patches_replace"]: + to["patches_replace"][name] = {} + to["patches_replace"][name][key] = patch + set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0)) + # from diffusers: + # unet.mid_block.attentions[0].register_forward_hook() + def forward_hook(m, inp, out): + nonlocal mid_block_shape + mid_block_shape = out[0].shape[-2:] + m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "Self-Attention Guidance": SagNode, +} diff --git a/nodes.py b/nodes.py index db96e0e2d..3d24750cb 100644 --- a/nodes.py +++ b/nodes.py @@ -1867,6 +1867,7 @@ def init_custom_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_sag.py", ] for node_file in extras_files: From ba04a87d104ca73d8ed8e8423706edcdf5e209a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Dec 2023 16:10:03 -0500 Subject: [PATCH 41/98] Refactor and improve the sag node. Moved all the sag related code to comfy_extras/nodes_sag.py --- comfy/model_patcher.py | 19 +- comfy/samplers.py | 533 +++++++++++++++++--------------------- comfy_extras/nodes_sag.py | 103 ++++++-- 3 files changed, 334 insertions(+), 321 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 55ca913ec..e0acdc961 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -61,6 +61,9 @@ class ModelPatcher: else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + def set_model_sampler_post_cfg_function(self, post_cfg_function): + self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] + def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function @@ -70,13 +73,17 @@ class ModelPatcher: to["patches"] = {} to["patches"][name] = to["patches"].get(name, []) + [patch] - def set_model_patch_replace(self, patch, name, block_name, number): + def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): to = self.model_options["transformer_options"] if "patches_replace" not in to: to["patches_replace"] = {} if name not in to["patches_replace"]: to["patches_replace"][name] = {} - to["patches_replace"][name][(block_name, number)] = patch + if transformer_index is not None: + block = (block_name, number, transformer_index) + else: + block = (block_name, number) + to["patches_replace"][name][block] = patch def set_model_attn1_patch(self, patch): self.set_model_patch(patch, "attn1_patch") @@ -84,11 +91,11 @@ class ModelPatcher: def set_model_attn2_patch(self, patch): self.set_model_patch(patch, "attn2_patch") - def set_model_attn1_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn1", block_name, number) + def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index) - def set_model_attn2_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn2", block_name, number) + def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index) def set_model_attn1_output_patch(self, patch): self.set_model_patch(patch, "attn1_output_patch") diff --git a/comfy/samplers.py b/comfy/samplers.py index 1cdad736d..106e72876 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,7 +1,6 @@ from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc import torch -import torch.nn.functional as F import enum from comfy import model_management import math @@ -9,310 +8,260 @@ from comfy import model_base import comfy.utils import comfy.conds +def get_area_and_mult(conds, x_in, timestep_in): + area = (x_in.shape[2], x_in.shape[3], 0, 0) + strength = 1.0 + + if 'timestep_start' in conds: + timestep_start = conds['timestep_start'] + if timestep_in[0] > timestep_start: + return None + if 'timestep_end' in conds: + timestep_end = conds['timestep_end'] + if timestep_in[0] < timestep_end: + return None + if 'area' in conds: + area = conds['area'] + if 'strength' in conds: + strength = conds['strength'] + + input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if 'mask' in conds: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask_strength = 1.0 + if "mask_strength" in conds: + mask_strength = conds["mask_strength"] + mask = conds['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in conds: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + + conditioning = {} + model_conds = conds["model_conds"] + for c in model_conds: + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + + control = None + if 'control' in conds: + control = conds['control'] + + patches = None + if 'gligen' in conds: + gligen = conds['gligen'] + patches = {} + gligen_type = gligen[0] + gligen_model = gligen[1] + if gligen_type == "position": + gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) + else: + gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) + + patches['middle_patch'] = [gligen_patch] + + return (input_x, mult, conditioning, area, control, patches) + +def cond_equal_size(c1, c2): + if c1 is c2: + return True + if c1.keys() != c2.keys(): + return False + for k in c1: + if not c1[k].can_concat(c2[k]): + return False + return True + +def can_concat_cond(c1, c2): + if c1[0].shape != c2[0].shape: + return False + + #control + if (c1[4] is None) != (c2[4] is None): + return False + if c1[4] is not None: + if c1[4] is not c2[4]: + return False + + #patches + if (c1[5] is None) != (c2[5] is None): + return False + if (c1[5] is not None): + if c1[5] is not c2[5]: + return False + + return cond_equal_size(c1[2], c2[2]) + +def cond_cat(c_list): + c_crossattn = [] + c_concat = [] + c_adm = [] + crossattn_max_len = 0 + + temp = {} + for x in c_list: + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur + + out = {} + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) + + return out + +def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): + out_cond = torch.zeros_like(x_in) + out_count = torch.ones_like(x_in) * 1e-37 + + out_uncond = torch.zeros_like(x_in) + out_uncond_count = torch.ones_like(x_in) * 1e-37 + + COND = 0 + UNCOND = 1 + + to_run = [] + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, COND)] + if uncond is not None: + for x in uncond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, UNCOND)] + + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = model_management.get_free_memory(x_in.device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) < free_memory: + to_batch = batch_amount + break + + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + area = [] + control = None + patches = None + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x += [p[0]] + mult += [p[1]] + c += [p[2]] + area += [p[3]] + cond_or_uncond += [o[1]] + control = p[4] + patches = p[5] + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) + + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) + + transformer_options = {} + if 'transformer_options' in model_options: + transformer_options = model_options['transformer_options'].copy() + + if patches is not None: + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["sigmas"] = timestep + + c['transformer_options'] = transformer_options + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + else: + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + del input_x + + for o in range(batch_chunks): + if cond_or_uncond[o] == COND: + out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + else: + out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + del mult + + out_cond /= out_count + del out_count + out_uncond /= out_uncond_count + del out_uncond_count + return out_cond, out_uncond #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - def get_area_and_mult(conds, x_in, timestep_in): - area = (x_in.shape[2], x_in.shape[3], 0, 0) - strength = 1.0 + if math.isclose(cond_scale, 1.0): + uncond_ = None + else: + uncond_ = uncond - if 'timestep_start' in conds: - timestep_start = conds['timestep_start'] - if timestep_in[0] > timestep_start: - return None - if 'timestep_end' in conds: - timestep_end = conds['timestep_end'] - if timestep_in[0] < timestep_end: - return None - if 'area' in conds: - area = conds['area'] - if 'strength' in conds: - strength = conds['strength'] - - input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in conds: - # Scale the mask to the size of the input - # The mask should have been resized as we began the sampling process - mask_strength = 1.0 - if "mask_strength" in conds: - mask_strength = conds["mask_strength"] - mask = conds['mask'] - assert(mask.shape[1] == x_in.shape[2]) - assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) - else: - mask = torch.ones_like(input_x) - mult = mask * strength - - if 'mask' not in conds: - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - - conditioning = {} - model_conds = conds["model_conds"] - for c in model_conds: - conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - - control = None - if 'control' in conds: - control = conds['control'] - - patches = None - if 'gligen' in conds: - gligen = conds['gligen'] - patches = {} - gligen_type = gligen[0] - gligen_model = gligen[1] - if gligen_type == "position": - gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) - else: - gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) - - patches['middle_patch'] = [gligen_patch] - - return (input_x, mult, conditioning, area, control, patches) - - def cond_equal_size(c1, c2): - if c1 is c2: - return True - if c1.keys() != c2.keys(): - return False - for k in c1: - if not c1[k].can_concat(c2[k]): - return False - return True - - def can_concat_cond(c1, c2): - if c1[0].shape != c2[0].shape: - return False - - #control - if (c1[4] is None) != (c2[4] is None): - return False - if c1[4] is not None: - if c1[4] is not c2[4]: - return False - - #patches - if (c1[5] is None) != (c2[5] is None): - return False - if (c1[5] is not None): - if c1[5] is not c2[5]: - return False - - return cond_equal_size(c1[2], c2[2]) - - def cond_cat(c_list): - c_crossattn = [] - c_concat = [] - c_adm = [] - crossattn_max_len = 0 - - temp = {} - for x in c_list: - for k in x: - cur = temp.get(k, []) - cur.append(x[k]) - temp[k] = cur - - out = {} - for k in temp: - conds = temp[k] - out[k] = conds[0].concat(conds[1:]) - - return out - - def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - - to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, UNCOND)] - - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - free_memory = model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: - to_batch = batch_amount - break - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x += [p[0]] - mult += [p[1]] - c += [p[2]] - area += [p[3]] - cond_or_uncond += [o[1]] - control = p[4] - patches = p[5] - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - else: - transformer_options["patches"] = patches - - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["sigmas"] = timestep - - c['transformer_options'] = transformer_options - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x - - for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult - - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond - - - # if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out. - if math.isclose(cond_scale, 1.0) and "sag" not in model_options: - uncond = None - - cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} cfg_result = x - model_options["sampler_cfg_function"](args) - if "sag" in model_options: - assert uncond is not None, "SAG requires uncond guidance" - sag_scale = model_options["sag_scale"] - sag_sigma = model_options["sag_sigma"] - sag_threshold = model_options.get("sag_threshold", 1.0) + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) - # these methods are added by the sag patcher - uncond_attn = model.get_attn_scores() - mid_shape = model.get_mid_block_shape() - # create the adversarially blurred image - degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold) - degraded_noised = degraded + x - uncond_pred - # call into the UNet - (sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options) - cfg_result += (degraded - sag) * sag_scale return cfg_result -def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0): - # reshape and GAP the attention map - _, hw1, hw2 = attn.shape - b, _, lh, lw = x0.shape - attn = attn.reshape(b, -1, hw1, hw2) - # Global Average Pool - mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - # Reshape - mask = ( - mask.reshape(b, *mid_shape) - .unsqueeze(1) - .type(attn.dtype) - ) - # Upsample - mask = F.interpolate(mask, (lh, lw)) - - blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) - blurred = blurred * mask + x0 * (1 - mask) - return blurred - -def gaussian_blur_2d(img, kernel_size, sigma): - ksize_half = (kernel_size - 1) * 0.5 - - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) - - pdf = torch.exp(-0.5 * (x / sigma).pow(2)) - - x_kernel = pdf / pdf.sum() - x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) - - kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) - kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) - - padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] - - img = F.pad(img, padding, mode="reflect") - img = F.conv2d(img, kernel2d, groups=img.shape[-3]) - return img - class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): super().__init__() diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 1ec0c93ac..4c609565a 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -1,8 +1,12 @@ import torch from torch import einsum +import torch.nn.functional as F +import math + from einops import rearrange, repeat import os from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION +import comfy.samplers # from comfy/ldm/modules/attention.py # but modified to return attention scores as well as output @@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): ) return (out, sim) -class SagNode: +def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): + # reshape and GAP the attention map + _, hw1, hw2 = attn.shape + b, _, lh, lw = x0.shape + attn = attn.reshape(b, -1, hw1, hw2) + # Global Average Pool + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + ratio = round(math.sqrt(lh * lw / hw1)) + mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] + + # Reshape + mask = ( + mask.reshape(b, *mid_shape) + .unsqueeze(1) + .type(attn.dtype) + ) + # Upsample + mask = F.interpolate(mask, (lh, lw)) + + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) + blurred = blurred * mask + x0 * (1 - mask) + return blurred + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + return img + +class SelfAttentionGuidance: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -63,15 +109,9 @@ class SagNode: def patch(self, model, scale, blur_sigma): m = model.clone() - # set extra options on the model - m.model_options["sag"] = True - m.model_options["sag_scale"] = scale - m.model_options["sag_sigma"] = blur_sigma - + attn_scores = None mid_block_shape = None - m.model.get_attn_scores = lambda: attn_scores - m.model.get_mid_block_shape = lambda: mid_block_shape # TODO: make this work properly with chunked batches # currently, we can only save the attn from one UNet call @@ -92,24 +132,41 @@ class SagNode: else: return optimized_attention(q, k, v, heads=heads) + def post_cfg_function(args): + nonlocal attn_scores + nonlocal mid_block_shape + uncond_attn = attn_scores + + sag_scale = scale + sag_sigma = blur_sigma + sag_threshold = 1.0 + model = args["model"] + uncond_pred = args["uncond_denoised"] + uncond = args["uncond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"] + x = args["input"] + + # create the adversarially blurred image + degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # call into the UNet + (sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) + return cfg_result + (degraded - sag) * sag_scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + # from diffusers: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch - def set_model_patch_replace(patch, name, key): - to = m.model_options["transformer_options"] - if "patches_replace" not in to: - to["patches_replace"] = {} - if name not in to["patches_replace"]: - to["patches_replace"][name] = {} - to["patches_replace"][name][key] = patch - set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0)) - # from diffusers: - # unet.mid_block.attentions[0].register_forward_hook() - def forward_hook(m, inp, out): - nonlocal mid_block_shape - mid_block_shape = out[0].shape[-2:] - m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook) + m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) + return (m, ) NODE_CLASS_MAPPINGS = { - "Self-Attention Guidance": SagNode, + "SelfAttentionGuidance": SelfAttentionGuidance, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SelfAttentionGuidance": "Self-Attention Guidance", } From 6c5990f7dba2d5d0ad04c7ed5a702b067926cbe2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Dec 2023 20:28:04 -0500 Subject: [PATCH 42/98] Fix cfg being calculated more than once if sampler_cfg_function. --- comfy/samplers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 106e72876..7dc27528a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -250,10 +250,11 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option uncond_ = uncond cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) - cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} cfg_result = x - model_options["sampler_cfg_function"](args) + else: + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale for fn in model_options.get("sampler_post_cfg_function", []): args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, From 329c57199302f6b9ccfebb86c96e937c386da92f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 14 Dec 2023 11:41:49 -0500 Subject: [PATCH 43/98] Improve code legibility. --- comfy/samplers.py | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 7dc27528a..39bc3774a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -2,6 +2,7 @@ from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc import torch import enum +import collections from comfy import model_management import math from comfy import model_base @@ -61,9 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in): for c in model_conds: conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - control = None - if 'control' in conds: - control = conds['control'] + control = conds.get('control', None) patches = None if 'gligen' in conds: @@ -78,7 +77,8 @@ def get_area_and_mult(conds, x_in, timestep_in): patches['middle_patch'] = [gligen_patch] - return (input_x, mult, conditioning, area, control, patches) + cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) + return cond_obj(input_x, mult, conditioning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -91,24 +91,24 @@ def cond_equal_size(c1, c2): return True def can_concat_cond(c1, c2): - if c1[0].shape != c2[0].shape: + if c1.input_x.shape != c2.input_x.shape: return False - #control - if (c1[4] is None) != (c2[4] is None): - return False - if c1[4] is not None: - if c1[4] is not c2[4]: + def objects_concatable(obj1, obj2): + if (obj1 is None) != (obj2 is None): return False + if obj1 is not None: + if obj1 is not obj2: + return False + return True - #patches - if (c1[5] is None) != (c2[5] is None): + if not objects_concatable(c1.control, c2.control): return False - if (c1[5] is not None): - if c1[5] is not c2[5]: - return False - return cond_equal_size(c1[2], c2[2]) + if not objects_concatable(c1.patches, c2.patches): + return False + + return cond_equal_size(c1.conditioning, c2.conditioning) def cond_cat(c_list): c_crossattn = [] @@ -184,13 +184,13 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): for x in to_batch: o = to_run.pop(x) p = o[0] - input_x += [p[0]] - mult += [p[1]] - c += [p[2]] - area += [p[3]] - cond_or_uncond += [o[1]] - control = p[4] - patches = p[5] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + control = p.control + patches = p.patches batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) From b12b48e170ccff156dc6ec11242bb6af7d8437fd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 14 Dec 2023 20:11:46 -0500 Subject: [PATCH 44/98] cleanup. --- comfy_extras/nodes_sag.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 4c609565a..0bcda84f0 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -111,7 +111,6 @@ class SelfAttentionGuidance: m = model.clone() attn_scores = None - mid_block_shape = None # TODO: make this work properly with chunked batches # currently, we can only save the attn from one UNet call @@ -134,7 +133,6 @@ class SelfAttentionGuidance: def post_cfg_function(args): nonlocal attn_scores - nonlocal mid_block_shape uncond_attn = attn_scores sag_scale = scale From a5056cfb1f41f1f9e6fcd523ef8091e6e7cd6e3b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Dec 2023 01:28:16 -0500 Subject: [PATCH 45/98] Remove useless code. --- comfy/ldm/modules/attention.py | 4 +--- comfy_extras/nodes_sag.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 8d86aa53d..3e12886b0 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -104,9 +104,7 @@ def attention_basic(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 0bcda84f0..7e293ef63 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -27,9 +27,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * scale + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale else: sim = einsum('b i d, b j d -> b i j', q, k) * scale From 574363a8a69cb48db71c96d03fe056d56853f4f6 Mon Sep 17 00:00:00 2001 From: Hari Date: Sat, 16 Dec 2023 00:28:16 +0530 Subject: [PATCH 46/98] Implement Perp-Neg --- comfy/samplers.py | 3 +- comfy_extras/nodes_perpneg.py | 58 +++++++++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_perpneg.py diff --git a/comfy/samplers.py b/comfy/samplers.py index 39bc3774a..35c9ccf05 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -251,7 +251,8 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, + "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} cfg_result = x - model_options["sampler_cfg_function"](args) else: cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py new file mode 100644 index 000000000..36f2eb01a --- /dev/null +++ b/comfy_extras/nodes_perpneg.py @@ -0,0 +1,58 @@ +import torch +import comfy.model_management +import comfy.sample +import comfy.samplers +import comfy.utils + + +class PerpNeg: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "clip": ("CLIP", ), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, clip, neg_scale): + m = model.clone() + + tokens = clip.tokenize("") + nocond, nocond_pooled = clip.encode_from_tokens(tokens, return_pooled=True) + nocond = [[nocond, {"pooled_output": nocond_pooled}]] + nocond = comfy.sample.convert_cond(nocond) + + def cfg_function(args): + model = args["model"] + noise_pred_pos = args["cond_denoised"] + noise_pred_neg = args["uncond_denoised"] + cond_scale = args["cond_scale"] + x = args["input"] + sigma = args["sigma"] + model_options = args["model_options"] + + (noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond, None, x, sigma, model_options) + + pos = noise_pred_pos - noise_pred_nocond + neg = noise_pred_neg - noise_pred_nocond + perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg + perp_neg = perp * neg_scale + cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) + cfg_result = x - cfg_result + return cfg_result + + m.set_model_sampler_cfg_function(cfg_function) + + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "PerpNeg": PerpNeg, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PerpNeg": "Perp-Neg", +} diff --git a/nodes.py b/nodes.py index 3d24750cb..3031b10aa 100644 --- a/nodes.py +++ b/nodes.py @@ -1868,6 +1868,7 @@ def init_custom_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_sag.py", + "nodes_perpneg.py", ] for node_file in extras_files: From 9cad2f06ff93e3ac512f7f008c11026530900b51 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Dec 2023 14:40:57 -0500 Subject: [PATCH 47/98] Make perp neg take a conditioning input instead of a CLIP one. --- comfy_extras/nodes_perpneg.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index 36f2eb01a..0c5ccb77a 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -9,7 +9,7 @@ class PerpNeg: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL", ), - "clip": ("CLIP", ), + "empty_conditioning": ("CONDITIONING", ), "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), }} RETURN_TYPES = ("MODEL",) @@ -17,13 +17,9 @@ class PerpNeg: CATEGORY = "_for_testing" - def patch(self, model, clip, neg_scale): + def patch(self, model, empty_conditioning, neg_scale): m = model.clone() - - tokens = clip.tokenize("") - nocond, nocond_pooled = clip.encode_from_tokens(tokens, return_pooled=True) - nocond = [[nocond, {"pooled_output": nocond_pooled}]] - nocond = comfy.sample.convert_cond(nocond) + nocond = comfy.sample.convert_cond(empty_conditioning) def cfg_function(args): model = args["model"] @@ -33,9 +29,9 @@ class PerpNeg: x = args["input"] sigma = args["sigma"] model_options = args["model_options"] - + (noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond, None, x, sigma, model_options) - + pos = noise_pred_pos - noise_pred_nocond neg = noise_pred_neg - noise_pred_nocond perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg From 014c8bf2f227eea118eb2f232962647289314853 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Dec 2023 15:26:12 -0500 Subject: [PATCH 48/98] Refactor LCM to support more model types. --- comfy_extras/nodes_model_advanced.py | 46 +++++----------------------- 1 file changed, 8 insertions(+), 38 deletions(-) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index efcdf1932..83ef73c70 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -17,41 +17,19 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class ModelSamplingDiscreteDistilled(torch.nn.Module): +class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 - def __init__(self): - super().__init__() - self.sigma_data = 1.0 - timesteps = 1000 - beta_start = 0.00085 - beta_end = 0.012 + def __init__(self, model_config=None): + super().__init__(model_config) - betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) + self.skip_steps = self.num_timesteps // self.original_timesteps - self.skip_steps = timesteps // self.original_timesteps - - - alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) for x in range(self.original_timesteps): - alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps] - sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 - self.set_sigmas(sigmas) - - def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - - @property - def sigma_min(self): - return self.sigmas[0] - - @property - def sigma_max(self): - return self.sigmas[-1] + self.set_sigmas(sigmas_valid) def timestep(self, sigma): log_sigma = sigma.log() @@ -66,14 +44,6 @@ class ModelSamplingDiscreteDistilled(torch.nn.Module): log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp().to(timestep.device) - def percent_to_sigma(self, percent): - if percent <= 0.0: - return 999999999.9 - if percent >= 1.0: - return 0.0 - percent = 1.0 - percent - return self.sigma(torch.tensor(percent * 999.0)).item() - def rescale_zero_terminal_snr_sigmas(sigmas): alphas_cumprod = 1 / ((sigmas * sigmas) + 1) @@ -154,7 +124,7 @@ class ModelSamplingContinuousEDM: class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) model_sampling.set_sigma_range(sigma_min, sigma_max) m.add_object_patch("model_sampling", model_sampling) return (m, ) From adc40e3d7bc612e81874cf9f5738bda0e17ce0a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Dec 2023 15:46:23 -0500 Subject: [PATCH 49/98] Forgot this. --- comfy_extras/nodes_model_advanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 83ef73c70..541ce8fa5 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -92,7 +92,7 @@ class ModelSamplingDiscrete: class ModelSamplingAdvanced(sampling_base, sampling_type): pass - model_sampling = ModelSamplingAdvanced() + model_sampling = ModelSamplingAdvanced(model.model.model_config) if zsnr: model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) From 719fa0866fcd7744de3bf5ffd9ddd076f7c36b98 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 15 Dec 2023 18:53:08 -0500 Subject: [PATCH 50/98] Set clip vision model in eval mode so it works without inference mode. --- comfy/clip_vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index ba8a3a8d5..85b017e0c 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -39,6 +39,7 @@ class ClipVisionModel(): self.dtype = torch.float16 self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.disable_weight_init) + self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): From 6596654d4792dae97831d429fcb095376f243a7c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 16 Dec 2023 01:21:00 -0500 Subject: [PATCH 51/98] Add a LatentBatch node. --- comfy_extras/nodes_latent.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index cedf39d63..2eefc4c55 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -3,9 +3,7 @@ import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: - latent.movedim(1, -1) latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") - latent.movedim(-1, 1) return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) @@ -102,9 +100,32 @@ class LatentInterpolate: samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) return (samples_out,) +class LatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "batch" + + CATEGORY = "latent/batch" + + def batch(self, samples1, samples2): + samples_out = samples1.copy() + s1 = samples1["samples"] + s2 = samples2["samples"] + + if s1.shape[1:] != s2.shape[1:]: + s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center") + s = torch.cat((s1, s2), dim=0) + samples_out["samples"] = s + samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, + "LatentBatch": LatentBatch, } From 172984db0175845c1a16bc3100fed0e46b42f604 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 16 Dec 2023 01:29:57 -0500 Subject: [PATCH 52/98] Fix SAG not working on certain resolutions. --- comfy_extras/nodes_sag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 7e293ef63..fea673d6c 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -58,7 +58,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - ratio = round(math.sqrt(lh * lw / hw1)) + ratio = math.ceil(math.sqrt(lh * lw / hw1)) mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] # Reshape From 574efd3782c022fd00f55745d784207f6d318b15 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 16 Dec 2023 02:30:16 -0500 Subject: [PATCH 53/98] Fix perpneg not working on SDXL. --- comfy_extras/nodes_perpneg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index 0c5ccb77a..45e4d418f 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -29,8 +29,9 @@ class PerpNeg: x = args["input"] sigma = args["sigma"] model_options = args["model_options"] + nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") - (noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond, None, x, sigma, model_options) + (noise_pred_nocond, _) = comfy.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options) pos = noise_pred_pos - noise_pred_nocond neg = noise_pred_neg - noise_pred_nocond From 13e6d5366e87ae76f517e1d79349e51fe92087b2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 16 Dec 2023 02:47:26 -0500 Subject: [PATCH 54/98] Switch clip vision to manual cast. Make it use the same dtype as the text encoder. --- comfy/clip_vision.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 85b017e0c..a95616f1d 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -34,11 +34,8 @@ class ClipVisionModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() - self.dtype = torch.float32 - if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): - self.dtype = torch.float16 - - self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.disable_weight_init) + self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) @@ -47,15 +44,8 @@ class ClipVisionModel(): def encode_image(self, image): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device)) - - if self.dtype != torch.float32: - precision_scope = torch.autocast - else: - precision_scope = lambda a, b: contextlib.nullcontext(a) - - with precision_scope(comfy.model_management.get_autocast_device(self.load_device), torch.float32): - out = self.model(pixel_values=pixel_values, intermediate_output=-2) + pixel_values = clip_preprocess(image.to(self.load_device)).float() + out = self.model(pixel_values=pixel_values, intermediate_output=-2) outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) From e45d920ae392c608b9cfcb1f863cfc8688ebb518 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 16 Dec 2023 03:06:10 -0500 Subject: [PATCH 55/98] Don't resize clip vision image when the size is already good. --- comfy/clip_vision.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index a95616f1d..4564fcfb2 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -19,11 +19,13 @@ class Output: def clip_preprocess(image, size=224): mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) - scale = (size / min(image.shape[1], image.shape[2])) - image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + scale = (size / min(image.shape[2], image.shape[3])) + image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) From 6453dc1ca2d98d89af7cf312bb48d1e3fd2ca27f Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sat, 16 Dec 2023 14:16:12 +0000 Subject: [PATCH 56/98] Fix name counter preventing more than 3 of the same node Fix linked widget offset when populating values --- tests-ui/tests/groupNode.test.js | 32 ++++++++++++++++++++++++++++++++ web/extensions/core/groupNode.js | 9 +++++---- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index 625890a09..e6ebedd91 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -970,4 +970,36 @@ describe("group node", () => { }); }); }); + test("converted inputs with linked widgets map values correctly on creation", async () => { + const { ez, graph, app } = await start(); + const k1 = ez.KSampler(); + const k2 = ez.KSampler(); + k1.widgets.seed.convertToInput(); + k2.widgets.seed.convertToInput(); + + const rr = ez.Reroute(); + rr.outputs[0].connectTo(k1.inputs.seed); + rr.outputs[0].connectTo(k2.inputs.seed); + + const group = await convertToGroup(app, graph, "test", [k1, k2, rr]); + expect(group.widgets.steps.value).toBe(20); + expect(group.widgets.cfg.value).toBe(8); + expect(group.widgets.scheduler.value).toBe("normal"); + expect(group.widgets["KSampler steps"].value).toBe(20); + expect(group.widgets["KSampler cfg"].value).toBe(8); + expect(group.widgets["KSampler scheduler"].value).toBe("normal"); + }); + test("allow multiple of the same node type to be added", async () => { + const { ez, graph, app } = await start(); + const nodes = [...Array(10)].map(() => ez.ImageScaleBy()); + const group = await convertToGroup(app, graph, "test", nodes); + expect(group.inputs.length).toBe(10); + expect(group.outputs.length).toBe(10); + expect(group.widgets.length).toBe(20); + expect(group.widgets.map((w) => w.widget.name)).toStrictEqual( + [...Array(10)] + .map((_, i) => `${i > 0 ? "ImageScaleBy " : ""}${i > 1 ? i + " " : ""}`) + .flatMap((p) => [`${p}upscale_method`, `${p}scale_by`]) + ); + }); }); diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index dc962ac24..4cf1f7621 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -331,16 +331,17 @@ export class GroupNodeConfig { getInputConfig(node, inputName, seenInputs, config, extra) { let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; + let key = name; let prefix = ""; // Special handling for primitive to include the title if it is set rather than just "value" if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) { prefix = `${node.title ?? node.type} `; - name = `${prefix}${inputName}`; + key = name = `${prefix}${inputName}`; if (name in seenInputs) { name = `${prefix}${seenInputs[name]} ${inputName}`; } } - seenInputs[name] = (seenInputs[name] ?? 1) + 1; + seenInputs[key] = (seenInputs[key] ?? 1) + 1; if (inputName === "seed" || inputName === "noise_seed") { if (!extra) extra = {}; @@ -1010,10 +1011,10 @@ export class GroupNodeHandler { const newName = map[oldName]; const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); const mainWidget = this.node.widgets[widgetIndex]; - if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift)) { + if (this.populatePrimitive(node, nodeId, oldName, i, linkedShift) || widgetIndex === -1) { // Find the inner widget and shift by the number of linked widgets as they will have been removed too const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName); - linkedShift += innerWidget.linkedWidgets?.length ?? 0; + linkedShift += innerWidget?.linkedWidgets?.length ?? 0; } if (widgetIndex === -1) { continue; From a036b940752fff830fde4108cd243a35df2fa1ee Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Dec 2023 02:37:22 -0500 Subject: [PATCH 57/98] Move SaveAnimated nodes to image->animation. --- comfy_extras/nodes_images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 5ad2235a5..aa80f5269 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -74,7 +74,7 @@ class SaveAnimatedWEBP: OUTPUT_NODE = True - CATEGORY = "_for_testing" + CATEGORY = "image/animation" def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): method = self.methods.get(method) @@ -136,7 +136,7 @@ class SaveAnimatedPNG: OUTPUT_NODE = True - CATEGORY = "_for_testing" + CATEGORY = "image/animation" def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append From 2f9d6a97ec7e3cb25beb13a320da8ec4573355d3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 17 Dec 2023 16:59:21 -0500 Subject: [PATCH 58/98] Add --deterministic option to make pytorch use deterministic algorithms. --- comfy/cli_args.py | 2 +- comfy/model_management.py | 4 ++++ main.py | 4 ++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d9c8668f4..8de0adb53 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -102,7 +102,7 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") - +parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/model_management.py b/comfy/model_management.py index b6a9471bf..23f39c985 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -28,6 +28,10 @@ total_vram = 0 lowvram_available = True xpu_available = False +if args.deterministic: + print("Using deterministic algorithms for pytorch") + torch.use_deterministic_algorithms(True, warn_only=True) + directml_enabled = False if args.directml is not None: import torch_directml diff --git a/main.py b/main.py index 1f9c5f443..f6aeceed2 100644 --- a/main.py +++ b/main.py @@ -64,6 +64,10 @@ if __name__ == "__main__": os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) print("Set cuda device to:", args.cuda_device) + if args.deterministic: + if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" + import cuda_malloc import comfy.utils From 2258f851593fcb4bf34d22dddd3b7cb711db91ec Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Dec 2023 03:18:40 -0500 Subject: [PATCH 59/98] Support stable zero 123 model. To use it use the ImageOnlyCheckpointLoader to load the checkpoint and the new Stable_Zero123 node. --- comfy/model_base.py | 30 ++++++++++++++++++ comfy/sample.py | 3 +- comfy/supported_models.py | 29 ++++++++++++++++- comfy_extras/nodes_stable3d.py | 58 ++++++++++++++++++++++++++++++++++ nodes.py | 1 + 5 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 comfy_extras/nodes_stable3d.py diff --git a/comfy/model_base.py b/comfy/model_base.py index a7582b330..c80848b27 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -328,3 +328,33 @@ class SVD_img2vid(BaseModel): out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) return out + +class Stable_Zero123(BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): + super().__init__(model_config, model_type, device=device) + self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) + self.cc_projection.weight.copy_(cc_projection_weight) + self.cc_projection.bias.copy_(cc_projection_bias) + + def extra_conds(self, **kwargs): + out = {} + + latent_image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + + if latent_image is None: + latent_image = torch.zeros_like(noise) + + if latent_image.shape[1:] != noise.shape[1:]: + latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + if cross_attn.shape[-1] != 768: + cross_attn = self.cc_projection(cross_attn) + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + return out diff --git a/comfy/sample.py b/comfy/sample.py index eadd6dcc8..4b0d15c49 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -47,7 +47,8 @@ def convert_cond(cond): temp = c[1].copy() model_conds = temp.get("model_conds", {}) if c[0] is not None: - model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) + model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove + temp["cross_attn"] = c[0] temp["model_conds"] = model_conds out.append(temp) return out diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2f2dee871..251bf6ace 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -252,5 +252,32 @@ class SVD_img2vid(supported_models_base.BASE): def clip_target(self): return None -models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega] +class Stable_Zero123(supported_models_base.BASE): + unet_config = { + "context_dim": 768, + "model_channels": 320, + "use_linear_in_transformer": False, + "adm_in_channels": None, + "use_temporal_attention": False, + "in_channels": 8, + } + + unet_extra_config = { + "num_heads": 8, + "num_head_channels": -1, + } + + clip_vision_prefix = "cond_stage_model.model.visual." + + latent_format = latent_formats.SD15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"]) + return out + + def clip_target(self): + return None + + +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py new file mode 100644 index 000000000..fa64d9246 --- /dev/null +++ b/comfy_extras/nodes_stable3d.py @@ -0,0 +1,58 @@ +import torch +import nodes +import comfy.utils + +def camera_embeddings(elevation, azimuth): + elevation = torch.as_tensor([elevation]) + azimuth = torch.as_tensor([azimuth]) + embeddings = torch.stack( + [ + torch.deg2rad( + (90 - elevation) - (90) + ), # Zero123 polar is 90-elevation + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + torch.deg2rad( + 90 - torch.full_like(elevation, 0) + ), + ], dim=-1).unsqueeze(1) + + return embeddings + + +class Zero123_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + cam_embeds = camera_embeddings(elevation, azimuth) + cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + +NODE_CLASS_MAPPINGS = { + "Zero123_Conditioning": Zero123_Conditioning, +} diff --git a/nodes.py b/nodes.py index 3031b10aa..7ed7a8e4a 100644 --- a/nodes.py +++ b/nodes.py @@ -1869,6 +1869,7 @@ def init_custom_nodes(): "nodes_video_model.py", "nodes_sag.py", "nodes_perpneg.py", + "nodes_stable3d.py", ] for node_file in extras_files: From d2f322902cf33c7235e7982900b5a88d55d5ecd1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Dec 2023 03:59:50 -0500 Subject: [PATCH 60/98] Fix wrong Stable Zero123 node name. --- comfy_extras/nodes_stable3d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index fa64d9246..c6791d8de 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -20,7 +20,7 @@ def camera_embeddings(elevation, azimuth): return embeddings -class Zero123_Conditioning: +class StableZero123_Conditioning: @classmethod def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), @@ -54,5 +54,5 @@ class Zero123_Conditioning: return (positive, negative, {"samples":latent}) NODE_CLASS_MAPPINGS = { - "Zero123_Conditioning": Zero123_Conditioning, + "StableZero123_Conditioning": StableZero123_Conditioning, } From 8cf1daa108400f2e29188fa0b4404d6ebc83b864 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Dec 2023 12:54:23 -0500 Subject: [PATCH 61/98] Fix SDXL area composition sometimes not using the right pooled output. --- comfy/model_base.py | 10 ++++++++++ comfy/samplers.py | 7 ++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index c80848b27..f2a6f9841 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -126,9 +126,15 @@ class BaseModel(torch.nn.Module): cond_concat.append(blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = comfy.conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = comfy.conds.CONDRegular(adm) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + return out def load_model_weights(self, sd, unet_prefix=""): @@ -322,6 +328,10 @@ class SVD_img2vid(BaseModel): out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + if "time_conditioning" in kwargs: out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) diff --git a/comfy/samplers.py b/comfy/samplers.py index 35c9ccf05..18bd75ef1 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -599,6 +599,10 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) @@ -613,9 +617,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model if latent_image is not None: latent_image = model.process_latent_in(latent_image) - if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} From 571ea8cdcc2d1bf4fa7f398dad68415dacfff02f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Dec 2023 17:03:32 -0500 Subject: [PATCH 62/98] Fix SAG not working with cfg 1.0 --- comfy/model_patcher.py | 8 ++++++-- comfy/samplers.py | 2 +- comfy_extras/nodes_sag.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e0acdc961..6acb2d647 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -55,14 +55,18 @@ class ModelPatcher: def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) - def set_model_sampler_cfg_function(self, sampler_cfg_function): + def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True - def set_model_sampler_post_cfg_function(self, post_cfg_function): + def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] + if disable_cfg1_optimization: + self.model_options["disable_cfg1_optimization"] = True def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function diff --git a/comfy/samplers.py b/comfy/samplers.py index 18bd75ef1..47f347787 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -244,7 +244,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - if math.isclose(cond_scale, 1.0): + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: uncond_ = None else: uncond_ = uncond diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index fea673d6c..450ac3eea 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -151,7 +151,7 @@ class SelfAttentionGuidance: (sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) return cfg_result + (degraded - sag) * sag_scale - m.set_model_sampler_post_cfg_function(post_cfg_function) + m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) # from diffusers: # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch From 9a7619b72de8e9e6cbc2818d4deef0914539fbe3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 19 Dec 2023 02:32:59 -0500 Subject: [PATCH 63/98] Fix regression with inpaint model. --- comfy/samplers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 47f347787..0453c1f6f 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -599,6 +599,9 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) + if latent_image is not None: + latent_image = model.process_latent_in(latent_image) + if hasattr(model, 'extra_conds'): positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) @@ -614,10 +617,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if latent_image is not None: - latent_image = model.process_latent_in(latent_image) - - extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) From 40ea2bd01113f9fd46be8e6a61cded204155f9a4 Mon Sep 17 00:00:00 2001 From: Oleksiy Nehlyadyuk Date: Tue, 19 Dec 2023 17:07:55 +0300 Subject: [PATCH 64/98] Update requirements.txt the UI launches with one missing module `torchvision`. spits out a `ModuleNotFoundError`. installing `torchvision` module fixed it. --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 14524485a..b698f2feb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ torch torchsde +torchvision einops transformers>=4.25.1 safetensors>=0.3.0 From e65110fd93a3f9e4c378e87b26a9fc6c5c68cc2d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 19 Dec 2023 20:22:01 +0000 Subject: [PATCH 65/98] Fix dom widgets not being hidden --- web/scripts/domWidget.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index e919428a0..bb4c892b5 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -177,6 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () { for (const w of node.widgets) { if (w.element) { w.element.hidden = hidden; + w.element.style.display = hidden ? "none" : null; if (hidden) { w.options.onHide?.(w); } From 8680ac3dfd51ab1276eb05d17ef8837e023f4a1f Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 19 Dec 2023 20:38:07 +0000 Subject: [PATCH 66/98] try to improve test reliability --- .github/workflows/test-ui.yaml | 2 +- tests-ui/afterSetup.js | 9 +++++++++ tests-ui/jest.config.js | 2 ++ 3 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 tests-ui/afterSetup.js diff --git a/.github/workflows/test-ui.yaml b/.github/workflows/test-ui.yaml index 950691755..4b8b97934 100644 --- a/.github/workflows/test-ui.yaml +++ b/.github/workflows/test-ui.yaml @@ -22,5 +22,5 @@ jobs: run: | npm ci npm run test:generate - npm test + npm test -- --verbose working-directory: ./tests-ui diff --git a/tests-ui/afterSetup.js b/tests-ui/afterSetup.js new file mode 100644 index 000000000..983f3af64 --- /dev/null +++ b/tests-ui/afterSetup.js @@ -0,0 +1,9 @@ +const { start } = require("./utils"); +const lg = require("./utils/litegraph"); + +// Load things once per test file before to ensure its all warmed up for the tests +beforeAll(async () => { + lg.setup(global); + await start({ resetEnv: true }); + lg.teardown(global); +}); diff --git a/tests-ui/jest.config.js b/tests-ui/jest.config.js index b5a5d646d..86fff5057 100644 --- a/tests-ui/jest.config.js +++ b/tests-ui/jest.config.js @@ -2,8 +2,10 @@ const config = { testEnvironment: "jsdom", setupFiles: ["./globalSetup.js"], + setupFilesAfterEnv: ["./afterSetup.js"], clearMocks: true, resetModules: true, + testTimeout: 10000 }; module.exports = config; From e82942cc293a7f707f3ba5611e33ec2284278268 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Dec 2023 02:51:18 -0500 Subject: [PATCH 67/98] Add a denoise parameter to the SDTurboScheduler. --- comfy_extras/nodes_custom_sampler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 008d0b8d6..8791d8ae3 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -87,6 +87,7 @@ class SDTurboScheduler: return {"required": {"model": ("MODEL",), "steps": ("INT", {"default": 1, "min": 1, "max": 10}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("SIGMAS",) @@ -94,8 +95,9 @@ class SDTurboScheduler: FUNCTION = "get_sigmas" - def get_sigmas(self, model, steps): - timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] + def get_sigmas(self, model, steps, denoise): + start_step = 10 - int(10 * denoise) + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] sigmas = model.model.model_sampling.sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) From 5f54614e7fa8a7ae493c7ac8a8c0677970cac908 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Dec 2023 16:22:18 -0500 Subject: [PATCH 68/98] Add a RebatchImages node. --- comfy_extras/nodes_rebatch.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 88a4ebe29..3010fbd4b 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -99,10 +99,40 @@ class LatentRebatch: return (output_list,) +class ImageRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("IMAGE",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "image/batch" + + def rebatch(self, images, batch_size): + batch_size = batch_size[0] + + output_list = [] + all_images = [] + for img in images: + for i in range(img.shape[0]): + all_images.append(img[i:i+1]) + + for i in range(0, len(all_images), batch_size): + output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) + + return (output_list,) + NODE_CLASS_MAPPINGS = { "RebatchLatents": LatentRebatch, + "RebatchImages": ImageRebatch, } NODE_DISPLAY_NAME_MAPPINGS = { "RebatchLatents": "Rebatch Latents", -} \ No newline at end of file + "RebatchImages": "Rebatch Images", +} From a1e1c69f7d555ae281ec46ca7a40c7195f3a249c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Dec 2023 16:39:09 -0500 Subject: [PATCH 69/98] LoadImage now loads all the frames from animated images as a batch. --- nodes.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/nodes.py b/nodes.py index 7ed7a8e4a..027bf55d9 100644 --- a/nodes.py +++ b/nodes.py @@ -9,7 +9,7 @@ import math import time import random -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageSequence from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch @@ -1410,17 +1410,30 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) - i = Image.open(image_path) - i = ImageOps.exif_transpose(i) - image = i.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - if 'A' in i.getbands(): - mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) + img = Image.open(image_path) + output_images = [] + output_masks = [] + for i in ImageSequence.Iterator(img): + i = ImageOps.exif_transpose(i) + image = i.convert("RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + + if len(output_images) > 1: + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (image, mask.unsqueeze(0)) + output_image = output_images[0] + output_mask = output_masks[0] + + return (output_image, output_mask) @classmethod def IS_CHANGED(s, image): From 6781b181ef8ab8101e6bdf45a580509d6e4e1f7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 21 Dec 2023 02:35:01 -0500 Subject: [PATCH 70/98] Fix potential tensor device issue with ImageCompositeMasked. --- comfy_extras/nodes_mask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index d8c65c2b6..a7d164bf7 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -6,6 +6,7 @@ import comfy.utils from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): + source = source.to(destination.device) if resize_source: source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") @@ -20,7 +21,7 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou if mask is None: mask = torch.ones_like(source) else: - mask = mask.clone() + mask = mask.to(destination.device, copy=True) mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) From d35267e85a865c30a5fa63fdb0a21f94f4cc37e7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 21 Dec 2023 13:21:25 -0500 Subject: [PATCH 71/98] Litegraph updates. Update from upstream repo. Auto select value in prompt. Increase maximum number of nodes to 10k. --- web/lib/litegraph.core.js | 45 +++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index f571edb30..434c4a83b 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -48,7 +48,7 @@ EVENT_LINK_COLOR: "#A86", CONNECTING_LINK_COLOR: "#AFA", - MAX_NUMBER_OF_NODES: 1000, //avoid infinite loops + MAX_NUMBER_OF_NODES: 10000, //avoid infinite loops DEFAULT_POSITION: [100, 100], //default node position VALID_SHAPES: ["default", "box", "round", "card"], //,"circle" @@ -3788,16 +3788,42 @@ /** * returns the bounding of the object, used for rendering purposes - * bounding is: [topleft_cornerx, topleft_cornery, width, height] * @method getBounding - * @return {Float32Array[4]} the total size + * @param out {Float32Array[4]?} [optional] a place to store the output, to free garbage + * @param compute_outer {boolean?} [optional] set to true to include the shadow and connection points in the bounding calculation + * @return {Float32Array[4]} the bounding box in format of [topleft_cornerx, topleft_cornery, width, height] */ - LGraphNode.prototype.getBounding = function(out) { + LGraphNode.prototype.getBounding = function(out, compute_outer) { out = out || new Float32Array(4); - out[0] = this.pos[0] - 4; - out[1] = this.pos[1] - LiteGraph.NODE_TITLE_HEIGHT; - out[2] = this.flags.collapsed ? (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) : this.size[0] + 4; - out[3] = this.flags.collapsed ? LiteGraph.NODE_TITLE_HEIGHT : this.size[1] + LiteGraph.NODE_TITLE_HEIGHT; + const nodePos = this.pos; + const isCollapsed = this.flags.collapsed; + const nodeSize = this.size; + + let left_offset = 0; + // 1 offset due to how nodes are rendered + let right_offset = 1 ; + let top_offset = 0; + let bottom_offset = 0; + + if (compute_outer) { + // 4 offset for collapsed node connection points + left_offset = 4; + // 6 offset for right shadow and collapsed node connection points + right_offset = 6 + left_offset; + // 4 offset for collapsed nodes top connection points + top_offset = 4; + // 5 offset for bottom shadow and collapsed node connection points + bottom_offset = 5 + top_offset; + } + + out[0] = nodePos[0] - left_offset; + out[1] = nodePos[1] - LiteGraph.NODE_TITLE_HEIGHT - top_offset; + out[2] = isCollapsed ? + (this._collapsed_width || LiteGraph.NODE_COLLAPSED_WIDTH) + right_offset : + nodeSize[0] + right_offset; + out[3] = isCollapsed ? + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset : + nodeSize[1] + LiteGraph.NODE_TITLE_HEIGHT + bottom_offset; if (this.onBounding) { this.onBounding(out); @@ -7674,7 +7700,7 @@ LGraphNode.prototype.executeAction = function(action) continue; } - if (!overlapBounding(this.visible_area, n.getBounding(temp))) { + if (!overlapBounding(this.visible_area, n.getBounding(temp, true))) { continue; } //out of the visible area @@ -11336,6 +11362,7 @@ LGraphNode.prototype.executeAction = function(action) name_element.innerText = title; var value_element = dialog.querySelector(".value"); value_element.value = value; + value_element.select(); var input = value_element; input.addEventListener("keydown", function(e) { From 261bcbb0d933c3bf1fce02e6cc652936da2de1e0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Dec 2023 04:05:42 -0500 Subject: [PATCH 72/98] A few missing comfy ops in the VAE. --- comfy/ldm/models/autoencoder.py | 5 +++-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index d2f1d74a9..b91ec3249 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -8,6 +8,7 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri from comfy.ldm.util import instantiate_from_config from comfy.ldm.modules.ema import LitEma +import comfy.ops class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = True): @@ -161,12 +162,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine): }, **kwargs, ) - self.quant_conv = torch.nn.Conv2d( + self.quant_conv = comfy.ops.disable_weight_init.Conv2d( (1 + ddconfig["double_z"]) * ddconfig["z_channels"], (1 + ddconfig["double_z"]) * embed_dim, 1, ) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim def get_autoencoder_params(self) -> list: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index fce29cb85..cc81c1f23 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -41,7 +41,7 @@ def nonlinearity(x): def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class Upsample(nn.Module): From 36a7953142ccf3f9debf9305e3cbeb3bfe956ee3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 22 Dec 2023 14:24:04 -0500 Subject: [PATCH 73/98] Greatly improve lowvram sampling speed by getting rid of accelerate. Let me know if this breaks anything. --- comfy/controlnet.py | 2 +- comfy/model_base.py | 6 +-- comfy/model_management.py | 52 ++++++++++++---------- comfy/ops.py | 92 ++++++++++++++++++++++++++++++--------- requirements.txt | 1 - 5 files changed, 103 insertions(+), 50 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 110b5c7c2..8404054f3 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -283,7 +283,7 @@ class ControlLora(ControlNet): cm = self.control_model.state_dict() for k in sd: - weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) + weight = sd[k] try: comfy.utils.set_attr(self.control_model, k, weight) except: diff --git a/comfy/model_base.py b/comfy/model_base.py index f2a6f9841..b3a1fcd51 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -162,11 +162,7 @@ class BaseModel(torch.nn.Module): def state_dict_for_saving(self, clip_state_dict, vae_state_dict): clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) - unet_sd = self.diffusion_model.state_dict() - unet_state_dict = {} - for k in unet_sd: - unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k) - + unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) if self.get_dtype() == torch.float16: diff --git a/comfy/model_management.py b/comfy/model_management.py index 23f39c985..61c967f64 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -218,15 +218,8 @@ if args.force_fp16: FORCE_FP16 = True if lowvram_available: - try: - import accelerate - if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): - vram_state = set_vram_to - except Exception as e: - import traceback - print(traceback.format_exc()) - print("ERROR: LOW VRAM MODE NEEDS accelerate.") - lowvram_available = False + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to if cpu_state != CPUState.GPU: @@ -298,8 +291,20 @@ class LoadedModel: if lowvram_model_memory > 0: print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) - device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) + mem_counter = 0 + for m in self.real_model.modules(): + if hasattr(m, "comfy_cast_weights"): + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + module_mem = 0 + sd = m.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nelement() * t.element_size() + if mem_counter + module_mem < lowvram_model_memory: + m.to(self.device) + mem_counter += module_mem + self.model_accelerated = True if is_intel_xpu() and not args.disable_ipex_optimize: @@ -309,7 +314,11 @@ class LoadedModel: def model_unload(self): if self.model_accelerated: - accelerate.hooks.remove_hook_from_submodules(self.real_model) + for m in self.real_model.modules(): + if hasattr(m, "prev_comfy_cast_weights"): + m.comfy_cast_weights = m.prev_comfy_cast_weights + del m.prev_comfy_cast_weights + self.model_accelerated = False self.model.unpatch_model(self.model.offload_device) @@ -402,14 +411,14 @@ def load_models_gpu(models, memory_required=0): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) + lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM else: lowvram_model_memory = 0 if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 256 * 1024 * 1024 + lowvram_model_memory = 64 * 1024 * 1024 cur_loaded_model = loaded_model.model_load(lowvram_model_memory) current_loaded_models.insert(0, loaded_model) @@ -566,6 +575,11 @@ def supports_dtype(device, dtype): #TODO return True return False +def device_supports_non_blocking(device): + if is_device_mps(device): + return False #pytorch bug? mps doesn't support non blocking + return True + def cast_to_device(tensor, device, dtype, copy=False): device_supports_cast = False if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: @@ -576,9 +590,7 @@ def cast_to_device(tensor, device, dtype, copy=False): elif is_intel_xpu(): device_supports_cast = True - non_blocking = True - if is_device_mps(device): - non_blocking = False #pytorch bug? mps doesn't support non blocking + non_blocking = device_supports_non_blocking(device) if device_supports_cast: if copy: @@ -742,11 +754,7 @@ def soft_empty_cache(force=False): torch.cuda.empty_cache() torch.cuda.ipc_collect() -def resolve_lowvram_weight(weight, model, key): - if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break. - key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device. - op = comfy.utils.get_attr(model, '.'.join(key_split[:-1])) - weight = op._hf_hook.weights_map[key_split[-1]] +def resolve_lowvram_weight(weight, model, key): #TODO: remove return weight #TODO: might be cleaner to put this somewhere else diff --git a/comfy/ops.py b/comfy/ops.py index 08c633847..f6f85de60 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,27 +1,93 @@ import torch from contextlib import contextmanager +import comfy.model_management + +def cast_bias_weight(s, input): + bias = None + non_blocking = comfy.model_management.device_supports_non_blocking(input.device) + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking) + return weight, bias + class disable_weight_init: class Linear(torch.nn.Linear): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class Conv2d(torch.nn.Conv2d): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class Conv3d(torch.nn.Conv3d): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + class GroupNorm(torch.nn.GroupNorm): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + class LayerNorm(torch.nn.LayerNorm): + comfy_cast_weights = False def reset_parameters(self): return None + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: @@ -31,35 +97,19 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") -def cast_bias_weight(s, input): - bias = None - if s.bias is not None: - bias = s.bias.to(device=input.device, dtype=input.dtype) - weight = s.weight.to(device=input.device, dtype=input.dtype) - return weight, bias class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + comfy_cast_weights = True class Conv2d(disable_weight_init.Conv2d): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + comfy_cast_weights = True class Conv3d(disable_weight_init.Conv3d): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + comfy_cast_weights = True class GroupNorm(disable_weight_init.GroupNorm): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + comfy_cast_weights = True class LayerNorm(disable_weight_init.LayerNorm): - def forward(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + comfy_cast_weights = True diff --git a/requirements.txt b/requirements.txt index 14524485a..da1fbb27e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ einops transformers>=4.25.1 safetensors>=0.3.0 aiohttp -accelerate pyyaml Pillow scipy From a252963f956a7d76344e3f0ce24b1047480a25af Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 23 Dec 2023 04:25:06 -0500 Subject: [PATCH 74/98] --disable-smart-memory now unloads everything like it did originally. --- comfy/model_management.py | 4 ++++ execution.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 61c967f64..3adc42702 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -754,6 +754,10 @@ def soft_empty_cache(force=False): torch.cuda.empty_cache() torch.cuda.ipc_collect() +def unload_all_models(): + free_memory(1e30, get_torch_device()) + + def resolve_lowvram_weight(weight, model, key): #TODO: remove return weight diff --git a/execution.py b/execution.py index 7db1f095b..7ad171313 100644 --- a/execution.py +++ b/execution.py @@ -382,6 +382,8 @@ class PromptExecutor: for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() From d0165d819afe76bd4e6bdd710eb5f3e571b6a804 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 24 Dec 2023 07:06:59 -0500 Subject: [PATCH 75/98] Fix SVD lowvram mode. --- comfy/ldm/modules/diffusionmodules/util.py | 6 +++--- comfy/ldm/modules/temporal_ae.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 68175b62a..ac7e27173 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -51,9 +51,9 @@ class AlphaBlender(nn.Module): if self.merge_strategy == "fixed": # make shape compatible # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) - alpha = self.mix_factor + alpha = self.mix_factor.to(image_only_indicator.device) elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor) + alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device)) # make shape compatible # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) elif self.merge_strategy == "learned_with_images": @@ -61,7 +61,7 @@ class AlphaBlender(nn.Module): alpha = torch.where( image_only_indicator.bool(), torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"), ) alpha = rearrange(alpha, self.rearrange_pattern) # make shape compatible diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py index 7ea68dc9e..2992aeafc 100644 --- a/comfy/ldm/modules/temporal_ae.py +++ b/comfy/ldm/modules/temporal_ae.py @@ -82,14 +82,14 @@ class VideoResBlock(ResnetBlock): x = self.time_stack(x, temb) - alpha = self.get_alpha(bs=b // timesteps) + alpha = self.get_alpha(bs=b // timesteps).to(x.device) x = alpha * x + (1.0 - alpha) * x_mix x = rearrange(x, "b c t h w -> (b t) c h w") return x -class AE3DConv(torch.nn.Conv2d): +class AE3DConv(ops.Conv2d): def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): super().__init__(in_channels, out_channels, *args, **kwargs) if isinstance(video_kernel_size, Iterable): @@ -97,7 +97,7 @@ class AE3DConv(torch.nn.Conv2d): else: padding = int(video_kernel_size // 2) - self.time_mix_conv = torch.nn.Conv3d( + self.time_mix_conv = ops.Conv3d( in_channels=out_channels, out_channels=out_channels, kernel_size=video_kernel_size, @@ -167,7 +167,7 @@ class AttnVideoBlock(AttnBlock): emb = emb[:, None, :] x_mix = x_mix + emb - alpha = self.get_alpha() + alpha = self.get_alpha().to(x.device) x_mix = self.time_mix_block(x_mix, timesteps=timesteps) x = alpha * x + (1.0 - alpha) * x_mix # alpha merge From 392878a2621d131ac9e856fb2d428d9c6e2a022e Mon Sep 17 00:00:00 2001 From: shiimizu Date: Mon, 25 Dec 2023 19:17:40 -0800 Subject: [PATCH 76/98] Fix hiding dom widgets. --- web/scripts/domWidget.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index bb4c892b5..eb0742d38 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -177,7 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () { for (const w of node.widgets) { if (w.element) { w.element.hidden = hidden; - w.element.style.display = hidden ? "none" : null; + w.element.style.display = hidden ? "none" : undefined; if (hidden) { w.options.onHide?.(w); } From 61b3f15f8f2bc0822cb98eac48742fb32f6af396 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 26 Dec 2023 05:02:02 -0500 Subject: [PATCH 77/98] Fix lowvram mode not working with unCLIP and Revision code. --- comfy/ldm/modules/diffusionmodules/upscaling.py | 4 ++-- comfy/ldm/modules/encoders/noise_aug_modules.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index 709a7f52e..768a47f9c 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -43,8 +43,8 @@ class AbstractLowScaleModel(nn.Module): def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise) def forward(self, x): return x, None diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index b59bf204b..66767b587 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -15,12 +15,12 @@ class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): def scale(self, x): # re-normalize to centered mean and unit variance - x = (x - self.data_mean) * 1. / self.data_std + x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) return x def unscale(self, x): # back to original data stats - x = (x * self.data_std) + self.data_mean + x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) return x def forward(self, x, noise_level=None): From f21bb41787ce590ea6eff16163ee83404d9ff0d5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 26 Dec 2023 12:52:21 -0500 Subject: [PATCH 78/98] Fix taesd VAE in lowvram mode. --- comfy/taesd/taesd.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 46f3097a2..8f96c54e5 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -7,9 +7,10 @@ import torch import torch.nn as nn import comfy.utils +import comfy.ops def conv(n_in, n_out, **kwargs): - return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs) class Clamp(nn.Module): def forward(self, x): @@ -19,7 +20,7 @@ class Block(nn.Module): def __init__(self, n_in, n_out): super().__init__() self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) - self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.fuse = nn.ReLU() def forward(self, x): return self.fuse(self.conv(x) + self.skip(x)) From f15dce71fde5eee12a5689e86468368a1791d200 Mon Sep 17 00:00:00 2001 From: AYF Date: Wed, 27 Dec 2023 00:55:11 -0500 Subject: [PATCH 79/98] Add title to the API workflow json. (#2380) * Add `title` to the API workflow json. * API: Move `title` to `_meta` dictionary, imply unused. --- web/scripts/app.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 62169abfb..73dba65cc 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1784,6 +1784,10 @@ export class ComfyApp { output[String(node.id)] = { inputs, class_type: node.comfyClass, + // Ignored by the backend. + "_meta": { + title: node.title, + }, }; } } From e478b1794e91977c50dc6eea6228ef1248044507 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 27 Dec 2023 01:07:02 -0500 Subject: [PATCH 80/98] Only add _meta title to api prompt when dev mode is enabled in UI. --- web/scripts/app.js | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 73dba65cc..62b71c0a1 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1781,14 +1781,19 @@ export class ComfyApp { } } - output[String(node.id)] = { + let node_data = { inputs, class_type: node.comfyClass, - // Ignored by the backend. - "_meta": { - title: node.title, - }, }; + + if (this.ui.settings.getSettingValue("Comfy.DevMode")) { + // Ignored by the backend. + node_data["_meta"] = { + title: node.title, + } + } + + output[String(node.id)] = node_data; } } From c782144433e41c21ae2dfd75d0bc28255d2e966d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 27 Dec 2023 13:50:57 -0500 Subject: [PATCH 81/98] Fix clip vision lowvram mode not working. --- comfy/clip_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 850b5fdbe..7397b7a26 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -151,7 +151,7 @@ class CLIPVisionEmbeddings(torch.nn.Module): def forward(self, pixel_values): embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) - return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight + return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device) class CLIPVision(torch.nn.Module): From a8baa40d85aafb7d0d33221ce86eb6ca1402b4c7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Dec 2023 12:23:07 -0500 Subject: [PATCH 82/98] Cleanup. --- .vscode/settings.json | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 202121e10..000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "path-intellisense.mappings": { - "../": "${workspaceFolder}/web/extensions/core" - }, - "[python]": { - "editor.defaultFormatter": "ms-python.autopep8" - }, - "python.formatting.provider": "none" -} From e1e322cf69319d125680d791822d8f4733fea027 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Dec 2023 21:41:10 -0500 Subject: [PATCH 83/98] Load weights that can't be lowvramed to target device. --- comfy/model_management.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3adc42702..c0cb4130c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -259,6 +259,14 @@ print("VAE dtype:", VAE_DTYPE) current_loaded_models = [] +def module_size(module): + module_mem = 0 + sd = module.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nelement() * t.element_size() + return module_mem + class LoadedModel: def __init__(self, model): self.model = model @@ -296,14 +304,14 @@ class LoadedModel: if hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True - module_mem = 0 - sd = m.state_dict() - for k in sd: - t = sd[k] - module_mem += t.nelement() * t.element_size() + module_mem = module_size(m) if mem_counter + module_mem < lowvram_model_memory: m.to(self.device) mem_counter += module_mem + elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode + m.to(self.device) + mem_counter += module_size(m) + print("lowvram: loaded module regularly", m) self.model_accelerated = True From 12e822c6c8a9019abd1127e0e61f1405de8d14e3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 28 Dec 2023 21:46:20 -0500 Subject: [PATCH 84/98] Use function to calculate model size in model patcher. --- comfy/model_patcher.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6acb2d647..b1b5ea6a8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -28,13 +28,9 @@ class ModelPatcher: if self.size > 0: return self.size model_sd = self.model.state_dict() - size = 0 - for k in model_sd: - t = model_sd[k] - size += t.nelement() * t.element_size() - self.size = size + self.size = comfy.model_management.module_size(self.model) self.model_keys = set(model_sd.keys()) - return size + return self.size def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) From 04b713dda1c4109f84386b17b0f7c25722f0ae15 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 29 Dec 2023 17:33:30 -0500 Subject: [PATCH 85/98] Fix VALIDATE_INPUTS getting called multiple times. Allow VALIDATE_INPUTS to only validate specific inputs. --- execution.py | 58 +++++++++++++++++++++++++++++++--------------------- nodes.py | 5 +---- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/execution.py b/execution.py index 7ad171313..53ba2e0f8 100644 --- a/execution.py +++ b/execution.py @@ -7,6 +7,7 @@ import threading import heapq import traceback import gc +import inspect import torch import nodes @@ -402,6 +403,10 @@ def validate_inputs(prompt, item, validated): errors = [] valid = True + validate_function_inputs = [] + if hasattr(obj_class, "VALIDATE_INPUTS"): + validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + for x in required_inputs: if x not in inputs: error = { @@ -531,29 +536,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all = get_input_data(inputs, obj_class, unique_id) - #ret = obj_class.VALIDATE_INPUTS(**input_data_all) - ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for i, r in enumerate(ret): - if r is not True: - details = f"{x}" - if r is not False: - details += f" - {str(r)}" - - error = { - "type": "custom_validation_failed", - "message": "Custom validation failed for node", - "details": details, - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - else: + if x not in validate_function_inputs: if isinstance(type_input, list): if val not in type_input: input_config = info @@ -580,6 +563,35 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue + if len(validate_function_inputs) > 0: + input_data_all = get_input_data(inputs, obj_class, unique_id) + input_filtered = {} + for x in input_data_all: + if x in validate_function_inputs: + input_filtered[x] = input_data_all[x] + + #ret = obj_class.VALIDATE_INPUTS(**input_filtered) + ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") + for x in input_filtered: + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f" - {str(r)}" + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) else: diff --git a/nodes.py b/nodes.py index 027bf55d9..8e3ec947c 100644 --- a/nodes.py +++ b/nodes.py @@ -1491,13 +1491,10 @@ class LoadImageMask: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, image, channel): + def VALIDATE_INPUTS(s, image): if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) - if channel not in s._color_channels: - return "Invalid color channel: {}".format(channel) - return True class ImageScale: From 144e6580a4a43d5390769665a5032bb584481ff1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 29 Dec 2023 17:47:24 -0500 Subject: [PATCH 86/98] This cache timeout is pretty useless in practice. --- folder_paths.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 98704945e..a8726d8dd 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -184,8 +184,7 @@ def cached_filename_list_(folder_name): if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] - if time.perf_counter() < (out[2] + 0.5): - return out + for x in out[1]: time_modified = out[1][x] folder = x From 1b103e0cb2d7aeb05fc8b7e006d4438e7bceca20 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 30 Dec 2023 05:38:21 -0500 Subject: [PATCH 87/98] Add argument to run the VAE on the CPU. --- comfy/cli_args.py | 2 ++ comfy/model_management.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 8de0adb53..50d7b62fa 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -66,6 +66,8 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") +parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.") + fpte_group = parser.add_mutually_exclusive_group() fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") diff --git a/comfy/model_management.py b/comfy/model_management.py index c0cb4130c..fefd3c8c9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -186,6 +186,9 @@ except: if is_intel_xpu(): VAE_DTYPE = torch.bfloat16 +if args.cpu_vae: + VAE_DTYPE = torch.float32 + if args.fp16_vae: VAE_DTYPE = torch.float16 elif args.bf16_vae: @@ -555,6 +558,8 @@ def intermediate_device(): return torch.device("cpu") def vae_device(): + if args.cpu_vae: + return torch.device("cpu") return get_torch_device() def vae_offload_device(): From 36e15f2507ee81e27140cf15ffcda40070968928 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 31 Dec 2023 05:05:14 -0500 Subject: [PATCH 88/98] Reregister nodes when pressing refresh button. --- web/scripts/app.js | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 62b71c0a1..7353f5a3b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -2020,12 +2020,8 @@ export class ComfyApp { async refreshComboInNodes() { const defs = await api.getNodeDefs(); - for(const nodeId in LiteGraph.registered_node_types) { - const node = LiteGraph.registered_node_types[nodeId]; - const nodeDef = defs[nodeId]; - if(!nodeDef) continue; - - node.nodeData = nodeDef; + for (const nodeId in defs) { + this.registerNodeDef(nodeId, defs[nodeId]); } for(let nodeNum in this.graph._nodes) { From d1f3637a5a944d0607b899babd8ff11d87100503 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 31 Dec 2023 15:37:20 -0500 Subject: [PATCH 89/98] Add a denoise parameter to BasicScheduler node. --- comfy_extras/nodes_custom_sampler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 8791d8ae3..d5f9ba007 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -13,6 +13,7 @@ class BasicScheduler: {"model": ("MODEL",), "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("SIGMAS",) @@ -20,8 +21,13 @@ class BasicScheduler: FUNCTION = "get_sigmas" - def get_sigmas(self, model, scheduler, steps): - sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu() + def get_sigmas(self, model, scheduler, steps, denoise): + total_steps = steps + if denoise < 1.0: + total_steps = int(steps/denoise) + + sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() + sigmas = sigmas[-(steps + 1):] return (sigmas, ) From 66831eb6e96cd974fb2d0fc4f299b23c6af16685 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 1 Jan 2024 14:27:56 -0500 Subject: [PATCH 90/98] Add node id and prompt id to websocket progress packet. --- main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index f6aeceed2..bcf738729 100644 --- a/main.py +++ b/main.py @@ -106,6 +106,8 @@ def prompt_worker(q, server): item, item_id = queue_item execution_start_time = time.perf_counter() prompt_id = item[1] + server.last_prompt_id = prompt_id + e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True q.task_done(item_id, e.outputs_ui) @@ -131,7 +133,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def hijack_progress(server): def hook(value, total, preview_image): comfy.model_management.throw_exception_if_processing_interrupted() - server.send_sync("progress", {"value": value, "max": total}, server.client_id) + progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} + + server.send_sync("progress", progress, server.client_id) if preview_image is not None: server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) comfy.utils.set_progress_bar_global_hook(hook) From 79f73a4b33c76867098b182cb0db1b657b2996f5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 Jan 2024 01:50:29 -0500 Subject: [PATCH 91/98] Remove useless code. --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 057dd16b2..cb0a79835 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -437,9 +437,6 @@ class UNetModel(nn.Module): operations=ops, ): super().__init__() - assert use_spatial_transformer == True, "use_spatial_transformer has to be true" - if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' if context_dim is not None: assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' @@ -456,7 +453,6 @@ class UNetModel(nn.Module): if num_head_channels == -1: assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' - self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels From a47f609f904842a12c54c465fc93bda38257e289 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 Jan 2024 01:50:57 -0500 Subject: [PATCH 92/98] Auto detect out_channels from model. --- comfy/model_detection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index e3af422a3..ad16c0fbf 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype): unet_config = { "use_checkpoint": False, "image_size": 32, - "out_channels": 4, "use_spatial_transformer": True, "legacy": False } @@ -49,6 +48,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] + out_channels = state_dict['{}out.2.weight'.format(key_prefix)].shape[0] num_res_blocks = [] channel_mult = [] @@ -122,6 +122,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): transformer_depth_middle = -1 unet_config["in_channels"] = in_channels + unet_config["out_channels"] = out_channels unet_config["model_channels"] = model_channels unet_config["num_res_blocks"] = num_res_blocks unet_config["transformer_depth"] = transformer_depth From 8e2c99e3cf3b85390ff9aa47edb7cbd319dfdc3b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 Jan 2024 11:50:00 -0500 Subject: [PATCH 93/98] Fix issue when websocket is deleted when data is being sent. --- server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 9b1e3269d..bd6f026b2 100644 --- a/server.py +++ b/server.py @@ -584,7 +584,8 @@ class PromptServer(): message = self.encode_bytes(event, data) if sid is None: - for ws in self.sockets.values(): + sockets = list(self.sockets.values()) + for ws in sockets: await send_socket_catch_exception(ws.send_bytes, message) elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_bytes, message) @@ -593,7 +594,8 @@ class PromptServer(): message = {"type": event, "data": data} if sid is None: - for ws in self.sockets.values(): + sockets = list(self.sockets.values()) + for ws in sockets: await send_socket_catch_exception(ws.send_json, message) elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_json, message) From 5eddfdd80caae18305cde55624c1b932a3e4a360 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 Jan 2024 13:24:34 -0500 Subject: [PATCH 94/98] Refactor VAE code. Replace constants with downscale_ratio and latent_channels. --- comfy/sd.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 220637a05..10a6715a8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -157,6 +157,8 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) + self.downscale_ratio = 8 + self.latent_channels = 4 if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -204,9 +206,9 @@ class VAE: decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() output = torch.clamp(( - (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar)) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.downscale_ratio, output_device=self.output_device, pbar = pbar)) / 3.0) / 2.0, min=0.0, max=1.0) return output @@ -217,9 +219,9 @@ class VAE: pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() - samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples /= 3.0 return samples @@ -231,7 +233,7 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device) + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) @@ -255,7 +257,7 @@ class VAE: free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device) + samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() From 2c4e92a98b8338f754855a0db7dce164945e366e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 Jan 2024 14:41:33 -0500 Subject: [PATCH 95/98] Fix regression. --- comfy/model_detection.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ad16c0fbf..ea824c44c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -48,7 +48,12 @@ def detect_unet_config(state_dict, key_prefix, dtype): unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] - out_channels = state_dict['{}out.2.weight'.format(key_prefix)].shape[0] + + out_key = '{}out.2.weight'.format(key_prefix) + if out_key in state_dict: + out_channels = state_dict[out_key].shape[0] + else: + out_channels = 4 num_res_blocks = [] channel_mult = [] From a7874d1a8b88f9e5cc3d37fdba9b763004b6357d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 Jan 2024 03:30:39 -0500 Subject: [PATCH 96/98] Add support for the stable diffusion x4 upscaling model. This is an old model. Load the checkpoint like a regular one and use the new SD_4XUpscale_Conditioning node. --- comfy/latent_formats.py | 4 +++ comfy/model_base.py | 21 +++++++++++++++ comfy/sd.py | 5 ++++ comfy/supported_models.py | 28 +++++++++++++++++++- comfy_extras/nodes_sdupscale.py | 45 +++++++++++++++++++++++++++++++++ nodes.py | 1 + 6 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_sdupscale.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index c209087e0..2252a075e 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -33,3 +33,7 @@ class SDXL(LatentFormat): [-0.3112, -0.2359, -0.2076] ] self.taesd_decoder_name = "taesdxl_decoder" + +class SD_X4(LatentFormat): + def __init__(self): + self.scale_factor = 0.08333 diff --git a/comfy/model_base.py b/comfy/model_base.py index b3a1fcd51..64a380ff3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -364,3 +364,24 @@ class Stable_Zero123(BaseModel): cross_attn = self.cc_projection(cross_attn) out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) return out + +class SD_X4Upscaler(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): + super().__init__(model_config, model_type, device=device) + + def extra_conds(self, **kwargs): + out = {} + + image = kwargs.get("concat_image", None) + noise = kwargs.get("noise", None) + + if image is None: + image = torch.zeros_like(noise)[:,:3] + + if image.shape[1:] != noise.shape[1:]: + image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + image = utils.resize_to_batch_size(image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(image) + return out diff --git a/comfy/sd.py b/comfy/sd.py index 10a6715a8..1ff25bec6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -174,6 +174,11 @@ class VAE: else: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + + if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + ddconfig['ch_mult'] = [1, 2, 4] + self.downscale_ratio = 4 + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 251bf6ace..e7a6cc179 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -278,6 +278,32 @@ class Stable_Zero123(supported_models_base.BASE): def clip_target(self): return None +class SD_X4Upscaler(SD20): + unet_config = { + "context_dim": 1024, + "model_channels": 256, + 'in_channels': 7, + "use_linear_in_transformer": True, + "adm_in_channels": None, + "use_temporal_attention": False, + } -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega] + unet_extra_config = { + "disable_self_attentions": [True, True, True, False], + "num_heads": 8, + "num_head_channels": -1, + } + + latent_format = latent_formats.SD_X4 + + sampling_settings = { + "linear_start": 0.0001, + "linear_end": 0.02, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SD_X4Upscaler(self, device=device) + return out + +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py new file mode 100644 index 000000000..38a027e0b --- /dev/null +++ b/comfy_extras/nodes_sdupscale.py @@ -0,0 +1,45 @@ +import torch +import nodes +import comfy.utils + +class SD_4XUpscale_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), + # "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), #TODO + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/upscale_diffusion" + + def encode(self, images, positive, negative, scale_ratio): + width = max(1, round(images.shape[-2] * scale_ratio)) + height = max(1, round(images.shape[-3] * scale_ratio)) + + pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center") + + out_cp = [] + out_cn = [] + + for t in positive: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + out_cp.append(n) + + for t in negative: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + out_cn.append(n) + + latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) + return (out_cp, out_cn, {"samples":latent}) + +NODE_CLASS_MAPPINGS = { + "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, +} diff --git a/nodes.py b/nodes.py index 8e3ec947c..82244cf76 100644 --- a/nodes.py +++ b/nodes.py @@ -1880,6 +1880,7 @@ def init_custom_nodes(): "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", + "nodes_sdupscale.py", ] for node_file in extras_files: From ef4f6037cbbbd4150c44862eb398428b70f19263 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 Jan 2024 12:16:30 -0500 Subject: [PATCH 97/98] Fix model patches not working in custom sampling scheduler nodes. --- comfy/model_patcher.py | 47 ++++++++++++++-------------- comfy_extras/nodes_custom_sampler.py | 8 +++-- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b1b5ea6a8..a88b737cc 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -174,40 +174,41 @@ class ModelPatcher: sd.pop(k) return sd - def patch_model(self, device_to=None): + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: old = getattr(self.model, k) if k not in self.object_patches_backup: self.object_patches_backup[k] = old setattr(self.model, k, self.object_patches[k]) - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) - continue + if patch_weights: + model_sd = self.model_state_dict() + for key in self.patches: + if key not in model_sd: + print("could not patch. key doesn't exist in model:", key) + continue - weight = model_sd[key] + weight = model_sd[key] - inplace_update = self.weight_inplace_update + inplace_update = self.weight_inplace_update - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + + if device_to is not None: + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr(self.model, key, out_weight) + del temp_weight if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr(self.model, key, out_weight) - del temp_weight - - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + self.model.to(device_to) + self.current_device = device_to return self.model diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index d5f9ba007..bb0ed57b2 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -26,7 +26,9 @@ class BasicScheduler: if denoise < 1.0: total_steps = int(steps/denoise) - sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() + inner_model = model.patch_model(patch_weights=False) + sigmas = comfy.samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu() + model.unpatch_model() sigmas = sigmas[-(steps + 1):] return (sigmas, ) @@ -104,7 +106,9 @@ class SDTurboScheduler: def get_sigmas(self, model, steps, denoise): start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] - sigmas = model.model.model_sampling.sigma(timesteps) + inner_model = model.patch_model(patch_weights=False) + sigmas = inner_model.model_sampling.sigma(timesteps) + model.unpatch_model() sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) From 8c6493578b3dda233e9b9a953feeaf1e6ca434ad Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 3 Jan 2024 14:27:11 -0500 Subject: [PATCH 98/98] Implement noise augmentation for SD 4X upscale model. --- .../modules/diffusionmodules/openaimodel.py | 2 +- .../ldm/modules/diffusionmodules/upscaling.py | 12 ++++++---- comfy/model_base.py | 22 ++++++++++++++----- comfy/samplers.py | 4 ++-- comfy/supported_models.py | 1 + comfy_extras/nodes_sdupscale.py | 6 +++-- 6 files changed, 33 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index cb0a79835..ea936e066 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -498,7 +498,7 @@ class UNetModel(nn.Module): if self.num_classes is not None: if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device) elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index 768a47f9c..f5ac7c2f9 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -41,8 +41,12 @@ class AbstractLowScaleModel(nn.Module): self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) + def q_sample(self, x_start, t, noise=None, seed=None): + if noise is None: + if seed is None: + noise = torch.randn_like(x_start) + else: + noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device) return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise) @@ -69,12 +73,12 @@ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): super().__init__(noise_schedule_config=noise_schedule_config) self.max_noise_level = max_noise_level - def forward(self, x, noise_level=None): + def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) - z = self.q_sample(x, noise_level) + z = self.q_sample(x, noise_level, seed=seed) return z, noise_level diff --git a/comfy/model_base.py b/comfy/model_base.py index 64a380ff3..f59526204 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,7 +1,7 @@ import torch -from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel +from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation -from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep +from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation import comfy.model_management import comfy.conds import comfy.ops @@ -78,8 +78,9 @@ class BaseModel(torch.nn.Module): extra_conds = {} for o in kwargs: extra = kwargs[o] - if hasattr(extra, "to"): - extra = extra.to(dtype) + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) extra_conds[o] = extra model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() @@ -368,20 +369,31 @@ class Stable_Zero123(BaseModel): class SD_X4Upscaler(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device) + self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350) def extra_conds(self, **kwargs): out = {} image = kwargs.get("concat_image", None) noise = kwargs.get("noise", None) + noise_augment = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + seed = kwargs["seed"] - 10 + + noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment) if image is None: image = torch.zeros_like(noise)[:,:3] if image.shape[1:] != noise.shape[1:]: - image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + noise_level = torch.tensor([noise_level], device=device) + if noise_augment > 0: + image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed) image = utils.resize_to_batch_size(image, noise.shape[0]) out['c_concat'] = comfy.conds.CONDNoiseShape(image) + out['y'] = comfy.conds.CONDRegular(noise_level) return out diff --git a/comfy/samplers.py b/comfy/samplers.py index 0453c1f6f..89d8d4f28 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -603,8 +603,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model latent_image = model.process_latent_in(latent_image) if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) #make sure each cond area has an opposite one with the same area for c in positive: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index e7a6cc179..1d442d4dd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -290,6 +290,7 @@ class SD_X4Upscaler(SD20): unet_extra_config = { "disable_self_attentions": [True, True, True, False], + "num_classes": 1000, "num_heads": 8, "num_head_channels": -1, } diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py index 38a027e0b..28c1cb0f1 100644 --- a/comfy_extras/nodes_sdupscale.py +++ b/comfy_extras/nodes_sdupscale.py @@ -9,7 +9,7 @@ class SD_4XUpscale_Conditioning: "positive": ("CONDITIONING",), "negative": ("CONDITIONING",), "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), - # "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), #TODO + "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), }} RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") @@ -18,7 +18,7 @@ class SD_4XUpscale_Conditioning: CATEGORY = "conditioning/upscale_diffusion" - def encode(self, images, positive, negative, scale_ratio): + def encode(self, images, positive, negative, scale_ratio, noise_augmentation): width = max(1, round(images.shape[-2] * scale_ratio)) height = max(1, round(images.shape[-3] * scale_ratio)) @@ -30,11 +30,13 @@ class SD_4XUpscale_Conditioning: for t in positive: n = [t[0], t[1].copy()] n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation out_cp.append(n) for t in negative: n = [t[0], t[1].copy()] n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation out_cn.append(n) latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])